diff --git a/Cargo.toml b/Cargo.toml index 294bebc04..289d46b32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -144,6 +144,7 @@ image = { version = "0.24.9", default-features = false, features = [ "jpeg", "webp", ] } +itertools = "0.12.1" enum-iterator = "1.5.0" js-sys = "^0.3.69" kanidmd_web_ui_shared = { path = "./server/web_ui/shared" } diff --git a/server/lib/Cargo.toml b/server/lib/Cargo.toml index b4fc6a6ac..1f01a3e39 100644 --- a/server/lib/Cargo.toml +++ b/server/lib/Cargo.toml @@ -36,6 +36,7 @@ fernet = { workspace = true, features = ["fernet_danger_timestamps"] } # futures-util = { workspace = true } hashbrown = { workspace = true } idlset = { workspace = true } +itertools = { workspace = true } kanidm_proto = { workspace = true } kanidm_lib_crypto = { workspace = true } lazy_static = { workspace = true } diff --git a/server/lib/src/idm/oauth2.rs b/server/lib/src/idm/oauth2.rs index 3df05002a..4fa2798e1 100644 --- a/server/lib/src/idm/oauth2.rs +++ b/server/lib/src/idm/oauth2.rs @@ -254,9 +254,7 @@ impl ClaimValue { } fn to_json_value(&self) -> serde_json::Value { - let join_char = match self.join { - OauthClaimMapJoin::CommaSeparatedValue => ',', - OauthClaimMapJoin::SpaceSeparatedValue => ' ', + let join_str = match self.join { OauthClaimMapJoin::JsonArray => { let arr: Vec<_> = self .values @@ -268,9 +266,10 @@ impl ClaimValue { // This shortcuts out. return serde_json::Value::Array(arr); } + joiner => joiner.to_str(), }; - let joined = str_concat!(&self.values, join_char); + let joined = str_concat!(&self.values, join_str); serde_json::Value::String(joined) } diff --git a/server/lib/src/macros.rs b/server/lib/src/macros.rs index 10e37ef10..2276de470 100644 --- a/server/lib/src/macros.rs +++ b/server/lib/src/macros.rs @@ -655,21 +655,7 @@ macro_rules! limmediate_warning { macro_rules! str_concat { ($str_iter:expr, $join_char:expr) => {{ - // Sub 1 here because we need N minus 1 join chars - let max_join_chars: usize = $str_iter.len() - 1; - let data_len: usize = $str_iter - .iter() - .map(|s| s.len()) - .fold(max_join_chars, |acc, x| acc + x); - - let mut joined = String::with_capacity(data_len); - for (i, value) in $str_iter.iter().enumerate() { - joined.push_str(value); - if i < max_join_chars { - joined.push($join_char); - } - } - - joined + use itertools::Itertools; + $str_iter.iter().join($join_char) }}; } diff --git a/server/lib/src/server/access/mod.rs b/server/lib/src/server/access/mod.rs index 96caf2838..fc901d8be 100644 --- a/server/lib/src/server/access/mod.rs +++ b/server/lib/src/server/access/mod.rs @@ -493,11 +493,7 @@ pub trait AccessControlsTransaction<'a> { false } else if !requested_rem.is_subset(&rem) { security_error!("requested_rem is not a subset of allowed"); - security_error!( - "requested_rem: {:?} !⊆ allowed: {:?}", - requested_rem, - rem - ); + security_error!("requested_rem: {:?} !⊆ allowed: {:?}", requested_rem, rem); false } else if !requested_classes.is_subset(&cls) { security_error!("requested_classes is not a subset of allowed"); @@ -626,11 +622,7 @@ pub trait AccessControlsTransaction<'a> { false } else if !requested_rem.is_subset(&rem) { security_error!("requested_rem is not a subset of allowed"); - security_error!( - "requested_rem: {:?} !⊆ allowed: {:?}", - requested_rem, - rem - ); + security_error!("requested_rem: {:?} !⊆ allowed: {:?}", requested_rem, rem); false } else if !requested_classes.is_subset(&cls) { security_error!("requested_classes is not a subset of allowed"); diff --git a/server/lib/src/server/mod.rs b/server/lib/src/server/mod.rs index b4be9c6a8..d3526a339 100644 --- a/server/lib/src/server/mod.rs +++ b/server/lib/src/server/mod.rs @@ -767,7 +767,7 @@ pub trait QueryServerTransaction<'a> { let mut v = Vec::new(); for (claim_name, mapping) in r_map.iter() { for (group_ref, claims) in mapping.values() { - let join_char = mapping.join().to_char(); + let join_char = mapping.join().to_str(); let nv = self.uuid_to_spn(*group_ref)?; let resolved_id = match nv { @@ -775,7 +775,7 @@ pub trait QueryServerTransaction<'a> { None => uuid_to_proto_string(*group_ref), }; - let joined = str_concat!(claims, ','); + let joined = str_concat!(claims, ","); v.push(format!( "{}:{}:{}:{:?}", diff --git a/server/lib/src/value.rs b/server/lib/src/value.rs index 00ae8cfb0..d939071fa 100644 --- a/server/lib/src/value.rs +++ b/server/lib/src/value.rs @@ -999,12 +999,12 @@ pub enum OauthClaimMapJoin { } impl OauthClaimMapJoin { - pub(crate) fn to_char(self) -> char { + pub(crate) fn to_str(self) -> &'static str { match self { - OauthClaimMapJoin::CommaSeparatedValue => ',', - OauthClaimMapJoin::SpaceSeparatedValue => ' ', + OauthClaimMapJoin::CommaSeparatedValue => ",", + OauthClaimMapJoin::SpaceSeparatedValue => " ", // Should this be something else? - OauthClaimMapJoin::JsonArray => ';', + OauthClaimMapJoin::JsonArray => ";", } } } diff --git a/server/lib/src/valueset/oauth.rs b/server/lib/src/valueset/oauth.rs index af8bbd49f..892cfc7e1 100644 --- a/server/lib/src/valueset/oauth.rs +++ b/server/lib/src/valueset/oauth.rs @@ -463,6 +463,14 @@ impl ValueSetOauthClaimMap { Some(Box::new(ValueSetOauthClaimMap { map })) } */ + + fn trim(&mut self) { + self.map + .values_mut() + .for_each(|mapping_mut| mapping_mut.values.retain(|_k, v| !v.is_empty())); + + self.map.retain(|_k, v| !v.values.is_empty()); + } } impl ValueSetT for ValueSetOauthClaimMap { @@ -561,11 +569,7 @@ impl ValueSetT for ValueSetOauthClaimMap { }; // Trim anything that is now empty. - self.map - .values_mut() - .for_each(|mapping_mut| mapping_mut.values.retain(|_k, v| !v.is_empty())); - - self.map.retain(|_k, v| !v.values.is_empty()); + self.trim(); res } @@ -622,6 +626,16 @@ impl ValueSetT for ValueSetOauthClaimMap { fn validate(&self, _schema_attr: &SchemaAttribute) -> bool { self.map.keys().all(|s| OAUTHSCOPE_RE.is_match(s)) + && self + .map + .values() + .flat_map(|mapping| { + mapping + .values + .values() + .map(|claim_values| claim_values.is_empty()) + }) + .all(|is_empty| !is_empty) && self .map .values() @@ -637,9 +651,9 @@ impl ValueSetT for ValueSetOauthClaimMap { fn to_proto_string_clone_iter(&self) -> Box + '_> { Box::new(self.map.iter().flat_map(|(name, mapping)| { mapping.values.iter().map(move |(group, claims)| { - let join_char = mapping.join.to_char(); + let join_str = mapping.join.to_str(); - let joined = str_concat!(claims, join_char); + let joined = str_concat!(claims, join_str); format!( "{}: {} \"{:?}\"", @@ -725,3 +739,25 @@ impl ValueSetT for ValueSetOauthClaimMap { )) } } + +#[cfg(test)] +mod tests { + use super::ValueSetOauthClaimMap; + use crate::valueset::ValueSetT; + use std::collections::BTreeSet; + + #[test] + fn test_oauth_claim_invalid_str_concat_when_empty() { + let group_uuid = uuid::uuid!("5a6b8783-3f67-4ebb-b6aa-77fd6e66589f"); + let vs = + ValueSetOauthClaimMap::new_value("claim".to_string(), group_uuid, BTreeSet::default()); + + // Invalid handling of an empty claim map would cause a crash. + let proto_value = vs.to_proto_string_clone_iter().next().unwrap(); + + assert_eq!( + &proto_value, + "claim: 5a6b8783-3f67-4ebb-b6aa-77fd6e66589f \"\"\"\"" + ); + } +}