Reduce the number of cow cells in idm (#1385)

* Reduce the number of cow cells in idm
This commit is contained in:
Firstyear 2023-02-19 09:51:36 +10:00 committed by GitHub
parent 0d8d9e1a62
commit 87b43d0c14
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 47 deletions

View file

@ -413,7 +413,7 @@ impl<'a> IdmServerProxyWriteTransaction<'a> {
OperationError::SerdeJsonError OperationError::SerdeJsonError
})?; })?;
let token_enc = self.token_enc_key.encrypt(&token_data); let token_enc = self.domain_keys.token_enc_key.encrypt(&token_data);
// Point of no return // Point of no return
@ -726,6 +726,7 @@ impl<'a> IdmServerProxyWriteTransaction<'a> {
OperationError, OperationError,
> { > {
let session_token: CredentialUpdateSessionTokenInner = self let session_token: CredentialUpdateSessionTokenInner = self
.domain_keys
.token_enc_key .token_enc_key
.decrypt(&cust.token_enc) .decrypt(&cust.token_enc)
.map_err(|e| { .map_err(|e| {
@ -944,6 +945,7 @@ impl<'a> IdmServerCredUpdateTransaction<'a> {
ct: Duration, ct: Duration,
) -> Result<CredentialUpdateSessionMutex, OperationError> { ) -> Result<CredentialUpdateSessionMutex, OperationError> {
let session_token: CredentialUpdateSessionTokenInner = self let session_token: CredentialUpdateSessionTokenInner = self
.domain_keys
.token_enc_key .token_enc_key
.decrypt(&cust.token_enc) .decrypt(&cust.token_enc)
.map_err(|e| { .map_err(|e| {

View file

@ -1,5 +1,4 @@
use std::convert::TryFrom; use std::convert::TryFrom;
use std::ops::Deref;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -60,6 +59,14 @@ use crate::value::{Oauth2Session, Session};
type AuthSessionMutex = Arc<Mutex<AuthSession>>; type AuthSessionMutex = Arc<Mutex<AuthSession>>;
type CredSoftLockMutex = Arc<Mutex<CredSoftLock>>; type CredSoftLockMutex = Arc<Mutex<CredSoftLock>>;
#[derive(Clone)]
pub struct DomainKeys {
pub(crate) uat_jwt_signer: JwsSigner,
pub(crate) uat_jwt_validator: JwsValidator,
pub(crate) token_enc_key: Fernet,
pub(crate) cookie_key: [u8; 32],
}
pub struct IdmServer { pub struct IdmServer {
// There is a good reason to keep this single thread - it // There is a good reason to keep this single thread - it
// means that limits to sessions can be easily applied and checked to // means that limits to sessions can be easily applied and checked to
@ -79,10 +86,7 @@ pub struct IdmServer {
webauthn: Webauthn, webauthn: Webauthn,
pw_badlist_cache: Arc<CowCell<HashSet<String>>>, pw_badlist_cache: Arc<CowCell<HashSet<String>>>,
oauth2rs: Arc<Oauth2ResourceServers>, oauth2rs: Arc<Oauth2ResourceServers>,
uat_jwt_signer: Arc<CowCell<JwsSigner>>, domain_keys: Arc<CowCell<DomainKeys>>,
uat_jwt_validator: Arc<CowCell<JwsValidator>>,
token_enc_key: Arc<CowCell<Fernet>>,
cookie_key: Arc<CowCell<[u8; 32]>>,
} }
/// Contains methods that require writes, but in the context of writing to the idm in memory structures (maybe the query server too). This is things like authentication. /// Contains methods that require writes, but in the context of writing to the idm in memory structures (maybe the query server too). This is things like authentication.
@ -98,8 +102,7 @@ pub struct IdmServerAuthTransaction<'a> {
async_tx: Sender<DelayedAction>, async_tx: Sender<DelayedAction>,
webauthn: &'a Webauthn, webauthn: &'a Webauthn,
pw_badlist_cache: CowCellReadTxn<HashSet<String>>, pw_badlist_cache: CowCellReadTxn<HashSet<String>>,
uat_jwt_signer: CowCellReadTxn<JwsSigner>, domain_keys: CowCellReadTxn<DomainKeys>,
uat_jwt_validator: CowCellReadTxn<JwsValidator>,
} }
pub struct IdmServerCredUpdateTransaction<'a> { pub struct IdmServerCredUpdateTransaction<'a> {
@ -108,14 +111,14 @@ pub struct IdmServerCredUpdateTransaction<'a> {
pub(crate) webauthn: &'a Webauthn, pub(crate) webauthn: &'a Webauthn,
pub(crate) pw_badlist_cache: CowCellReadTxn<HashSet<String>>, pub(crate) pw_badlist_cache: CowCellReadTxn<HashSet<String>>,
pub(crate) cred_update_sessions: BptreeMapReadTxn<'a, Uuid, CredentialUpdateSessionMutex>, pub(crate) cred_update_sessions: BptreeMapReadTxn<'a, Uuid, CredentialUpdateSessionMutex>,
pub(crate) token_enc_key: CowCellReadTxn<Fernet>, pub(crate) domain_keys: CowCellReadTxn<DomainKeys>,
pub(crate) crypto_policy: &'a CryptoPolicy, pub(crate) crypto_policy: &'a CryptoPolicy,
} }
/// This contains read-only methods, like getting users, groups and other structured content. /// This contains read-only methods, like getting users, groups and other structured content.
pub struct IdmServerProxyReadTransaction<'a> { pub struct IdmServerProxyReadTransaction<'a> {
pub qs_read: QueryServerReadTransaction<'a>, pub qs_read: QueryServerReadTransaction<'a>,
uat_jwt_validator: CowCellReadTxn<JwsValidator>, pub(crate) domain_keys: CowCellReadTxn<DomainKeys>,
pub(crate) oauth2rs: Oauth2ResourceServersReadTransaction, pub(crate) oauth2rs: Oauth2ResourceServersReadTransaction,
pub(crate) async_tx: Sender<DelayedAction>, pub(crate) async_tx: Sender<DelayedAction>,
} }
@ -130,10 +133,7 @@ pub struct IdmServerProxyWriteTransaction<'a> {
crypto_policy: &'a CryptoPolicy, crypto_policy: &'a CryptoPolicy,
webauthn: &'a Webauthn, webauthn: &'a Webauthn,
pw_badlist_cache: CowCellWriteTxn<'a, HashSet<String>>, pw_badlist_cache: CowCellWriteTxn<'a, HashSet<String>>,
uat_jwt_signer: CowCellWriteTxn<'a, JwsSigner>, pub(crate) domain_keys: CowCellWriteTxn<'a, DomainKeys>,
uat_jwt_validator: CowCellWriteTxn<'a, JwsValidator>,
cookie_key: CowCellWriteTxn<'a, [u8; 32]>,
pub(crate) token_enc_key: CowCellWriteTxn<'a, Fernet>,
pub(crate) oauth2rs: Oauth2ResourceServersWriteTransaction<'a>, pub(crate) oauth2rs: Oauth2ResourceServersWriteTransaction<'a>,
} }
@ -213,26 +213,27 @@ impl IdmServer {
})?; })?;
// Setup our auth token signing key. // Setup our auth token signing key.
let fernet_key = Fernet::new(&fernet_private_key).ok_or_else(|| { let token_enc_key = Fernet::new(&fernet_private_key).ok_or_else(|| {
admin_error!("Unable to load Fernet encryption key"); admin_error!("Unable to load Fernet encryption key");
OperationError::CryptographyError OperationError::CryptographyError
})?; })?;
let token_enc_key = Arc::new(CowCell::new(fernet_key));
let jwt_signer = JwsSigner::from_es256_der(&es256_private_key).map_err(|e| { let uat_jwt_signer = JwsSigner::from_es256_der(&es256_private_key).map_err(|e| {
admin_error!(err = ?e, "Unable to load ES256 JwsSigner from DER"); admin_error!(err = ?e, "Unable to load ES256 JwsSigner from DER");
OperationError::CryptographyError OperationError::CryptographyError
})?; })?;
let jwt_validator = jwt_signer.get_validator().map_err(|e| { let uat_jwt_validator = uat_jwt_signer.get_validator().map_err(|e| {
admin_error!(err = ?e, "Unable to load ES256 JwsValidator from JwsSigner"); admin_error!(err = ?e, "Unable to load ES256 JwsValidator from JwsSigner");
OperationError::CryptographyError OperationError::CryptographyError
})?; })?;
let uat_jwt_signer = Arc::new(CowCell::new(jwt_signer)); let domain_keys = Arc::new(CowCell::new(DomainKeys {
let uat_jwt_validator = Arc::new(CowCell::new(jwt_validator)); uat_jwt_signer,
uat_jwt_validator,
let cookie_key = Arc::new(CowCell::new(cookie_key)); token_enc_key,
cookie_key,
}));
let oauth2rs = let oauth2rs =
Oauth2ResourceServers::try_from((oauth2rs_set, origin_url)).map_err(|e| { Oauth2ResourceServers::try_from((oauth2rs_set, origin_url)).map_err(|e| {
@ -251,10 +252,7 @@ impl IdmServer {
async_tx, async_tx,
webauthn, webauthn,
pw_badlist_cache: Arc::new(CowCell::new(pw_badlist_set)), pw_badlist_cache: Arc::new(CowCell::new(pw_badlist_set)),
uat_jwt_signer, domain_keys,
uat_jwt_validator,
token_enc_key,
cookie_key,
oauth2rs: Arc::new(oauth2rs), oauth2rs: Arc::new(oauth2rs),
}, },
IdmServerDelayed { async_rx }, IdmServerDelayed { async_rx },
@ -262,7 +260,7 @@ impl IdmServer {
} }
pub fn get_cookie_key(&self) -> [u8; 32] { pub fn get_cookie_key(&self) -> [u8; 32] {
*self.cookie_key.read().deref() self.domain_keys.read().cookie_key
} }
#[cfg(test)] #[cfg(test)]
@ -286,8 +284,7 @@ impl IdmServer {
async_tx: self.async_tx.clone(), async_tx: self.async_tx.clone(),
webauthn: &self.webauthn, webauthn: &self.webauthn,
pw_badlist_cache: self.pw_badlist_cache.read(), pw_badlist_cache: self.pw_badlist_cache.read(),
uat_jwt_signer: self.uat_jwt_signer.read(), domain_keys: self.domain_keys.read(),
uat_jwt_validator: self.uat_jwt_validator.read(),
} }
} }
@ -296,7 +293,7 @@ impl IdmServer {
pub async fn proxy_read(&self) -> IdmServerProxyReadTransaction<'_> { pub async fn proxy_read(&self) -> IdmServerProxyReadTransaction<'_> {
IdmServerProxyReadTransaction { IdmServerProxyReadTransaction {
qs_read: self.qs.read().await, qs_read: self.qs.read().await,
uat_jwt_validator: self.uat_jwt_validator.read(), domain_keys: self.domain_keys.read(),
oauth2rs: self.oauth2rs.read(), oauth2rs: self.oauth2rs.read(),
async_tx: self.async_tx.clone(), async_tx: self.async_tx.clone(),
} }
@ -317,10 +314,7 @@ impl IdmServer {
crypto_policy: &self.crypto_policy, crypto_policy: &self.crypto_policy,
webauthn: &self.webauthn, webauthn: &self.webauthn,
pw_badlist_cache: self.pw_badlist_cache.write(), pw_badlist_cache: self.pw_badlist_cache.write(),
uat_jwt_signer: self.uat_jwt_signer.write(), domain_keys: self.domain_keys.write(),
uat_jwt_validator: self.uat_jwt_validator.write(),
token_enc_key: self.token_enc_key.write(),
cookie_key: self.cookie_key.write(),
oauth2rs: self.oauth2rs.write(), oauth2rs: self.oauth2rs.write(),
} }
} }
@ -337,7 +331,7 @@ impl IdmServer {
webauthn: &self.webauthn, webauthn: &self.webauthn,
pw_badlist_cache: self.pw_badlist_cache.read(), pw_badlist_cache: self.pw_badlist_cache.read(),
cred_update_sessions: self.cred_update_sessions.read(), cred_update_sessions: self.cred_update_sessions.read(),
token_enc_key: self.token_enc_key.read(), domain_keys: self.domain_keys.read(),
crypto_policy: &self.crypto_policy, crypto_policy: &self.crypto_policy,
} }
} }
@ -893,7 +887,7 @@ impl<'a> IdmServerTransaction<'a> for IdmServerAuthTransaction<'a> {
} }
fn get_uat_validator_txn(&self) -> &JwsValidator { fn get_uat_validator_txn(&self) -> &JwsValidator {
&self.uat_jwt_validator &self.domain_keys.uat_jwt_validator
} }
} }
@ -1167,7 +1161,7 @@ impl<'a> IdmServerAuthTransaction<'a> {
&self.async_tx, &self.async_tx,
self.webauthn, self.webauthn,
pw_badlist_cache, pw_badlist_cache,
&self.uat_jwt_signer, &self.domain_keys.uat_jwt_signer,
) )
.map(|aus| { .map(|aus| {
// Inspect the result: // Inspect the result:
@ -1436,7 +1430,7 @@ impl<'a> IdmServerTransaction<'a> for IdmServerProxyReadTransaction<'a> {
} }
fn get_uat_validator_txn(&self) -> &JwsValidator { fn get_uat_validator_txn(&self) -> &JwsValidator {
&self.uat_jwt_validator &self.domain_keys.uat_jwt_validator
} }
} }
@ -1539,7 +1533,7 @@ impl<'a> IdmServerTransaction<'a> for IdmServerProxyWriteTransaction<'a> {
} }
fn get_uat_validator_txn(&self) -> &JwsValidator { fn get_uat_validator_txn(&self) -> &JwsValidator {
&self.uat_jwt_validator &self.domain_keys.uat_jwt_validator
} }
} }
@ -2200,7 +2194,7 @@ impl<'a> IdmServerProxyWriteTransaction<'a> {
}) })
}) })
.map(|new_handle| { .map(|new_handle| {
*self.token_enc_key = new_handle; self.domain_keys.token_enc_key = new_handle;
})?; })?;
self.qs_write self.qs_write
.get_domain_es256_private_key() .get_domain_es256_private_key()
@ -2220,21 +2214,18 @@ impl<'a> IdmServerProxyWriteTransaction<'a> {
.map(|validator| (signer, validator)) .map(|validator| (signer, validator))
}) })
.map(|(new_signer, new_validator)| { .map(|(new_signer, new_validator)| {
*self.uat_jwt_signer = new_signer; self.domain_keys.uat_jwt_signer = new_signer;
*self.uat_jwt_validator = new_validator; self.domain_keys.uat_jwt_validator = new_validator;
})?; })?;
self.qs_write self.qs_write
.get_domain_cookie_key() .get_domain_cookie_key()
.map(|new_cookie_key| { .map(|new_cookie_key| {
*self.cookie_key = new_cookie_key; self.domain_keys.cookie_key = new_cookie_key;
})?; })?;
} }
// Commit everything. // Commit everything.
self.oauth2rs.commit(); self.oauth2rs.commit();
self.uat_jwt_signer.commit(); self.domain_keys.commit();
self.uat_jwt_validator.commit();
self.cookie_key.commit();
self.token_enc_key.commit();
self.pw_badlist_cache.commit(); self.pw_badlist_cache.commit();
self.cred_update_sessions.commit(); self.cred_update_sessions.commit();
trace!("cred_update_session.commit"); trace!("cred_update_session.commit");