68 20230919 replication configuration (#2131)

This commit is contained in:
Firstyear 2023-09-29 12:02:13 +10:00 committed by GitHub
parent 034ddd624a
commit 3e345174b6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
38 changed files with 2297 additions and 180 deletions

7
Cargo.lock generated
View file

@ -2873,6 +2873,7 @@ dependencies = [
"sketching",
"tracing",
"tss-esapi",
"uuid",
]
[[package]]
@ -2953,6 +2954,7 @@ dependencies = [
"kanidm_proto",
"kanidm_utils_users",
"kanidmd_core",
"kanidmd_testkit",
"libc",
"libsqlite3-sys",
"lru 0.8.1",
@ -2999,6 +3001,7 @@ dependencies = [
"http",
"hyper",
"kanidm_build_profiles",
"kanidm_lib_crypto",
"kanidm_proto",
"kanidm_utils_users",
"kanidmd_lib",
@ -3021,6 +3024,7 @@ dependencies = [
"tower-http",
"tracing",
"tracing-subscriber",
"url",
"urlencoding",
"uuid",
]
@ -3100,6 +3104,7 @@ dependencies = [
"hyper-tls",
"kanidm_build_profiles",
"kanidm_client",
"kanidm_lib_crypto",
"kanidm_proto",
"kanidmd_core",
"kanidmd_lib",
@ -3113,8 +3118,10 @@ dependencies = [
"testkit-macros",
"time",
"tokio",
"tokio-openssl",
"tracing",
"url",
"uuid",
"webauthn-authenticator-rs",
]

View file

@ -74,10 +74,11 @@ webauthn-rs-proto = { git = "https://github.com/kanidm/webauthn-rs.git", rev = "
kanidmd_core = { path = "./server/core" }
kanidmd_lib = { path = "./server/lib" }
kanidmd_lib_macros = { path = "./server/lib-macros" }
kanidmd_testkit = { path = "./server/testkit" }
kanidm_build_profiles = { path = "./libs/profiles", version = "1.1.0-rc.14-dev" }
kanidm_client = { path = "./libs/client", version = "1.1.0-rc.14-dev" }
kanidm_lib_crypto = { path = "./libs/crypto" }
kanidm_lib_file_permissions = { path = "./libs/file_permissions" }
kanidm_client = { path = "./libs/client", version = "1.1.0-rc.14-dev" }
kanidm_proto = { path = "./proto", version = "1.1.0-rc.14-dev" }
kanidm_unix_int = { path = "./unix_integration" }
kanidm_utils_users = { path = "./libs/users" }

View file

@ -133,7 +133,7 @@ install-tools:
codespell: ## spell-check things.
codespell:
codespell -c \
-L 'crate,unexpect,Pres,pres,ACI,aci,te,ue,unx,aNULL' \
-L 'crate,unexpect,Pres,pres,ACI,aci,ser,te,ue,unx,aNULL' \
--skip='./target,./pykanidm/.venv,./pykanidm/.mypy_cache,./.mypy_cache,./pykanidm/poetry.lock' \
--skip='./book/book/*' \
--skip='./book/src/images/*' \

View file

@ -21,6 +21,7 @@ rand = { workspace = true }
serde = { workspace = true, features = ["derive"] }
tracing = { workspace = true }
tss-esapi = { workspace = true, optional = true }
uuid = { workspace = true }
[dev-dependencies]
sketching = { workspace = true }

View file

@ -23,11 +23,16 @@ use std::fmt;
use std::time::{Duration, Instant};
use kanidm_proto::v1::OperationError;
use openssl::error::ErrorStack as OpenSSLErrorStack;
use openssl::hash::{self, MessageDigest};
use openssl::nid::Nid;
use openssl::pkcs5::pbkdf2_hmac;
use openssl::sha::Sha512;
pub mod mtls;
pub mod prelude;
pub mod serialise;
#[cfg(feature = "tpm")]
pub use tss_esapi::{handles::ObjectHandle as TpmHandle, Context as TpmContext, Error as TpmError};
#[cfg(not(feature = "tpm"))]
@ -68,13 +73,21 @@ pub enum CryptoError {
Tpm2FeatureMissing,
Tpm2InputExceeded,
Tpm2ContextMissing,
OpenSSL,
OpenSSL(u64),
Md4Disabled,
Argon2,
Argon2Version,
Argon2Parameters,
}
impl From<OpenSSLErrorStack> for CryptoError {
fn from(ossl_err: OpenSSLErrorStack) -> Self {
error!(?ossl_err);
let code = ossl_err.errors().get(0).map(|e| e.code()).unwrap_or(0);
CryptoError::OpenSSL(code)
}
}
#[allow(clippy::from_over_into)]
impl Into<OperationError> for CryptoError {
fn into(self) -> OperationError {
@ -785,8 +798,8 @@ impl Password {
// Turn key to a vec.
Kdf::PBKDF2(pbkdf2_cost, salt, key)
})
.map_err(|_| CryptoError::OpenSSL)
.map(|material| Password { material })
.map_err(|e| e.into())
}
pub fn new_argon2id(policy: &CryptoPolicy, cleartext: &str) -> Result<Self, CryptoError> {
@ -810,6 +823,7 @@ impl Password {
})
.map_err(|_| CryptoError::Argon2)
.map(|material| Password { material })
.map_err(|e| e.into())
}
pub fn new_argon2id_tpm(
@ -843,6 +857,7 @@ impl Password {
key,
})
.map(|material| Password { material })
.map_err(|e| e.into())
}
#[inline]
@ -962,11 +977,11 @@ impl Password {
MessageDigest::sha256(),
chal_key.as_mut_slice(),
)
.map_err(|_| CryptoError::OpenSSL)
.map(|()| {
// Actually compare the outputs.
&chal_key == key
})
.map_err(|e| e.into())
}
(Kdf::PBKDF2_SHA1(cost, salt, key), _) => {
let key_len = key.len();
@ -979,11 +994,11 @@ impl Password {
MessageDigest::sha1(),
chal_key.as_mut_slice(),
)
.map_err(|_| CryptoError::OpenSSL)
.map(|()| {
// Actually compare the outputs.
&chal_key == key
})
.map_err(|e| e.into())
}
(Kdf::PBKDF2_SHA512(cost, salt, key), _) => {
let key_len = key.len();
@ -996,11 +1011,11 @@ impl Password {
MessageDigest::sha512(),
chal_key.as_mut_slice(),
)
.map_err(|_| CryptoError::OpenSSL)
.map(|()| {
// Actually compare the outputs.
&chal_key == key
})
.map_err(|e| e.into())
}
(Kdf::SSHA512(salt, key), _) => {
let mut hasher = Sha512::new();

89
libs/crypto/src/mtls.rs Normal file
View file

@ -0,0 +1,89 @@
use crate::CryptoError;
use openssl::asn1;
use openssl::bn;
use openssl::ec;
use openssl::error::ErrorStack as OpenSSLError;
use openssl::hash;
use openssl::nid::Nid;
use openssl::pkey::{PKey, Private};
use openssl::x509::extension::BasicConstraints;
use openssl::x509::extension::ExtendedKeyUsage;
use openssl::x509::extension::KeyUsage;
use openssl::x509::extension::SubjectAlternativeName;
use openssl::x509::extension::SubjectKeyIdentifier;
use openssl::x509::X509NameBuilder;
use openssl::x509::X509;
use uuid::Uuid;
/// Gets an [ec::EcGroup] for P-256
pub fn get_group() -> Result<ec::EcGroup, OpenSSLError> {
ec::EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)
}
pub fn build_self_signed_server_and_client_identity(
cn: Uuid,
domain_name: &str,
expiration_days: u32,
) -> Result<(PKey<Private>, X509), CryptoError> {
let ecgroup = get_group()?;
let eckey = ec::EcKey::generate(&ecgroup)?;
let ca_key = PKey::from_ec_key(eckey)?;
let mut x509_name = X509NameBuilder::new()?;
// x509_name.append_entry_by_text("C", "AU")?;
// x509_name.append_entry_by_text("ST", "QLD")?;
x509_name.append_entry_by_text("O", "Kanidm Replication")?;
x509_name.append_entry_by_text("CN", &cn.as_hyphenated().to_string())?;
let x509_name = x509_name.build();
let mut cert_builder = X509::builder()?;
// Yes, 2 actually means 3 here ...
cert_builder.set_version(2)?;
let serial_number = bn::BigNum::from_u32(1).and_then(|serial| serial.to_asn1_integer())?;
cert_builder.set_serial_number(&serial_number)?;
cert_builder.set_subject_name(&x509_name)?;
cert_builder.set_issuer_name(&x509_name)?;
let not_before = asn1::Asn1Time::days_from_now(0)?;
cert_builder.set_not_before(&not_before)?;
let not_after = asn1::Asn1Time::days_from_now(expiration_days)?;
cert_builder.set_not_after(&not_after)?;
// Do we need pathlen 0?
cert_builder.append_extension(BasicConstraints::new().critical().build()?)?;
cert_builder.append_extension(
KeyUsage::new()
.critical()
.digital_signature()
.key_encipherment()
.build()?,
)?;
cert_builder.append_extension(
ExtendedKeyUsage::new()
.server_auth()
.client_auth()
.build()?,
)?;
let subject_key_identifier =
SubjectKeyIdentifier::new().build(&cert_builder.x509v3_context(None, None))?;
cert_builder.append_extension(subject_key_identifier)?;
let subject_alt_name = SubjectAlternativeName::new()
.dns(domain_name)
.build(&cert_builder.x509v3_context(None, None))?;
cert_builder.append_extension(subject_alt_name)?;
cert_builder.set_pubkey(&ca_key)?;
cert_builder.sign(&ca_key, hash::MessageDigest::sha256())?;
let ca_cert = cert_builder.build();
Ok((ca_key, ca_cert))
}

View file

@ -0,0 +1,2 @@
pub use openssl::pkey::{PKey, Private, Public};
pub use openssl::x509::X509;

View file

@ -0,0 +1,85 @@
pub mod pkeyb64 {
use base64::{engine::general_purpose, Engine as _};
use openssl::pkey::{PKey, Private};
use serde::{
de::Error as DeError, ser::Error as SerError, Deserialize, Deserializer, Serializer,
};
use tracing::error;
pub fn serialize<S>(key: &PKey<Private>, ser: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let der = key.private_key_to_der().map_err(|err| {
error!(?err, "openssl private_key_to_der");
S::Error::custom("openssl private_key_to_der")
})?;
let s = general_purpose::URL_SAFE.encode(&der);
ser.serialize_str(&s)
}
pub fn deserialize<'de, D>(des: D) -> Result<PKey<Private>, D::Error>
where
D: Deserializer<'de>,
{
let raw = <&str>::deserialize(des)?;
let s = general_purpose::URL_SAFE.decode(raw).map_err(|err| {
error!(?err, "base64 url-safe invalid");
D::Error::custom("base64 url-safe invalid")
})?;
PKey::private_key_from_der(&s).map_err(|err| {
error!(?err, "openssl pkey invalid der");
D::Error::custom("openssl pkey invalid der")
})
}
}
pub mod x509b64 {
use crate::CryptoError;
use base64::{engine::general_purpose, Engine as _};
use openssl::x509::X509;
use serde::{
de::Error as DeError, ser::Error as SerError, Deserialize, Deserializer, Serializer,
};
use tracing::error;
pub fn cert_to_string(cert: &X509) -> Result<String, CryptoError> {
cert.to_der()
.map_err(|err| {
error!(?err, "openssl cert to_der");
err.into()
})
.map(|der| general_purpose::URL_SAFE.encode(&der))
}
pub fn serialize<S>(cert: &X509, ser: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let der = cert.to_der().map_err(|err| {
error!(?err, "openssl cert to_der");
S::Error::custom("openssl private_key_to_der")
})?;
let s = general_purpose::URL_SAFE.encode(&der);
ser.serialize_str(&s)
}
pub fn deserialize<'de, D>(des: D) -> Result<X509, D::Error>
where
D: Deserializer<'de>,
{
let raw = <&str>::deserialize(des)?;
let s = general_purpose::URL_SAFE.decode(raw).map_err(|err| {
error!(?err, "base64 url-safe invalid");
D::Error::custom("base64 url-safe invalid")
})?;
X509::from_der(&s).map_err(|err| {
error!(?err, "openssl x509 invalid der");
D::Error::custom("openssl x509 invalid der")
})
}
}

View file

@ -29,6 +29,7 @@ hyper = { workspace = true }
kanidm_proto = { workspace = true }
kanidm_utils_users = { workspace = true }
kanidmd_lib = { workspace = true }
kanidm_lib_crypto = { workspace = true }
ldap3_proto = { workspace = true }
libc = { workspace = true }
openssl = { workspace = true }
@ -57,8 +58,8 @@ tracing = { workspace = true, features = ["attributes"] }
tracing-subscriber = { workspace = true, features = ["time", "json"] }
urlencoding = { workspace = true }
tempfile = { workspace = true }
url = { workspace = true, features = ["serde"] }
uuid = { workspace = true, features = ["serde", "v4"] }
[build-dependencies]
kanidm_build_profiles = { workspace = true }

View file

@ -61,8 +61,8 @@ impl QueryServerWriteV1 {
proto_ml: &ProtoModifyList,
filter: Filter<FilterInvalid>,
) -> Result<(), OperationError> {
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let ident = idms_prox_write
.validate_and_parse_token_to_ident(uat.as_deref(), ct)
@ -109,8 +109,8 @@ impl QueryServerWriteV1 {
ml: &ModifyList<ModifyInvalid>,
filter: Filter<FilterInvalid>,
) -> Result<(), OperationError> {
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let ident = idms_prox_write
.validate_and_parse_token_to_ident(uat.as_deref(), ct)
@ -163,8 +163,8 @@ impl QueryServerWriteV1 {
req: CreateRequest,
eventid: Uuid,
) -> Result<(), OperationError> {
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let ident = idms_prox_write
.validate_and_parse_token_to_ident(uat.as_deref(), ct)
@ -200,8 +200,8 @@ impl QueryServerWriteV1 {
req: ModifyRequest,
eventid: Uuid,
) -> Result<(), OperationError> {
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let ident = idms_prox_write
.validate_and_parse_token_to_ident(uat.as_deref(), ct)
.map_err(|e| {
@ -236,8 +236,8 @@ impl QueryServerWriteV1 {
req: DeleteRequest,
eventid: Uuid,
) -> Result<(), OperationError> {
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let ident = idms_prox_write
.validate_and_parse_token_to_ident(uat.as_deref(), ct)
.map_err(|e| {
@ -273,8 +273,8 @@ impl QueryServerWriteV1 {
eventid: Uuid,
) -> Result<(), OperationError> {
// Given a protoEntry, turn this into a modification set.
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let ident = idms_prox_write
.validate_and_parse_token_to_ident(uat.as_deref(), ct)
.map_err(|e| {
@ -315,8 +315,8 @@ impl QueryServerWriteV1 {
filter: Filter<FilterInvalid>,
eventid: Uuid,
) -> Result<(), OperationError> {
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let ident = idms_prox_write
.validate_and_parse_token_to_ident(uat.as_deref(), ct)
.map_err(|e| {
@ -350,8 +350,8 @@ impl QueryServerWriteV1 {
filter: Filter<FilterInvalid>,
eventid: Uuid,
) -> Result<(), OperationError> {
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let ident = idms_prox_write
.validate_and_parse_token_to_ident(uat.as_deref(), ct)
.map_err(|e| {
@ -919,8 +919,8 @@ impl QueryServerWriteV1 {
filter: Filter<FilterInvalid>,
eventid: Uuid,
) -> Result<(), OperationError> {
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let ident = idms_prox_write
.validate_and_parse_token_to_ident(uat.as_deref(), ct)
.map_err(|e| {
@ -1196,8 +1196,8 @@ impl QueryServerWriteV1 {
) -> Result<(), OperationError> {
// Because this is from internal, we can generate a real modlist, rather
// than relying on the proto ones.
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let ident = idms_prox_write
.validate_and_parse_token_to_ident(uat.as_deref(), ct)
@ -1254,8 +1254,8 @@ impl QueryServerWriteV1 {
filter: Filter<FilterInvalid>,
eventid: Uuid,
) -> Result<(), OperationError> {
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let ident = idms_prox_write
.validate_and_parse_token_to_ident(uat.as_deref(), ct)
@ -1311,8 +1311,8 @@ impl QueryServerWriteV1 {
) -> Result<(), OperationError> {
// Because this is from internal, we can generate a real modlist, rather
// than relying on the proto ones.
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let ident = idms_prox_write
.validate_and_parse_token_to_ident(uat.as_deref(), ct)
@ -1514,7 +1514,8 @@ impl QueryServerWriteV1 {
)]
pub async fn handle_purgerecycledevent(&self, msg: PurgeRecycledEvent) {
trace!(?msg, "Begin purge recycled event");
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let res = idms_prox_write
.qs_write
.purge_recycled()
@ -1551,7 +1552,8 @@ impl QueryServerWriteV1 {
eventid: Uuid,
) -> Result<String, OperationError> {
trace!(%name, "Begin admin recover account event");
let mut idms_prox_write = self.idms.proxy_write(duration_from_epoch_now()).await;
let ct = duration_from_epoch_now();
let mut idms_prox_write = self.idms.proxy_write(ct).await;
let pw = idms_prox_write.recover_account(name.as_str(), None)?;
idms_prox_write.commit().map(|()| pw)

View file

@ -1,7 +1,9 @@
use crate::actors::v1_write::QueryServerWriteV1;
use crate::repl::ReplCtrl;
use crate::CoreAction;
use bytes::{BufMut, BytesMut};
use futures::{SinkExt, StreamExt};
use kanidm_lib_crypto::serialise::x509b64;
use kanidm_utils_users::get_current_uid;
use serde::{Deserialize, Serialize};
use std::error::Error;
@ -9,18 +11,23 @@ use std::io;
use std::path::Path;
use tokio::net::{UnixListener, UnixStream};
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio_util::codec::{Decoder, Encoder, Framed};
use tracing::{span, Level};
use tracing::{span, Instrument, Level};
use uuid::Uuid;
#[derive(Serialize, Deserialize, Debug)]
pub enum AdminTaskRequest {
RecoverAccount { name: String },
ShowReplicationCertificate,
RenewReplicationCertificate,
}
#[derive(Serialize, Deserialize, Debug)]
pub enum AdminTaskResponse {
RecoverAccount { password: String },
ShowReplicationCertificate { cert: String },
Error,
}
@ -99,6 +106,7 @@ impl AdminActor {
sock_path: &str,
server: &'static QueryServerWriteV1,
mut broadcast_rx: broadcast::Receiver<CoreAction>,
repl_ctrl_tx: Option<mpsc::Sender<ReplCtrl>>,
) -> Result<tokio::task::JoinHandle<()>, ()> {
debug!("🧹 Cleaning up sockets from previous invocations");
rm_if_exist(sock_path);
@ -144,8 +152,9 @@ impl AdminActor {
};
// spawn the worker.
let task_repl_ctrl_tx = repl_ctrl_tx.clone();
tokio::spawn(async move {
if let Err(e) = handle_client(socket, server).await {
if let Err(e) = handle_client(socket, server, task_repl_ctrl_tx).await {
error!(err = ?e, "admin client error");
}
});
@ -177,9 +186,61 @@ fn rm_if_exist(p: &str) {
}
}
async fn show_replication_certificate(ctrl_tx: &mut mpsc::Sender<ReplCtrl>) -> AdminTaskResponse {
let (tx, rx) = oneshot::channel();
if ctrl_tx
.send(ReplCtrl::GetCertificate { respond: tx })
.await
.is_err()
{
error!("replication control channel has shutdown");
return AdminTaskResponse::Error;
}
match rx.await {
Ok(cert) => x509b64::cert_to_string(&cert)
.map(|cert| AdminTaskResponse::ShowReplicationCertificate { cert })
.unwrap_or(AdminTaskResponse::Error),
Err(_) => {
error!("replication control channel did not respond with certificate.");
AdminTaskResponse::Error
}
}
}
async fn renew_replication_certificate(ctrl_tx: &mut mpsc::Sender<ReplCtrl>) -> AdminTaskResponse {
let (tx, rx) = oneshot::channel();
if ctrl_tx
.send(ReplCtrl::RenewCertificate { respond: tx })
.await
.is_err()
{
error!("replication control channel has shutdown");
return AdminTaskResponse::Error;
}
match rx.await {
Ok(success) => {
if success {
show_replication_certificate(ctrl_tx).await
} else {
error!("replication control channel indicated that certificate renewal failed.");
AdminTaskResponse::Error
}
}
Err(_) => {
error!("replication control channel did not respond with renewal status.");
AdminTaskResponse::Error
}
}
}
async fn handle_client(
sock: UnixStream,
server: &'static QueryServerWriteV1,
mut repl_ctrl_tx: Option<mpsc::Sender<ReplCtrl>>,
) -> Result<(), Box<dyn Error>> {
debug!("Accepted admin socket connection");
@ -190,9 +251,10 @@ async fn handle_client(
// Setup the logging span
let eventid = Uuid::new_v4();
let nspan = span!(Level::INFO, "handle_admin_client_request", uuid = ?eventid);
let _span = nspan.enter();
// let _span = nspan.enter();
let resp = match req {
let resp = async {
match req {
AdminTaskRequest::RecoverAccount { name } => {
match server.handle_admin_recover_account(name, eventid).await {
Ok(password) => AdminTaskResponse::RecoverAccount { password },
@ -202,10 +264,27 @@ async fn handle_client(
}
}
}
};
AdminTaskRequest::ShowReplicationCertificate => match repl_ctrl_tx.as_mut() {
Some(ctrl_tx) => show_replication_certificate(ctrl_tx).await,
None => {
error!("replication not configured, unable to display certificate.");
AdminTaskResponse::Error
}
},
AdminTaskRequest::RenewReplicationCertificate => match repl_ctrl_tx.as_mut() {
Some(ctrl_tx) => renew_replication_certificate(ctrl_tx).await,
None => {
error!("replication not configured, unable to renew certificate.");
AdminTaskResponse::Error
}
},
}
}
.instrument(nspan)
.await;
reqs.send(resp).await?;
reqs.flush().await?;
trace!("flushed response!");
}
debug!("Disconnecting client ...");

View file

@ -4,25 +4,25 @@
//! These components should be "per server". Any "per domain" config should be in the system
//! or domain entries that are able to be replicated.
use std::collections::BTreeMap;
use std::fmt;
use std::fs::File;
use std::io::Read;
use std::net::SocketAddr;
use std::path::Path;
use std::str::FromStr;
use kanidm_proto::constants::DEFAULT_SERVER_ADDRESS;
use kanidm_proto::messages::ConsoleOutputMode;
use serde::{Deserialize, Serialize};
use kanidm_lib_crypto::prelude::X509;
use kanidm_lib_crypto::serialise::x509b64;
use serde::Deserialize;
use sketching::tracing_subscriber::EnvFilter;
use url::Url;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct IntegrationTestConfig {
pub admin_user: String,
pub admin_password: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone)]
pub struct OnlineBackup {
pub path: String,
#[serde(default = "default_online_backup_schedule")]
@ -39,13 +39,54 @@ fn default_online_backup_versions() -> usize {
7
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone)]
pub struct TlsConfiguration {
pub chain: String,
pub key: String,
}
#[derive(Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum RepNodeConfig {
#[serde(rename = "allow-pull")]
AllowPull {
#[serde(with = "x509b64")]
consumer_cert: X509,
},
#[serde(rename = "pull")]
Pull {
#[serde(with = "x509b64")]
supplier_cert: X509,
automatic_refresh: bool,
},
#[serde(rename = "mutual-pull")]
MutualPull {
#[serde(with = "x509b64")]
partner_cert: X509,
automatic_refresh: bool,
},
/*
AllowPush {
},
Push {
},
*/
}
#[derive(Deserialize, Debug, Clone)]
pub struct ReplicationConfiguration {
pub origin: Url,
pub bindaddress: SocketAddr,
#[serde(flatten)]
pub manual: BTreeMap<Url, RepNodeConfig>,
}
/// This is the Server Configuration as read from server.toml. Important to note
/// is that not all flags or values from Configuration are exposed via this structure
/// to prevent certain settings being set (e.g. integration test modes)
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ServerConfig {
pub bindaddress: Option<String>,
pub ldapbindaddress: Option<String>,
@ -59,10 +100,15 @@ pub struct ServerConfig {
pub tls_key: Option<String>,
pub online_backup: Option<OnlineBackup>,
pub domain: String,
// TODO -this should be URL
pub origin: String,
pub log_level: Option<LogLevel>,
#[serde(default)]
pub role: ServerRole,
pub log_level: Option<LogLevel>,
#[serde(default)]
pub i_acknowledge_that_replication_is_in_development: bool,
#[serde(rename = "replication")]
pub repl_config: Option<ReplicationConfiguration>,
}
impl ServerConfig {
@ -85,7 +131,7 @@ impl ServerConfig {
}
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy, Default, Eq, PartialEq)]
#[derive(Debug, Deserialize, Clone, Copy, Default, Eq, PartialEq)]
pub enum ServerRole {
#[default]
WriteReplica,
@ -116,7 +162,7 @@ impl FromStr for ServerRole {
}
}
#[derive(Clone, Serialize, Deserialize, Debug, Default)]
#[derive(Clone, Deserialize, Debug, Default)]
pub enum LogLevel {
#[default]
#[serde(rename = "info")]
@ -160,7 +206,22 @@ impl From<LogLevel> for EnvFilter {
}
}
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
#[derive(Debug, Clone)]
pub struct IntegrationTestConfig {
pub admin_user: String,
pub admin_password: String,
}
#[derive(Debug, Clone)]
pub struct IntegrationReplConfig {
// We can bake in a private key for mTLS here.
// pub private_key: PKey
// We might need some condition variables / timers to force replication
// events? Or a channel to submit with oneshot responses.
}
#[derive(Debug, Clone)]
pub struct Configuration {
pub address: String,
pub ldapaddress: Option<String>,
@ -180,41 +241,64 @@ pub struct Configuration {
pub role: ServerRole,
pub output_mode: ConsoleOutputMode,
pub log_level: LogLevel,
/// Replication settings.
pub repl_config: Option<ReplicationConfiguration>,
/// This allows internally setting some unsafe options for replication.
pub integration_repl_config: Option<Box<IntegrationReplConfig>>,
}
impl fmt::Display for Configuration {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "address: {}, ", self.address)?;
write!(f, "domain: {}, ", self.domain)
.and_then(|_| match &self.ldapaddress {
write!(f, "domain: {}, ", self.domain)?;
match &self.ldapaddress {
Some(la) => write!(f, "ldap address: {}, ", la),
None => write!(f, "ldap address: disabled, "),
})
.and_then(|_| write!(f, "admin bind path: {}, ", self.adminbindpath))
.and_then(|_| write!(f, "thread count: {}, ", self.threads))
.and_then(|_| write!(f, "dbpath: {}, ", self.db_path))
.and_then(|_| match self.db_arc_size {
}?;
write!(f, "origin: {} ", self.origin)?;
write!(f, "admin bind path: {}, ", self.adminbindpath)?;
write!(f, "thread count: {}, ", self.threads)?;
write!(f, "dbpath: {}, ", self.db_path)?;
match self.db_arc_size {
Some(v) => write!(f, "arcsize: {}, ", v),
None => write!(f, "arcsize: AUTO, "),
})
.and_then(|_| write!(f, "max request size: {}b, ", self.maximum_request))
.and_then(|_| write!(f, "trust X-Forwarded-For: {}, ", self.trust_x_forward_for))
.and_then(|_| write!(f, "with TLS: {}, ", self.tls_config.is_some()))
// TODO: include the backup timings
.and_then(|_| match &self.online_backup {
Some(_) => write!(f, "online_backup: enabled, "),
}?;
write!(f, "max request size: {}b, ", self.maximum_request)?;
write!(f, "trust X-Forwarded-For: {}, ", self.trust_x_forward_for)?;
write!(f, "with TLS: {}, ", self.tls_config.is_some())?;
match &self.online_backup {
Some(bck) => write!(
f,
"online_backup: enabled - schedule: {} versions: {}, ",
bck.schedule, bck.versions
),
None => write!(f, "online_backup: disabled, "),
})
.and_then(|_| write!(f, "role: {}, ", self.role.to_string()))
.and_then(|_| {
}?;
write!(
f,
"integration mode: {}, ",
self.integration_test_config.is_some()
)
})
.and_then(|_| write!(f, "console output format: {:?} ", self.output_mode))
.and_then(|_| write!(f, "log_level: {}", self.log_level.clone().to_string()))
)?;
write!(f, "console output format: {:?} ", self.output_mode)?;
write!(f, "log_level: {}", self.log_level.clone().to_string())?;
write!(f, "role: {}, ", self.role.to_string())?;
match &self.repl_config {
Some(repl) => {
write!(f, "replication: enabled")?;
write!(f, "repl_origin: {} ", repl.origin)?;
write!(f, "repl_address: {} ", repl.bindaddress)?;
write!(
f,
"integration repl config mode: {}, ",
self.integration_repl_config.is_some()
)?;
}
None => {
write!(f, "replication: disabled, ")?;
}
}
Ok(())
}
}
@ -240,9 +324,11 @@ impl Configuration {
online_backup: None,
domain: "idm.example.com".to_string(),
origin: "https://idm.example.com".to_string(),
role: ServerRole::WriteReplica,
output_mode: ConsoleOutputMode::default(),
log_level: Default::default(),
role: ServerRole::WriteReplica,
repl_config: None,
integration_repl_config: None,
}
}
@ -335,6 +421,10 @@ impl Configuration {
self.output_mode = om;
}
pub fn update_replication_config(&mut self, repl_config: Option<ReplicationConfiguration>) {
self.repl_config = repl_config;
}
pub fn update_tls(&mut self, chain: &Option<String>, key: &Option<String>) {
match (chain, key) {
(None, None) => {}

View file

@ -460,7 +460,6 @@ pub(crate) fn build_cert(
cert_builder.append_extension(
KeyUsage::new()
.critical()
// .non_repudiation()
.digital_signature()
.key_encipherment()
.build()?,

View file

@ -2,14 +2,16 @@ mod extractors;
mod generic;
mod javascript;
mod manifest;
mod middleware;
pub(crate) mod middleware;
mod oauth2;
mod tests;
mod trace;
pub(crate) mod trace;
mod ui;
mod v1;
mod v1_scim;
use self::generic::*;
use self::javascript::*;
use crate::actors::v1_read::QueryServerReadV1;
use crate::actors::v1_write::QueryServerWriteV1;
use crate::config::{Configuration, ServerRole, TlsConfiguration};
@ -21,13 +23,11 @@ use axum::Router;
use axum_csp::{CspDirectiveType, CspValue};
use axum_macros::FromRef;
use compact_jwt::{Jws, JwsSigner, JwsUnverified};
use generic::*;
use http::header::{ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE};
use http::{HeaderMap, HeaderValue, StatusCode};
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrStream, Http};
use hyper::Body;
use javascript::*;
use kanidm_proto::constants::APPLICATION_JSON;
use kanidm_proto::v1::OperationError;
use kanidmd_lib::status::StatusActor;
@ -288,7 +288,17 @@ pub async fn create_https_server(
}
res = match config.tls_config {
Some(tls_param) => {
tokio::spawn(server_loop(tls_param, addr, app))
// This isn't optimal, but we can't share this with the
// other path for integration tests because that doesn't
// do tls (yet?)
let listener = match TcpListener::bind(addr).await {
Ok(l) => l,
Err(err) => {
error!(?err, "Failed to bind tcp listener");
return
}
};
tokio::spawn(server_loop(tls_param, listener, app))
},
None => {
tokio::spawn(axum_server::bind(addr).serve(app))
@ -307,7 +317,7 @@ pub async fn create_https_server(
async fn server_loop(
tls_param: TlsConfiguration,
addr: SocketAddr,
listener: TcpListener,
app: IntoMakeServiceWithConnectInfo<Router, SocketAddr>,
) -> Result<(), std::io::Error> {
let mut tls_builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?;
@ -335,7 +345,6 @@ async fn server_loop(
)
})?;
let acceptor = tls_builder.build();
let listener = TcpListener::bind(addr).await?;
let protocol = Arc::new(Http::new());
let mut listener =
@ -355,7 +364,7 @@ async fn server_loop(
}
/// This handles an individual connection.
async fn handle_conn(
pub(crate) async fn handle_conn(
acceptor: SslAcceptor,
stream: AddrStream,
svc: ResponseFuture<Router, SocketAddr>,

View file

@ -1,5 +1,5 @@
use std::marker::Unpin;
use std::net;
use std::pin::Pin;
use std::str::FromStr;
use crate::actors::v1_read::QueryServerReadV1;
@ -10,8 +10,7 @@ use kanidmd_lib::prelude::*;
use ldap3_proto::proto::LdapMsg;
use ldap3_proto::LdapCodec;
use openssl::ssl::{Ssl, SslAcceptor};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio::net::{TcpListener, TcpStream};
use tokio_openssl::SslStream;
use tokio_util::codec::{FramedRead, FramedWrite};
@ -47,12 +46,31 @@ async fn client_process_msg(
qe_r_ref.handle_ldaprequest(eventid, protomsg, uat).await
}
async fn client_process<W: AsyncWrite + Unpin, R: AsyncRead + Unpin>(
mut r: FramedRead<R, LdapCodec>,
mut w: FramedWrite<W, LdapCodec>,
async fn client_process(
tcpstream: TcpStream,
tls_acceptor: SslAcceptor,
client_address: net::SocketAddr,
qe_r_ref: &'static QueryServerReadV1,
) {
// Start the event
// From the parameters we need to create an SslContext.
let mut tlsstream = match Ssl::new(tls_acceptor.context())
.and_then(|tls_obj| SslStream::new(tls_obj, tcpstream))
{
Ok(ta) => ta,
Err(e) => {
error!("LDAP TLS setup error, continuing -> {:?}", e);
return;
}
};
if let Err(e) = SslStream::accept(Pin::new(&mut tlsstream)).await {
error!("LDAP TLS accept error, continuing -> {:?}", e);
return;
};
let (r, w) = tokio::io::split(tlsstream);
let mut r = FramedRead::new(r, LdapCodec);
let mut w = FramedWrite::new(w, LdapCodec);
// This is a connected client session. we need to associate some state to the session
let mut session = LdapSession::new();
// Now that we have the session we begin an event loop to process input OR we return.
@ -108,7 +126,7 @@ async fn client_process<W: AsyncWrite + Unpin, R: AsyncRead + Unpin>(
/// TLS LDAP Listener, hands off to [client_process]
async fn tls_acceptor(
listener: TcpListener,
ssl_acceptor: SslAcceptor,
tls_acceptor: SslAcceptor,
qe_r_ref: &'static QueryServerReadV1,
mut rx: broadcast::Receiver<CoreAction>,
) {
@ -122,25 +140,8 @@ async fn tls_acceptor(
accept_result = listener.accept() => {
match accept_result {
Ok((tcpstream, client_socket_addr)) => {
// Start the event
// From the parameters we need to create an SslContext.
let mut tlsstream = match Ssl::new(ssl_acceptor.context())
.and_then(|tls_obj| SslStream::new(tls_obj, tcpstream))
{
Ok(ta) => ta,
Err(e) => {
error!("LDAP TLS setup error, continuing -> {:?}", e);
continue;
}
};
if let Err(e) = SslStream::accept(Pin::new(&mut tlsstream)).await {
error!("LDAP TLS accept error, continuing -> {:?}", e);
continue;
};
let (r, w) = tokio::io::split(tlsstream);
let r = FramedRead::new(r, LdapCodec);
let w = FramedWrite::new(w, LdapCodec);
tokio::spawn(client_process(r, w, client_socket_addr, qe_r_ref));
let clone_tls_acceptor = tls_acceptor.clone();
tokio::spawn(client_process(tcpstream, clone_tls_acceptor, client_socket_addr, qe_r_ref));
}
Err(e) => {
error!("LDAP acceptor error, continuing -> {:?}", e);

View file

@ -32,6 +32,7 @@ mod crypto;
mod https;
mod interval;
mod ldaps;
mod repl;
use std::path::Path;
use std::sync::Arc;
@ -867,22 +868,6 @@ pub async fn create_server_core(
}
};
// If we are NOT in integration test mode, start the admin socket now
let maybe_admin_sock_handle = if config.integration_test_config.is_none() {
let broadcast_rx = broadcast_tx.subscribe();
let admin_handle = AdminActor::create_admin_sock(
config.adminbindpath.as_str(),
server_write_ref,
broadcast_rx,
)
.await?;
Some(admin_handle)
} else {
None
};
// If we have been requested to init LDAP, configure it now.
let maybe_ldap_acceptor_handle = match &config.ldapaddress {
Some(la) => {
@ -913,6 +898,26 @@ pub async fn create_server_core(
}
};
// If we have replication configured, setup the listener with it's initial replication
// map (if any).
let (maybe_repl_handle, maybe_repl_ctrl_tx) = match &config.repl_config {
Some(rc) => {
if !config_test {
// ⚠️ only start the sockets and listeners in non-config-test modes.
let (h, repl_ctrl_tx) =
repl::create_repl_server(idms_arc.clone(), rc, broadcast_tx.subscribe())
.await?;
(Some(h), Some(repl_ctrl_tx))
} else {
(None, None)
}
}
None => {
debug!("Replication not requested, skipping");
(None, None)
}
};
let maybe_http_acceptor_handle = if config_test {
admin_info!("this config rocks! 🪨 ");
None
@ -941,6 +946,23 @@ pub async fn create_server_core(
Some(h)
};
// If we are NOT in integration test mode, start the admin socket now
let maybe_admin_sock_handle = if config.integration_test_config.is_none() {
let broadcast_rx = broadcast_tx.subscribe();
let admin_handle = AdminActor::create_admin_sock(
config.adminbindpath.as_str(),
server_write_ref,
broadcast_rx,
maybe_repl_ctrl_tx,
)
.await?;
Some(admin_handle)
} else {
None
};
let mut handles = vec![interval_handle, delayed_handle, auditd_handle];
if let Some(backup_handle) = maybe_backup_handle {
@ -959,6 +981,10 @@ pub async fn create_server_core(
handles.push(http_handle)
}
if let Some(repl_handle) = maybe_repl_handle {
handles.push(repl_handle)
}
Ok(CoreHandle {
clean_shutdown: false,
tx: broadcast_tx,

View file

@ -0,0 +1,250 @@
use bytes::{BufMut, BytesMut};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::io;
use tokio_util::codec::{Decoder, Encoder};
use kanidmd_lib::repl::proto::{ReplIncrementalContext, ReplRefreshContext, ReplRuvRange};
#[derive(Serialize, Deserialize, Debug)]
pub enum ConsumerRequest {
Ping,
Incremental(ReplRuvRange),
Refresh,
}
#[derive(Serialize, Deserialize, Debug)]
pub enum SupplierResponse {
Pong,
Incremental(ReplIncrementalContext),
Refresh(ReplRefreshContext),
}
#[derive(Default)]
pub struct ConsumerCodec {
max_frame_bytes: usize,
}
impl ConsumerCodec {
pub fn new(max_frame_bytes: usize) -> Self {
ConsumerCodec { max_frame_bytes }
}
}
impl Decoder for ConsumerCodec {
type Error = io::Error;
type Item = SupplierResponse;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
decode_length_checked_json(self.max_frame_bytes, src)
}
}
impl Encoder<ConsumerRequest> for ConsumerCodec {
type Error = io::Error;
fn encode(&mut self, msg: ConsumerRequest, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_length_checked_json(msg, dst)
}
}
#[derive(Default)]
pub struct SupplierCodec {
max_frame_bytes: usize,
}
impl SupplierCodec {
pub fn new(max_frame_bytes: usize) -> Self {
SupplierCodec { max_frame_bytes }
}
}
impl Decoder for SupplierCodec {
type Error = io::Error;
type Item = ConsumerRequest;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
decode_length_checked_json(self.max_frame_bytes, src)
}
}
impl Encoder<SupplierResponse> for SupplierCodec {
type Error = io::Error;
fn encode(&mut self, msg: SupplierResponse, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_length_checked_json(msg, dst)
}
}
fn encode_length_checked_json<R: Serialize>(msg: R, dst: &mut BytesMut) -> Result<(), io::Error> {
// Null the head of the buffer.
let zero_len = u64::MIN.to_be_bytes();
dst.extend_from_slice(&zero_len);
// skip the buffer ahead 8 bytes.
// Remember, this split returns the *already set* bytes.
// ⚠️ Can't use split or split_at - these return the
// len bytes into a new bytes mut which confuses unsplit
// by appending the value when we need to append our json.
let json_buf = dst.split_off(zero_len.len());
let mut json_writer = json_buf.writer();
serde_json::to_writer(&mut json_writer, &msg).map_err(|err| {
error!(?err, "consumer encoding error");
io::Error::new(io::ErrorKind::Other, "JSON encode error")
})?;
let json_buf = json_writer.into_inner();
let final_len = json_buf.len() as u64;
let final_len_bytes = final_len.to_be_bytes();
if final_len_bytes.len() != dst.len() {
error!("consumer buffer size error");
return Err(io::Error::new(io::ErrorKind::Other, "buffer length error"));
}
dst.copy_from_slice(&final_len_bytes);
// Now stitch them back together.
dst.unsplit(json_buf);
Ok(())
}
fn decode_length_checked_json<T: DeserializeOwned>(
max_frame_bytes: usize,
src: &mut BytesMut,
) -> Result<Option<T>, io::Error> {
trace!(capacity = ?src.capacity());
if src.len() < 8 {
// Not enough for the length header.
trace!("Insufficient bytes for length header.");
return Ok(None);
}
let (src_len_bytes, json_bytes) = src.split_at(8);
let mut len_be_bytes = [0; 8];
assert_eq!(len_be_bytes.len(), src_len_bytes.len());
len_be_bytes.copy_from_slice(src_len_bytes);
let req_len = u64::from_be_bytes(len_be_bytes);
if req_len == 0 {
error!("request has size 0");
return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty request"));
}
if req_len > max_frame_bytes as u64 {
error!(
"requested decode frame too large {} > {}",
req_len, max_frame_bytes
);
return Err(io::Error::new(
io::ErrorKind::OutOfMemory,
"request too large",
));
}
if (src.len() as u64) < req_len {
trace!(
"Insufficient bytes for json, need: {} have: {}",
req_len,
src.len()
);
return Ok(None);
}
// Okay, we have enough. Lets go.
let res = serde_json::from_slice(json_bytes)
.map(|msg| Some(msg))
.map_err(|err| {
error!(?err, "received invalid input");
io::Error::new(io::ErrorKind::InvalidInput, "JSON decode error")
});
// Trim to length.
if src.len() as u64 == req_len {
src.clear();
} else {
let mut rem = src.split_off((8 + req_len) as usize);
std::mem::swap(&mut rem, src);
};
res
}
#[cfg(test)]
mod tests {
use bytes::BytesMut;
use tokio_util::codec::{Decoder, Encoder};
use super::{ConsumerCodec, ConsumerRequest, SupplierCodec, SupplierResponse};
#[test]
fn test_repl_codec() {
sketching::test_init();
let mut consumer_codec = ConsumerCodec::new(32);
let mut buf = BytesMut::with_capacity(32);
// Empty buffer
assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
let zero = [0, 0, 0, 0];
buf.extend_from_slice(&zero);
// Not enough to fill the length header.
assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
// Length header reports a zero size request.
let zero = [0, 0, 0, 0];
buf.extend_from_slice(&zero);
assert!(buf.len() == 8);
assert!(consumer_codec.decode(&mut buf).is_err());
// Clear buffer - setup a request with a length > allowed max.
buf.clear();
let len_bytes = (34 as u64).to_be_bytes();
buf.extend_from_slice(&len_bytes);
// Even though the buf len is only 8, this will error as the overall
// request will be too large.
assert!(buf.len() == 8);
assert!(consumer_codec.decode(&mut buf).is_err());
// Assert that we request more data on a validly sized req
buf.clear();
let len_bytes = (20 as u64).to_be_bytes();
buf.extend_from_slice(&len_bytes);
// Pad in some extra bytes.
buf.extend_from_slice(&zero);
assert!(buf.len() == 12);
assert!(matches!(consumer_codec.decode(&mut buf), Ok(None)));
// Make a request that is correctly sized.
buf.clear();
let mut supplier_codec = SupplierCodec::new(32);
assert!(consumer_codec
.encode(ConsumerRequest::Ping, &mut buf)
.is_ok());
assert!(matches!(
supplier_codec.decode(&mut buf),
Ok(Some(ConsumerRequest::Ping))
));
// The buf will have been cleared by the supplier codec here.
assert!(buf.is_empty());
assert!(supplier_codec
.encode(SupplierResponse::Pong, &mut buf)
.is_ok());
assert!(matches!(
consumer_codec.decode(&mut buf),
Ok(Some(SupplierResponse::Pong))
));
assert!(buf.is_empty());
}
}

712
server/core/src/repl/mod.rs Normal file
View file

@ -0,0 +1,712 @@
use openssl::{
pkey::{PKey, Private},
ssl::{Ssl, SslAcceptor, SslConnector, SslMethod, SslVerifyMode},
x509::{store::X509StoreBuilder, X509},
};
use std::collections::VecDeque;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::time::{interval, sleep, timeout};
use tokio_openssl::SslStream;
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::{error, Instrument};
use url::Url;
use uuid::Uuid;
use futures_util::sink::SinkExt;
use futures_util::stream::StreamExt;
use kanidmd_lib::prelude::duration_from_epoch_now;
use kanidmd_lib::prelude::IdmServer;
use kanidmd_lib::repl::proto::ConsumerState;
use kanidmd_lib::server::QueryServerTransaction;
use crate::config::RepNodeConfig;
use crate::config::ReplicationConfiguration;
use crate::CoreAction;
use self::codec::{ConsumerRequest, SupplierResponse};
mod codec;
pub(crate) enum ReplCtrl {
GetCertificate { respond: oneshot::Sender<X509> },
RenewCertificate { respond: oneshot::Sender<bool> },
}
pub(crate) async fn create_repl_server(
idms: Arc<IdmServer>,
repl_config: &ReplicationConfiguration,
rx: broadcast::Receiver<CoreAction>,
) -> Result<(tokio::task::JoinHandle<()>, mpsc::Sender<ReplCtrl>), ()> {
// We need to start the tcp listener. This will persist over ssl reloads!
let listener = TcpListener::bind(&repl_config.bindaddress)
.await
.map_err(|e| {
error!(
"Could not bind to replication address {} -> {:?}",
repl_config.bindaddress, e
);
})?;
// Create the control channel. Use a low msg count, there won't be that much going on.
let (ctrl_tx, ctrl_rx) = mpsc::channel(4);
// We need to start the tcp listener. This will persist over ssl reloads!
info!(
"Starting replication interface https://{} ...",
repl_config.bindaddress
);
let repl_handle = tokio::spawn(repl_acceptor(
listener,
idms,
repl_config.clone(),
rx,
ctrl_rx,
));
info!("Created replication interface");
Ok((repl_handle, ctrl_tx))
}
#[instrument(level = "info", skip_all)]
async fn repl_run_consumer(
max_frame_bytes: usize,
domain: &str,
sock_addrs: &[SocketAddr],
tls_connector: &SslConnector,
automatic_refresh: bool,
idms: &IdmServer,
) {
let replica_connect_timeout = Duration::from_secs(2);
// This is pretty gnarly, but we need to loop to try out each socket addr.
for sock_addr in sock_addrs {
debug!("Connecting to {} replica via {}", domain, sock_addr);
let tcpstream = match timeout(replica_connect_timeout, TcpStream::connect(sock_addr)).await
{
Ok(Ok(tc)) => tc,
Ok(Err(err)) => {
error!(?err, "Failed to connect to {}", sock_addr);
continue;
}
Err(_) => {
error!("Timeout connecting to {}", sock_addr);
continue;
}
};
trace!("connection established");
let mut tlsstream = match Ssl::new(tls_connector.context())
.and_then(|tls_obj| SslStream::new(tls_obj, tcpstream))
{
Ok(ta) => ta,
Err(e) => {
error!("replication client TLS setup error, continuing -> {:?}", e);
continue;
}
};
if let Err(e) = SslStream::connect(Pin::new(&mut tlsstream)).await {
error!("replication client TLS accept error, continuing -> {:?}", e);
continue;
};
let (r, w) = tokio::io::split(tlsstream);
let mut r = FramedRead::new(r, codec::ConsumerCodec::new(max_frame_bytes));
let mut w = FramedWrite::new(w, codec::ConsumerCodec::new(max_frame_bytes));
// Perform incremental.
let consumer_ruv_range = {
let mut read_txn = idms.proxy_read().await;
match read_txn.qs_read.consumer_get_state() {
Ok(ruv_range) => ruv_range,
Err(err) => {
error!(
?err,
"consumer ruv range could not be accessed, unable to continue."
);
break;
}
}
};
if let Err(err) = w
.send(ConsumerRequest::Incremental(consumer_ruv_range))
.await
{
error!(?err, "consumer encode error, unable to continue.");
break;
}
let changes = if let Some(codec_msg) = r.next().await {
match codec_msg {
Ok(SupplierResponse::Incremental(changes)) => {
// Success - return to bypass the error message.
changes
}
Ok(SupplierResponse::Pong) | Ok(SupplierResponse::Refresh(_)) => {
error!("Supplier Response contains invalid State");
break;
}
Err(err) => {
error!(?err, "consumer decode error, unable to continue.");
break;
}
}
} else {
error!("Connection closed");
break;
};
// Now apply the changes if possible
let consumer_state = {
let ct = duration_from_epoch_now();
let mut write_txn = idms.proxy_write(ct).await;
match write_txn
.qs_write
.consumer_apply_changes(&changes)
.and_then(|cs| write_txn.commit().map(|()| cs))
{
Ok(state) => state,
Err(err) => {
error!(?err, "consumer was not able to apply changes.");
break;
}
}
};
match consumer_state {
ConsumerState::Ok => {
info!("Incremental Replication Success");
// return to bypass the failure message.
return;
}
ConsumerState::RefreshRequired => {
if automatic_refresh {
warn!("Consumer is out of date and must be refreshed. This will happen *now*.");
} else {
error!("Consumer is out of date and must be refreshed. You must manually resolve this situation.");
return;
};
}
}
if let Err(err) = w.send(ConsumerRequest::Refresh).await {
error!(?err, "consumer encode error, unable to continue.");
break;
}
let refresh = if let Some(codec_msg) = r.next().await {
match codec_msg {
Ok(SupplierResponse::Refresh(changes)) => {
// Success - return to bypass the error message.
changes
}
Ok(SupplierResponse::Pong) | Ok(SupplierResponse::Incremental(_)) => {
error!("Supplier Response contains invalid State");
break;
}
Err(err) => {
error!(?err, "consumer decode error, unable to continue.");
break;
}
}
} else {
error!("Connection closed");
break;
};
// Now apply the refresh if possible
let ct = duration_from_epoch_now();
let mut write_txn = idms.proxy_write(ct).await;
if let Err(err) = write_txn
.qs_write
.consumer_apply_refresh(&refresh)
.and_then(|cs| write_txn.commit().map(|()| cs))
{
error!(?err, "consumer was not able to apply refresh.");
break;
}
warn!("Replication refresh was successful.");
return;
}
error!("Unable to complete replication successfully.");
}
async fn repl_task(
origin: Url,
client_key: PKey<Private>,
client_cert: X509,
supplier_cert: X509,
max_frame_bytes: usize,
task_poll_interval: Duration,
mut task_rx: broadcast::Receiver<()>,
automatic_refresh: bool,
idms: Arc<IdmServer>,
) {
if origin.scheme() != "repl" {
error!("Replica origin is not repl:// - refusing to proceed.");
return;
}
let domain = match origin.domain() {
Some(d) => d,
None => {
error!("Replica origin does not have a valid domain name, unable to proceed. Perhaps you tried to use an ip address?");
return;
}
};
let socket_addrs = match origin.socket_addrs(|| Some(443)) {
Ok(sa) => sa,
Err(err) => {
error!(?err, "Replica origin could not resolve to ip:port");
return;
}
};
// Setup our tls connector.
let mut ssl_builder = match SslConnector::builder(SslMethod::tls_client()) {
Ok(sb) => sb,
Err(err) => {
error!(?err, "Unable to configure tls connector");
return;
}
};
let setup_client_cert = ssl_builder
.set_certificate(&client_cert)
.and_then(|_| ssl_builder.set_private_key(&client_key))
.and_then(|_| ssl_builder.check_private_key());
if let Err(err) = setup_client_cert {
error!(?err, "Unable to configure client certificate/key");
return;
}
// Add the supplier cert.
// ⚠️ note that here we need to build a new cert store. This is because
// openssl SslConnector adds the default system cert locations with
// the call to ::builder and we *don't* want this. We want our certstore
// to pin a single certificate!
let mut cert_store = match X509StoreBuilder::new() {
Ok(csb) => csb,
Err(err) => {
error!(?err, "Unable to configure certificate store builder.");
return;
}
};
if let Err(err) = cert_store.add_cert(supplier_cert) {
error!(?err, "Unable to add supplier certificate to cert store");
return;
}
let cert_store = cert_store.build();
ssl_builder.set_cert_store(cert_store);
// Configure the expected hostname of the remote.
let verify_param = ssl_builder.verify_param_mut();
if let Err(err) = verify_param.set_host(domain) {
error!(?err, "Unable to set domain name for tls peer verification");
return;
}
// Assert the expected supplier certificate is correct and has a valid domain san
ssl_builder.set_verify(SslVerifyMode::PEER);
let tls_connector = ssl_builder.build();
let mut repl_interval = interval(task_poll_interval);
info!("Replica task for {} has started.", origin);
// Okay, all the parameters are setup. Now we wait on our interval.
loop {
tokio::select! {
Ok(()) = task_rx.recv() => {
break;
}
_ = repl_interval.tick() => {
// Interval passed, attempt a replication run.
let eventid = Uuid::new_v4();
let span = info_span!("replication_run_consumer", uuid = ?eventid);
let _enter = span.enter();
repl_run_consumer(
max_frame_bytes,
domain,
&socket_addrs,
&tls_connector,
automatic_refresh,
&idms
).await;
}
}
}
info!("Replica task for {} has stopped.", origin);
}
#[instrument(level = "info", skip_all)]
async fn handle_repl_conn(
max_frame_bytes: usize,
tcpstream: TcpStream,
client_address: SocketAddr,
tls_parms: SslAcceptor,
idms: Arc<IdmServer>,
) {
debug!(?client_address, "replication client connected 🛫");
let mut tlsstream = match Ssl::new(tls_parms.context())
.and_then(|tls_obj| SslStream::new(tls_obj, tcpstream))
{
Ok(ta) => ta,
Err(err) => {
error!(?err, "LDAP TLS setup error, disconnecting client");
return;
}
};
if let Err(err) = SslStream::accept(Pin::new(&mut tlsstream)).await {
error!(?err, "LDAP TLS accept error, disconnecting client");
return;
};
let (r, w) = tokio::io::split(tlsstream);
let mut r = FramedRead::new(r, codec::SupplierCodec::new(max_frame_bytes));
let mut w = FramedWrite::new(w, codec::SupplierCodec::new(max_frame_bytes));
while let Some(codec_msg) = r.next().await {
match codec_msg {
Ok(ConsumerRequest::Ping) => {
debug!("consumer requested ping");
if let Err(err) = w.send(SupplierResponse::Pong).await {
error!(?err, "supplier encode error, unable to continue.");
break;
}
}
Ok(ConsumerRequest::Incremental(consumer_ruv_range)) => {
let mut read_txn = idms.proxy_read().await;
let changes = match read_txn
.qs_read
.supplier_provide_changes(consumer_ruv_range)
{
Ok(changes) => changes,
Err(err) => {
error!(?err, "supplier provide changes failed.");
break;
}
};
if let Err(err) = w.send(SupplierResponse::Incremental(changes)).await {
error!(?err, "supplier encode error, unable to continue.");
break;
}
}
Ok(ConsumerRequest::Refresh) => {
let mut read_txn = idms.proxy_read().await;
let changes = match read_txn.qs_read.supplier_provide_refresh() {
Ok(changes) => changes,
Err(err) => {
error!(?err, "supplier provide refresh failed.");
break;
}
};
if let Err(err) = w.send(SupplierResponse::Refresh(changes)).await {
error!(?err, "supplier encode error, unable to continue.");
break;
}
}
Err(err) => {
error!(?err, "supplier decode error, unable to continue.");
break;
}
}
}
debug!(?client_address, "replication client disconnected 🛬");
}
async fn repl_acceptor(
listener: TcpListener,
idms: Arc<IdmServer>,
repl_config: ReplicationConfiguration,
mut rx: broadcast::Receiver<CoreAction>,
mut ctrl_rx: mpsc::Receiver<ReplCtrl>,
) {
info!("Starting Replication Acceptor ...");
// Persistent parts
// These all probably need changes later ...
let task_poll_interval = Duration::from_secs(10);
let retry_timeout = Duration::from_secs(60);
let max_frame_bytes = 268435456;
// Setup a broadcast to control our tasks.
let (task_tx, task_rx1) = broadcast::channel(2);
// Note, we drop this task here since each task will re-subscribe. That way the
// broadcast doesn't jam up because we aren't draining this task.
drop(task_rx1);
let mut task_handles = VecDeque::new();
// Create another broadcast to control the replication tasks and their need to reload.
// Spawn a KRC communication task?
// In future we need to update this from the KRC if configured, and we default this
// to "empty". But if this map exists in the config, we have to always use that.
let replication_node_map = repl_config.manual.clone();
// This needs to have an event loop that can respond to changes.
// For now we just design it to reload ssl if the map changes internally.
'event: loop {
info!("Starting replication reload ...");
// Tell existing tasks to shutdown.
// Note: We ignore the result here since an err can occur *if* there are
// no tasks currently listening on the channel.
info!("Stopping {} Replication Tasks ...", task_handles.len());
debug_assert!(task_handles.len() >= task_tx.receiver_count());
let _ = task_tx.send(());
for task_handle in task_handles.drain(..) {
// Let each task join.
let res: Result<(), _> = task_handle.await;
if res.is_err() {
warn!("Failed to join replication task, continuing ...");
}
}
// Now we can start to re-load configurations and setup our client tasks
// as well.
// Get the private key / cert.
let res = {
// Does this actually need to be a read in case we need to write
// to sqlite?
let ct = duration_from_epoch_now();
let mut idms_prox_write = idms.proxy_write(ct).await;
idms_prox_write
.qs_write
.supplier_get_key_cert()
.and_then(|res| idms_prox_write.commit().map(|()| res))
};
let (server_key, server_cert) = match res {
Ok(r) => r,
Err(err) => {
error!(?err, "CRITICAL: Unable to access supplier certificate/key.");
sleep(retry_timeout).await;
continue;
}
};
info!(
replication_cert_not_before = ?server_cert.not_before(),
replication_cert_not_after = ?server_cert.not_after(),
);
let mut client_certs = Vec::new();
// For each node in the map, either spawn a task to pull from that node,
// or setup the node as allowed to pull from us.
for (origin, node) in replication_node_map.iter() {
// Setup client certs
match node {
RepNodeConfig::MutualPull {
partner_cert: consumer_cert,
automatic_refresh: _,
}
| RepNodeConfig::AllowPull { consumer_cert } => {
client_certs.push(consumer_cert.clone())
}
RepNodeConfig::Pull {
supplier_cert: _,
automatic_refresh: _,
} => {}
};
match node {
RepNodeConfig::MutualPull {
partner_cert: supplier_cert,
automatic_refresh,
}
| RepNodeConfig::Pull {
supplier_cert,
automatic_refresh,
} => {
let task_rx = task_tx.subscribe();
let handle = tokio::spawn(repl_task(
origin.clone(),
server_key.clone(),
server_cert.clone(),
supplier_cert.clone(),
max_frame_bytes,
task_poll_interval,
task_rx,
*automatic_refresh,
idms.clone(),
));
task_handles.push_back(handle);
debug_assert!(task_handles.len() == task_tx.receiver_count());
}
RepNodeConfig::AllowPull { consumer_cert: _ } => {}
};
}
// ⚠️ This section is critical to the security of replication
// Since replication relies on mTLS we MUST ensure these options
// are absolutely correct!
//
// Setup the TLS builder.
let mut tls_builder = match SslAcceptor::mozilla_modern_v5(SslMethod::tls()) {
Ok(tls_builder) => tls_builder,
Err(err) => {
error!(?err, "CRITICAL, unable to create SslAcceptorBuilder.");
sleep(retry_timeout).await;
continue;
}
};
// tls_builder.set_keylog_callback(keylog_cb);
if let Err(err) = tls_builder
.set_certificate(&server_cert)
.and_then(|_| tls_builder.set_private_key(&server_key))
.and_then(|_| tls_builder.check_private_key())
{
error!(?err, "CRITICAL, unable to set server_cert and server key.");
sleep(retry_timeout).await;
continue;
};
// ⚠️ CRITICAL - ensure that the cert store only has client certs from
// the repl map added.
let cert_store = tls_builder.cert_store_mut();
for client_cert in client_certs.into_iter() {
if let Err(err) = cert_store.add_cert(client_cert.clone()) {
error!(?err, "CRITICAL, unable to add client certificates.");
sleep(retry_timeout).await;
continue;
}
}
// ⚠️ CRITICAL - Both verifications here are needed. PEER requests
// the client cert to be sent. FAIL_IF_NO_PEER_CERT triggers an
// error if the cert is NOT present. FAIL_IF_NO_PEER_CERT on it's own
// DOES NOTHING.
let mut verify = SslVerifyMode::PEER;
verify.insert(SslVerifyMode::FAIL_IF_NO_PEER_CERT);
tls_builder.set_verify(verify);
let tls_acceptor = tls_builder.build();
loop {
// This is great to diagnose when spans are entered or present and they capture
// things incorrectly.
// eprintln!("🔥 C ---> {:?}", tracing::Span::current());
let eventid = Uuid::new_v4();
tokio::select! {
Ok(action) = rx.recv() => {
match action {
CoreAction::Shutdown => break 'event,
}
}
Some(ctrl_msg) = ctrl_rx.recv() => {
match ctrl_msg {
ReplCtrl::GetCertificate {
respond
} => {
let _span = debug_span!("supplier_accept_loop", uuid = ?eventid).entered();
if let Err(_) = respond.send(server_cert.clone()) {
warn!("Server certificate was requested, but requsetor disconnected");
} else {
trace!("Sent server certificate via control channel");
}
}
ReplCtrl::RenewCertificate {
respond
} => {
let span = debug_span!("supplier_accept_loop", uuid = ?eventid);
async {
debug!("renewing replication certificate ...");
// Renew the cert.
let res = {
let ct = duration_from_epoch_now();
let mut idms_prox_write = idms.proxy_write(ct).await;
idms_prox_write
.qs_write
.supplier_renew_key_cert()
.and_then(|res| idms_prox_write.commit().map(|()| res))
};
let success = res.is_ok();
if let Err(err) = res {
error!(?err, "failed to renew server certificate");
}
if let Err(_) = respond.send(success) {
warn!("Server certificate renewal was requested, but requsetor disconnected");
} else {
trace!("Sent server certificate renewal status via control channel");
}
}
.instrument(span)
.await;
// Start a reload.
continue 'event;
}
}
}
// Handle accepts.
// Handle *reloads*
/*
_ = reload.recv() => {
info!("initiate tls reload");
continue
}
*/
accept_result = listener.accept() => {
match accept_result {
Ok((tcpstream, client_socket_addr)) => {
let clone_idms = idms.clone();
let clone_tls_acceptor = tls_acceptor.clone();
// We don't care about the join handle here - once a client connects
// it sticks to whatever ssl settings it had at launch.
let _ = tokio::spawn(
handle_repl_conn(max_frame_bytes, tcpstream, client_socket_addr, clone_tls_acceptor, clone_idms)
);
}
Err(e) => {
error!("replication acceptor error, continuing -> {:?}", e);
}
}
}
} // end select
// Continue to poll/loop
}
}
// Shutdown child tasks.
info!("Stopping {} Replication Tasks ...", task_handles.len());
debug_assert!(task_handles.len() >= task_tx.receiver_count());
let _ = task_tx.send(());
for task_handle in task_handles.drain(..) {
// Let each task join.
let res: Result<(), _> = task_handle.await;
if res.is_err() {
warn!("Failed to join replication task, continuing ...");
}
}
info!("Stopped Replication Acceptor");
}

View file

@ -18,7 +18,7 @@ if [ ! -d "${KANI_TMP}" ]; then
mkdir -p "${KANI_TMP}"
fi
CONFIG_FILE="../../examples/insecure_server.toml"
CONFIG_FILE=${CONFIG_FILE:="../../examples/insecure_server.toml"}
if [ ! -f "${CONFIG_FILE}" ]; then
SCRIPT_DIR="$(dirname -a "$0")"

View file

@ -77,7 +77,9 @@ impl KanidmdOpt {
| KanidmdOpt::DbScan {
commands: DbScanOpt::RestoreQuarantined { commonopts, .. },
}
| KanidmdOpt::RecoverAccount { commonopts, .. } => commonopts,
| KanidmdOpt::ShowReplicationCertificate { commonopts }
| KanidmdOpt::RenewReplicationCertificate { commonopts } => commonopts,
KanidmdOpt::RecoverAccount { commonopts, .. } => commonopts,
KanidmdOpt::DbScan {
commands: DbScanOpt::ListIndex(dopt),
} => &dopt.commonopts,
@ -145,6 +147,14 @@ async fn submit_admin_req(path: &str, req: AdminTaskRequest, output_mode: Consol
info!(new_password = ?password)
}
},
Some(Ok(AdminTaskResponse::ShowReplicationCertificate { cert })) => match output_mode {
ConsoleOutputMode::JSON => {
eprintln!("{{\"certificate\":\"{}\"}}", cert)
}
ConsoleOutputMode::Text => {
info!(certificate = ?cert)
}
},
_ => {
error!("Error making request to admin socket");
}
@ -258,6 +268,13 @@ async fn main() -> ExitCode {
}
};
// Stop early if replication was found
if sconfig.repl_config.is_some() &&
!sconfig.i_acknowledge_that_replication_is_in_development
{
error!("Unable to proceed. Replication should not be configured manually.");
return ExitCode::FAILURE
}
#[cfg(target_family = "unix")]
{
@ -341,6 +358,11 @@ async fn main() -> ExitCode {
config.update_output_mode(opt.commands.commonopt().output_mode.to_owned().into());
config.update_trust_x_forward_for(sconfig.trust_x_forward_for);
config.update_admin_bind_path(&sconfig.adminbindpath);
config.update_replication_config(
sconfig.repl_config.clone()
);
match &opt.commands {
// we aren't going to touch the DB so we can carry on
KanidmdOpt::HealthCheck(_) => (),
@ -531,6 +553,26 @@ async fn main() -> ExitCode {
info!("Running in db verification mode ...");
verify_server_core(&config).await;
}
KanidmdOpt::ShowReplicationCertificate {
commonopts
} => {
info!("Running show replication certificate ...");
let output_mode: ConsoleOutputMode = commonopts.output_mode.to_owned().into();
submit_admin_req(config.adminbindpath.as_str(),
AdminTaskRequest::ShowReplicationCertificate,
output_mode,
).await;
}
KanidmdOpt::RenewReplicationCertificate {
commonopts
} => {
info!("Running renew replication certificate ...");
let output_mode: ConsoleOutputMode = commonopts.output_mode.to_owned().into();
submit_admin_req(config.adminbindpath.as_str(),
AdminTaskRequest::RenewReplicationCertificate,
output_mode,
).await;
}
KanidmdOpt::RecoverAccount {
name, commonopts
} => {

View file

@ -154,6 +154,16 @@ enum KanidmdOpt {
#[clap(flatten)]
commonopts: CommonOpt,
},
/// Display this server's replication certificate
ShowReplicationCertificate {
#[clap(flatten)]
commonopts: CommonOpt,
},
/// Renew this server's replication certificate
RenewReplicationCertificate {
#[clap(flatten)]
commonopts: CommonOpt,
},
// #[clap(name = "reset_server_id")]
// ResetServerId(CommonOpt),
#[clap(name = "db-scan")]

View file

@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
use smartstring::alias::String as AttrString;
use uuid::Uuid;
use super::keystorage::{KeyHandle, KeyHandleId};
use crate::be::dbvalue::{DbValueEmailAddressV1, DbValuePhoneNumberV1, DbValueSetV2, DbValueV1};
use crate::prelude::entries::Attribute;
use crate::prelude::OperationError;
@ -57,6 +58,13 @@ pub enum DbBackup {
db_ts_max: Duration,
entries: Vec<DbEntry>,
},
V3 {
db_s_uuid: Uuid,
db_d_uuid: Uuid,
db_ts_max: Duration,
keyhandles: BTreeMap<KeyHandleId, KeyHandle>,
entries: Vec<DbEntry>,
},
}
fn from_vec_dbval1(attr_val: NonEmpty<DbValueV1>) -> Result<DbValueSetV2, OperationError> {

View file

@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::convert::TryInto;
use std::ops::DerefMut;
@ -19,6 +20,7 @@ use crate::be::idl_sqlite::{
use crate::be::idxkey::{
IdlCacheKey, IdlCacheKeyRef, IdlCacheKeyToRef, IdxKey, IdxKeyRef, IdxKeyToRef, IdxSlope,
};
use crate::be::keystorage::{KeyHandle, KeyHandleId};
use crate::be::{BackendConfig, IdList, IdRawEntry};
use crate::entry::{Entry, EntryCommitted, EntrySealed};
use crate::prelude::*;
@ -56,6 +58,7 @@ pub struct IdlArcSqlite {
op_ts_max: CowCell<Option<Duration>>,
allids: CowCell<IDLBitRange>,
maxid: CowCell<u64>,
keyhandles: CowCell<HashMap<KeyHandleId, KeyHandle>>,
}
pub struct IdlArcSqliteReadTransaction<'a> {
@ -67,13 +70,14 @@ pub struct IdlArcSqliteReadTransaction<'a> {
}
pub struct IdlArcSqliteWriteTransaction<'a> {
db: IdlSqliteWriteTransaction,
pub(super) db: IdlSqliteWriteTransaction,
entry_cache: ARCacheWriteTxn<'a, u64, Arc<EntrySealedCommitted>, ()>,
idl_cache: ARCacheWriteTxn<'a, IdlCacheKey, Box<IDLBitRange>, ()>,
name_cache: ARCacheWriteTxn<'a, NameCacheKey, NameCacheValue, ()>,
op_ts_max: CowCellWriteTxn<'a, Option<Duration>>,
allids: CowCellWriteTxn<'a, IDLBitRange>,
maxid: CowCellWriteTxn<'a, u64>,
pub(super) keyhandles: CowCellWriteTxn<'a, HashMap<KeyHandleId, KeyHandle>>,
}
macro_rules! get_identry {
@ -353,6 +357,8 @@ pub trait IdlArcSqliteTransaction {
fn get_db_ts_max(&self) -> Result<Option<Duration>, OperationError>;
fn get_key_handles(&mut self) -> Result<BTreeMap<KeyHandleId, KeyHandle>, OperationError>;
fn verify(&self) -> Vec<Result<(), ConsistencyError>>;
fn is_dirty(&self) -> bool;
@ -417,6 +423,10 @@ impl<'a> IdlArcSqliteTransaction for IdlArcSqliteReadTransaction<'a> {
self.db.get_db_ts_max()
}
fn get_key_handles(&mut self) -> Result<BTreeMap<KeyHandleId, KeyHandle>, OperationError> {
self.db.get_key_handles()
}
fn verify(&self) -> Vec<Result<(), ConsistencyError>> {
verify!(self)
}
@ -511,6 +521,10 @@ impl<'a> IdlArcSqliteTransaction for IdlArcSqliteWriteTransaction<'a> {
}
}
fn get_key_handles(&mut self) -> Result<BTreeMap<KeyHandleId, KeyHandle>, OperationError> {
self.db.get_key_handles()
}
fn verify(&self) -> Vec<Result<(), ConsistencyError>> {
verify!(self)
}
@ -593,6 +607,7 @@ impl<'a> IdlArcSqliteWriteTransaction<'a> {
op_ts_max,
allids,
maxid,
keyhandles,
} = self;
// Write any dirty items to the disk.
@ -656,15 +671,20 @@ impl<'a> IdlArcSqliteWriteTransaction<'a> {
e
})?;
// Undo the caches in the reverse order.
db.commit().map(|()| {
// Ensure the db commit succeeds first.
db.commit()?;
// Can no longer fail from this point.
op_ts_max.commit();
name_cache.commit();
idl_cache.commit();
entry_cache.commit();
allids.commit();
maxid.commit();
})
keyhandles.commit();
// Unlock the entry cache last to remove contention on everything else.
entry_cache.commit();
Ok(())
}
pub fn get_id2entry_max_id(&self) -> Result<u64, OperationError> {
@ -1238,6 +1258,8 @@ impl IdlArcSqlite {
let maxid = CowCell::new(0);
let keyhandles = CowCell::new(HashMap::default());
let op_ts_max = CowCell::new(None);
Ok(IdlArcSqlite {
@ -1248,6 +1270,7 @@ impl IdlArcSqlite {
op_ts_max,
allids,
maxid,
keyhandles,
})
}
@ -1283,6 +1306,7 @@ impl IdlArcSqlite {
let allids_write = self.allids.write();
let maxid_write = self.maxid.write();
let db_write = self.db.write();
let keyhandles_write = self.keyhandles.write();
IdlArcSqliteWriteTransaction {
db: db_write,
entry_cache: entry_cache_write,
@ -1291,6 +1315,7 @@ impl IdlArcSqlite {
op_ts_max: op_ts_max_write,
allids: allids_write,
maxid: maxid_write,
keyhandles: keyhandles_write,
}
}

View file

@ -1,9 +1,12 @@
use std::collections::BTreeMap;
use std::collections::VecDeque;
use std::convert::{TryFrom, TryInto};
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use super::keystorage::{KeyHandle, KeyHandleId};
// use crate::valueset;
use hashbrown::HashMap;
use idlset::v2::IDLBitRange;
@ -24,13 +27,13 @@ const DBV_ID2ENTRY: &str = "id2entry";
const DBV_INDEXV: &str = "indexv";
#[allow(clippy::needless_pass_by_value)] // needs to accept value from `map_err`
fn sqlite_error(e: rusqlite::Error) -> OperationError {
pub(super) fn sqlite_error(e: rusqlite::Error) -> OperationError {
admin_error!(?e, "SQLite Error");
OperationError::SqliteError
}
#[allow(clippy::needless_pass_by_value)] // needs to accept value from `map_err`
fn serde_json_error(e: serde_json::Error) -> OperationError {
pub(super) fn serde_json_error(e: serde_json::Error) -> OperationError {
admin_error!(?e, "Serde JSON Error");
OperationError::SerdeJsonError
}
@ -482,6 +485,29 @@ pub trait IdlSqliteTransaction {
})
}
fn get_key_handles(&mut self) -> Result<BTreeMap<KeyHandleId, KeyHandle>, OperationError> {
let mut stmt = self
.get_conn()?
.prepare(&format!(
"SELECT id, data FROM {}.keyhandles",
self.get_db_name()
))
.map_err(sqlite_error)?;
let kh_iter = stmt
.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))
.map_err(sqlite_error)?;
kh_iter
.map(|v| {
let (id, data): (Vec<u8>, Vec<u8>) = v.map_err(sqlite_error)?;
let id = serde_json::from_slice(id.as_slice()).map_err(serde_json_error)?;
let data = serde_json::from_slice(data.as_slice()).map_err(serde_json_error)?;
Ok((id, data))
})
.collect()
}
#[instrument(level = "debug", name = "idl_sqlite::get_allids", skip_all)]
fn get_allids(&self) -> Result<IDLBitRange, OperationError> {
let mut stmt = self
@ -1079,6 +1105,19 @@ impl IdlSqliteWriteTransaction {
}
}
pub(crate) fn create_keyhandles(&self) -> Result<(), OperationError> {
self.get_conn()?
.execute(
&format!(
"CREATE TABLE IF NOT EXISTS {}.keyhandles (id TEXT PRIMARY KEY, data TEXT)",
self.get_db_name()
),
[],
)
.map(|_| ())
.map_err(sqlite_error)
}
pub fn create_idx(&self, attr: Attribute, itype: IndexType) -> Result<(), OperationError> {
// Is there a better way than formatting this? I can't seem
// to template into the str.
@ -1565,7 +1604,7 @@ impl IdlSqliteWriteTransaction {
dbv_id2entry = 6;
info!(entry = %dbv_id2entry, "dbv_id2entry migrated (externalid2uuid)");
}
// * if v6 -> complete.
// * if v6 -> create id2entry_quarantine.
if dbv_id2entry == 6 {
self.get_conn()?
.execute(
@ -1584,7 +1623,13 @@ impl IdlSqliteWriteTransaction {
dbv_id2entry = 7;
info!(entry = %dbv_id2entry, "dbv_id2entry migrated (quarantine)");
}
// * if v7 -> complete.
// * if v7 -> create keyhandles storage.
if dbv_id2entry == 7 {
self.create_keyhandles()?;
dbv_id2entry = 8;
info!(entry = %dbv_id2entry, "dbv_id2entry migrated (keyhandles)");
}
// * if v8 -> complete
self.set_db_version_key(DBV_ID2ENTRY, dbv_id2entry)?;

View file

@ -0,0 +1,191 @@
use crate::rusqlite::OptionalExtension;
use kanidm_lib_crypto::prelude::{PKey, Private, X509};
use kanidm_lib_crypto::serialise::{pkeyb64, x509b64};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::hash::Hash;
use super::idl_arc_sqlite::IdlArcSqliteWriteTransaction;
use super::idl_sqlite::IdlSqliteTransaction;
use super::idl_sqlite::IdlSqliteWriteTransaction;
use super::idl_sqlite::{serde_json_error, sqlite_error};
use super::BackendWriteTransaction;
use crate::prelude::OperationError;
/// These are key handles for storing keys related to various cryptographic components
/// within Kanidm. Generally these are for keys that are "static", as in have known
/// long term uses. This could be the servers private replication key, a TPM Storage
/// Root Key, or the Duplicable Storage Key. In future these may end up being in
/// a HSM or similar, but we'll always need a way to persist serialised forms of these.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub enum KeyHandleId {
ReplicationKey,
}
/// This is a key handle that contains the actual data that is persisted in the DB.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum KeyHandle {
X509Key {
#[serde(with = "pkeyb64")]
private: PKey<Private>,
#[serde(with = "x509b64")]
x509: X509,
},
}
impl<'a> BackendWriteTransaction<'a> {
/// Retrieve a key stored in the database by it's key handle. This
/// handle may require further processing for the key to be usable
/// in higher level contexts as this is simply the storage layer
/// for these keys.
pub(crate) fn get_key_handle(
&mut self,
handle: KeyHandleId,
) -> Result<Option<KeyHandle>, OperationError> {
self.idlayer.get_key_handle(handle)
}
/// Update the content of a keyhandle with this new data.
pub(crate) fn set_key_handle(
&mut self,
handle: KeyHandleId,
data: KeyHandle,
) -> Result<(), OperationError> {
self.idlayer.set_key_handle(handle, data)
}
}
impl<'a> IdlArcSqliteWriteTransaction<'a> {
pub(crate) fn get_key_handle(
&mut self,
handle: KeyHandleId,
) -> Result<Option<KeyHandle>, OperationError> {
if let Some(kh) = self.keyhandles.get(&handle) {
Ok(Some(kh.clone()))
} else {
let r = self.db.get_key_handle(handle);
if let Ok(Some(kh)) = &r {
self.keyhandles.insert(handle, kh.clone());
}
r
}
}
/// Update the content of a keyhandle with this new data.
#[instrument(level = "debug", skip(self, data))]
pub(crate) fn set_key_handle(
&mut self,
handle: KeyHandleId,
data: KeyHandle,
) -> Result<(), OperationError> {
self.db.set_key_handle(handle, &data)?;
self.keyhandles.insert(handle, data);
Ok(())
}
pub(super) fn set_key_handles(
&mut self,
keyhandles: BTreeMap<KeyHandleId, KeyHandle>,
) -> Result<(), OperationError> {
self.db.set_key_handles(&keyhandles)?;
self.keyhandles.clear();
self.keyhandles.extend(keyhandles.into_iter());
Ok(())
}
}
impl IdlSqliteWriteTransaction {
pub(crate) fn get_key_handle(
&mut self,
handle: KeyHandleId,
) -> Result<Option<KeyHandle>, OperationError> {
let s_handle = serde_json::to_vec(&handle).map_err(serde_json_error)?;
let mut stmt = self
.get_conn()?
.prepare(&format!(
"SELECT data FROM {}.keyhandles WHERE id = :id",
self.get_db_name()
))
.map_err(sqlite_error)?;
let data_raw: Option<Vec<u8>> = stmt
.query_row(&[(":id", &s_handle)], |row| row.get(0))
// We don't mind if it doesn't exist
.optional()
.map_err(sqlite_error)?;
let data: Option<KeyHandle> = match data_raw {
Some(d) => serde_json::from_slice(d.as_slice())
.map(Some)
.map_err(serde_json_error)?,
None => None,
};
Ok(data)
}
pub(super) fn get_key_handles(
&mut self,
) -> Result<BTreeMap<KeyHandleId, KeyHandle>, OperationError> {
let mut stmt = self
.get_conn()?
.prepare(&format!(
"SELECT id, data FROM {}.keyhandles",
self.get_db_name()
))
.map_err(sqlite_error)?;
let kh_iter = stmt
.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))
.map_err(sqlite_error)?;
kh_iter
.map(|v| {
let (id, data): (Vec<u8>, Vec<u8>) = v.map_err(sqlite_error)?;
let id = serde_json::from_slice(id.as_slice()).map_err(serde_json_error)?;
let data = serde_json::from_slice(data.as_slice()).map_err(serde_json_error)?;
Ok((id, data))
})
.collect()
}
/// Update the content of a keyhandle with this new data.
#[instrument(level = "debug", skip(self, data))]
pub(crate) fn set_key_handle(
&mut self,
handle: KeyHandleId,
data: &KeyHandle,
) -> Result<(), OperationError> {
let s_handle = serde_json::to_vec(&handle).map_err(serde_json_error)?;
let s_data = serde_json::to_vec(&data).map_err(serde_json_error)?;
self.get_conn()?
.prepare(&format!(
"INSERT OR REPLACE INTO {}.keyhandles (id, data) VALUES(:id, :data)",
self.get_db_name()
))
.and_then(|mut stmt| stmt.execute(&[(":id", &s_handle), (":data", &s_data)]))
.map(|_| ())
.map_err(sqlite_error)
}
pub(super) fn set_key_handles(
&mut self,
keyhandles: &BTreeMap<KeyHandleId, KeyHandle>,
) -> Result<(), OperationError> {
self.get_conn()?
.execute(
&format!("DELETE FROM {}.keyhandles", self.get_db_name()),
[],
)
.map(|_| ())
.map_err(sqlite_error)?;
for (handle, data) in keyhandles {
self.set_key_handle(*handle, data)?;
}
Ok(())
}
}

View file

@ -36,6 +36,7 @@ pub mod dbvalue;
mod idl_arc_sqlite;
mod idl_sqlite;
pub(crate) mod idxkey;
pub(crate) mod keystorage;
pub(crate) use self::idxkey::{IdxKey, IdxKeyRef, IdxKeyToRef, IdxSlope};
use crate::be::idl_arc_sqlite::{
@ -872,10 +873,13 @@ pub trait BackendTransaction {
.get_db_ts_max()
.and_then(|u| u.ok_or(OperationError::InvalidDbState))?;
let bak = DbBackup::V2 {
let keyhandles = idlayer.get_key_handles()?;
let bak = DbBackup::V3 {
db_s_uuid,
db_d_uuid,
db_ts_max,
keyhandles,
entries,
};
@ -1724,6 +1728,20 @@ impl<'a> BackendWriteTransaction<'a> {
idlayer.set_db_ts_max(db_ts_max)?;
entries
}
DbBackup::V3 {
db_s_uuid,
db_d_uuid,
db_ts_max,
keyhandles,
entries,
} => {
// Do stuff.
idlayer.write_db_s_uuid(db_s_uuid)?;
idlayer.write_db_d_uuid(db_d_uuid)?;
idlayer.set_db_ts_max(db_ts_max)?;
idlayer.set_key_handles(keyhandles)?;
entries
}
};
info!("Restoring {} entries ...", dbentries.len());
@ -2505,6 +2523,15 @@ mod tests {
} => {
let _ = entries.pop();
}
DbBackup::V3 {
db_s_uuid: _,
db_d_uuid: _,
db_ts_max: _,
keyhandles: _,
entries,
} => {
let _ = entries.pop();
}
};
let serialized_entries_str = serde_json::to_string_pretty(&dbbak).unwrap();

View file

@ -53,7 +53,7 @@ pub mod valueset;
#[macro_use]
mod plugins;
pub mod idm;
mod repl;
pub mod repl;
pub mod schema;
pub mod server;
pub mod status;

View file

@ -4,11 +4,6 @@ use crate::prelude::*;
use std::collections::{BTreeMap, BTreeSet};
use std::sync::Arc;
pub enum ConsumerState {
Ok,
RefreshRequired,
}
impl<'a> QueryServerWriteTransaction<'a> {
// Apply the state changes if they are valid.
@ -255,6 +250,11 @@ impl<'a> QueryServerWriteTransaction<'a> {
ctx: &ReplIncrementalContext,
) -> Result<ConsumerState, OperationError> {
match ctx {
ReplIncrementalContext::DomainMismatch => {
error!("Unable to proceed with consumer incremental - the supplier has indicated that our domain_uuid's are not equivalent. This can occur when adding a new consumer to an existing topology.");
error!("This server's content must be refreshed to proceed. If you have configured automatic refresh, this will occur shortly.");
Ok(ConsumerState::RefreshRequired)
}
ReplIncrementalContext::NoChangesAvailable => {
info!("no changes are available");
Ok(ConsumerState::Ok)

View file

@ -1,10 +1,10 @@
pub mod cid;
pub mod entry;
pub mod ruv;
pub(crate) mod cid;
pub(crate) mod entry;
pub(crate) mod ruv;
pub mod consumer;
pub(crate) mod consumer;
pub mod proto;
pub mod supplier;
pub(crate) mod supplier;
#[cfg(test)]
mod tests;

View file

@ -16,6 +16,11 @@ use webauthn_rs::prelude::{
// Re-export this for our own usage.
pub use kanidm_lib_crypto::ReplPasswordV1;
pub enum ConsumerState {
Ok,
RefreshRequired,
}
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct ReplCidV1 {
#[serde(rename = "t")]
@ -63,22 +68,15 @@ pub struct ReplCidRange {
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
pub enum ReplRuvRange {
V1 {
domain_uuid: Uuid,
ranges: BTreeMap<Uuid, ReplCidRange>,
},
}
impl Default for ReplRuvRange {
fn default() -> Self {
ReplRuvRange::V1 {
ranges: BTreeMap::default(),
}
}
}
impl ReplRuvRange {
pub fn is_empty(&self) -> bool {
match self {
ReplRuvRange::V1 { ranges } => ranges.is_empty(),
ReplRuvRange::V1 { ranges, .. } => ranges.is_empty(),
}
}
}
@ -689,6 +687,7 @@ pub enum ReplRefreshContext {
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ReplIncrementalContext {
DomainMismatch,
NoChangesAvailable,
RefreshRequired,
UnwillingToSupply,

View file

@ -5,6 +5,74 @@ use super::ruv::{RangeDiffStatus, ReplicationUpdateVector, ReplicationUpdateVect
use crate::be::BackendTransaction;
use crate::prelude::*;
use crate::be::keystorage::{KeyHandle, KeyHandleId};
use kanidm_lib_crypto::mtls::build_self_signed_server_and_client_identity;
use kanidm_lib_crypto::prelude::{PKey, Private, X509};
impl<'a> QueryServerWriteTransaction<'a> {
fn supplier_generate_key_cert(&mut self) -> Result<(PKey<Private>, X509), OperationError> {
// Invalid, must need to re-generate.
let domain_name = "localhost";
let expiration_days = 1;
let s_uuid = self.get_server_uuid();
let (private, x509) =
build_self_signed_server_and_client_identity(s_uuid, domain_name, expiration_days)
.map_err(|err| {
error!(?err, "Unable to generate self signed key/cert");
// What error?
OperationError::CryptographyError
})?;
let kh = KeyHandle::X509Key {
private: private.clone(),
x509: x509.clone(),
};
self.get_be_txn()
.set_key_handle(KeyHandleId::ReplicationKey, kh)
.map_err(|err| {
error!(?err, "Unable to persist replication key");
err
})
.map(|()| (private, x509))
}
#[instrument(level = "info", skip_all)]
pub fn supplier_renew_key_cert(&mut self) -> Result<(), OperationError> {
self.supplier_generate_key_cert().map(|_| ())
}
#[instrument(level = "info", skip_all)]
pub fn supplier_get_key_cert(&mut self) -> Result<(PKey<Private>, X509), OperationError> {
// Later we need to put this through a HSM or similar, but we will always need a way
// to persist a handle, so we still need the db write and load components.
// Does the handle exist?
let maybe_key_handle = self
.get_be_txn()
.get_key_handle(KeyHandleId::ReplicationKey)
.map_err(|err| {
error!(?err, "Unable to access replication key");
err
})?;
// Can you process the keyhande?
let key_cert = match maybe_key_handle {
Some(KeyHandle::X509Key { private, x509 }) => (private, x509),
/*
Some(Keyhandle::...) => {
// invalid key
// error? regenerate?
}
*/
None => self.supplier_generate_key_cert()?,
};
Ok(key_cert)
}
}
impl<'a> QueryServerReadTransaction<'a> {
// Given a consumers state, calculate the differential of changes they
// need to be sent to bring them to the equivalent state.
@ -19,10 +87,19 @@ impl<'a> QueryServerReadTransaction<'a> {
ctx_ruv: ReplRuvRange,
) -> Result<ReplIncrementalContext, OperationError> {
// Convert types if needed. This way we can compare ruv's correctly.
let ctx_ranges = match ctx_ruv {
ReplRuvRange::V1 { ranges } => ranges,
let (ctx_domain_uuid, ctx_ranges) = match ctx_ruv {
ReplRuvRange::V1 {
domain_uuid,
ranges,
} => (domain_uuid, ranges),
};
if ctx_domain_uuid != self.d_info.d_uuid {
error!("Replication - Consumer Domain UUID does not match our local domain uuid.");
debug!(consumer_domain_uuid = ?ctx_domain_uuid, supplier_domain_uuid = ?self.d_info.d_uuid);
return Ok(ReplIncrementalContext::DomainMismatch);
}
let our_ranges = self
.get_be_txn()
.get_ruv()

View file

@ -1,7 +1,7 @@
use crate::be::BackendTransaction;
use crate::prelude::*;
use crate::repl::consumer::ConsumerState;
use crate::repl::entry::State;
use crate::repl::proto::ConsumerState;
use crate::repl::proto::ReplIncrementalContext;
use crate::repl::ruv::ReplicationUpdateVectorTransaction;
use crate::repl::ruv::{RangeDiffStatus, ReplicationUpdateVector};
@ -2941,6 +2941,46 @@ async fn test_repl_increment_attrunique_conflict_complex(
drop(server_a_txn);
}
// Test the behaviour of a "new server join". This will have the supplier and
// consumer mismatch on the domain_uuid, leading to the consumer with a
// refresh required message. This should then be refreshed and succeed
#[qs_pair_test]
async fn test_repl_initial_consumer_join(server_a: &QueryServer, server_b: &QueryServer) {
let ct = duration_from_epoch_now();
let mut server_a_txn = server_a.write(ct).await;
let mut server_b_txn = server_b.read().await;
let a_ruv_range = server_a_txn
.consumer_get_state()
.expect("Unable to access RUV range");
let changes = server_b_txn
.supplier_provide_changes(a_ruv_range)
.expect("Unable to generate supplier changes");
assert!(matches!(changes, ReplIncrementalContext::DomainMismatch));
let result = server_a_txn
.consumer_apply_changes(&changes)
.expect("Unable to apply changes to consumer.");
assert!(matches!(result, ConsumerState::RefreshRequired));
drop(server_a_txn);
drop(server_b_txn);
// Then a refresh resolves.
let mut server_a_txn = server_a.write(ct).await;
let mut server_b_txn = server_b.read().await;
assert!(repl_initialise(&mut server_b_txn, &mut server_a_txn)
.and_then(|_| server_a_txn.commit())
.is_ok());
drop(server_b_txn);
}
// Test change of domain version over incremental.
//
// todo when I have domain version migrations working.

View file

@ -896,6 +896,8 @@ pub trait QueryServerTransaction<'a> {
//
// ...
let domain_uuid = self.get_domain_uuid();
// Which then the supplier will use to actually retrieve the set of entries.
// and the needed attributes we need.
let ruv_snapshot = self.get_be_txn().get_ruv();
@ -903,7 +905,10 @@ pub trait QueryServerTransaction<'a> {
// What's the current set of ranges?
ruv_snapshot
.current_ruv_range()
.map(|ranges| ReplRuvRange::V1 { ranges })
.map(|ranges| ReplRuvRange::V1 {
domain_uuid,
ranges,
})
}
}
@ -1247,6 +1252,11 @@ impl QueryServer {
}
impl<'a> QueryServerWriteTransaction<'a> {
pub(crate) fn get_server_uuid(&self) -> Uuid {
// Cid has our server id within
self.cid.s_uuid
}
pub(crate) fn get_curtime(&self) -> Duration {
self.curtime
}

View file

@ -50,3 +50,7 @@ webauthn-authenticator-rs = { workspace = true }
oauth2_ext = { workspace = true, default-features = false }
futures = { workspace = true }
time = { workspace = true }
openssl = { workspace = true }
tokio-openssl = { workspace = true }
kanidm_lib_crypto = { workspace = true }
uuid = { workspace = true }

View file

@ -54,7 +54,7 @@ pub async fn setup_async_test(mut config: Configuration) -> (KanidmClient, CoreH
counter += 1;
#[allow(clippy::panic)]
if counter >= 5 {
eprintln!("Unable to allocate port!");
tracing::error!("Unable to allocate port!");
panic!();
}
};

View file

@ -0,0 +1,277 @@
use kanidm_lib_crypto::mtls::build_self_signed_server_and_client_identity;
use kanidmd_testkit::{is_free_port, PORT_ALLOC};
use std::sync::atomic::Ordering;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::pin::Pin;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use tokio::time;
use tracing::{error, trace};
use openssl::ssl::{Ssl, SslAcceptor, SslMethod, SslRef};
use openssl::ssl::{SslConnector, SslVerifyMode};
use openssl::x509::X509;
use tokio_openssl::SslStream;
use uuid::Uuid;
fn keylog_cb(_ssl_ref: &SslRef, key: &str) {
trace!(?key);
}
async fn setup_mtls_test(
testcase: TestCase,
) -> (
SslStream<TcpStream>,
tokio::task::JoinHandle<Result<(), u64>>,
oneshot::Sender<()>,
) {
sketching::test_init();
let mut counter = 0;
let port = loop {
let possible_port = PORT_ALLOC.fetch_add(1, Ordering::SeqCst);
if is_free_port(possible_port) {
break possible_port;
}
counter += 1;
#[allow(clippy::panic)]
if counter >= 5 {
error!("Unable to allocate port!");
panic!();
}
};
trace!("{:?}", port);
// First we need the two certificates.
let client_uuid = Uuid::new_v4();
let (client_key, client_cert) =
build_self_signed_server_and_client_identity(client_uuid, "localhost", 1).unwrap();
let server_san = if testcase == TestCase::ServerCertSanInvalid {
"evilcorp.com"
} else {
"localhost"
};
let server_uuid = Uuid::new_v4();
let (server_key, server_cert): (_, X509) =
build_self_signed_server_and_client_identity(server_uuid, server_san, 1).unwrap();
let server_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), port);
let listener = TcpListener::bind(&server_addr)
.await
.expect("Failed to bind");
// Setup the TLS parameters.
let (tx, mut rx) = oneshot::channel();
let mut ssl_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls()).unwrap();
ssl_builder.set_keylog_callback(keylog_cb);
ssl_builder.set_certificate(&server_cert).unwrap();
ssl_builder.set_private_key(&server_key).unwrap();
ssl_builder.check_private_key().unwrap();
if testcase != TestCase::ServerWithoutClientCa {
let cert_store = ssl_builder.cert_store_mut();
cert_store.add_cert(client_cert.clone()).unwrap();
}
// Request a client cert.
let mut verify = SslVerifyMode::PEER;
verify.insert(SslVerifyMode::FAIL_IF_NO_PEER_CERT);
ssl_builder.set_verify(verify);
// Setup the client cert store.
let tls_parms = ssl_builder.build();
// Start the server in a task.
// The server is designed to stop/die as soon as a single connection has been made.
let handle = tokio::spawn(async move {
// This is our safety net.
let sleep = time::sleep(Duration::from_secs(15));
tokio::pin!(sleep);
trace!("Started listener");
tokio::select! {
Ok((tcpstream, client_socket_addr)) = listener.accept() => {
let mut tlsstream = match Ssl::new(tls_parms.context())
.and_then(|tls_obj| SslStream::new(tls_obj, tcpstream))
{
Ok(ta) => ta,
Err(err) => {
error!("LDAP TLS setup error, continuing -> {:?}", err);
let ossl_err = err.errors().last().unwrap();
return Err(
ossl_err.code()
);
}
};
if let Err(err) = SslStream::accept(Pin::new(&mut tlsstream)).await {
error!("LDAP TLS accept error, continuing -> {:?}", err);
let ossl_err = err.ssl_error().and_then(|e| e.errors().last()).unwrap();
return Err(
ossl_err.code()
);
};
trace!("Got connection. {:?}", client_socket_addr);
let tlsstream_ref = tlsstream.ssl();
match tlsstream_ref.peer_certificate() {
Some(peer_cert) => {
trace!("{:?}", peer_cert.subject_name());
}
None => {
return Err(2);
}
}
Ok(())
}
Ok(()) = &mut rx => {
trace!("stopping listener");
Err(1)
}
_ = &mut sleep => {
error!("timeout");
Err(1)
}
else => {
trace!("error condition in accept");
Err(1)
}
}
});
// Create the client and connect. We do this inline to be sensitive to errors.
let tcpclient = TcpStream::connect(server_addr).await.unwrap();
trace!("connection established");
let mut ssl_builder = SslConnector::builder(SslMethod::tls_client()).unwrap();
if testcase != TestCase::ClientWithoutClientCert {
ssl_builder.set_certificate(&client_cert).unwrap();
ssl_builder.set_private_key(&client_key).unwrap();
ssl_builder.check_private_key().unwrap();
}
// Add the server cert
if testcase != TestCase::ClientWithoutServerCa {
let cert_store = ssl_builder.cert_store_mut();
cert_store.add_cert(server_cert).unwrap();
}
let verify_param = ssl_builder.verify_param_mut();
verify_param.set_host("localhost").unwrap();
ssl_builder.set_verify(SslVerifyMode::PEER);
let tls_parms = ssl_builder.build();
let tlsstream = Ssl::new(tls_parms.context())
.and_then(|tls_obj| SslStream::new(tls_obj, tcpclient))
.unwrap();
(tlsstream, handle, tx)
}
#[derive(PartialEq, Eq, Debug)]
enum TestCase {
Valid,
ServerCertSanInvalid,
ServerWithoutClientCa,
ClientWithoutClientCert,
ClientWithoutServerCa,
}
#[tokio::test]
async fn test_mtls_basic_auth() {
let (mut tlsstream, handle, _tx) = setup_mtls_test(TestCase::Valid).await;
SslStream::connect(Pin::new(&mut tlsstream)).await.unwrap();
trace!("Waiting on listener ...");
let result = handle.await.expect("Failed to stop task.");
// If this isn't true, it means something failed in the server accept process.
assert!(result.is_ok());
}
#[tokio::test]
async fn test_mtls_server_san_invalid() {
let (mut tlsstream, handle, _tx) = setup_mtls_test(TestCase::ServerCertSanInvalid).await;
let err: openssl::ssl::Error = SslStream::connect(Pin::new(&mut tlsstream))
.await
.unwrap_err();
trace!(?err);
// Certification Verification Failure
let ossl_err = err.ssl_error().and_then(|e| e.errors().last()).unwrap();
assert_eq!(ossl_err.code(), 167772294);
trace!("Waiting on listener ...");
let result = handle.await.expect("Failed to stop task.");
// Must be FALSE server should not have accepted the connection.
trace!(?result);
// SSL Read bytes (client disconnected)
assert_eq!(result, Err(167773202));
}
#[tokio::test]
async fn test_mtls_server_without_client_ca() {
let (mut tlsstream, handle, _tx) = setup_mtls_test(TestCase::ServerWithoutClientCa).await;
// The client isn't the one that errors, the server does.
SslStream::connect(Pin::new(&mut tlsstream)).await.unwrap();
trace!("Waiting on listener ...");
let result = handle.await.expect("Failed to stop task.");
// Must be FALSE server should not have accepted the connection.
trace!(?result);
// Certification Verification Failure
assert_eq!(result, Err(167772294));
}
#[tokio::test]
async fn test_mtls_client_without_client_cert() {
let (mut tlsstream, handle, _tx) = setup_mtls_test(TestCase::ClientWithoutClientCert).await;
// The client isn't the one that errors, the server does.
SslStream::connect(Pin::new(&mut tlsstream)).await.unwrap();
trace!("Waiting on listener ...");
let result = handle.await.expect("Failed to stop task.");
// Must be FALSE server should not have accepted the connection.
trace!(?result);
// Peer Did Not Provide Certificate
assert_eq!(result, Err(167772359));
}
#[tokio::test]
async fn test_mtls_client_without_server_ca() {
let (mut tlsstream, handle, _tx) = setup_mtls_test(TestCase::ClientWithoutServerCa).await;
let err: openssl::ssl::Error = SslStream::connect(Pin::new(&mut tlsstream))
.await
.unwrap_err();
trace!(?err);
// Tls Post Process Certificate (Certificate Verify Failed)
let ossl_err = err.ssl_error().and_then(|e| e.errors().last()).unwrap();
assert_eq!(ossl_err.code(), 167772294);
trace!("Waiting on listener ...");
let result = handle.await.expect("Failed to stop task.");
// Must be FALSE server should not have accepted the connection.
trace!(?result);
// SSL Read bytes (client disconnected)
assert_eq!(result, Err(167773208));
}

View file

@ -78,6 +78,7 @@ kanidm_utils_users = { workspace = true }
[dev-dependencies]
kanidmd_core = { workspace = true }
kanidmd_testkit = { workspace = true }
[build-dependencies]
clap = { workspace = true, features = ["derive"] }

View file

@ -1,8 +1,7 @@
#![deny(warnings)]
use std::future::Future;
use std::net::TcpStream;
use std::pin::Pin;
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::atomic::Ordering;
use std::time::Duration;
use kanidm_client::{KanidmClient, KanidmClientBuilder};
@ -18,10 +17,10 @@ use kanidm_unix_common::resolver::Resolver;
use kanidm_unix_common::unix_config::TpmPolicy;
use kanidmd_core::config::{Configuration, IntegrationTestConfig, ServerRole};
use kanidmd_core::create_server_core;
use kanidmd_testkit::{is_free_port, PORT_ALLOC};
use tokio::task;
use tracing::log::{debug, trace};
static PORT_ALLOC: AtomicU16 = AtomicU16::new(28080);
const ADMIN_TEST_USER: &str = "admin";
const ADMIN_TEST_PASSWORD: &str = "integration test admin password";
const TESTACCOUNT1_PASSWORD_A: &str = "password a for account1 test";
@ -29,13 +28,6 @@ const TESTACCOUNT1_PASSWORD_B: &str = "password b for account1 test";
const TESTACCOUNT1_PASSWORD_INC: &str = "never going to work";
const ACCOUNT_EXPIRE: &str = "1970-01-01T00:00:00+00:00";
fn is_free_port(port: u16) -> bool {
match TcpStream::connect(("0.0.0.0", port)) {
Ok(_) => false,
Err(_) => true,
}
}
type Fixture = Box<dyn FnOnce(KanidmClient) -> Pin<Box<dyn Future<Output = ()>>>>;
fn fixture<T>(f: fn(KanidmClient) -> T) -> Fixture