Resolve future send issue with keystore (#2311)

This commit is contained in:
Firstyear 2023-11-20 12:46:52 +10:00 committed by GitHub
parent 2bb69f2544
commit 6dc8f1db3a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 176 additions and 124 deletions

View file

@ -759,8 +759,8 @@ async fn main() -> ExitCode {
} }
}; };
// With the assistance of the db, setup the hsm and it's machine key. // With the assistance of the DB, setup the HSM and its machine key.
let db_txn = db.write().await; let mut db_txn = db.write().await;
let loadable_machine_key = match db_txn.get_hsm_machine_key() { let loadable_machine_key = match db_txn.get_hsm_machine_key() {
Ok(Some(lmk)) => lmk, Ok(Some(lmk)) => lmk,

View file

@ -25,6 +25,15 @@ pub trait Cache {
async fn write<'db>(&'db self) -> Self::Txn<'db>; async fn write<'db>(&'db self) -> Self::Txn<'db>;
} }
#[async_trait]
pub trait KeyStore {
type Txn<'db>
where
Self: 'db;
async fn write_keystore<'db>(&'db self) -> Self::Txn<'db>;
}
#[derive(Debug)] #[derive(Debug)]
pub enum CacheError { pub enum CacheError {
Cryptography, Cryptography,
@ -37,32 +46,35 @@ pub enum CacheError {
} }
pub trait CacheTxn { pub trait CacheTxn {
fn migrate(&self) -> Result<(), CacheError>; fn migrate(&mut self) -> Result<(), CacheError>;
fn commit(self) -> Result<(), CacheError>; fn commit(self) -> Result<(), CacheError>;
fn invalidate(&self) -> Result<(), CacheError>; fn invalidate(&mut self) -> Result<(), CacheError>;
fn clear(&self) -> Result<(), CacheError>; fn clear(&mut self) -> Result<(), CacheError>;
fn get_hsm_machine_key(&self) -> Result<Option<LoadableMachineKey>, CacheError>; fn get_hsm_machine_key(&mut self) -> Result<Option<LoadableMachineKey>, CacheError>;
fn insert_hsm_machine_key(&self, machine_key: &LoadableMachineKey) -> Result<(), CacheError>; fn insert_hsm_machine_key(
&mut self,
machine_key: &LoadableMachineKey,
) -> Result<(), CacheError>;
fn get_hsm_hmac_key(&self) -> Result<Option<LoadableHmacKey>, CacheError>; fn get_hsm_hmac_key(&mut self) -> Result<Option<LoadableHmacKey>, CacheError>;
fn insert_hsm_hmac_key(&self, hmac_key: &LoadableHmacKey) -> Result<(), CacheError>; fn insert_hsm_hmac_key(&mut self, hmac_key: &LoadableHmacKey) -> Result<(), CacheError>;
fn get_account(&self, account_id: &Id) -> Result<Option<(UserToken, u64)>, CacheError>; fn get_account(&mut self, account_id: &Id) -> Result<Option<(UserToken, u64)>, CacheError>;
fn get_accounts(&self) -> Result<Vec<UserToken>, CacheError>; fn get_accounts(&mut self) -> Result<Vec<UserToken>, CacheError>;
fn update_account(&self, account: &UserToken, expire: u64) -> Result<(), CacheError>; fn update_account(&mut self, account: &UserToken, expire: u64) -> Result<(), CacheError>;
fn delete_account(&self, a_uuid: Uuid) -> Result<(), CacheError>; fn delete_account(&mut self, a_uuid: Uuid) -> Result<(), CacheError>;
fn update_account_password( fn update_account_password(
&self, &mut self,
a_uuid: Uuid, a_uuid: Uuid,
cred: &str, cred: &str,
hsm: &mut dyn Tpm, hsm: &mut dyn Tpm,
@ -70,22 +82,34 @@ pub trait CacheTxn {
) -> Result<(), CacheError>; ) -> Result<(), CacheError>;
fn check_account_password( fn check_account_password(
&self, &mut self,
a_uuid: Uuid, a_uuid: Uuid,
cred: &str, cred: &str,
hsm: &mut dyn Tpm, hsm: &mut dyn Tpm,
hmac_key: &HmacKey, hmac_key: &HmacKey,
) -> Result<bool, CacheError>; ) -> Result<bool, CacheError>;
fn get_group(&self, grp_id: &Id) -> Result<Option<(GroupToken, u64)>, CacheError>; fn get_group(&mut self, grp_id: &Id) -> Result<Option<(GroupToken, u64)>, CacheError>;
fn get_group_members(&self, g_uuid: Uuid) -> Result<Vec<UserToken>, CacheError>; fn get_group_members(&mut self, g_uuid: Uuid) -> Result<Vec<UserToken>, CacheError>;
fn get_groups(&self) -> Result<Vec<GroupToken>, CacheError>; fn get_groups(&mut self) -> Result<Vec<GroupToken>, CacheError>;
fn update_group(&self, grp: &GroupToken, expire: u64) -> Result<(), CacheError>; fn update_group(&mut self, grp: &GroupToken, expire: u64) -> Result<(), CacheError>;
fn delete_group(&self, g_uuid: Uuid) -> Result<(), CacheError>; fn delete_group(&mut self, g_uuid: Uuid) -> Result<(), CacheError>;
}
pub trait KeyStoreTxn {
fn get_tagged_hsm_key<K: DeserializeOwned>(
&mut self,
tag: &str,
) -> Result<Option<K>, CacheError>;
fn insert_tagged_hsm_key<K: Serialize>(&mut self, tag: &str, key: &K)
-> Result<(), CacheError>;
fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError>;
} }
pub struct Db { pub struct Db {
@ -184,7 +208,10 @@ impl<'a> DbTxn<'a> {
CacheError::Sqlite CacheError::Sqlite
} }
fn get_account_data_name(&self, account_id: &str) -> Result<Vec<(Vec<u8>, i64)>, CacheError> { fn get_account_data_name(
&mut self,
account_id: &str,
) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
let mut stmt = self.conn let mut stmt = self.conn
.prepare( .prepare(
"SELECT token, expiry FROM account_t WHERE uuid = :account_id OR name = :account_id OR spn = :account_id" "SELECT token, expiry FROM account_t WHERE uuid = :account_id OR name = :account_id OR spn = :account_id"
@ -203,7 +230,7 @@ impl<'a> DbTxn<'a> {
data data
} }
fn get_account_data_gid(&self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> { fn get_account_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
let mut stmt = self let mut stmt = self
.conn .conn
.prepare("SELECT token, expiry FROM account_t WHERE gidnumber = :gid") .prepare("SELECT token, expiry FROM account_t WHERE gidnumber = :gid")
@ -219,7 +246,7 @@ impl<'a> DbTxn<'a> {
data data
} }
fn get_group_data_name(&self, grp_id: &str) -> Result<Vec<(Vec<u8>, i64)>, CacheError> { fn get_group_data_name(&mut self, grp_id: &str) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
let mut stmt = self.conn let mut stmt = self.conn
.prepare( .prepare(
"SELECT token, expiry FROM group_t WHERE uuid = :grp_id OR name = :grp_id OR spn = :grp_id" "SELECT token, expiry FROM group_t WHERE uuid = :grp_id OR name = :grp_id OR spn = :grp_id"
@ -238,7 +265,7 @@ impl<'a> DbTxn<'a> {
data data
} }
fn get_group_data_gid(&self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> { fn get_group_data_gid(&mut self, gid: u32) -> Result<Vec<(Vec<u8>, i64)>, CacheError> {
let mut stmt = self let mut stmt = self
.conn .conn
.prepare("SELECT token, expiry FROM group_t WHERE gidnumber = :gid") .prepare("SELECT token, expiry FROM group_t WHERE gidnumber = :gid")
@ -253,9 +280,11 @@ impl<'a> DbTxn<'a> {
.collect(); .collect();
data data
} }
}
pub fn get_tagged_hsm_key<K: DeserializeOwned>( impl<'a> KeyStoreTxn for DbTxn<'a> {
&self, fn get_tagged_hsm_key<K: DeserializeOwned>(
&mut self,
tag: &str, tag: &str,
) -> Result<Option<K>, CacheError> { ) -> Result<Option<K>, CacheError> {
let mut stmt = self let mut stmt = self
@ -283,8 +312,8 @@ impl<'a> DbTxn<'a> {
} }
} }
pub fn insert_tagged_hsm_key<K: Serialize>( fn insert_tagged_hsm_key<K: Serialize>(
&self, &mut self,
tag: &str, tag: &str,
key: &K, key: &K,
) -> Result<(), CacheError> { ) -> Result<(), CacheError> {
@ -308,7 +337,7 @@ impl<'a> DbTxn<'a> {
.map_err(|e| self.sqlite_error("execute", &e)) .map_err(|e| self.sqlite_error("execute", &e))
} }
pub fn delete_tagged_hsm_key(&self, tag: &str) -> Result<(), CacheError> { fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), CacheError> {
self.conn self.conn
.execute( .execute(
"DELETE FROM hsm_data_t where key = :key", "DELETE FROM hsm_data_t where key = :key",
@ -322,7 +351,7 @@ impl<'a> DbTxn<'a> {
} }
impl<'a> CacheTxn for DbTxn<'a> { impl<'a> CacheTxn for DbTxn<'a> {
fn migrate(&self) -> Result<(), CacheError> { fn migrate(&mut self) -> Result<(), CacheError> {
self.conn.set_prepared_statement_cache_capacity(16); self.conn.set_prepared_statement_cache_capacity(16);
self.conn self.conn
.prepare("PRAGMA journal_mode=WAL;") .prepare("PRAGMA journal_mode=WAL;")
@ -423,7 +452,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
.map_err(|e| self.sqlite_error("commit", &e)) .map_err(|e| self.sqlite_error("commit", &e))
} }
fn invalidate(&self) -> Result<(), CacheError> { fn invalidate(&mut self) -> Result<(), CacheError> {
self.conn self.conn
.execute("UPDATE group_t SET expiry = 0", []) .execute("UPDATE group_t SET expiry = 0", [])
.map_err(|e| self.sqlite_error("update group_t", &e))?; .map_err(|e| self.sqlite_error("update group_t", &e))?;
@ -435,7 +464,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
Ok(()) Ok(())
} }
fn clear(&self) -> Result<(), CacheError> { fn clear(&mut self) -> Result<(), CacheError> {
self.conn self.conn
.execute("DELETE FROM memberof_t", []) .execute("DELETE FROM memberof_t", [])
.map_err(|e| self.sqlite_error("delete memberof_t", &e))?; .map_err(|e| self.sqlite_error("delete memberof_t", &e))?;
@ -451,7 +480,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
Ok(()) Ok(())
} }
fn get_hsm_machine_key(&self) -> Result<Option<LoadableMachineKey>, CacheError> { fn get_hsm_machine_key(&mut self) -> Result<Option<LoadableMachineKey>, CacheError> {
let mut stmt = self let mut stmt = self
.conn .conn
.prepare("SELECT value FROM hsm_int_t WHERE key = 'mk'") .prepare("SELECT value FROM hsm_int_t WHERE key = 'mk'")
@ -472,7 +501,10 @@ impl<'a> CacheTxn for DbTxn<'a> {
} }
} }
fn insert_hsm_machine_key(&self, machine_key: &LoadableMachineKey) -> Result<(), CacheError> { fn insert_hsm_machine_key(
&mut self,
machine_key: &LoadableMachineKey,
) -> Result<(), CacheError> {
let data = serde_json::to_vec(machine_key).map_err(|e| { let data = serde_json::to_vec(machine_key).map_err(|e| {
error!("insert_hsm_machine_key json error -> {:?}", e); error!("insert_hsm_machine_key json error -> {:?}", e);
CacheError::SerdeJson CacheError::SerdeJson
@ -493,7 +525,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
.map_err(|e| self.sqlite_error("execute", &e)) .map_err(|e| self.sqlite_error("execute", &e))
} }
fn get_hsm_hmac_key(&self) -> Result<Option<LoadableHmacKey>, CacheError> { fn get_hsm_hmac_key(&mut self) -> Result<Option<LoadableHmacKey>, CacheError> {
let mut stmt = self let mut stmt = self
.conn .conn
.prepare("SELECT value FROM hsm_int_t WHERE key = 'hmac'") .prepare("SELECT value FROM hsm_int_t WHERE key = 'hmac'")
@ -514,7 +546,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
} }
} }
fn insert_hsm_hmac_key(&self, hmac_key: &LoadableHmacKey) -> Result<(), CacheError> { fn insert_hsm_hmac_key(&mut self, hmac_key: &LoadableHmacKey) -> Result<(), CacheError> {
let data = serde_json::to_vec(hmac_key).map_err(|e| { let data = serde_json::to_vec(hmac_key).map_err(|e| {
error!("insert_hsm_hmac_key json error -> {:?}", e); error!("insert_hsm_hmac_key json error -> {:?}", e);
CacheError::SerdeJson CacheError::SerdeJson
@ -535,7 +567,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
.map_err(|e| self.sqlite_error("execute", &e)) .map_err(|e| self.sqlite_error("execute", &e))
} }
fn get_account(&self, account_id: &Id) -> Result<Option<(UserToken, u64)>, CacheError> { fn get_account(&mut self, account_id: &Id) -> Result<Option<(UserToken, u64)>, CacheError> {
let data = match account_id { let data = match account_id {
Id::Name(n) => self.get_account_data_name(n.as_str()), Id::Name(n) => self.get_account_data_name(n.as_str()),
Id::Gid(g) => self.get_account_data_gid(*g), Id::Gid(g) => self.get_account_data_gid(*g),
@ -569,7 +601,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
} }
} }
fn get_accounts(&self) -> Result<Vec<UserToken>, CacheError> { fn get_accounts(&mut self) -> Result<Vec<UserToken>, CacheError> {
let mut stmt = self let mut stmt = self
.conn .conn
.prepare("SELECT token FROM account_t") .prepare("SELECT token FROM account_t")
@ -598,7 +630,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
.collect()) .collect())
} }
fn update_account(&self, account: &UserToken, expire: u64) -> Result<(), CacheError> { fn update_account(&mut self, account: &UserToken, expire: u64) -> Result<(), CacheError> {
let data = serde_json::to_vec(account).map_err(|e| { let data = serde_json::to_vec(account).map_err(|e| {
error!("update_account json error -> {:?}", e); error!("update_account json error -> {:?}", e);
CacheError::SerdeJson CacheError::SerdeJson
@ -694,7 +726,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
}) })
} }
fn delete_account(&self, a_uuid: Uuid) -> Result<(), CacheError> { fn delete_account(&mut self, a_uuid: Uuid) -> Result<(), CacheError> {
let account_uuid = a_uuid.as_hyphenated().to_string(); let account_uuid = a_uuid.as_hyphenated().to_string();
self.conn self.conn
@ -715,7 +747,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
} }
fn update_account_password( fn update_account_password(
&self, &mut self,
a_uuid: Uuid, a_uuid: Uuid,
cred: &str, cred: &str,
hsm: &mut dyn Tpm, hsm: &mut dyn Tpm,
@ -746,7 +778,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
} }
fn check_account_password( fn check_account_password(
&self, &mut self,
a_uuid: Uuid, a_uuid: Uuid,
cred: &str, cred: &str,
hsm: &mut dyn Tpm, hsm: &mut dyn Tpm,
@ -796,7 +828,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
}) })
} }
fn get_group(&self, grp_id: &Id) -> Result<Option<(GroupToken, u64)>, CacheError> { fn get_group(&mut self, grp_id: &Id) -> Result<Option<(GroupToken, u64)>, CacheError> {
let data = match grp_id { let data = match grp_id {
Id::Name(n) => self.get_group_data_name(n.as_str()), Id::Name(n) => self.get_group_data_name(n.as_str()),
Id::Gid(g) => self.get_group_data_gid(*g), Id::Gid(g) => self.get_group_data_gid(*g),
@ -830,7 +862,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
} }
} }
fn get_group_members(&self, g_uuid: Uuid) -> Result<Vec<UserToken>, CacheError> { fn get_group_members(&mut self, g_uuid: Uuid) -> Result<Vec<UserToken>, CacheError> {
let mut stmt = self let mut stmt = self
.conn .conn
.prepare("SELECT account_t.token FROM (account_t, memberof_t) WHERE account_t.uuid = memberof_t.a_uuid AND memberof_t.g_uuid = :g_uuid") .prepare("SELECT account_t.token FROM (account_t, memberof_t) WHERE account_t.uuid = memberof_t.a_uuid AND memberof_t.g_uuid = :g_uuid")
@ -859,7 +891,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
.collect() .collect()
} }
fn get_groups(&self) -> Result<Vec<GroupToken>, CacheError> { fn get_groups(&mut self) -> Result<Vec<GroupToken>, CacheError> {
let mut stmt = self let mut stmt = self
.conn .conn
.prepare("SELECT token FROM group_t") .prepare("SELECT token FROM group_t")
@ -888,7 +920,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
.collect()) .collect())
} }
fn update_group(&self, grp: &GroupToken, expire: u64) -> Result<(), CacheError> { fn update_group(&mut self, grp: &GroupToken, expire: u64) -> Result<(), CacheError> {
let data = serde_json::to_vec(grp).map_err(|e| { let data = serde_json::to_vec(grp).map_err(|e| {
error!("json error -> {:?}", e); error!("json error -> {:?}", e);
CacheError::SerdeJson CacheError::SerdeJson
@ -919,7 +951,7 @@ impl<'a> CacheTxn for DbTxn<'a> {
.map_err(|e| self.sqlite_error("execute", &e)) .map_err(|e| self.sqlite_error("execute", &e))
} }
fn delete_group(&self, g_uuid: Uuid) -> Result<(), CacheError> { fn delete_group(&mut self, g_uuid: Uuid) -> Result<(), CacheError> {
let group_uuid = g_uuid.as_hyphenated().to_string(); let group_uuid = g_uuid.as_hyphenated().to_string();
self.conn self.conn
.execute( .execute(
@ -968,7 +1000,7 @@ mod tests {
async fn test_cache_db_account_basic() { async fn test_cache_db_account_basic() {
sketching::test_init(); sketching::test_init();
let db = Db::new("").expect("failed to create."); let db = Db::new("").expect("failed to create.");
let dbtxn = db.write().await; let mut dbtxn = db.write().await;
assert!(dbtxn.migrate().is_ok()); assert!(dbtxn.migrate().is_ok());
let mut ut1 = UserToken { let mut ut1 = UserToken {
@ -1052,7 +1084,7 @@ mod tests {
async fn test_cache_db_group_basic() { async fn test_cache_db_group_basic() {
sketching::test_init(); sketching::test_init();
let db = Db::new("").expect("failed to create."); let db = Db::new("").expect("failed to create.");
let dbtxn = db.write().await; let mut dbtxn = db.write().await;
assert!(dbtxn.migrate().is_ok()); assert!(dbtxn.migrate().is_ok());
let mut gt1 = GroupToken { let mut gt1 = GroupToken {
@ -1127,7 +1159,7 @@ mod tests {
async fn test_cache_db_account_group_update() { async fn test_cache_db_account_group_update() {
sketching::test_init(); sketching::test_init();
let db = Db::new("").expect("failed to create."); let db = Db::new("").expect("failed to create.");
let dbtxn = db.write().await; let mut dbtxn = db.write().await;
assert!(dbtxn.migrate().is_ok()); assert!(dbtxn.migrate().is_ok());
let gt1 = GroupToken { let gt1 = GroupToken {
@ -1197,7 +1229,7 @@ mod tests {
let db = Db::new("").expect("failed to create."); let db = Db::new("").expect("failed to create.");
let dbtxn = db.write().await; let mut dbtxn = db.write().await;
assert!(dbtxn.migrate().is_ok()); assert!(dbtxn.migrate().is_ok());
// Setup the hsm // Setup the hsm
@ -1283,7 +1315,7 @@ mod tests {
async fn test_cache_db_group_rename_duplicate() { async fn test_cache_db_group_rename_duplicate() {
sketching::test_init(); sketching::test_init();
let db = Db::new("").expect("failed to create."); let db = Db::new("").expect("failed to create.");
let dbtxn = db.write().await; let mut dbtxn = db.write().await;
assert!(dbtxn.migrate().is_ok()); assert!(dbtxn.migrate().is_ok());
let mut gt1 = GroupToken { let mut gt1 = GroupToken {
@ -1338,7 +1370,7 @@ mod tests {
async fn test_cache_db_account_rename_duplicate() { async fn test_cache_db_account_rename_duplicate() {
sketching::test_init(); sketching::test_init();
let db = Db::new("").expect("failed to create."); let db = Db::new("").expect("failed to create.");
let dbtxn = db.write().await; let mut dbtxn = db.write().await;
assert!(dbtxn.migrate().is_ok()); assert!(dbtxn.migrate().is_ok());
let mut ut1 = UserToken { let mut ut1 = UserToken {

View file

@ -1,10 +1,10 @@
use crate::db::DbTxn; use crate::db::KeyStoreTxn;
use crate::unix_proto::{DeviceAuthorizationResponse, PamAuthRequest, PamAuthResponse}; use crate::unix_proto::{DeviceAuthorizationResponse, PamAuthRequest, PamAuthResponse};
use async_trait::async_trait; use async_trait::async_trait;
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
pub use kanidm_hsm_crypto::{KeyAlgorithm, MachineKey, Tpm}; pub use kanidm_hsm_crypto as tpm;
/// Errors that the IdProvider may return. These drive the resolver state machine /// Errors that the IdProvider may return. These drive the resolver state machine
/// and should be carefully selected to match your expected errors. /// and should be carefully selected to match your expected errors.
@ -26,6 +26,8 @@ pub enum IdpError {
NotFound, NotFound,
/// The idp was unable to perform an operation on the underlying hsm keystorage /// The idp was unable to perform an operation on the underlying hsm keystorage
KeyStore, KeyStore,
/// The idp failed to interact with the configured TPM
Tpm,
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
@ -91,48 +93,13 @@ pub enum AuthCacheAction {
PasswordHashUpdate { cred: String }, PasswordHashUpdate { cred: String },
} }
pub struct KeyStore<'a> {
dbtxn: &'a DbTxn<'a>,
}
impl<'a> KeyStore<'a> {
pub(crate) fn new(dbtxn: &'a DbTxn<'a>) -> Self {
KeyStore { dbtxn }
}
pub fn get_tagged_hsm_key<K: DeserializeOwned>(
&mut self,
tag: &str,
) -> Result<Option<K>, IdpError> {
self.dbtxn
.get_tagged_hsm_key(tag)
.map_err(|_err| IdpError::KeyStore)
}
pub fn insert_tagged_hsm_key<K: Serialize>(
&mut self,
tag: &str,
key: &K,
) -> Result<(), IdpError> {
self.dbtxn
.insert_tagged_hsm_key(tag, key)
.map_err(|_err| IdpError::KeyStore)
}
pub fn delete_tagged_hsm_key(&mut self, tag: &str) -> Result<(), IdpError> {
self.dbtxn
.delete_tagged_hsm_key(tag)
.map_err(|_err| IdpError::KeyStore)
}
}
#[async_trait] #[async_trait]
pub trait IdProvider { pub trait IdProvider {
async fn configure_hsm_keys( async fn configure_hsm_keys<D: KeyStoreTxn + Send>(
&self, &self,
_keystore: &mut KeyStore, _keystore: &mut D,
_tpm: &mut (dyn Tpm + Send), _tpm: &mut (dyn tpm::Tpm + Send),
_machine_key: &MachineKey, _machine_key: &tpm::MachineKey,
) -> Result<(), IdpError> { ) -> Result<(), IdpError> {
Ok(()) Ok(())
} }

View file

@ -1,14 +1,26 @@
use crate::db::KeyStoreTxn;
use async_trait::async_trait; use async_trait::async_trait;
use kanidm_client::{ClientError, KanidmClient, StatusCode}; use kanidm_client::{ClientError, KanidmClient, StatusCode};
use kanidm_proto::v1::{OperationError, UnixGroupToken, UnixUserToken}; use kanidm_proto::v1::{OperationError, UnixGroupToken, UnixUserToken};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use super::interface::{ use super::interface::{
AuthCacheAction, AuthCredHandler, AuthRequest, AuthResult, GroupToken, Id, IdProvider, // KeyStore,
IdpError, UserToken, tpm,
AuthCacheAction,
AuthCredHandler,
AuthRequest,
AuthResult,
GroupToken,
Id,
IdProvider,
IdpError,
UserToken,
}; };
use crate::unix_proto::PamAuthRequest; use crate::unix_proto::PamAuthRequest;
const TAG_IDKEY: &str = "idkey";
pub struct KanidmProvider { pub struct KanidmProvider {
client: RwLock<KanidmClient>, client: RwLock<KanidmClient>,
} }
@ -71,6 +83,37 @@ impl From<UnixGroupToken> for GroupToken {
#[async_trait] #[async_trait]
impl IdProvider for KanidmProvider { impl IdProvider for KanidmProvider {
async fn configure_hsm_keys<D: KeyStoreTxn + Send>(
&self,
keystore: &mut D,
tpm: &mut (dyn tpm::Tpm + Send),
machine_key: &tpm::MachineKey,
) -> Result<(), IdpError> {
let id_key: Option<tpm::LoadableIdentityKey> =
keystore.get_tagged_hsm_key(TAG_IDKEY).map_err(|ks_err| {
error!(?ks_err);
IdpError::KeyStore
})?;
if id_key.is_none() {
let lik = tpm
.identity_key_create(machine_key, tpm::KeyAlgorithm::Ecdsa256)
.map_err(|tpm_err| {
error!(?tpm_err);
IdpError::Tpm
})?;
keystore
.insert_tagged_hsm_key(TAG_IDKEY, &lik)
.map_err(|ks_err| {
error!(?ks_err);
IdpError::KeyStore
})?;
}
Ok(())
}
// Needs .read on all types except re-auth. // Needs .read on all types except re-auth.
async fn provider_authenticate(&self) -> Result<(), IdpError> { async fn provider_authenticate(&self) -> Result<(), IdpError> {
match self.client.write().await.auth_anonymous().await { match self.client.write().await.auth_anonymous().await {

View file

@ -13,7 +13,14 @@ use uuid::Uuid;
use crate::db::{Cache, CacheTxn, Db}; use crate::db::{Cache, CacheTxn, Db};
use crate::idprovider::interface::{ use crate::idprovider::interface::{
AuthCacheAction, AuthCredHandler, AuthResult, GroupToken, Id, IdProvider, IdpError, KeyStore, AuthCacheAction,
AuthCredHandler,
AuthResult,
GroupToken,
Id,
IdProvider,
IdpError,
// KeyStore,
UserToken, UserToken,
}; };
use crate::unix_config::{HomeAttr, UidAttr}; use crate::unix_config::{HomeAttr, UidAttr};
@ -104,12 +111,12 @@ where
let mut hsm_lock = hsm.lock().await; let mut hsm_lock = hsm.lock().await;
// setup and do a migrate. // setup and do a migrate.
let dbtxn = db.write().await; let mut dbtxn = db.write().await;
dbtxn.migrate().map_err(|_| ())?; dbtxn.migrate().map_err(|_| ())?;
dbtxn.commit().map_err(|_| ())?; dbtxn.commit().map_err(|_| ())?;
// Setup our internal keys // Setup our internal keys
let dbtxn = db.write().await; let mut dbtxn = db.write().await;
let loadable_hmac_key = match dbtxn.get_hsm_hmac_key() { let loadable_hmac_key = match dbtxn.get_hsm_hmac_key() {
Ok(Some(hmk)) => hmk, Ok(Some(hmk)) => hmk,
@ -141,17 +148,20 @@ where
// Ask the client what keys it wants the HSM to configure. // Ask the client what keys it wants the HSM to configure.
// make a key store // make a key store
let mut ks = KeyStore::new(&dbtxn); // let mut ks = KeyStore::new(&mut dbtxn);
client let result = client
.configure_hsm_keys(&mut ks, &mut **hsm_lock.deref_mut(), &machine_key) // .configure_hsm_keys(&mut ks, &mut **hsm_lock.deref_mut(), &machine_key)
.await .configure_hsm_keys(&mut dbtxn, &mut **hsm_lock.deref_mut(), &machine_key)
.map_err(|err| { .await;
error!(?err, "Client was unable to configure hsm keys");
})?;
// drop(ks);
drop(hsm_lock); drop(hsm_lock);
result.map_err(|err| {
error!(?err, "Client was unable to configure hsm keys");
})?;
dbtxn.commit().map_err(|_| ())?; dbtxn.commit().map_err(|_| ())?;
if pam_allow_groups.is_empty() { if pam_allow_groups.is_empty() {
@ -204,14 +214,14 @@ where
pub async fn clear_cache(&self) -> Result<(), ()> { pub async fn clear_cache(&self) -> Result<(), ()> {
let mut nxcache_txn = self.nxcache.lock().await; let mut nxcache_txn = self.nxcache.lock().await;
nxcache_txn.clear(); nxcache_txn.clear();
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
dbtxn.clear().and_then(|_| dbtxn.commit()).map_err(|_| ()) dbtxn.clear().and_then(|_| dbtxn.commit()).map_err(|_| ())
} }
pub async fn invalidate(&self) -> Result<(), ()> { pub async fn invalidate(&self) -> Result<(), ()> {
let mut nxcache_txn = self.nxcache.lock().await; let mut nxcache_txn = self.nxcache.lock().await;
nxcache_txn.clear(); nxcache_txn.clear();
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
dbtxn dbtxn
.invalidate() .invalidate()
.and_then(|_| dbtxn.commit()) .and_then(|_| dbtxn.commit())
@ -219,12 +229,12 @@ where
} }
async fn get_cached_usertokens(&self) -> Result<Vec<UserToken>, ()> { async fn get_cached_usertokens(&self) -> Result<Vec<UserToken>, ()> {
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
dbtxn.get_accounts().map_err(|_| ()) dbtxn.get_accounts().map_err(|_| ())
} }
async fn get_cached_grouptokens(&self) -> Result<Vec<GroupToken>, ()> { async fn get_cached_grouptokens(&self) -> Result<Vec<GroupToken>, ()> {
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
dbtxn.get_groups().map_err(|_| ()) dbtxn.get_groups().map_err(|_| ())
} }
@ -268,7 +278,7 @@ where
// * spn // * spn
// * uuid // * uuid
// Attempt to search these in the db. // Attempt to search these in the db.
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
let r = dbtxn.get_account(account_id).map_err(|_| ())?; let r = dbtxn.get_account(account_id).map_err(|_| ())?;
match r { match r {
@ -316,7 +326,7 @@ where
// * spn // * spn
// * uuid // * uuid
// Attempt to search these in the db. // Attempt to search these in the db.
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
let r = dbtxn.get_group(grp_id).map_err(|_| ())?; let r = dbtxn.get_group(grp_id).map_err(|_| ())?;
match r { match r {
@ -398,7 +408,7 @@ where
}); });
} }
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
token token
.groups .groups
.iter() .iter()
@ -421,7 +431,7 @@ where
error!("time conversion error - ex_time less than epoch? {:?}", e); error!("time conversion error - ex_time less than epoch? {:?}", e);
})?; })?;
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
dbtxn dbtxn
.update_group(token, offset.as_secs()) .update_group(token, offset.as_secs())
.and_then(|_| dbtxn.commit()) .and_then(|_| dbtxn.commit())
@ -429,7 +439,7 @@ where
} }
async fn delete_cache_usertoken(&self, a_uuid: Uuid) -> Result<(), ()> { async fn delete_cache_usertoken(&self, a_uuid: Uuid) -> Result<(), ()> {
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
dbtxn dbtxn
.delete_account(a_uuid) .delete_account(a_uuid)
.and_then(|_| dbtxn.commit()) .and_then(|_| dbtxn.commit())
@ -437,7 +447,7 @@ where
} }
async fn delete_cache_grouptoken(&self, g_uuid: Uuid) -> Result<(), ()> { async fn delete_cache_grouptoken(&self, g_uuid: Uuid) -> Result<(), ()> {
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
dbtxn dbtxn
.delete_group(g_uuid) .delete_group(g_uuid)
.and_then(|_| dbtxn.commit()) .and_then(|_| dbtxn.commit())
@ -445,7 +455,7 @@ where
} }
async fn set_cache_userpassword(&self, a_uuid: Uuid, cred: &str) -> Result<(), ()> { async fn set_cache_userpassword(&self, a_uuid: Uuid, cred: &str) -> Result<(), ()> {
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
let mut hsm_txn = self.hsm.lock().await; let mut hsm_txn = self.hsm.lock().await;
dbtxn dbtxn
.update_account_password(a_uuid, cred, &mut **hsm_txn, &self.hmac_key) .update_account_password(a_uuid, cred, &mut **hsm_txn, &self.hmac_key)
@ -454,7 +464,7 @@ where
} }
async fn check_cache_userpassword(&self, a_uuid: Uuid, cred: &str) -> Result<bool, ()> { async fn check_cache_userpassword(&self, a_uuid: Uuid, cred: &str) -> Result<bool, ()> {
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
let mut hsm_txn = self.hsm.lock().await; let mut hsm_txn = self.hsm.lock().await;
dbtxn dbtxn
.check_account_password(a_uuid, cred, &mut **hsm_txn, &self.hmac_key) .check_account_password(a_uuid, cred, &mut **hsm_txn, &self.hmac_key)
@ -505,7 +515,7 @@ where
Ok(None) Ok(None)
} }
Err(IdpError::KeyStore) | Err(IdpError::BadRequest) => { Err(IdpError::KeyStore) | Err(IdpError::BadRequest) | Err(IdpError::Tpm) => {
// Some other transient error, continue with the token. // Some other transient error, continue with the token.
Ok(token) Ok(token)
} }
@ -552,7 +562,7 @@ where
self.set_nxcache(grp_id).await; self.set_nxcache(grp_id).await;
Ok(None) Ok(None)
} }
Err(IdpError::KeyStore) | Err(IdpError::BadRequest) => { Err(IdpError::KeyStore) | Err(IdpError::BadRequest) | Err(IdpError::Tpm) => {
// Some other transient error, continue with the token. // Some other transient error, continue with the token.
Ok(token) Ok(token)
} }
@ -659,7 +669,7 @@ where
} }
async fn get_groupmembers(&self, g_uuid: Uuid) -> Vec<String> { async fn get_groupmembers(&self, g_uuid: Uuid) -> Vec<String> {
let dbtxn = self.db.write().await; let mut dbtxn = self.db.write().await;
dbtxn dbtxn
.get_group_members(g_uuid) .get_group_members(g_uuid)
@ -887,7 +897,7 @@ where
.await; .await;
Err(()) Err(())
} }
Err(IdpError::BadRequest) | Err(IdpError::KeyStore) => Err(()), Err(IdpError::BadRequest) | Err(IdpError::KeyStore) | Err(IdpError::Tpm) => Err(()),
} }
} }
@ -1040,7 +1050,7 @@ where
.await; .await;
Err(()) Err(())
} }
Err(IdpError::KeyStore) | Err(IdpError::BadRequest) => Err(()), Err(IdpError::KeyStore) | Err(IdpError::BadRequest) | Err(IdpError::Tpm) => Err(()),
} }
} }