diff --git a/Cargo.lock b/Cargo.lock index 0e028233a..036dcf24b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2873,6 +2873,7 @@ dependencies = [ "sketching", "tracing", "tss-esapi", + "uuid", ] [[package]] @@ -2953,6 +2954,7 @@ dependencies = [ "kanidm_proto", "kanidm_utils_users", "kanidmd_core", + "kanidmd_testkit", "libc", "libsqlite3-sys", "lru 0.8.1", @@ -2999,6 +3001,7 @@ dependencies = [ "http", "hyper", "kanidm_build_profiles", + "kanidm_lib_crypto", "kanidm_proto", "kanidm_utils_users", "kanidmd_lib", @@ -3021,6 +3024,7 @@ dependencies = [ "tower-http", "tracing", "tracing-subscriber", + "url", "urlencoding", "uuid", ] @@ -3100,6 +3104,7 @@ dependencies = [ "hyper-tls", "kanidm_build_profiles", "kanidm_client", + "kanidm_lib_crypto", "kanidm_proto", "kanidmd_core", "kanidmd_lib", @@ -3113,8 +3118,10 @@ dependencies = [ "testkit-macros", "time", "tokio", + "tokio-openssl", "tracing", "url", + "uuid", "webauthn-authenticator-rs", ] diff --git a/Cargo.toml b/Cargo.toml index 4d5a4b1c3..1afcfd4e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,10 +74,11 @@ webauthn-rs-proto = { git = "https://github.com/kanidm/webauthn-rs.git", rev = " kanidmd_core = { path = "./server/core" } kanidmd_lib = { path = "./server/lib" } kanidmd_lib_macros = { path = "./server/lib-macros" } +kanidmd_testkit = { path = "./server/testkit" } kanidm_build_profiles = { path = "./libs/profiles", version = "1.1.0-rc.14-dev" } +kanidm_client = { path = "./libs/client", version = "1.1.0-rc.14-dev" } kanidm_lib_crypto = { path = "./libs/crypto" } kanidm_lib_file_permissions = { path = "./libs/file_permissions" } -kanidm_client = { path = "./libs/client", version = "1.1.0-rc.14-dev" } kanidm_proto = { path = "./proto", version = "1.1.0-rc.14-dev" } kanidm_unix_int = { path = "./unix_integration" } kanidm_utils_users = { path = "./libs/users" } diff --git a/Makefile b/Makefile index 0de7b485d..13c4a702c 100644 --- a/Makefile +++ b/Makefile @@ -133,7 +133,7 @@ install-tools: codespell: ## spell-check things. codespell: codespell -c \ - -L 'crate,unexpect,Pres,pres,ACI,aci,te,ue,unx,aNULL' \ + -L 'crate,unexpect,Pres,pres,ACI,aci,ser,te,ue,unx,aNULL' \ --skip='./target,./pykanidm/.venv,./pykanidm/.mypy_cache,./.mypy_cache,./pykanidm/poetry.lock' \ --skip='./book/book/*' \ --skip='./book/src/images/*' \ diff --git a/libs/crypto/Cargo.toml b/libs/crypto/Cargo.toml index 6a902ba66..df3df00fb 100644 --- a/libs/crypto/Cargo.toml +++ b/libs/crypto/Cargo.toml @@ -21,6 +21,7 @@ rand = { workspace = true } serde = { workspace = true, features = ["derive"] } tracing = { workspace = true } tss-esapi = { workspace = true, optional = true } +uuid = { workspace = true } [dev-dependencies] sketching = { workspace = true } diff --git a/libs/crypto/src/lib.rs b/libs/crypto/src/lib.rs index bbeddacda..cc9b0494c 100644 --- a/libs/crypto/src/lib.rs +++ b/libs/crypto/src/lib.rs @@ -23,11 +23,16 @@ use std::fmt; use std::time::{Duration, Instant}; use kanidm_proto::v1::OperationError; +use openssl::error::ErrorStack as OpenSSLErrorStack; use openssl::hash::{self, MessageDigest}; use openssl::nid::Nid; use openssl::pkcs5::pbkdf2_hmac; use openssl::sha::Sha512; +pub mod mtls; +pub mod prelude; +pub mod serialise; + #[cfg(feature = "tpm")] pub use tss_esapi::{handles::ObjectHandle as TpmHandle, Context as TpmContext, Error as TpmError}; #[cfg(not(feature = "tpm"))] @@ -68,13 +73,21 @@ pub enum CryptoError { Tpm2FeatureMissing, Tpm2InputExceeded, Tpm2ContextMissing, - OpenSSL, + OpenSSL(u64), Md4Disabled, Argon2, Argon2Version, Argon2Parameters, } +impl From for CryptoError { + fn from(ossl_err: OpenSSLErrorStack) -> Self { + error!(?ossl_err); + let code = ossl_err.errors().get(0).map(|e| e.code()).unwrap_or(0); + CryptoError::OpenSSL(code) + } +} + #[allow(clippy::from_over_into)] impl Into for CryptoError { fn into(self) -> OperationError { @@ -785,8 +798,8 @@ impl Password { // Turn key to a vec. Kdf::PBKDF2(pbkdf2_cost, salt, key) }) - .map_err(|_| CryptoError::OpenSSL) .map(|material| Password { material }) + .map_err(|e| e.into()) } pub fn new_argon2id(policy: &CryptoPolicy, cleartext: &str) -> Result { @@ -810,6 +823,7 @@ impl Password { }) .map_err(|_| CryptoError::Argon2) .map(|material| Password { material }) + .map_err(|e| e.into()) } pub fn new_argon2id_tpm( @@ -843,6 +857,7 @@ impl Password { key, }) .map(|material| Password { material }) + .map_err(|e| e.into()) } #[inline] @@ -962,11 +977,11 @@ impl Password { MessageDigest::sha256(), chal_key.as_mut_slice(), ) - .map_err(|_| CryptoError::OpenSSL) .map(|()| { // Actually compare the outputs. &chal_key == key }) + .map_err(|e| e.into()) } (Kdf::PBKDF2_SHA1(cost, salt, key), _) => { let key_len = key.len(); @@ -979,11 +994,11 @@ impl Password { MessageDigest::sha1(), chal_key.as_mut_slice(), ) - .map_err(|_| CryptoError::OpenSSL) .map(|()| { // Actually compare the outputs. &chal_key == key }) + .map_err(|e| e.into()) } (Kdf::PBKDF2_SHA512(cost, salt, key), _) => { let key_len = key.len(); @@ -996,11 +1011,11 @@ impl Password { MessageDigest::sha512(), chal_key.as_mut_slice(), ) - .map_err(|_| CryptoError::OpenSSL) .map(|()| { // Actually compare the outputs. &chal_key == key }) + .map_err(|e| e.into()) } (Kdf::SSHA512(salt, key), _) => { let mut hasher = Sha512::new(); diff --git a/libs/crypto/src/mtls.rs b/libs/crypto/src/mtls.rs new file mode 100644 index 000000000..203e510f2 --- /dev/null +++ b/libs/crypto/src/mtls.rs @@ -0,0 +1,89 @@ +use crate::CryptoError; + +use openssl::asn1; +use openssl::bn; +use openssl::ec; +use openssl::error::ErrorStack as OpenSSLError; +use openssl::hash; +use openssl::nid::Nid; +use openssl::pkey::{PKey, Private}; +use openssl::x509::extension::BasicConstraints; +use openssl::x509::extension::ExtendedKeyUsage; +use openssl::x509::extension::KeyUsage; +use openssl::x509::extension::SubjectAlternativeName; +use openssl::x509::extension::SubjectKeyIdentifier; +use openssl::x509::X509NameBuilder; +use openssl::x509::X509; + +use uuid::Uuid; + +/// Gets an [ec::EcGroup] for P-256 +pub fn get_group() -> Result { + ec::EcGroup::from_curve_name(Nid::X9_62_PRIME256V1) +} + +pub fn build_self_signed_server_and_client_identity( + cn: Uuid, + domain_name: &str, + expiration_days: u32, +) -> Result<(PKey, X509), CryptoError> { + let ecgroup = get_group()?; + let eckey = ec::EcKey::generate(&ecgroup)?; + let ca_key = PKey::from_ec_key(eckey)?; + let mut x509_name = X509NameBuilder::new()?; + + // x509_name.append_entry_by_text("C", "AU")?; + // x509_name.append_entry_by_text("ST", "QLD")?; + x509_name.append_entry_by_text("O", "Kanidm Replication")?; + x509_name.append_entry_by_text("CN", &cn.as_hyphenated().to_string())?; + let x509_name = x509_name.build(); + + let mut cert_builder = X509::builder()?; + // Yes, 2 actually means 3 here ... + cert_builder.set_version(2)?; + + let serial_number = bn::BigNum::from_u32(1).and_then(|serial| serial.to_asn1_integer())?; + + cert_builder.set_serial_number(&serial_number)?; + cert_builder.set_subject_name(&x509_name)?; + cert_builder.set_issuer_name(&x509_name)?; + + let not_before = asn1::Asn1Time::days_from_now(0)?; + cert_builder.set_not_before(¬_before)?; + let not_after = asn1::Asn1Time::days_from_now(expiration_days)?; + cert_builder.set_not_after(¬_after)?; + + // Do we need pathlen 0? + cert_builder.append_extension(BasicConstraints::new().critical().build()?)?; + cert_builder.append_extension( + KeyUsage::new() + .critical() + .digital_signature() + .key_encipherment() + .build()?, + )?; + + cert_builder.append_extension( + ExtendedKeyUsage::new() + .server_auth() + .client_auth() + .build()?, + )?; + + let subject_key_identifier = + SubjectKeyIdentifier::new().build(&cert_builder.x509v3_context(None, None))?; + cert_builder.append_extension(subject_key_identifier)?; + + let subject_alt_name = SubjectAlternativeName::new() + .dns(domain_name) + .build(&cert_builder.x509v3_context(None, None))?; + + cert_builder.append_extension(subject_alt_name)?; + + cert_builder.set_pubkey(&ca_key)?; + + cert_builder.sign(&ca_key, hash::MessageDigest::sha256())?; + let ca_cert = cert_builder.build(); + + Ok((ca_key, ca_cert)) +} diff --git a/libs/crypto/src/prelude.rs b/libs/crypto/src/prelude.rs new file mode 100644 index 000000000..13762a393 --- /dev/null +++ b/libs/crypto/src/prelude.rs @@ -0,0 +1,2 @@ +pub use openssl::pkey::{PKey, Private, Public}; +pub use openssl::x509::X509; diff --git a/libs/crypto/src/serialise.rs b/libs/crypto/src/serialise.rs new file mode 100644 index 000000000..3e953555e --- /dev/null +++ b/libs/crypto/src/serialise.rs @@ -0,0 +1,85 @@ +pub mod pkeyb64 { + use base64::{engine::general_purpose, Engine as _}; + use openssl::pkey::{PKey, Private}; + use serde::{ + de::Error as DeError, ser::Error as SerError, Deserialize, Deserializer, Serializer, + }; + use tracing::error; + + pub fn serialize(key: &PKey, ser: S) -> Result + where + S: Serializer, + { + let der = key.private_key_to_der().map_err(|err| { + error!(?err, "openssl private_key_to_der"); + S::Error::custom("openssl private_key_to_der") + })?; + let s = general_purpose::URL_SAFE.encode(&der); + + ser.serialize_str(&s) + } + + pub fn deserialize<'de, D>(des: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let raw = <&str>::deserialize(des)?; + let s = general_purpose::URL_SAFE.decode(raw).map_err(|err| { + error!(?err, "base64 url-safe invalid"); + D::Error::custom("base64 url-safe invalid") + })?; + + PKey::private_key_from_der(&s).map_err(|err| { + error!(?err, "openssl pkey invalid der"); + D::Error::custom("openssl pkey invalid der") + }) + } +} + +pub mod x509b64 { + use crate::CryptoError; + use base64::{engine::general_purpose, Engine as _}; + use openssl::x509::X509; + use serde::{ + de::Error as DeError, ser::Error as SerError, Deserialize, Deserializer, Serializer, + }; + use tracing::error; + + pub fn cert_to_string(cert: &X509) -> Result { + cert.to_der() + .map_err(|err| { + error!(?err, "openssl cert to_der"); + err.into() + }) + .map(|der| general_purpose::URL_SAFE.encode(&der)) + } + + pub fn serialize(cert: &X509, ser: S) -> Result + where + S: Serializer, + { + let der = cert.to_der().map_err(|err| { + error!(?err, "openssl cert to_der"); + S::Error::custom("openssl private_key_to_der") + })?; + let s = general_purpose::URL_SAFE.encode(&der); + + ser.serialize_str(&s) + } + + pub fn deserialize<'de, D>(des: D) -> Result + where + D: Deserializer<'de>, + { + let raw = <&str>::deserialize(des)?; + let s = general_purpose::URL_SAFE.decode(raw).map_err(|err| { + error!(?err, "base64 url-safe invalid"); + D::Error::custom("base64 url-safe invalid") + })?; + + X509::from_der(&s).map_err(|err| { + error!(?err, "openssl x509 invalid der"); + D::Error::custom("openssl x509 invalid der") + }) + } +} diff --git a/server/core/Cargo.toml b/server/core/Cargo.toml index d3d017023..9a19403b8 100644 --- a/server/core/Cargo.toml +++ b/server/core/Cargo.toml @@ -29,6 +29,7 @@ hyper = { workspace = true } kanidm_proto = { workspace = true } kanidm_utils_users = { workspace = true } kanidmd_lib = { workspace = true } +kanidm_lib_crypto = { workspace = true } ldap3_proto = { workspace = true } libc = { workspace = true } openssl = { workspace = true } @@ -57,8 +58,8 @@ tracing = { workspace = true, features = ["attributes"] } tracing-subscriber = { workspace = true, features = ["time", "json"] } urlencoding = { workspace = true } tempfile = { workspace = true } +url = { workspace = true, features = ["serde"] } uuid = { workspace = true, features = ["serde", "v4"] } - [build-dependencies] kanidm_build_profiles = { workspace = true } diff --git a/server/core/src/actors/v1_write.rs b/server/core/src/actors/v1_write.rs index 2bba12938..db5808482 100644 --- a/server/core/src/actors/v1_write.rs +++ b/server/core/src/actors/v1_write.rs @@ -61,8 +61,8 @@ impl QueryServerWriteV1 { proto_ml: &ProtoModifyList, filter: Filter, ) -> Result<(), OperationError> { - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let ident = idms_prox_write .validate_and_parse_token_to_ident(uat.as_deref(), ct) @@ -109,8 +109,8 @@ impl QueryServerWriteV1 { ml: &ModifyList, filter: Filter, ) -> Result<(), OperationError> { - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let ident = idms_prox_write .validate_and_parse_token_to_ident(uat.as_deref(), ct) @@ -163,8 +163,8 @@ impl QueryServerWriteV1 { req: CreateRequest, eventid: Uuid, ) -> Result<(), OperationError> { - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let ident = idms_prox_write .validate_and_parse_token_to_ident(uat.as_deref(), ct) @@ -200,8 +200,8 @@ impl QueryServerWriteV1 { req: ModifyRequest, eventid: Uuid, ) -> Result<(), OperationError> { - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let ident = idms_prox_write .validate_and_parse_token_to_ident(uat.as_deref(), ct) .map_err(|e| { @@ -236,8 +236,8 @@ impl QueryServerWriteV1 { req: DeleteRequest, eventid: Uuid, ) -> Result<(), OperationError> { - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let ident = idms_prox_write .validate_and_parse_token_to_ident(uat.as_deref(), ct) .map_err(|e| { @@ -273,8 +273,8 @@ impl QueryServerWriteV1 { eventid: Uuid, ) -> Result<(), OperationError> { // Given a protoEntry, turn this into a modification set. - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let ident = idms_prox_write .validate_and_parse_token_to_ident(uat.as_deref(), ct) .map_err(|e| { @@ -315,8 +315,8 @@ impl QueryServerWriteV1 { filter: Filter, eventid: Uuid, ) -> Result<(), OperationError> { - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let ident = idms_prox_write .validate_and_parse_token_to_ident(uat.as_deref(), ct) .map_err(|e| { @@ -350,8 +350,8 @@ impl QueryServerWriteV1 { filter: Filter, eventid: Uuid, ) -> Result<(), OperationError> { - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let ident = idms_prox_write .validate_and_parse_token_to_ident(uat.as_deref(), ct) .map_err(|e| { @@ -919,8 +919,8 @@ impl QueryServerWriteV1 { filter: Filter, eventid: Uuid, ) -> Result<(), OperationError> { - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let ident = idms_prox_write .validate_and_parse_token_to_ident(uat.as_deref(), ct) .map_err(|e| { @@ -1196,8 +1196,8 @@ impl QueryServerWriteV1 { ) -> Result<(), OperationError> { // Because this is from internal, we can generate a real modlist, rather // than relying on the proto ones. - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let ident = idms_prox_write .validate_and_parse_token_to_ident(uat.as_deref(), ct) @@ -1254,8 +1254,8 @@ impl QueryServerWriteV1 { filter: Filter, eventid: Uuid, ) -> Result<(), OperationError> { - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let ident = idms_prox_write .validate_and_parse_token_to_ident(uat.as_deref(), ct) @@ -1311,8 +1311,8 @@ impl QueryServerWriteV1 { ) -> Result<(), OperationError> { // Because this is from internal, we can generate a real modlist, rather // than relying on the proto ones. - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let ident = idms_prox_write .validate_and_parse_token_to_ident(uat.as_deref(), ct) @@ -1514,7 +1514,8 @@ impl QueryServerWriteV1 { )] pub async fn handle_purgerecycledevent(&self, msg: PurgeRecycledEvent) { trace!(?msg, "Begin purge recycled event"); - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; + let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let res = idms_prox_write .qs_write .purge_recycled() @@ -1551,7 +1552,8 @@ impl QueryServerWriteV1 { eventid: Uuid, ) -> Result { trace!(%name, "Begin admin recover account event"); - let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await; + let ct = duration_from_epoch_now(); + let mut idms_prox_write = self.idms.proxy_write(ct).await; let pw = idms_prox_write.recover_account(name.as_str(), None)?; idms_prox_write.commit().map(|()| pw) diff --git a/server/core/src/admin.rs b/server/core/src/admin.rs index bddd5ee33..beb1802e7 100644 --- a/server/core/src/admin.rs +++ b/server/core/src/admin.rs @@ -1,7 +1,9 @@ use crate::actors::v1_write::QueryServerWriteV1; +use crate::repl::ReplCtrl; use crate::CoreAction; use bytes::{BufMut, BytesMut}; use futures::{SinkExt, StreamExt}; +use kanidm_lib_crypto::serialise::x509b64; use kanidm_utils_users::get_current_uid; use serde::{Deserialize, Serialize}; use std::error::Error; @@ -9,18 +11,23 @@ use std::io; use std::path::Path; use tokio::net::{UnixListener, UnixStream}; use tokio::sync::broadcast; +use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio_util::codec::{Decoder, Encoder, Framed}; -use tracing::{span, Level}; +use tracing::{span, Instrument, Level}; use uuid::Uuid; #[derive(Serialize, Deserialize, Debug)] pub enum AdminTaskRequest { RecoverAccount { name: String }, + ShowReplicationCertificate, + RenewReplicationCertificate, } #[derive(Serialize, Deserialize, Debug)] pub enum AdminTaskResponse { RecoverAccount { password: String }, + ShowReplicationCertificate { cert: String }, Error, } @@ -99,6 +106,7 @@ impl AdminActor { sock_path: &str, server: &'static QueryServerWriteV1, mut broadcast_rx: broadcast::Receiver, + repl_ctrl_tx: Option>, ) -> Result, ()> { debug!("🧹 Cleaning up sockets from previous invocations"); rm_if_exist(sock_path); @@ -144,8 +152,9 @@ impl AdminActor { }; // spawn the worker. + let task_repl_ctrl_tx = repl_ctrl_tx.clone(); tokio::spawn(async move { - if let Err(e) = handle_client(socket, server).await { + if let Err(e) = handle_client(socket, server, task_repl_ctrl_tx).await { error!(err = ?e, "admin client error"); } }); @@ -177,9 +186,61 @@ fn rm_if_exist(p: &str) { } } +async fn show_replication_certificate(ctrl_tx: &mut mpsc::Sender) -> AdminTaskResponse { + let (tx, rx) = oneshot::channel(); + + if ctrl_tx + .send(ReplCtrl::GetCertificate { respond: tx }) + .await + .is_err() + { + error!("replication control channel has shutdown"); + return AdminTaskResponse::Error; + } + + match rx.await { + Ok(cert) => x509b64::cert_to_string(&cert) + .map(|cert| AdminTaskResponse::ShowReplicationCertificate { cert }) + .unwrap_or(AdminTaskResponse::Error), + Err(_) => { + error!("replication control channel did not respond with certificate."); + AdminTaskResponse::Error + } + } +} + +async fn renew_replication_certificate(ctrl_tx: &mut mpsc::Sender) -> AdminTaskResponse { + let (tx, rx) = oneshot::channel(); + + if ctrl_tx + .send(ReplCtrl::RenewCertificate { respond: tx }) + .await + .is_err() + { + error!("replication control channel has shutdown"); + return AdminTaskResponse::Error; + } + + match rx.await { + Ok(success) => { + if success { + show_replication_certificate(ctrl_tx).await + } else { + error!("replication control channel indicated that certificate renewal failed."); + AdminTaskResponse::Error + } + } + Err(_) => { + error!("replication control channel did not respond with renewal status."); + AdminTaskResponse::Error + } + } +} + async fn handle_client( sock: UnixStream, server: &'static QueryServerWriteV1, + mut repl_ctrl_tx: Option>, ) -> Result<(), Box> { debug!("Accepted admin socket connection"); @@ -190,22 +251,40 @@ async fn handle_client( // Setup the logging span let eventid = Uuid::new_v4(); let nspan = span!(Level::INFO, "handle_admin_client_request", uuid = ?eventid); - let _span = nspan.enter(); + // let _span = nspan.enter(); - let resp = match req { - AdminTaskRequest::RecoverAccount { name } => { - match server.handle_admin_recover_account(name, eventid).await { - Ok(password) => AdminTaskResponse::RecoverAccount { password }, - Err(e) => { - error!(err = ?e, "error during recover-account"); - AdminTaskResponse::Error + let resp = async { + match req { + AdminTaskRequest::RecoverAccount { name } => { + match server.handle_admin_recover_account(name, eventid).await { + Ok(password) => AdminTaskResponse::RecoverAccount { password }, + Err(e) => { + error!(err = ?e, "error during recover-account"); + AdminTaskResponse::Error + } } } + AdminTaskRequest::ShowReplicationCertificate => match repl_ctrl_tx.as_mut() { + Some(ctrl_tx) => show_replication_certificate(ctrl_tx).await, + None => { + error!("replication not configured, unable to display certificate."); + AdminTaskResponse::Error + } + }, + AdminTaskRequest::RenewReplicationCertificate => match repl_ctrl_tx.as_mut() { + Some(ctrl_tx) => renew_replication_certificate(ctrl_tx).await, + None => { + error!("replication not configured, unable to renew certificate."); + AdminTaskResponse::Error + } + }, } - }; + } + .instrument(nspan) + .await; + reqs.send(resp).await?; reqs.flush().await?; - trace!("flushed response!"); } debug!("Disconnecting client ..."); diff --git a/server/core/src/config.rs b/server/core/src/config.rs index 474a4f6cc..8c965ed05 100644 --- a/server/core/src/config.rs +++ b/server/core/src/config.rs @@ -4,25 +4,25 @@ //! These components should be "per server". Any "per domain" config should be in the system //! or domain entries that are able to be replicated. +use std::collections::BTreeMap; use std::fmt; use std::fs::File; use std::io::Read; +use std::net::SocketAddr; use std::path::Path; - use std::str::FromStr; use kanidm_proto::constants::DEFAULT_SERVER_ADDRESS; use kanidm_proto::messages::ConsoleOutputMode; -use serde::{Deserialize, Serialize}; + +use kanidm_lib_crypto::prelude::X509; +use kanidm_lib_crypto::serialise::x509b64; + +use serde::Deserialize; use sketching::tracing_subscriber::EnvFilter; +use url::Url; -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct IntegrationTestConfig { - pub admin_user: String, - pub admin_password: String, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone)] pub struct OnlineBackup { pub path: String, #[serde(default = "default_online_backup_schedule")] @@ -39,13 +39,54 @@ fn default_online_backup_versions() -> usize { 7 } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone)] pub struct TlsConfiguration { pub chain: String, pub key: String, } +#[derive(Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum RepNodeConfig { + #[serde(rename = "allow-pull")] + AllowPull { + #[serde(with = "x509b64")] + consumer_cert: X509, + }, + #[serde(rename = "pull")] + Pull { + #[serde(with = "x509b64")] + supplier_cert: X509, + automatic_refresh: bool, + }, + #[serde(rename = "mutual-pull")] + MutualPull { + #[serde(with = "x509b64")] + partner_cert: X509, + automatic_refresh: bool, + }, + /* + AllowPush { + }, + Push { + }, + */ +} + +#[derive(Deserialize, Debug, Clone)] +pub struct ReplicationConfiguration { + pub origin: Url, + pub bindaddress: SocketAddr, + + #[serde(flatten)] + pub manual: BTreeMap, +} + +/// This is the Server Configuration as read from server.toml. Important to note +/// is that not all flags or values from Configuration are exposed via this structure +/// to prevent certain settings being set (e.g. integration test modes) #[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] pub struct ServerConfig { pub bindaddress: Option, pub ldapbindaddress: Option, @@ -59,10 +100,15 @@ pub struct ServerConfig { pub tls_key: Option, pub online_backup: Option, pub domain: String, + // TODO -this should be URL pub origin: String, + pub log_level: Option, #[serde(default)] pub role: ServerRole, - pub log_level: Option, + #[serde(default)] + pub i_acknowledge_that_replication_is_in_development: bool, + #[serde(rename = "replication")] + pub repl_config: Option, } impl ServerConfig { @@ -85,7 +131,7 @@ impl ServerConfig { } } -#[derive(Debug, Serialize, Deserialize, Clone, Copy, Default, Eq, PartialEq)] +#[derive(Debug, Deserialize, Clone, Copy, Default, Eq, PartialEq)] pub enum ServerRole { #[default] WriteReplica, @@ -116,7 +162,7 @@ impl FromStr for ServerRole { } } -#[derive(Clone, Serialize, Deserialize, Debug, Default)] +#[derive(Clone, Deserialize, Debug, Default)] pub enum LogLevel { #[default] #[serde(rename = "info")] @@ -160,7 +206,22 @@ impl From for EnvFilter { } } -#[derive(Serialize, Deserialize, Debug, Default, Clone)] +#[derive(Debug, Clone)] +pub struct IntegrationTestConfig { + pub admin_user: String, + pub admin_password: String, +} + +#[derive(Debug, Clone)] +pub struct IntegrationReplConfig { + // We can bake in a private key for mTLS here. + // pub private_key: PKey + + // We might need some condition variables / timers to force replication + // events? Or a channel to submit with oneshot responses. +} + +#[derive(Debug, Clone)] pub struct Configuration { pub address: String, pub ldapaddress: Option, @@ -180,41 +241,64 @@ pub struct Configuration { pub role: ServerRole, pub output_mode: ConsoleOutputMode, pub log_level: LogLevel, + + /// Replication settings. + pub repl_config: Option, + /// This allows internally setting some unsafe options for replication. + pub integration_repl_config: Option>, } impl fmt::Display for Configuration { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "address: {}, ", self.address)?; - write!(f, "domain: {}, ", self.domain) - .and_then(|_| match &self.ldapaddress { - Some(la) => write!(f, "ldap address: {}, ", la), - None => write!(f, "ldap address: disabled, "), - }) - .and_then(|_| write!(f, "admin bind path: {}, ", self.adminbindpath)) - .and_then(|_| write!(f, "thread count: {}, ", self.threads)) - .and_then(|_| write!(f, "dbpath: {}, ", self.db_path)) - .and_then(|_| match self.db_arc_size { - Some(v) => write!(f, "arcsize: {}, ", v), - None => write!(f, "arcsize: AUTO, "), - }) - .and_then(|_| write!(f, "max request size: {}b, ", self.maximum_request)) - .and_then(|_| write!(f, "trust X-Forwarded-For: {}, ", self.trust_x_forward_for)) - .and_then(|_| write!(f, "with TLS: {}, ", self.tls_config.is_some())) - // TODO: include the backup timings - .and_then(|_| match &self.online_backup { - Some(_) => write!(f, "online_backup: enabled, "), - None => write!(f, "online_backup: disabled, "), - }) - .and_then(|_| write!(f, "role: {}, ", self.role.to_string())) - .and_then(|_| { + write!(f, "domain: {}, ", self.domain)?; + match &self.ldapaddress { + Some(la) => write!(f, "ldap address: {}, ", la), + None => write!(f, "ldap address: disabled, "), + }?; + write!(f, "origin: {} ", self.origin)?; + write!(f, "admin bind path: {}, ", self.adminbindpath)?; + write!(f, "thread count: {}, ", self.threads)?; + write!(f, "dbpath: {}, ", self.db_path)?; + match self.db_arc_size { + Some(v) => write!(f, "arcsize: {}, ", v), + None => write!(f, "arcsize: AUTO, "), + }?; + write!(f, "max request size: {}b, ", self.maximum_request)?; + write!(f, "trust X-Forwarded-For: {}, ", self.trust_x_forward_for)?; + write!(f, "with TLS: {}, ", self.tls_config.is_some())?; + match &self.online_backup { + Some(bck) => write!( + f, + "online_backup: enabled - schedule: {} versions: {}, ", + bck.schedule, bck.versions + ), + None => write!(f, "online_backup: disabled, "), + }?; + write!( + f, + "integration mode: {}, ", + self.integration_test_config.is_some() + )?; + write!(f, "console output format: {:?} ", self.output_mode)?; + write!(f, "log_level: {}", self.log_level.clone().to_string())?; + write!(f, "role: {}, ", self.role.to_string())?; + match &self.repl_config { + Some(repl) => { + write!(f, "replication: enabled")?; + write!(f, "repl_origin: {} ", repl.origin)?; + write!(f, "repl_address: {} ", repl.bindaddress)?; write!( f, - "integration mode: {}, ", - self.integration_test_config.is_some() - ) - }) - .and_then(|_| write!(f, "console output format: {:?} ", self.output_mode)) - .and_then(|_| write!(f, "log_level: {}", self.log_level.clone().to_string())) + "integration repl config mode: {}, ", + self.integration_repl_config.is_some() + )?; + } + None => { + write!(f, "replication: disabled, ")?; + } + } + Ok(()) } } @@ -240,9 +324,11 @@ impl Configuration { online_backup: None, domain: "idm.example.com".to_string(), origin: "https://idm.example.com".to_string(), - role: ServerRole::WriteReplica, output_mode: ConsoleOutputMode::default(), log_level: Default::default(), + role: ServerRole::WriteReplica, + repl_config: None, + integration_repl_config: None, } } @@ -335,6 +421,10 @@ impl Configuration { self.output_mode = om; } + pub fn update_replication_config(&mut self, repl_config: Option) { + self.repl_config = repl_config; + } + pub fn update_tls(&mut self, chain: &Option, key: &Option) { match (chain, key) { (None, None) => {} diff --git a/server/core/src/crypto.rs b/server/core/src/crypto.rs index 03086f467..bcd5f492f 100644 --- a/server/core/src/crypto.rs +++ b/server/core/src/crypto.rs @@ -460,7 +460,6 @@ pub(crate) fn build_cert( cert_builder.append_extension( KeyUsage::new() .critical() - // .non_repudiation() .digital_signature() .key_encipherment() .build()?, diff --git a/server/core/src/https/mod.rs b/server/core/src/https/mod.rs index f9282f711..2f4930fd5 100644 --- a/server/core/src/https/mod.rs +++ b/server/core/src/https/mod.rs @@ -2,14 +2,16 @@ mod extractors; mod generic; mod javascript; mod manifest; -mod middleware; +pub(crate) mod middleware; mod oauth2; mod tests; -mod trace; +pub(crate) mod trace; mod ui; mod v1; mod v1_scim; +use self::generic::*; +use self::javascript::*; use crate::actors::v1_read::QueryServerReadV1; use crate::actors::v1_write::QueryServerWriteV1; use crate::config::{Configuration, ServerRole, TlsConfiguration}; @@ -21,13 +23,11 @@ use axum::Router; use axum_csp::{CspDirectiveType, CspValue}; use axum_macros::FromRef; use compact_jwt::{Jws, JwsSigner, JwsUnverified}; -use generic::*; use http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE}; use http::{HeaderMap, HeaderValue, StatusCode}; use hyper::server::accept::Accept; use hyper::server::conn::{AddrStream, Http}; use hyper::Body; -use javascript::*; use kanidm_proto::constants::APPLICATION_JSON; use kanidm_proto::v1::OperationError; use kanidmd_lib::status::StatusActor; @@ -288,7 +288,17 @@ pub async fn create_https_server( } res = match config.tls_config { Some(tls_param) => { - tokio::spawn(server_loop(tls_param, addr, app)) + // This isn't optimal, but we can't share this with the + // other path for integration tests because that doesn't + // do tls (yet?) + let listener = match TcpListener::bind(addr).await { + Ok(l) => l, + Err(err) => { + error!(?err, "Failed to bind tcp listener"); + return + } + }; + tokio::spawn(server_loop(tls_param, listener, app)) }, None => { tokio::spawn(axum_server::bind(addr).serve(app)) @@ -307,7 +317,7 @@ pub async fn create_https_server( async fn server_loop( tls_param: TlsConfiguration, - addr: SocketAddr, + listener: TcpListener, app: IntoMakeServiceWithConnectInfo, ) -> Result<(), std::io::Error> { let mut tls_builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?; @@ -335,7 +345,6 @@ async fn server_loop( ) })?; let acceptor = tls_builder.build(); - let listener = TcpListener::bind(addr).await?; let protocol = Arc::new(Http::new()); let mut listener = @@ -355,7 +364,7 @@ async fn server_loop( } /// This handles an individual connection. -async fn handle_conn( +pub(crate) async fn handle_conn( acceptor: SslAcceptor, stream: AddrStream, svc: ResponseFuture, diff --git a/server/core/src/ldaps.rs b/server/core/src/ldaps.rs index ffbe63596..4cfddc4c0 100644 --- a/server/core/src/ldaps.rs +++ b/server/core/src/ldaps.rs @@ -1,5 +1,5 @@ -use std::marker::Unpin; use std::net; +use std::pin::Pin; use std::str::FromStr; use crate::actors::v1_read::QueryServerReadV1; @@ -10,8 +10,7 @@ use kanidmd_lib::prelude::*; use ldap3_proto::proto::LdapMsg; use ldap3_proto::LdapCodec; use openssl::ssl::{Ssl, SslAcceptor}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::TcpListener; +use tokio::net::{TcpListener, TcpStream}; use tokio_openssl::SslStream; use tokio_util::codec::{FramedRead, FramedWrite}; @@ -47,12 +46,31 @@ async fn client_process_msg( qe_r_ref.handle_ldaprequest(eventid, protomsg, uat).await } -async fn client_process( - mut r: FramedRead, - mut w: FramedWrite, +async fn client_process( + tcpstream: TcpStream, + tls_acceptor: SslAcceptor, client_address: net::SocketAddr, qe_r_ref: &'static QueryServerReadV1, ) { + // Start the event + // From the parameters we need to create an SslContext. + let mut tlsstream = match Ssl::new(tls_acceptor.context()) + .and_then(|tls_obj| SslStream::new(tls_obj, tcpstream)) + { + Ok(ta) => ta, + Err(e) => { + error!("LDAP TLS setup error, continuing -> {:?}", e); + return; + } + }; + if let Err(e) = SslStream::accept(Pin::new(&mut tlsstream)).await { + error!("LDAP TLS accept error, continuing -> {:?}", e); + return; + }; + let (r, w) = tokio::io::split(tlsstream); + let mut r = FramedRead::new(r, LdapCodec); + let mut w = FramedWrite::new(w, LdapCodec); + // This is a connected client session. we need to associate some state to the session let mut session = LdapSession::new(); // Now that we have the session we begin an event loop to process input OR we return. @@ -108,7 +126,7 @@ async fn client_process( /// TLS LDAP Listener, hands off to [client_process] async fn tls_acceptor( listener: TcpListener, - ssl_acceptor: SslAcceptor, + tls_acceptor: SslAcceptor, qe_r_ref: &'static QueryServerReadV1, mut rx: broadcast::Receiver, ) { @@ -122,25 +140,8 @@ async fn tls_acceptor( accept_result = listener.accept() => { match accept_result { Ok((tcpstream, client_socket_addr)) => { - // Start the event - // From the parameters we need to create an SslContext. - let mut tlsstream = match Ssl::new(ssl_acceptor.context()) - .and_then(|tls_obj| SslStream::new(tls_obj, tcpstream)) - { - Ok(ta) => ta, - Err(e) => { - error!("LDAP TLS setup error, continuing -> {:?}", e); - continue; - } - }; - if let Err(e) = SslStream::accept(Pin::new(&mut tlsstream)).await { - error!("LDAP TLS accept error, continuing -> {:?}", e); - continue; - }; - let (r, w) = tokio::io::split(tlsstream); - let r = FramedRead::new(r, LdapCodec); - let w = FramedWrite::new(w, LdapCodec); - tokio::spawn(client_process(r, w, client_socket_addr, qe_r_ref)); + let clone_tls_acceptor = tls_acceptor.clone(); + tokio::spawn(client_process(tcpstream, clone_tls_acceptor, client_socket_addr, qe_r_ref)); } Err(e) => { error!("LDAP acceptor error, continuing -> {:?}", e); diff --git a/server/core/src/lib.rs b/server/core/src/lib.rs index 0d6e1c1e4..93c46091d 100644 --- a/server/core/src/lib.rs +++ b/server/core/src/lib.rs @@ -32,6 +32,7 @@ mod crypto; mod https; mod interval; mod ldaps; +mod repl; use std::path::Path; use std::sync::Arc; @@ -867,22 +868,6 @@ pub async fn create_server_core( } }; - // If we are NOT in integration test mode, start the admin socket now - let maybe_admin_sock_handle = if config.integration_test_config.is_none() { - let broadcast_rx = broadcast_tx.subscribe(); - - let admin_handle = AdminActor::create_admin_sock( - config.adminbindpath.as_str(), - server_write_ref, - broadcast_rx, - ) - .await?; - - Some(admin_handle) - } else { - None - }; - // If we have been requested to init LDAP, configure it now. let maybe_ldap_acceptor_handle = match &config.ldapaddress { Some(la) => { @@ -913,6 +898,26 @@ pub async fn create_server_core( } }; + // If we have replication configured, setup the listener with it's initial replication + // map (if any). + let (maybe_repl_handle, maybe_repl_ctrl_tx) = match &config.repl_config { + Some(rc) => { + if !config_test { + // ⚠️ only start the sockets and listeners in non-config-test modes. + let (h, repl_ctrl_tx) = + repl::create_repl_server(idms_arc.clone(), rc, broadcast_tx.subscribe()) + .await?; + (Some(h), Some(repl_ctrl_tx)) + } else { + (None, None) + } + } + None => { + debug!("Replication not requested, skipping"); + (None, None) + } + }; + let maybe_http_acceptor_handle = if config_test { admin_info!("this config rocks! 🪨 "); None @@ -941,6 +946,23 @@ pub async fn create_server_core( Some(h) }; + // If we are NOT in integration test mode, start the admin socket now + let maybe_admin_sock_handle = if config.integration_test_config.is_none() { + let broadcast_rx = broadcast_tx.subscribe(); + + let admin_handle = AdminActor::create_admin_sock( + config.adminbindpath.as_str(), + server_write_ref, + broadcast_rx, + maybe_repl_ctrl_tx, + ) + .await?; + + Some(admin_handle) + } else { + None + }; + let mut handles = vec![interval_handle, delayed_handle, auditd_handle]; if let Some(backup_handle) = maybe_backup_handle { @@ -959,6 +981,10 @@ pub async fn create_server_core( handles.push(http_handle) } + if let Some(repl_handle) = maybe_repl_handle { + handles.push(repl_handle) + } + Ok(CoreHandle { clean_shutdown: false, tx: broadcast_tx, diff --git a/server/core/src/repl/codec.rs b/server/core/src/repl/codec.rs new file mode 100644 index 000000000..c9b68fe89 --- /dev/null +++ b/server/core/src/repl/codec.rs @@ -0,0 +1,250 @@ +use bytes::{BufMut, BytesMut}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::io; +use tokio_util::codec::{Decoder, Encoder}; + +use kanidmd_lib::repl::proto::{ReplIncrementalContext, ReplRefreshContext, ReplRuvRange}; + +#[derive(Serialize, Deserialize, Debug)] +pub enum ConsumerRequest { + Ping, + Incremental(ReplRuvRange), + Refresh, +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum SupplierResponse { + Pong, + Incremental(ReplIncrementalContext), + Refresh(ReplRefreshContext), +} + +#[derive(Default)] +pub struct ConsumerCodec { + max_frame_bytes: usize, +} + +impl ConsumerCodec { + pub fn new(max_frame_bytes: usize) -> Self { + ConsumerCodec { max_frame_bytes } + } +} + +impl Decoder for ConsumerCodec { + type Error = io::Error; + type Item = SupplierResponse; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + decode_length_checked_json(self.max_frame_bytes, src) + } +} + +impl Encoder for ConsumerCodec { + type Error = io::Error; + + fn encode(&mut self, msg: ConsumerRequest, dst: &mut BytesMut) -> Result<(), Self::Error> { + encode_length_checked_json(msg, dst) + } +} + +#[derive(Default)] +pub struct SupplierCodec { + max_frame_bytes: usize, +} + +impl SupplierCodec { + pub fn new(max_frame_bytes: usize) -> Self { + SupplierCodec { max_frame_bytes } + } +} + +impl Decoder for SupplierCodec { + type Error = io::Error; + type Item = ConsumerRequest; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + decode_length_checked_json(self.max_frame_bytes, src) + } +} + +impl Encoder for SupplierCodec { + type Error = io::Error; + + fn encode(&mut self, msg: SupplierResponse, dst: &mut BytesMut) -> Result<(), Self::Error> { + encode_length_checked_json(msg, dst) + } +} + +fn encode_length_checked_json(msg: R, dst: &mut BytesMut) -> Result<(), io::Error> { + // Null the head of the buffer. + let zero_len = u64::MIN.to_be_bytes(); + dst.extend_from_slice(&zero_len); + + // skip the buffer ahead 8 bytes. + // Remember, this split returns the *already set* bytes. + // ⚠️ Can't use split or split_at - these return the + // len bytes into a new bytes mut which confuses unsplit + // by appending the value when we need to append our json. + let json_buf = dst.split_off(zero_len.len()); + + let mut json_writer = json_buf.writer(); + + serde_json::to_writer(&mut json_writer, &msg).map_err(|err| { + error!(?err, "consumer encoding error"); + io::Error::new(io::ErrorKind::Other, "JSON encode error") + })?; + + let json_buf = json_writer.into_inner(); + + let final_len = json_buf.len() as u64; + let final_len_bytes = final_len.to_be_bytes(); + + if final_len_bytes.len() != dst.len() { + error!("consumer buffer size error"); + return Err(io::Error::new(io::ErrorKind::Other, "buffer length error")); + } + + dst.copy_from_slice(&final_len_bytes); + + // Now stitch them back together. + dst.unsplit(json_buf); + + Ok(()) +} + +fn decode_length_checked_json( + max_frame_bytes: usize, + src: &mut BytesMut, +) -> Result, io::Error> { + trace!(capacity = ?src.capacity()); + + if src.len() < 8 { + // Not enough for the length header. + trace!("Insufficient bytes for length header."); + return Ok(None); + } + + let (src_len_bytes, json_bytes) = src.split_at(8); + let mut len_be_bytes = [0; 8]; + + assert_eq!(len_be_bytes.len(), src_len_bytes.len()); + len_be_bytes.copy_from_slice(src_len_bytes); + let req_len = u64::from_be_bytes(len_be_bytes); + + if req_len == 0 { + error!("request has size 0"); + return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty request")); + } + + if req_len > max_frame_bytes as u64 { + error!( + "requested decode frame too large {} > {}", + req_len, max_frame_bytes + ); + return Err(io::Error::new( + io::ErrorKind::OutOfMemory, + "request too large", + )); + } + + if (src.len() as u64) < req_len { + trace!( + "Insufficient bytes for json, need: {} have: {}", + req_len, + src.len() + ); + return Ok(None); + } + + // Okay, we have enough. Lets go. + let res = serde_json::from_slice(json_bytes) + .map(|msg| Some(msg)) + .map_err(|err| { + error!(?err, "received invalid input"); + io::Error::new(io::ErrorKind::InvalidInput, "JSON decode error") + }); + + // Trim to length. + + if src.len() as u64 == req_len { + src.clear(); + } else { + let mut rem = src.split_off((8 + req_len) as usize); + std::mem::swap(&mut rem, src); + }; + + res +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + use tokio_util::codec::{Decoder, Encoder}; + + use super::{ConsumerCodec, ConsumerRequest, SupplierCodec, SupplierResponse}; + + #[test] + fn test_repl_codec() { + sketching::test_init(); + + let mut consumer_codec = ConsumerCodec::new(32); + + let mut buf = BytesMut::with_capacity(32); + + // Empty buffer + assert!(matches!(consumer_codec.decode(&mut buf), Ok(None))); + + let zero = [0, 0, 0, 0]; + buf.extend_from_slice(&zero); + + // Not enough to fill the length header. + assert!(matches!(consumer_codec.decode(&mut buf), Ok(None))); + + // Length header reports a zero size request. + let zero = [0, 0, 0, 0]; + buf.extend_from_slice(&zero); + assert!(buf.len() == 8); + assert!(consumer_codec.decode(&mut buf).is_err()); + + // Clear buffer - setup a request with a length > allowed max. + buf.clear(); + let len_bytes = (34 as u64).to_be_bytes(); + buf.extend_from_slice(&len_bytes); + + // Even though the buf len is only 8, this will error as the overall + // request will be too large. + assert!(buf.len() == 8); + assert!(consumer_codec.decode(&mut buf).is_err()); + + // Assert that we request more data on a validly sized req + buf.clear(); + let len_bytes = (20 as u64).to_be_bytes(); + buf.extend_from_slice(&len_bytes); + // Pad in some extra bytes. + buf.extend_from_slice(&zero); + assert!(buf.len() == 12); + assert!(matches!(consumer_codec.decode(&mut buf), Ok(None))); + + // Make a request that is correctly sized. + buf.clear(); + let mut supplier_codec = SupplierCodec::new(32); + + assert!(consumer_codec + .encode(ConsumerRequest::Ping, &mut buf) + .is_ok()); + assert!(matches!( + supplier_codec.decode(&mut buf), + Ok(Some(ConsumerRequest::Ping)) + )); + // The buf will have been cleared by the supplier codec here. + assert!(buf.is_empty()); + assert!(supplier_codec + .encode(SupplierResponse::Pong, &mut buf) + .is_ok()); + assert!(matches!( + consumer_codec.decode(&mut buf), + Ok(Some(SupplierResponse::Pong)) + )); + assert!(buf.is_empty()); + } +} diff --git a/server/core/src/repl/mod.rs b/server/core/src/repl/mod.rs new file mode 100644 index 000000000..683498a4c --- /dev/null +++ b/server/core/src/repl/mod.rs @@ -0,0 +1,712 @@ +use openssl::{ + pkey::{PKey, Private}, + ssl::{Ssl, SslAcceptor, SslConnector, SslMethod, SslVerifyMode}, + x509::{store::X509StoreBuilder, X509}, +}; +use std::collections::VecDeque; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::broadcast; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::time::{interval, sleep, timeout}; +use tokio_openssl::SslStream; +use tokio_util::codec::{FramedRead, FramedWrite}; +use tracing::{error, Instrument}; +use url::Url; +use uuid::Uuid; + +use futures_util::sink::SinkExt; +use futures_util::stream::StreamExt; + +use kanidmd_lib::prelude::duration_from_epoch_now; +use kanidmd_lib::prelude::IdmServer; +use kanidmd_lib::repl::proto::ConsumerState; +use kanidmd_lib::server::QueryServerTransaction; + +use crate::config::RepNodeConfig; +use crate::config::ReplicationConfiguration; +use crate::CoreAction; + +use self::codec::{ConsumerRequest, SupplierResponse}; + +mod codec; + +pub(crate) enum ReplCtrl { + GetCertificate { respond: oneshot::Sender }, + RenewCertificate { respond: oneshot::Sender }, +} + +pub(crate) async fn create_repl_server( + idms: Arc, + repl_config: &ReplicationConfiguration, + rx: broadcast::Receiver, +) -> Result<(tokio::task::JoinHandle<()>, mpsc::Sender), ()> { + // We need to start the tcp listener. This will persist over ssl reloads! + let listener = TcpListener::bind(&repl_config.bindaddress) + .await + .map_err(|e| { + error!( + "Could not bind to replication address {} -> {:?}", + repl_config.bindaddress, e + ); + })?; + + // Create the control channel. Use a low msg count, there won't be that much going on. + let (ctrl_tx, ctrl_rx) = mpsc::channel(4); + + // We need to start the tcp listener. This will persist over ssl reloads! + info!( + "Starting replication interface https://{} ...", + repl_config.bindaddress + ); + let repl_handle = tokio::spawn(repl_acceptor( + listener, + idms, + repl_config.clone(), + rx, + ctrl_rx, + )); + + info!("Created replication interface"); + Ok((repl_handle, ctrl_tx)) +} + +#[instrument(level = "info", skip_all)] +async fn repl_run_consumer( + max_frame_bytes: usize, + domain: &str, + sock_addrs: &[SocketAddr], + tls_connector: &SslConnector, + automatic_refresh: bool, + idms: &IdmServer, +) { + let replica_connect_timeout = Duration::from_secs(2); + + // This is pretty gnarly, but we need to loop to try out each socket addr. + for sock_addr in sock_addrs { + debug!("Connecting to {} replica via {}", domain, sock_addr); + + let tcpstream = match timeout(replica_connect_timeout, TcpStream::connect(sock_addr)).await + { + Ok(Ok(tc)) => tc, + Ok(Err(err)) => { + error!(?err, "Failed to connect to {}", sock_addr); + continue; + } + Err(_) => { + error!("Timeout connecting to {}", sock_addr); + continue; + } + }; + + trace!("connection established"); + + let mut tlsstream = match Ssl::new(tls_connector.context()) + .and_then(|tls_obj| SslStream::new(tls_obj, tcpstream)) + { + Ok(ta) => ta, + Err(e) => { + error!("replication client TLS setup error, continuing -> {:?}", e); + continue; + } + }; + + if let Err(e) = SslStream::connect(Pin::new(&mut tlsstream)).await { + error!("replication client TLS accept error, continuing -> {:?}", e); + continue; + }; + let (r, w) = tokio::io::split(tlsstream); + let mut r = FramedRead::new(r, codec::ConsumerCodec::new(max_frame_bytes)); + let mut w = FramedWrite::new(w, codec::ConsumerCodec::new(max_frame_bytes)); + + // Perform incremental. + let consumer_ruv_range = { + let mut read_txn = idms.proxy_read().await; + match read_txn.qs_read.consumer_get_state() { + Ok(ruv_range) => ruv_range, + Err(err) => { + error!( + ?err, + "consumer ruv range could not be accessed, unable to continue." + ); + break; + } + } + }; + + if let Err(err) = w + .send(ConsumerRequest::Incremental(consumer_ruv_range)) + .await + { + error!(?err, "consumer encode error, unable to continue."); + break; + } + + let changes = if let Some(codec_msg) = r.next().await { + match codec_msg { + Ok(SupplierResponse::Incremental(changes)) => { + // Success - return to bypass the error message. + changes + } + Ok(SupplierResponse::Pong) | Ok(SupplierResponse::Refresh(_)) => { + error!("Supplier Response contains invalid State"); + break; + } + Err(err) => { + error!(?err, "consumer decode error, unable to continue."); + break; + } + } + } else { + error!("Connection closed"); + break; + }; + + // Now apply the changes if possible + let consumer_state = { + let ct = duration_from_epoch_now(); + let mut write_txn = idms.proxy_write(ct).await; + match write_txn + .qs_write + .consumer_apply_changes(&changes) + .and_then(|cs| write_txn.commit().map(|()| cs)) + { + Ok(state) => state, + Err(err) => { + error!(?err, "consumer was not able to apply changes."); + break; + } + } + }; + + match consumer_state { + ConsumerState::Ok => { + info!("Incremental Replication Success"); + // return to bypass the failure message. + return; + } + ConsumerState::RefreshRequired => { + if automatic_refresh { + warn!("Consumer is out of date and must be refreshed. This will happen *now*."); + } else { + error!("Consumer is out of date and must be refreshed. You must manually resolve this situation."); + return; + }; + } + } + + if let Err(err) = w.send(ConsumerRequest::Refresh).await { + error!(?err, "consumer encode error, unable to continue."); + break; + } + + let refresh = if let Some(codec_msg) = r.next().await { + match codec_msg { + Ok(SupplierResponse::Refresh(changes)) => { + // Success - return to bypass the error message. + changes + } + Ok(SupplierResponse::Pong) | Ok(SupplierResponse::Incremental(_)) => { + error!("Supplier Response contains invalid State"); + break; + } + Err(err) => { + error!(?err, "consumer decode error, unable to continue."); + break; + } + } + } else { + error!("Connection closed"); + break; + }; + + // Now apply the refresh if possible + let ct = duration_from_epoch_now(); + let mut write_txn = idms.proxy_write(ct).await; + if let Err(err) = write_txn + .qs_write + .consumer_apply_refresh(&refresh) + .and_then(|cs| write_txn.commit().map(|()| cs)) + { + error!(?err, "consumer was not able to apply refresh."); + break; + } + + warn!("Replication refresh was successful."); + return; + } + + error!("Unable to complete replication successfully."); +} + +async fn repl_task( + origin: Url, + client_key: PKey, + client_cert: X509, + supplier_cert: X509, + max_frame_bytes: usize, + task_poll_interval: Duration, + mut task_rx: broadcast::Receiver<()>, + automatic_refresh: bool, + idms: Arc, +) { + if origin.scheme() != "repl" { + error!("Replica origin is not repl:// - refusing to proceed."); + return; + } + + let domain = match origin.domain() { + Some(d) => d, + None => { + error!("Replica origin does not have a valid domain name, unable to proceed. Perhaps you tried to use an ip address?"); + return; + } + }; + + let socket_addrs = match origin.socket_addrs(|| Some(443)) { + Ok(sa) => sa, + Err(err) => { + error!(?err, "Replica origin could not resolve to ip:port"); + return; + } + }; + + // Setup our tls connector. + let mut ssl_builder = match SslConnector::builder(SslMethod::tls_client()) { + Ok(sb) => sb, + Err(err) => { + error!(?err, "Unable to configure tls connector"); + return; + } + }; + + let setup_client_cert = ssl_builder + .set_certificate(&client_cert) + .and_then(|_| ssl_builder.set_private_key(&client_key)) + .and_then(|_| ssl_builder.check_private_key()); + if let Err(err) = setup_client_cert { + error!(?err, "Unable to configure client certificate/key"); + return; + } + + // Add the supplier cert. + // ⚠️ note that here we need to build a new cert store. This is because + // openssl SslConnector adds the default system cert locations with + // the call to ::builder and we *don't* want this. We want our certstore + // to pin a single certificate! + let mut cert_store = match X509StoreBuilder::new() { + Ok(csb) => csb, + Err(err) => { + error!(?err, "Unable to configure certificate store builder."); + return; + } + }; + + if let Err(err) = cert_store.add_cert(supplier_cert) { + error!(?err, "Unable to add supplier certificate to cert store"); + return; + } + + let cert_store = cert_store.build(); + ssl_builder.set_cert_store(cert_store); + + // Configure the expected hostname of the remote. + let verify_param = ssl_builder.verify_param_mut(); + if let Err(err) = verify_param.set_host(domain) { + error!(?err, "Unable to set domain name for tls peer verification"); + return; + } + + // Assert the expected supplier certificate is correct and has a valid domain san + ssl_builder.set_verify(SslVerifyMode::PEER); + let tls_connector = ssl_builder.build(); + + let mut repl_interval = interval(task_poll_interval); + + info!("Replica task for {} has started.", origin); + + // Okay, all the parameters are setup. Now we wait on our interval. + loop { + tokio::select! { + Ok(()) = task_rx.recv() => { + break; + } + _ = repl_interval.tick() => { + // Interval passed, attempt a replication run. + let eventid = Uuid::new_v4(); + let span = info_span!("replication_run_consumer", uuid = ?eventid); + let _enter = span.enter(); + repl_run_consumer( + max_frame_bytes, + domain, + &socket_addrs, + &tls_connector, + automatic_refresh, + &idms + ).await; + } + } + } + + info!("Replica task for {} has stopped.", origin); +} + +#[instrument(level = "info", skip_all)] +async fn handle_repl_conn( + max_frame_bytes: usize, + tcpstream: TcpStream, + client_address: SocketAddr, + tls_parms: SslAcceptor, + idms: Arc, +) { + debug!(?client_address, "replication client connected 🛫"); + + let mut tlsstream = match Ssl::new(tls_parms.context()) + .and_then(|tls_obj| SslStream::new(tls_obj, tcpstream)) + { + Ok(ta) => ta, + Err(err) => { + error!(?err, "LDAP TLS setup error, disconnecting client"); + return; + } + }; + if let Err(err) = SslStream::accept(Pin::new(&mut tlsstream)).await { + error!(?err, "LDAP TLS accept error, disconnecting client"); + return; + }; + let (r, w) = tokio::io::split(tlsstream); + let mut r = FramedRead::new(r, codec::SupplierCodec::new(max_frame_bytes)); + let mut w = FramedWrite::new(w, codec::SupplierCodec::new(max_frame_bytes)); + + while let Some(codec_msg) = r.next().await { + match codec_msg { + Ok(ConsumerRequest::Ping) => { + debug!("consumer requested ping"); + if let Err(err) = w.send(SupplierResponse::Pong).await { + error!(?err, "supplier encode error, unable to continue."); + break; + } + } + Ok(ConsumerRequest::Incremental(consumer_ruv_range)) => { + let mut read_txn = idms.proxy_read().await; + + let changes = match read_txn + .qs_read + .supplier_provide_changes(consumer_ruv_range) + { + Ok(changes) => changes, + Err(err) => { + error!(?err, "supplier provide changes failed."); + break; + } + }; + + if let Err(err) = w.send(SupplierResponse::Incremental(changes)).await { + error!(?err, "supplier encode error, unable to continue."); + break; + } + } + Ok(ConsumerRequest::Refresh) => { + let mut read_txn = idms.proxy_read().await; + + let changes = match read_txn.qs_read.supplier_provide_refresh() { + Ok(changes) => changes, + Err(err) => { + error!(?err, "supplier provide refresh failed."); + break; + } + }; + + if let Err(err) = w.send(SupplierResponse::Refresh(changes)).await { + error!(?err, "supplier encode error, unable to continue."); + break; + } + } + Err(err) => { + error!(?err, "supplier decode error, unable to continue."); + break; + } + } + } + + debug!(?client_address, "replication client disconnected 🛬"); +} + +async fn repl_acceptor( + listener: TcpListener, + idms: Arc, + repl_config: ReplicationConfiguration, + mut rx: broadcast::Receiver, + mut ctrl_rx: mpsc::Receiver, +) { + info!("Starting Replication Acceptor ..."); + // Persistent parts + // These all probably need changes later ... + let task_poll_interval = Duration::from_secs(10); + let retry_timeout = Duration::from_secs(60); + let max_frame_bytes = 268435456; + + // Setup a broadcast to control our tasks. + let (task_tx, task_rx1) = broadcast::channel(2); + // Note, we drop this task here since each task will re-subscribe. That way the + // broadcast doesn't jam up because we aren't draining this task. + drop(task_rx1); + let mut task_handles = VecDeque::new(); + + // Create another broadcast to control the replication tasks and their need to reload. + + // Spawn a KRC communication task? + + // In future we need to update this from the KRC if configured, and we default this + // to "empty". But if this map exists in the config, we have to always use that. + let replication_node_map = repl_config.manual.clone(); + + // This needs to have an event loop that can respond to changes. + // For now we just design it to reload ssl if the map changes internally. + 'event: loop { + info!("Starting replication reload ..."); + // Tell existing tasks to shutdown. + // Note: We ignore the result here since an err can occur *if* there are + // no tasks currently listening on the channel. + info!("Stopping {} Replication Tasks ...", task_handles.len()); + debug_assert!(task_handles.len() >= task_tx.receiver_count()); + let _ = task_tx.send(()); + for task_handle in task_handles.drain(..) { + // Let each task join. + let res: Result<(), _> = task_handle.await; + if res.is_err() { + warn!("Failed to join replication task, continuing ..."); + } + } + + // Now we can start to re-load configurations and setup our client tasks + // as well. + + // Get the private key / cert. + let res = { + // Does this actually need to be a read in case we need to write + // to sqlite? + let ct = duration_from_epoch_now(); + let mut idms_prox_write = idms.proxy_write(ct).await; + idms_prox_write + .qs_write + .supplier_get_key_cert() + .and_then(|res| idms_prox_write.commit().map(|()| res)) + }; + + let (server_key, server_cert) = match res { + Ok(r) => r, + Err(err) => { + error!(?err, "CRITICAL: Unable to access supplier certificate/key."); + sleep(retry_timeout).await; + continue; + } + }; + + info!( + replication_cert_not_before = ?server_cert.not_before(), + replication_cert_not_after = ?server_cert.not_after(), + ); + + let mut client_certs = Vec::new(); + + // For each node in the map, either spawn a task to pull from that node, + // or setup the node as allowed to pull from us. + for (origin, node) in replication_node_map.iter() { + // Setup client certs + match node { + RepNodeConfig::MutualPull { + partner_cert: consumer_cert, + automatic_refresh: _, + } + | RepNodeConfig::AllowPull { consumer_cert } => { + client_certs.push(consumer_cert.clone()) + } + RepNodeConfig::Pull { + supplier_cert: _, + automatic_refresh: _, + } => {} + }; + + match node { + RepNodeConfig::MutualPull { + partner_cert: supplier_cert, + automatic_refresh, + } + | RepNodeConfig::Pull { + supplier_cert, + automatic_refresh, + } => { + let task_rx = task_tx.subscribe(); + + let handle = tokio::spawn(repl_task( + origin.clone(), + server_key.clone(), + server_cert.clone(), + supplier_cert.clone(), + max_frame_bytes, + task_poll_interval, + task_rx, + *automatic_refresh, + idms.clone(), + )); + + task_handles.push_back(handle); + debug_assert!(task_handles.len() == task_tx.receiver_count()); + } + RepNodeConfig::AllowPull { consumer_cert: _ } => {} + }; + } + + // ⚠️ This section is critical to the security of replication + // Since replication relies on mTLS we MUST ensure these options + // are absolutely correct! + // + // Setup the TLS builder. + let mut tls_builder = match SslAcceptor::mozilla_modern_v5(SslMethod::tls()) { + Ok(tls_builder) => tls_builder, + Err(err) => { + error!(?err, "CRITICAL, unable to create SslAcceptorBuilder."); + sleep(retry_timeout).await; + continue; + } + }; + + // tls_builder.set_keylog_callback(keylog_cb); + if let Err(err) = tls_builder + .set_certificate(&server_cert) + .and_then(|_| tls_builder.set_private_key(&server_key)) + .and_then(|_| tls_builder.check_private_key()) + { + error!(?err, "CRITICAL, unable to set server_cert and server key."); + sleep(retry_timeout).await; + continue; + }; + + // ⚠️ CRITICAL - ensure that the cert store only has client certs from + // the repl map added. + let cert_store = tls_builder.cert_store_mut(); + for client_cert in client_certs.into_iter() { + if let Err(err) = cert_store.add_cert(client_cert.clone()) { + error!(?err, "CRITICAL, unable to add client certificates."); + sleep(retry_timeout).await; + continue; + } + } + + // ⚠️ CRITICAL - Both verifications here are needed. PEER requests + // the client cert to be sent. FAIL_IF_NO_PEER_CERT triggers an + // error if the cert is NOT present. FAIL_IF_NO_PEER_CERT on it's own + // DOES NOTHING. + let mut verify = SslVerifyMode::PEER; + verify.insert(SslVerifyMode::FAIL_IF_NO_PEER_CERT); + tls_builder.set_verify(verify); + + let tls_acceptor = tls_builder.build(); + + loop { + // This is great to diagnose when spans are entered or present and they capture + // things incorrectly. + // eprintln!("🔥 C ---> {:?}", tracing::Span::current()); + let eventid = Uuid::new_v4(); + + tokio::select! { + Ok(action) = rx.recv() => { + match action { + CoreAction::Shutdown => break 'event, + } + } + Some(ctrl_msg) = ctrl_rx.recv() => { + match ctrl_msg { + ReplCtrl::GetCertificate { + respond + } => { + let _span = debug_span!("supplier_accept_loop", uuid = ?eventid).entered(); + if let Err(_) = respond.send(server_cert.clone()) { + warn!("Server certificate was requested, but requsetor disconnected"); + } else { + trace!("Sent server certificate via control channel"); + } + } + ReplCtrl::RenewCertificate { + respond + } => { + let span = debug_span!("supplier_accept_loop", uuid = ?eventid); + async { + debug!("renewing replication certificate ..."); + // Renew the cert. + let res = { + let ct = duration_from_epoch_now(); + let mut idms_prox_write = idms.proxy_write(ct).await; + idms_prox_write + .qs_write + .supplier_renew_key_cert() + .and_then(|res| idms_prox_write.commit().map(|()| res)) + }; + + let success = res.is_ok(); + + if let Err(err) = res { + error!(?err, "failed to renew server certificate"); + } + + if let Err(_) = respond.send(success) { + warn!("Server certificate renewal was requested, but requsetor disconnected"); + } else { + trace!("Sent server certificate renewal status via control channel"); + } + } + .instrument(span) + .await; + + // Start a reload. + continue 'event; + } + } + } + // Handle accepts. + // Handle *reloads* + /* + _ = reload.recv() => { + info!("initiate tls reload"); + continue + } + */ + accept_result = listener.accept() => { + match accept_result { + Ok((tcpstream, client_socket_addr)) => { + let clone_idms = idms.clone(); + let clone_tls_acceptor = tls_acceptor.clone(); + // We don't care about the join handle here - once a client connects + // it sticks to whatever ssl settings it had at launch. + let _ = tokio::spawn( + handle_repl_conn(max_frame_bytes, tcpstream, client_socket_addr, clone_tls_acceptor, clone_idms) + ); + } + Err(e) => { + error!("replication acceptor error, continuing -> {:?}", e); + } + } + } + } // end select + // Continue to poll/loop + } + } + // Shutdown child tasks. + info!("Stopping {} Replication Tasks ...", task_handles.len()); + debug_assert!(task_handles.len() >= task_tx.receiver_count()); + let _ = task_tx.send(()); + for task_handle in task_handles.drain(..) { + // Let each task join. + let res: Result<(), _> = task_handle.await; + if res.is_err() { + warn!("Failed to join replication task, continuing ..."); + } + } + + info!("Stopped Replication Acceptor"); +} diff --git a/server/daemon/run_insecure_dev_server.sh b/server/daemon/run_insecure_dev_server.sh index 74bd3f790..ed04ae2ee 100755 --- a/server/daemon/run_insecure_dev_server.sh +++ b/server/daemon/run_insecure_dev_server.sh @@ -18,7 +18,7 @@ if [ ! -d "${KANI_TMP}" ]; then mkdir -p "${KANI_TMP}" fi -CONFIG_FILE="../../examples/insecure_server.toml" +CONFIG_FILE=${CONFIG_FILE:="../../examples/insecure_server.toml"} if [ ! -f "${CONFIG_FILE}" ]; then SCRIPT_DIR="$(dirname -a "$0")" diff --git a/server/daemon/src/main.rs b/server/daemon/src/main.rs index e33cf6ccd..d65df5922 100644 --- a/server/daemon/src/main.rs +++ b/server/daemon/src/main.rs @@ -77,7 +77,9 @@ impl KanidmdOpt { | KanidmdOpt::DbScan { commands: DbScanOpt::RestoreQuarantined { commonopts, .. }, } - | KanidmdOpt::RecoverAccount { commonopts, .. } => commonopts, + | KanidmdOpt::ShowReplicationCertificate { commonopts } + | KanidmdOpt::RenewReplicationCertificate { commonopts } => commonopts, + KanidmdOpt::RecoverAccount { commonopts, .. } => commonopts, KanidmdOpt::DbScan { commands: DbScanOpt::ListIndex(dopt), } => &dopt.commonopts, @@ -145,6 +147,14 @@ async fn submit_admin_req(path: &str, req: AdminTaskRequest, output_mode: Consol info!(new_password = ?password) } }, + Some(Ok(AdminTaskResponse::ShowReplicationCertificate { cert })) => match output_mode { + ConsoleOutputMode::JSON => { + eprintln!("{{\"certificate\":\"{}\"}}", cert) + } + ConsoleOutputMode::Text => { + info!(certificate = ?cert) + } + }, _ => { error!("Error making request to admin socket"); } @@ -258,6 +268,13 @@ async fn main() -> ExitCode { } }; + // Stop early if replication was found + if sconfig.repl_config.is_some() && + !sconfig.i_acknowledge_that_replication_is_in_development + { + error!("Unable to proceed. Replication should not be configured manually."); + return ExitCode::FAILURE + } #[cfg(target_family = "unix")] { @@ -341,6 +358,11 @@ async fn main() -> ExitCode { config.update_output_mode(opt.commands.commonopt().output_mode.to_owned().into()); config.update_trust_x_forward_for(sconfig.trust_x_forward_for); config.update_admin_bind_path(&sconfig.adminbindpath); + + config.update_replication_config( + sconfig.repl_config.clone() + ); + match &opt.commands { // we aren't going to touch the DB so we can carry on KanidmdOpt::HealthCheck(_) => (), @@ -531,6 +553,26 @@ async fn main() -> ExitCode { info!("Running in db verification mode ..."); verify_server_core(&config).await; } + KanidmdOpt::ShowReplicationCertificate { + commonopts + } => { + info!("Running show replication certificate ..."); + let output_mode: ConsoleOutputMode = commonopts.output_mode.to_owned().into(); + submit_admin_req(config.adminbindpath.as_str(), + AdminTaskRequest::ShowReplicationCertificate, + output_mode, + ).await; + } + KanidmdOpt::RenewReplicationCertificate { + commonopts + } => { + info!("Running renew replication certificate ..."); + let output_mode: ConsoleOutputMode = commonopts.output_mode.to_owned().into(); + submit_admin_req(config.adminbindpath.as_str(), + AdminTaskRequest::RenewReplicationCertificate, + output_mode, + ).await; + } KanidmdOpt::RecoverAccount { name, commonopts } => { diff --git a/server/daemon/src/opt.rs b/server/daemon/src/opt.rs index 70f16ccab..723fabac6 100644 --- a/server/daemon/src/opt.rs +++ b/server/daemon/src/opt.rs @@ -154,6 +154,16 @@ enum KanidmdOpt { #[clap(flatten)] commonopts: CommonOpt, }, + /// Display this server's replication certificate + ShowReplicationCertificate { + #[clap(flatten)] + commonopts: CommonOpt, + }, + /// Renew this server's replication certificate + RenewReplicationCertificate { + #[clap(flatten)] + commonopts: CommonOpt, + }, // #[clap(name = "reset_server_id")] // ResetServerId(CommonOpt), #[clap(name = "db-scan")] diff --git a/server/lib/src/be/dbentry.rs b/server/lib/src/be/dbentry.rs index 557f03f02..3a9606e0e 100644 --- a/server/lib/src/be/dbentry.rs +++ b/server/lib/src/be/dbentry.rs @@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize}; use smartstring::alias::String as AttrString; use uuid::Uuid; +use super::keystorage::{KeyHandle, KeyHandleId}; use crate::be::dbvalue::{DbValueEmailAddressV1, DbValuePhoneNumberV1, DbValueSetV2, DbValueV1}; use crate::prelude::entries::Attribute; use crate::prelude::OperationError; @@ -57,6 +58,13 @@ pub enum DbBackup { db_ts_max: Duration, entries: Vec, }, + V3 { + db_s_uuid: Uuid, + db_d_uuid: Uuid, + db_ts_max: Duration, + keyhandles: BTreeMap, + entries: Vec, + }, } fn from_vec_dbval1(attr_val: NonEmpty) -> Result { diff --git a/server/lib/src/be/idl_arc_sqlite.rs b/server/lib/src/be/idl_arc_sqlite.rs index 94172d35a..58d677344 100644 --- a/server/lib/src/be/idl_arc_sqlite.rs +++ b/server/lib/src/be/idl_arc_sqlite.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::collections::BTreeSet; use std::convert::TryInto; use std::ops::DerefMut; @@ -19,6 +20,7 @@ use crate::be::idl_sqlite::{ use crate::be::idxkey::{ IdlCacheKey, IdlCacheKeyRef, IdlCacheKeyToRef, IdxKey, IdxKeyRef, IdxKeyToRef, IdxSlope, }; +use crate::be::keystorage::{KeyHandle, KeyHandleId}; use crate::be::{BackendConfig, IdList, IdRawEntry}; use crate::entry::{Entry, EntryCommitted, EntrySealed}; use crate::prelude::*; @@ -56,6 +58,7 @@ pub struct IdlArcSqlite { op_ts_max: CowCell>, allids: CowCell, maxid: CowCell, + keyhandles: CowCell>, } pub struct IdlArcSqliteReadTransaction<'a> { @@ -67,13 +70,14 @@ pub struct IdlArcSqliteReadTransaction<'a> { } pub struct IdlArcSqliteWriteTransaction<'a> { - db: IdlSqliteWriteTransaction, + pub(super) db: IdlSqliteWriteTransaction, entry_cache: ARCacheWriteTxn<'a, u64, Arc, ()>, idl_cache: ARCacheWriteTxn<'a, IdlCacheKey, Box, ()>, name_cache: ARCacheWriteTxn<'a, NameCacheKey, NameCacheValue, ()>, op_ts_max: CowCellWriteTxn<'a, Option>, allids: CowCellWriteTxn<'a, IDLBitRange>, maxid: CowCellWriteTxn<'a, u64>, + pub(super) keyhandles: CowCellWriteTxn<'a, HashMap>, } macro_rules! get_identry { @@ -353,6 +357,8 @@ pub trait IdlArcSqliteTransaction { fn get_db_ts_max(&self) -> Result, OperationError>; + fn get_key_handles(&mut self) -> Result, OperationError>; + fn verify(&self) -> Vec>; fn is_dirty(&self) -> bool; @@ -417,6 +423,10 @@ impl<'a> IdlArcSqliteTransaction for IdlArcSqliteReadTransaction<'a> { self.db.get_db_ts_max() } + fn get_key_handles(&mut self) -> Result, OperationError> { + self.db.get_key_handles() + } + fn verify(&self) -> Vec> { verify!(self) } @@ -511,6 +521,10 @@ impl<'a> IdlArcSqliteTransaction for IdlArcSqliteWriteTransaction<'a> { } } + fn get_key_handles(&mut self) -> Result, OperationError> { + self.db.get_key_handles() + } + fn verify(&self) -> Vec> { verify!(self) } @@ -593,6 +607,7 @@ impl<'a> IdlArcSqliteWriteTransaction<'a> { op_ts_max, allids, maxid, + keyhandles, } = self; // Write any dirty items to the disk. @@ -656,15 +671,20 @@ impl<'a> IdlArcSqliteWriteTransaction<'a> { e })?; - // Undo the caches in the reverse order. - db.commit().map(|()| { - op_ts_max.commit(); - name_cache.commit(); - idl_cache.commit(); - entry_cache.commit(); - allids.commit(); - maxid.commit(); - }) + // Ensure the db commit succeeds first. + db.commit()?; + + // Can no longer fail from this point. + op_ts_max.commit(); + name_cache.commit(); + idl_cache.commit(); + allids.commit(); + maxid.commit(); + keyhandles.commit(); + // Unlock the entry cache last to remove contention on everything else. + entry_cache.commit(); + + Ok(()) } pub fn get_id2entry_max_id(&self) -> Result { @@ -1238,6 +1258,8 @@ impl IdlArcSqlite { let maxid = CowCell::new(0); + let keyhandles = CowCell::new(HashMap::default()); + let op_ts_max = CowCell::new(None); Ok(IdlArcSqlite { @@ -1248,6 +1270,7 @@ impl IdlArcSqlite { op_ts_max, allids, maxid, + keyhandles, }) } @@ -1283,6 +1306,7 @@ impl IdlArcSqlite { let allids_write = self.allids.write(); let maxid_write = self.maxid.write(); let db_write = self.db.write(); + let keyhandles_write = self.keyhandles.write(); IdlArcSqliteWriteTransaction { db: db_write, entry_cache: entry_cache_write, @@ -1291,6 +1315,7 @@ impl IdlArcSqlite { op_ts_max: op_ts_max_write, allids: allids_write, maxid: maxid_write, + keyhandles: keyhandles_write, } } diff --git a/server/lib/src/be/idl_sqlite.rs b/server/lib/src/be/idl_sqlite.rs index b12f55a2d..21f98a3cf 100644 --- a/server/lib/src/be/idl_sqlite.rs +++ b/server/lib/src/be/idl_sqlite.rs @@ -1,9 +1,12 @@ +use std::collections::BTreeMap; use std::collections::VecDeque; use std::convert::{TryFrom, TryInto}; use std::sync::Arc; use std::sync::Mutex; use std::time::Duration; +use super::keystorage::{KeyHandle, KeyHandleId}; + // use crate::valueset; use hashbrown::HashMap; use idlset::v2::IDLBitRange; @@ -24,13 +27,13 @@ const DBV_ID2ENTRY: &str = "id2entry"; const DBV_INDEXV: &str = "indexv"; #[allow(clippy::needless_pass_by_value)] // needs to accept value from `map_err` -fn sqlite_error(e: rusqlite::Error) -> OperationError { +pub(super) fn sqlite_error(e: rusqlite::Error) -> OperationError { admin_error!(?e, "SQLite Error"); OperationError::SqliteError } #[allow(clippy::needless_pass_by_value)] // needs to accept value from `map_err` -fn serde_json_error(e: serde_json::Error) -> OperationError { +pub(super) fn serde_json_error(e: serde_json::Error) -> OperationError { admin_error!(?e, "Serde JSON Error"); OperationError::SerdeJsonError } @@ -482,6 +485,29 @@ pub trait IdlSqliteTransaction { }) } + fn get_key_handles(&mut self) -> Result, OperationError> { + let mut stmt = self + .get_conn()? + .prepare(&format!( + "SELECT id, data FROM {}.keyhandles", + self.get_db_name() + )) + .map_err(sqlite_error)?; + + let kh_iter = stmt + .query_map([], |row| Ok((row.get(0)?, row.get(1)?))) + .map_err(sqlite_error)?; + + kh_iter + .map(|v| { + let (id, data): (Vec, Vec) = v.map_err(sqlite_error)?; + let id = serde_json::from_slice(id.as_slice()).map_err(serde_json_error)?; + let data = serde_json::from_slice(data.as_slice()).map_err(serde_json_error)?; + Ok((id, data)) + }) + .collect() + } + #[instrument(level = "debug", name = "idl_sqlite::get_allids", skip_all)] fn get_allids(&self) -> Result { let mut stmt = self @@ -1079,6 +1105,19 @@ impl IdlSqliteWriteTransaction { } } + pub(crate) fn create_keyhandles(&self) -> Result<(), OperationError> { + self.get_conn()? + .execute( + &format!( + "CREATE TABLE IF NOT EXISTS {}.keyhandles (id TEXT PRIMARY KEY, data TEXT)", + self.get_db_name() + ), + [], + ) + .map(|_| ()) + .map_err(sqlite_error) + } + pub fn create_idx(&self, attr: Attribute, itype: IndexType) -> Result<(), OperationError> { // Is there a better way than formatting this? I can't seem // to template into the str. @@ -1565,7 +1604,7 @@ impl IdlSqliteWriteTransaction { dbv_id2entry = 6; info!(entry = %dbv_id2entry, "dbv_id2entry migrated (externalid2uuid)"); } - // * if v6 -> complete. + // * if v6 -> create id2entry_quarantine. if dbv_id2entry == 6 { self.get_conn()? .execute( @@ -1584,7 +1623,13 @@ impl IdlSqliteWriteTransaction { dbv_id2entry = 7; info!(entry = %dbv_id2entry, "dbv_id2entry migrated (quarantine)"); } - // * if v7 -> complete. + // * if v7 -> create keyhandles storage. + if dbv_id2entry == 7 { + self.create_keyhandles()?; + dbv_id2entry = 8; + info!(entry = %dbv_id2entry, "dbv_id2entry migrated (keyhandles)"); + } + // * if v8 -> complete self.set_db_version_key(DBV_ID2ENTRY, dbv_id2entry)?; diff --git a/server/lib/src/be/keystorage.rs b/server/lib/src/be/keystorage.rs new file mode 100644 index 000000000..54dee2206 --- /dev/null +++ b/server/lib/src/be/keystorage.rs @@ -0,0 +1,191 @@ +use crate::rusqlite::OptionalExtension; +use kanidm_lib_crypto::prelude::{PKey, Private, X509}; +use kanidm_lib_crypto::serialise::{pkeyb64, x509b64}; +use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; +use std::hash::Hash; + +use super::idl_arc_sqlite::IdlArcSqliteWriteTransaction; +use super::idl_sqlite::IdlSqliteTransaction; +use super::idl_sqlite::IdlSqliteWriteTransaction; +use super::idl_sqlite::{serde_json_error, sqlite_error}; +use super::BackendWriteTransaction; +use crate::prelude::OperationError; + +/// These are key handles for storing keys related to various cryptographic components +/// within Kanidm. Generally these are for keys that are "static", as in have known +/// long term uses. This could be the servers private replication key, a TPM Storage +/// Root Key, or the Duplicable Storage Key. In future these may end up being in +/// a HSM or similar, but we'll always need a way to persist serialised forms of these. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub enum KeyHandleId { + ReplicationKey, +} + +/// This is a key handle that contains the actual data that is persisted in the DB. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum KeyHandle { + X509Key { + #[serde(with = "pkeyb64")] + private: PKey, + #[serde(with = "x509b64")] + x509: X509, + }, +} + +impl<'a> BackendWriteTransaction<'a> { + /// Retrieve a key stored in the database by it's key handle. This + /// handle may require further processing for the key to be usable + /// in higher level contexts as this is simply the storage layer + /// for these keys. + pub(crate) fn get_key_handle( + &mut self, + handle: KeyHandleId, + ) -> Result, OperationError> { + self.idlayer.get_key_handle(handle) + } + + /// Update the content of a keyhandle with this new data. + pub(crate) fn set_key_handle( + &mut self, + handle: KeyHandleId, + data: KeyHandle, + ) -> Result<(), OperationError> { + self.idlayer.set_key_handle(handle, data) + } +} + +impl<'a> IdlArcSqliteWriteTransaction<'a> { + pub(crate) fn get_key_handle( + &mut self, + handle: KeyHandleId, + ) -> Result, OperationError> { + if let Some(kh) = self.keyhandles.get(&handle) { + Ok(Some(kh.clone())) + } else { + let r = self.db.get_key_handle(handle); + + if let Ok(Some(kh)) = &r { + self.keyhandles.insert(handle, kh.clone()); + } + + r + } + } + + /// Update the content of a keyhandle with this new data. + #[instrument(level = "debug", skip(self, data))] + pub(crate) fn set_key_handle( + &mut self, + handle: KeyHandleId, + data: KeyHandle, + ) -> Result<(), OperationError> { + self.db.set_key_handle(handle, &data)?; + self.keyhandles.insert(handle, data); + Ok(()) + } + + pub(super) fn set_key_handles( + &mut self, + keyhandles: BTreeMap, + ) -> Result<(), OperationError> { + self.db.set_key_handles(&keyhandles)?; + self.keyhandles.clear(); + self.keyhandles.extend(keyhandles.into_iter()); + Ok(()) + } +} + +impl IdlSqliteWriteTransaction { + pub(crate) fn get_key_handle( + &mut self, + handle: KeyHandleId, + ) -> Result, OperationError> { + let s_handle = serde_json::to_vec(&handle).map_err(serde_json_error)?; + + let mut stmt = self + .get_conn()? + .prepare(&format!( + "SELECT data FROM {}.keyhandles WHERE id = :id", + self.get_db_name() + )) + .map_err(sqlite_error)?; + let data_raw: Option> = stmt + .query_row(&[(":id", &s_handle)], |row| row.get(0)) + // We don't mind if it doesn't exist + .optional() + .map_err(sqlite_error)?; + + let data: Option = match data_raw { + Some(d) => serde_json::from_slice(d.as_slice()) + .map(Some) + .map_err(serde_json_error)?, + None => None, + }; + + Ok(data) + } + + pub(super) fn get_key_handles( + &mut self, + ) -> Result, OperationError> { + let mut stmt = self + .get_conn()? + .prepare(&format!( + "SELECT id, data FROM {}.keyhandles", + self.get_db_name() + )) + .map_err(sqlite_error)?; + + let kh_iter = stmt + .query_map([], |row| Ok((row.get(0)?, row.get(1)?))) + .map_err(sqlite_error)?; + + kh_iter + .map(|v| { + let (id, data): (Vec, Vec) = v.map_err(sqlite_error)?; + let id = serde_json::from_slice(id.as_slice()).map_err(serde_json_error)?; + let data = serde_json::from_slice(data.as_slice()).map_err(serde_json_error)?; + Ok((id, data)) + }) + .collect() + } + + /// Update the content of a keyhandle with this new data. + #[instrument(level = "debug", skip(self, data))] + pub(crate) fn set_key_handle( + &mut self, + handle: KeyHandleId, + data: &KeyHandle, + ) -> Result<(), OperationError> { + let s_handle = serde_json::to_vec(&handle).map_err(serde_json_error)?; + let s_data = serde_json::to_vec(&data).map_err(serde_json_error)?; + + self.get_conn()? + .prepare(&format!( + "INSERT OR REPLACE INTO {}.keyhandles (id, data) VALUES(:id, :data)", + self.get_db_name() + )) + .and_then(|mut stmt| stmt.execute(&[(":id", &s_handle), (":data", &s_data)])) + .map(|_| ()) + .map_err(sqlite_error) + } + + pub(super) fn set_key_handles( + &mut self, + keyhandles: &BTreeMap, + ) -> Result<(), OperationError> { + self.get_conn()? + .execute( + &format!("DELETE FROM {}.keyhandles", self.get_db_name()), + [], + ) + .map(|_| ()) + .map_err(sqlite_error)?; + + for (handle, data) in keyhandles { + self.set_key_handle(*handle, data)?; + } + Ok(()) + } +} diff --git a/server/lib/src/be/mod.rs b/server/lib/src/be/mod.rs index 00500bb99..85ba4f869 100644 --- a/server/lib/src/be/mod.rs +++ b/server/lib/src/be/mod.rs @@ -36,6 +36,7 @@ pub mod dbvalue; mod idl_arc_sqlite; mod idl_sqlite; pub(crate) mod idxkey; +pub(crate) mod keystorage; pub(crate) use self::idxkey::{IdxKey, IdxKeyRef, IdxKeyToRef, IdxSlope}; use crate::be::idl_arc_sqlite::{ @@ -872,10 +873,13 @@ pub trait BackendTransaction { .get_db_ts_max() .and_then(|u| u.ok_or(OperationError::InvalidDbState))?; - let bak = DbBackup::V2 { + let keyhandles = idlayer.get_key_handles()?; + + let bak = DbBackup::V3 { db_s_uuid, db_d_uuid, db_ts_max, + keyhandles, entries, }; @@ -1724,6 +1728,20 @@ impl<'a> BackendWriteTransaction<'a> { idlayer.set_db_ts_max(db_ts_max)?; entries } + DbBackup::V3 { + db_s_uuid, + db_d_uuid, + db_ts_max, + keyhandles, + entries, + } => { + // Do stuff. + idlayer.write_db_s_uuid(db_s_uuid)?; + idlayer.write_db_d_uuid(db_d_uuid)?; + idlayer.set_db_ts_max(db_ts_max)?; + idlayer.set_key_handles(keyhandles)?; + entries + } }; info!("Restoring {} entries ...", dbentries.len()); @@ -2505,6 +2523,15 @@ mod tests { } => { let _ = entries.pop(); } + DbBackup::V3 { + db_s_uuid: _, + db_d_uuid: _, + db_ts_max: _, + keyhandles: _, + entries, + } => { + let _ = entries.pop(); + } }; let serialized_entries_str = serde_json::to_string_pretty(&dbbak).unwrap(); diff --git a/server/lib/src/lib.rs b/server/lib/src/lib.rs index 9edfbe5de..074d7a5f7 100644 --- a/server/lib/src/lib.rs +++ b/server/lib/src/lib.rs @@ -53,7 +53,7 @@ pub mod valueset; #[macro_use] mod plugins; pub mod idm; -mod repl; +pub mod repl; pub mod schema; pub mod server; pub mod status; diff --git a/server/lib/src/repl/consumer.rs b/server/lib/src/repl/consumer.rs index e421d15b7..4f3f6bb4e 100644 --- a/server/lib/src/repl/consumer.rs +++ b/server/lib/src/repl/consumer.rs @@ -4,11 +4,6 @@ use crate::prelude::*; use std::collections::{BTreeMap, BTreeSet}; use std::sync::Arc; -pub enum ConsumerState { - Ok, - RefreshRequired, -} - impl<'a> QueryServerWriteTransaction<'a> { // Apply the state changes if they are valid. @@ -255,6 +250,11 @@ impl<'a> QueryServerWriteTransaction<'a> { ctx: &ReplIncrementalContext, ) -> Result { match ctx { + ReplIncrementalContext::DomainMismatch => { + error!("Unable to proceed with consumer incremental - the supplier has indicated that our domain_uuid's are not equivalent. This can occur when adding a new consumer to an existing topology."); + error!("This server's content must be refreshed to proceed. If you have configured automatic refresh, this will occur shortly."); + Ok(ConsumerState::RefreshRequired) + } ReplIncrementalContext::NoChangesAvailable => { info!("no changes are available"); Ok(ConsumerState::Ok) diff --git a/server/lib/src/repl/mod.rs b/server/lib/src/repl/mod.rs index 31e56acb7..4581aaf50 100644 --- a/server/lib/src/repl/mod.rs +++ b/server/lib/src/repl/mod.rs @@ -1,10 +1,10 @@ -pub mod cid; -pub mod entry; -pub mod ruv; +pub(crate) mod cid; +pub(crate) mod entry; +pub(crate) mod ruv; -pub mod consumer; +pub(crate) mod consumer; pub mod proto; -pub mod supplier; +pub(crate) mod supplier; #[cfg(test)] mod tests; diff --git a/server/lib/src/repl/proto.rs b/server/lib/src/repl/proto.rs index afde4ed70..6170746c1 100644 --- a/server/lib/src/repl/proto.rs +++ b/server/lib/src/repl/proto.rs @@ -16,6 +16,11 @@ use webauthn_rs::prelude::{ // Re-export this for our own usage. pub use kanidm_lib_crypto::ReplPasswordV1; +pub enum ConsumerState { + Ok, + RefreshRequired, +} + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct ReplCidV1 { #[serde(rename = "t")] @@ -63,22 +68,15 @@ pub struct ReplCidRange { #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub enum ReplRuvRange { V1 { + domain_uuid: Uuid, ranges: BTreeMap, }, } -impl Default for ReplRuvRange { - fn default() -> Self { - ReplRuvRange::V1 { - ranges: BTreeMap::default(), - } - } -} - impl ReplRuvRange { pub fn is_empty(&self) -> bool { match self { - ReplRuvRange::V1 { ranges } => ranges.is_empty(), + ReplRuvRange::V1 { ranges, .. } => ranges.is_empty(), } } } @@ -689,6 +687,7 @@ pub enum ReplRefreshContext { #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] #[serde(rename_all = "lowercase")] pub enum ReplIncrementalContext { + DomainMismatch, NoChangesAvailable, RefreshRequired, UnwillingToSupply, diff --git a/server/lib/src/repl/supplier.rs b/server/lib/src/repl/supplier.rs index f4ae5784a..a9d939be8 100644 --- a/server/lib/src/repl/supplier.rs +++ b/server/lib/src/repl/supplier.rs @@ -5,6 +5,74 @@ use super::ruv::{RangeDiffStatus, ReplicationUpdateVector, ReplicationUpdateVect use crate::be::BackendTransaction; use crate::prelude::*; +use crate::be::keystorage::{KeyHandle, KeyHandleId}; +use kanidm_lib_crypto::mtls::build_self_signed_server_and_client_identity; +use kanidm_lib_crypto::prelude::{PKey, Private, X509}; + +impl<'a> QueryServerWriteTransaction<'a> { + fn supplier_generate_key_cert(&mut self) -> Result<(PKey, X509), OperationError> { + // Invalid, must need to re-generate. + let domain_name = "localhost"; + let expiration_days = 1; + let s_uuid = self.get_server_uuid(); + + let (private, x509) = + build_self_signed_server_and_client_identity(s_uuid, domain_name, expiration_days) + .map_err(|err| { + error!(?err, "Unable to generate self signed key/cert"); + // What error? + OperationError::CryptographyError + })?; + + let kh = KeyHandle::X509Key { + private: private.clone(), + x509: x509.clone(), + }; + + self.get_be_txn() + .set_key_handle(KeyHandleId::ReplicationKey, kh) + .map_err(|err| { + error!(?err, "Unable to persist replication key"); + err + }) + .map(|()| (private, x509)) + } + + #[instrument(level = "info", skip_all)] + pub fn supplier_renew_key_cert(&mut self) -> Result<(), OperationError> { + self.supplier_generate_key_cert().map(|_| ()) + } + + #[instrument(level = "info", skip_all)] + pub fn supplier_get_key_cert(&mut self) -> Result<(PKey, X509), OperationError> { + // Later we need to put this through a HSM or similar, but we will always need a way + // to persist a handle, so we still need the db write and load components. + + // Does the handle exist? + let maybe_key_handle = self + .get_be_txn() + .get_key_handle(KeyHandleId::ReplicationKey) + .map_err(|err| { + error!(?err, "Unable to access replication key"); + err + })?; + + // Can you process the keyhande? + let key_cert = match maybe_key_handle { + Some(KeyHandle::X509Key { private, x509 }) => (private, x509), + /* + Some(Keyhandle::...) => { + // invalid key + // error? regenerate? + } + */ + None => self.supplier_generate_key_cert()?, + }; + + Ok(key_cert) + } +} + impl<'a> QueryServerReadTransaction<'a> { // Given a consumers state, calculate the differential of changes they // need to be sent to bring them to the equivalent state. @@ -19,10 +87,19 @@ impl<'a> QueryServerReadTransaction<'a> { ctx_ruv: ReplRuvRange, ) -> Result { // Convert types if needed. This way we can compare ruv's correctly. - let ctx_ranges = match ctx_ruv { - ReplRuvRange::V1 { ranges } => ranges, + let (ctx_domain_uuid, ctx_ranges) = match ctx_ruv { + ReplRuvRange::V1 { + domain_uuid, + ranges, + } => (domain_uuid, ranges), }; + if ctx_domain_uuid != self.d_info.d_uuid { + error!("Replication - Consumer Domain UUID does not match our local domain uuid."); + debug!(consumer_domain_uuid = ?ctx_domain_uuid, supplier_domain_uuid = ?self.d_info.d_uuid); + return Ok(ReplIncrementalContext::DomainMismatch); + } + let our_ranges = self .get_be_txn() .get_ruv() diff --git a/server/lib/src/repl/tests.rs b/server/lib/src/repl/tests.rs index d5d6d4e1c..5a54ebb59 100644 --- a/server/lib/src/repl/tests.rs +++ b/server/lib/src/repl/tests.rs @@ -1,7 +1,7 @@ use crate::be::BackendTransaction; use crate::prelude::*; -use crate::repl::consumer::ConsumerState; use crate::repl::entry::State; +use crate::repl::proto::ConsumerState; use crate::repl::proto::ReplIncrementalContext; use crate::repl::ruv::ReplicationUpdateVectorTransaction; use crate::repl::ruv::{RangeDiffStatus, ReplicationUpdateVector}; @@ -2941,6 +2941,46 @@ async fn test_repl_increment_attrunique_conflict_complex( drop(server_a_txn); } +// Test the behaviour of a "new server join". This will have the supplier and +// consumer mismatch on the domain_uuid, leading to the consumer with a +// refresh required message. This should then be refreshed and succeed + +#[qs_pair_test] +async fn test_repl_initial_consumer_join(server_a: &QueryServer, server_b: &QueryServer) { + let ct = duration_from_epoch_now(); + + let mut server_a_txn = server_a.write(ct).await; + let mut server_b_txn = server_b.read().await; + + let a_ruv_range = server_a_txn + .consumer_get_state() + .expect("Unable to access RUV range"); + + let changes = server_b_txn + .supplier_provide_changes(a_ruv_range) + .expect("Unable to generate supplier changes"); + + assert!(matches!(changes, ReplIncrementalContext::DomainMismatch)); + + let result = server_a_txn + .consumer_apply_changes(&changes) + .expect("Unable to apply changes to consumer."); + + assert!(matches!(result, ConsumerState::RefreshRequired)); + + drop(server_a_txn); + drop(server_b_txn); + + // Then a refresh resolves. + let mut server_a_txn = server_a.write(ct).await; + let mut server_b_txn = server_b.read().await; + + assert!(repl_initialise(&mut server_b_txn, &mut server_a_txn) + .and_then(|_| server_a_txn.commit()) + .is_ok()); + drop(server_b_txn); +} + // Test change of domain version over incremental. // // todo when I have domain version migrations working. diff --git a/server/lib/src/server/mod.rs b/server/lib/src/server/mod.rs index 2b75e7eca..c1dbadacd 100644 --- a/server/lib/src/server/mod.rs +++ b/server/lib/src/server/mod.rs @@ -896,6 +896,8 @@ pub trait QueryServerTransaction<'a> { // // ... + let domain_uuid = self.get_domain_uuid(); + // Which then the supplier will use to actually retrieve the set of entries. // and the needed attributes we need. let ruv_snapshot = self.get_be_txn().get_ruv(); @@ -903,7 +905,10 @@ pub trait QueryServerTransaction<'a> { // What's the current set of ranges? ruv_snapshot .current_ruv_range() - .map(|ranges| ReplRuvRange::V1 { ranges }) + .map(|ranges| ReplRuvRange::V1 { + domain_uuid, + ranges, + }) } } @@ -1247,6 +1252,11 @@ impl QueryServer { } impl<'a> QueryServerWriteTransaction<'a> { + pub(crate) fn get_server_uuid(&self) -> Uuid { + // Cid has our server id within + self.cid.s_uuid + } + pub(crate) fn get_curtime(&self) -> Duration { self.curtime } diff --git a/server/testkit/Cargo.toml b/server/testkit/Cargo.toml index f1b4a20d8..3308749c4 100644 --- a/server/testkit/Cargo.toml +++ b/server/testkit/Cargo.toml @@ -50,3 +50,7 @@ webauthn-authenticator-rs = { workspace = true } oauth2_ext = { workspace = true, default-features = false } futures = { workspace = true } time = { workspace = true } +openssl = { workspace = true } +tokio-openssl = { workspace = true } +kanidm_lib_crypto = { workspace = true } +uuid = { workspace = true } diff --git a/server/testkit/src/lib.rs b/server/testkit/src/lib.rs index c56824109..699145840 100644 --- a/server/testkit/src/lib.rs +++ b/server/testkit/src/lib.rs @@ -54,7 +54,7 @@ pub async fn setup_async_test(mut config: Configuration) -> (KanidmClient, CoreH counter += 1; #[allow(clippy::panic)] if counter >= 5 { - eprintln!("Unable to allocate port!"); + tracing::error!("Unable to allocate port!"); panic!(); } }; diff --git a/server/testkit/tests/mtls_test.rs b/server/testkit/tests/mtls_test.rs new file mode 100644 index 000000000..bff41071f --- /dev/null +++ b/server/testkit/tests/mtls_test.rs @@ -0,0 +1,277 @@ +use kanidm_lib_crypto::mtls::build_self_signed_server_and_client_identity; + +use kanidmd_testkit::{is_free_port, PORT_ALLOC}; +use std::sync::atomic::Ordering; + +use std::net::{IpAddr, Ipv6Addr, SocketAddr}; +use std::pin::Pin; +use std::time::Duration; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::oneshot; +use tokio::time; +use tracing::{error, trace}; + +use openssl::ssl::{Ssl, SslAcceptor, SslMethod, SslRef}; +use openssl::ssl::{SslConnector, SslVerifyMode}; +use openssl::x509::X509; +use tokio_openssl::SslStream; + +use uuid::Uuid; + +fn keylog_cb(_ssl_ref: &SslRef, key: &str) { + trace!(?key); +} + +async fn setup_mtls_test( + testcase: TestCase, +) -> ( + SslStream, + tokio::task::JoinHandle>, + oneshot::Sender<()>, +) { + sketching::test_init(); + + let mut counter = 0; + let port = loop { + let possible_port = PORT_ALLOC.fetch_add(1, Ordering::SeqCst); + if is_free_port(possible_port) { + break possible_port; + } + counter += 1; + #[allow(clippy::panic)] + if counter >= 5 { + error!("Unable to allocate port!"); + panic!(); + } + }; + + trace!("{:?}", port); + + // First we need the two certificates. + let client_uuid = Uuid::new_v4(); + let (client_key, client_cert) = + build_self_signed_server_and_client_identity(client_uuid, "localhost", 1).unwrap(); + + let server_san = if testcase == TestCase::ServerCertSanInvalid { + "evilcorp.com" + } else { + "localhost" + }; + let server_uuid = Uuid::new_v4(); + let (server_key, server_cert): (_, X509) = + build_self_signed_server_and_client_identity(server_uuid, server_san, 1).unwrap(); + let server_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), port); + + let listener = TcpListener::bind(&server_addr) + .await + .expect("Failed to bind"); + + // Setup the TLS parameters. + let (tx, mut rx) = oneshot::channel(); + + let mut ssl_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls()).unwrap(); + ssl_builder.set_keylog_callback(keylog_cb); + + ssl_builder.set_certificate(&server_cert).unwrap(); + ssl_builder.set_private_key(&server_key).unwrap(); + ssl_builder.check_private_key().unwrap(); + + if testcase != TestCase::ServerWithoutClientCa { + let cert_store = ssl_builder.cert_store_mut(); + cert_store.add_cert(client_cert.clone()).unwrap(); + } + // Request a client cert. + let mut verify = SslVerifyMode::PEER; + verify.insert(SslVerifyMode::FAIL_IF_NO_PEER_CERT); + ssl_builder.set_verify(verify); + // Setup the client cert store. + + let tls_parms = ssl_builder.build(); + + // Start the server in a task. + // The server is designed to stop/die as soon as a single connection has been made. + let handle = tokio::spawn(async move { + // This is our safety net. + let sleep = time::sleep(Duration::from_secs(15)); + tokio::pin!(sleep); + + trace!("Started listener"); + tokio::select! { + Ok((tcpstream, client_socket_addr)) = listener.accept() => { + let mut tlsstream = match Ssl::new(tls_parms.context()) + .and_then(|tls_obj| SslStream::new(tls_obj, tcpstream)) + { + Ok(ta) => ta, + Err(err) => { + error!("LDAP TLS setup error, continuing -> {:?}", err); + let ossl_err = err.errors().last().unwrap(); + + return Err( + ossl_err.code() + ); + } + }; + + if let Err(err) = SslStream::accept(Pin::new(&mut tlsstream)).await { + error!("LDAP TLS accept error, continuing -> {:?}", err); + + let ossl_err = err.ssl_error().and_then(|e| e.errors().last()).unwrap(); + + return Err( + ossl_err.code() + ); + }; + + trace!("Got connection. {:?}", client_socket_addr); + + let tlsstream_ref = tlsstream.ssl(); + + match tlsstream_ref.peer_certificate() { + Some(peer_cert) => { + trace!("{:?}", peer_cert.subject_name()); + } + None => { + return Err(2); + } + } + + Ok(()) + } + Ok(()) = &mut rx => { + trace!("stopping listener"); + Err(1) + } + _ = &mut sleep => { + error!("timeout"); + Err(1) + } + else => { + trace!("error condition in accept"); + Err(1) + } + } + }); + + // Create the client and connect. We do this inline to be sensitive to errors. + let tcpclient = TcpStream::connect(server_addr).await.unwrap(); + trace!("connection established"); + + let mut ssl_builder = SslConnector::builder(SslMethod::tls_client()).unwrap(); + if testcase != TestCase::ClientWithoutClientCert { + ssl_builder.set_certificate(&client_cert).unwrap(); + ssl_builder.set_private_key(&client_key).unwrap(); + ssl_builder.check_private_key().unwrap(); + } + // Add the server cert + if testcase != TestCase::ClientWithoutServerCa { + let cert_store = ssl_builder.cert_store_mut(); + cert_store.add_cert(server_cert).unwrap(); + } + + let verify_param = ssl_builder.verify_param_mut(); + verify_param.set_host("localhost").unwrap(); + + ssl_builder.set_verify(SslVerifyMode::PEER); + let tls_parms = ssl_builder.build(); + let tlsstream = Ssl::new(tls_parms.context()) + .and_then(|tls_obj| SslStream::new(tls_obj, tcpclient)) + .unwrap(); + + (tlsstream, handle, tx) +} + +#[derive(PartialEq, Eq, Debug)] +enum TestCase { + Valid, + ServerCertSanInvalid, + ServerWithoutClientCa, + ClientWithoutClientCert, + ClientWithoutServerCa, +} + +#[tokio::test] +async fn test_mtls_basic_auth() { + let (mut tlsstream, handle, _tx) = setup_mtls_test(TestCase::Valid).await; + + SslStream::connect(Pin::new(&mut tlsstream)).await.unwrap(); + + trace!("Waiting on listener ..."); + let result = handle.await.expect("Failed to stop task."); + + // If this isn't true, it means something failed in the server accept process. + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_mtls_server_san_invalid() { + let (mut tlsstream, handle, _tx) = setup_mtls_test(TestCase::ServerCertSanInvalid).await; + + let err: openssl::ssl::Error = SslStream::connect(Pin::new(&mut tlsstream)) + .await + .unwrap_err(); + trace!(?err); + // Certification Verification Failure + let ossl_err = err.ssl_error().and_then(|e| e.errors().last()).unwrap(); + assert_eq!(ossl_err.code(), 167772294); + + trace!("Waiting on listener ..."); + let result = handle.await.expect("Failed to stop task."); + + // Must be FALSE server should not have accepted the connection. + trace!(?result); + // SSL Read bytes (client disconnected) + assert_eq!(result, Err(167773202)); +} + +#[tokio::test] +async fn test_mtls_server_without_client_ca() { + let (mut tlsstream, handle, _tx) = setup_mtls_test(TestCase::ServerWithoutClientCa).await; + + // The client isn't the one that errors, the server does. + SslStream::connect(Pin::new(&mut tlsstream)).await.unwrap(); + + trace!("Waiting on listener ..."); + let result = handle.await.expect("Failed to stop task."); + + // Must be FALSE server should not have accepted the connection. + trace!(?result); + // Certification Verification Failure + assert_eq!(result, Err(167772294)); +} + +#[tokio::test] +async fn test_mtls_client_without_client_cert() { + let (mut tlsstream, handle, _tx) = setup_mtls_test(TestCase::ClientWithoutClientCert).await; + + // The client isn't the one that errors, the server does. + SslStream::connect(Pin::new(&mut tlsstream)).await.unwrap(); + + trace!("Waiting on listener ..."); + let result = handle.await.expect("Failed to stop task."); + + // Must be FALSE server should not have accepted the connection. + trace!(?result); + // Peer Did Not Provide Certificate + assert_eq!(result, Err(167772359)); +} + +#[tokio::test] +async fn test_mtls_client_without_server_ca() { + let (mut tlsstream, handle, _tx) = setup_mtls_test(TestCase::ClientWithoutServerCa).await; + + let err: openssl::ssl::Error = SslStream::connect(Pin::new(&mut tlsstream)) + .await + .unwrap_err(); + trace!(?err); + // Tls Post Process Certificate (Certificate Verify Failed) + let ossl_err = err.ssl_error().and_then(|e| e.errors().last()).unwrap(); + assert_eq!(ossl_err.code(), 167772294); + + trace!("Waiting on listener ..."); + let result = handle.await.expect("Failed to stop task."); + + // Must be FALSE server should not have accepted the connection. + trace!(?result); + // SSL Read bytes (client disconnected) + assert_eq!(result, Err(167773208)); +} diff --git a/unix_integration/Cargo.toml b/unix_integration/Cargo.toml index 609331379..18943465f 100644 --- a/unix_integration/Cargo.toml +++ b/unix_integration/Cargo.toml @@ -78,6 +78,7 @@ kanidm_utils_users = { workspace = true } [dev-dependencies] kanidmd_core = { workspace = true } +kanidmd_testkit = { workspace = true } [build-dependencies] clap = { workspace = true, features = ["derive"] } diff --git a/unix_integration/tests/cache_layer_test.rs b/unix_integration/tests/cache_layer_test.rs index ca8ae47d0..a79b1105b 100644 --- a/unix_integration/tests/cache_layer_test.rs +++ b/unix_integration/tests/cache_layer_test.rs @@ -1,8 +1,7 @@ #![deny(warnings)] use std::future::Future; -use std::net::TcpStream; use std::pin::Pin; -use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::atomic::Ordering; use std::time::Duration; use kanidm_client::{KanidmClient, KanidmClientBuilder}; @@ -18,10 +17,10 @@ use kanidm_unix_common::resolver::Resolver; use kanidm_unix_common::unix_config::TpmPolicy; use kanidmd_core::config::{Configuration, IntegrationTestConfig, ServerRole}; use kanidmd_core::create_server_core; +use kanidmd_testkit::{is_free_port, PORT_ALLOC}; use tokio::task; use tracing::log::{debug, trace}; -static PORT_ALLOC: AtomicU16 = AtomicU16::new(28080); const ADMIN_TEST_USER: &str = "admin"; const ADMIN_TEST_PASSWORD: &str = "integration test admin password"; const TESTACCOUNT1_PASSWORD_A: &str = "password a for account1 test"; @@ -29,13 +28,6 @@ const TESTACCOUNT1_PASSWORD_B: &str = "password b for account1 test"; const TESTACCOUNT1_PASSWORD_INC: &str = "never going to work"; const ACCOUNT_EXPIRE: &str = "1970-01-01T00:00:00+00:00"; -fn is_free_port(port: u16) -> bool { - match TcpStream::connect(("0.0.0.0", port)) { - Ok(_) => false, - Err(_) => true, - } -} - type Fixture = Box Pin>>>; fn fixture(f: fn(KanidmClient) -> T) -> Fixture