20240725 allow connection to older servers (#2930)

Co-authored-by: James Hodgkinson <james@terminaloutcomes.com>
This commit is contained in:
Firstyear 2024-07-25 16:11:14 +10:00 committed by GitHub
parent 38b0a6f8af
commit 7bbb193cdf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 141 additions and 30 deletions

View file

@ -22,6 +22,7 @@ use std::io::{ErrorKind, Read};
#[cfg(target_family = "unix")] // not needed for windows builds #[cfg(target_family = "unix")] // not needed for windows builds
use std::os::unix::fs::MetadataExt; use std::os::unix::fs::MetadataExt;
use std::path::Path; use std::path::Path;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use compact_jwt::Jwk; use compact_jwt::Jwk;
@ -30,10 +31,11 @@ use kanidm_proto::constants::uri::V1_AUTH_VALID;
use kanidm_proto::constants::{ use kanidm_proto::constants::{
ATTR_DOMAIN_DISPLAY_NAME, ATTR_DOMAIN_LDAP_BASEDN, ATTR_DOMAIN_SSID, ATTR_ENTRY_MANAGED_BY, ATTR_DOMAIN_DISPLAY_NAME, ATTR_DOMAIN_LDAP_BASEDN, ATTR_DOMAIN_SSID, ATTR_ENTRY_MANAGED_BY,
ATTR_KEY_ACTION_REVOKE, ATTR_LDAP_ALLOW_UNIX_PW_BIND, ATTR_NAME, CLIENT_TOKEN_CACHE, KOPID, ATTR_KEY_ACTION_REVOKE, ATTR_LDAP_ALLOW_UNIX_PW_BIND, ATTR_NAME, CLIENT_TOKEN_CACHE, KOPID,
KVERSION, KSESSIONID, KVERSION,
}; };
use kanidm_proto::internal::*; use kanidm_proto::internal::*;
use kanidm_proto::v1::*; use kanidm_proto::v1::*;
use reqwest::cookie::{CookieStore, Jar};
use reqwest::Response; use reqwest::Response;
pub use reqwest::StatusCode; pub use reqwest::StatusCode;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
@ -192,10 +194,12 @@ fn test_kanidmclientbuilder_display() {
#[derive(Debug)] #[derive(Debug)]
pub struct KanidmClient { pub struct KanidmClient {
pub(crate) client: reqwest::Client, pub(crate) client: reqwest::Client,
client_cookies: Arc<Jar>,
pub(crate) addr: String, pub(crate) addr: String,
pub(crate) origin: Url, pub(crate) origin: Url,
pub(crate) builder: KanidmClientBuilder, pub(crate) builder: KanidmClientBuilder,
pub(crate) bearer_token: RwLock<Option<String>>, pub(crate) bearer_token: RwLock<Option<String>>,
pub(crate) auth_session_id: RwLock<Option<String>>,
pub(crate) check_version: Mutex<bool>, pub(crate) check_version: Mutex<bool>,
/// Where to store the tokens when you auth, only modify in testing. /// Where to store the tokens when you auth, only modify in testing.
token_cache_path: String, token_cache_path: String,
@ -501,11 +505,14 @@ impl KanidmClientBuilder {
self.display_warnings(&address); self.display_warnings(&address);
let client_cookies = Arc::new(Jar::default());
let client_builder = reqwest::Client::builder() let client_builder = reqwest::Client::builder()
.user_agent(KanidmClientBuilder::user_agent()) .user_agent(KanidmClientBuilder::user_agent())
// We don't directly use cookies, but it may be required for load balancers that // We don't directly use cookies, but it may be required for load balancers that
// implement sticky sessions with cookies. // implement sticky sessions with cookies.
.cookie_store(true) .cookie_store(true)
.cookie_provider(client_cookies.clone())
.danger_accept_invalid_hostnames(!self.verify_hostnames) .danger_accept_invalid_hostnames(!self.verify_hostnames)
.danger_accept_invalid_certs(!self.verify_ca); .danger_accept_invalid_certs(!self.verify_ca);
@ -546,9 +553,11 @@ impl KanidmClientBuilder {
Ok(KanidmClient { Ok(KanidmClient {
client, client,
client_cookies,
addr: address, addr: address,
builder: self, builder: self,
bearer_token: RwLock::new(None), bearer_token: RwLock::new(None),
auth_session_id: RwLock::new(None),
origin, origin,
check_version: Mutex::new(true), check_version: Mutex::new(true),
token_cache_path, token_cache_path,
@ -765,7 +774,9 @@ impl KanidmClient {
) -> Result<T, ClientError> { ) -> Result<T, ClientError> {
trace!("perform_auth_post_request connecting to {}", dest); trace!("perform_auth_post_request connecting to {}", dest);
let response = self.client.post(self.make_url(dest)).json(&request); let auth_url = self.make_url(dest);
let response = self.client.post(auth_url.clone()).json(&request);
// If we have a bearer token, set it now. // If we have a bearer token, set it now.
let response = { let response = {
@ -777,6 +788,17 @@ impl KanidmClient {
} }
}; };
// If we have a session header, set it now. This is only used when connecting
// to an older server.
let response = {
let sguard = self.auth_session_id.read().await;
if let Some(sessionid) = &(*sguard) {
response.header(KSESSIONID, sessionid)
} else {
response
}
};
let response = response let response = response
.send() .send()
.await .await
@ -798,6 +820,42 @@ impl KanidmClient {
} }
} }
// Do we have a cookie? Our job here isn't to parse and validate the cookies, but just to
// know if the session id was set *in* our cookie store at all.
let cookie_present = self
.client_cookies
.cookies(&auth_url)
.map(|cookie_header| {
cookie_header
.to_str()
.ok()
.map(|cookie_str| {
cookie_str
.split(';')
.filter_map(|c| c.split_once('='))
.any(|(name, _)| name == COOKIE_AUTH_SESSION_ID)
})
.unwrap_or_default()
})
.unwrap_or_default();
{
let headers = response.headers();
let mut sguard = self.auth_session_id.write().await;
trace!(?cookie_present);
if cookie_present {
// Clear and auth session id if present, we have the cookie instead.
*sguard = None;
} else {
// This situation occurs when a newer client connects to an older server
debug!("Auth SessionID cookie not present, falling back to header.");
*sguard = headers
.get(KSESSIONID)
.and_then(|hv| hv.to_str().ok().map(str::to_string));
}
}
response response
.json() .json()
.await .await

View file

@ -36,6 +36,7 @@ use kanidm_unix_resolver::unix_config::{HsmType, KanidmUnixdConfig};
use kanidm_utils_users::{get_current_gid, get_current_uid, get_effective_gid, get_effective_uid}; use kanidm_utils_users::{get_current_gid, get_current_uid, get_effective_gid, get_effective_uid};
use libc::umask; use libc::umask;
use sketching::tracing::span;
use sketching::tracing_forest::traits::*; use sketching::tracing_forest::traits::*;
use sketching::tracing_forest::util::*; use sketching::tracing_forest::util::*;
use sketching::tracing_forest::{self}; use sketching::tracing_forest::{self};
@ -211,6 +212,9 @@ async fn handle_client(
trace!("Waiting for requests ..."); trace!("Waiting for requests ...");
while let Some(Ok(req)) = reqs.next().await { while let Some(Ok(req)) = reqs.next().await {
let span = span!(Level::INFO, "client_request");
let _enter = span.enter();
let resp = match req { let resp = match req {
ClientRequest::SshKey(account_id) => { ClientRequest::SshKey(account_id) => {
debug!("sshkey req"); debug!("sshkey req");

View file

@ -150,7 +150,7 @@ fn create_home_directory(
debug!(?use_selinux, "selinux for home dir labeling"); debug!(?use_selinux, "selinux for home dir labeling");
#[cfg(all(target_family = "unix", feature = "selinux"))] #[cfg(all(target_family = "unix", feature = "selinux"))]
let labeler = if use_selinux { let labeler = if use_selinux {
selinux_util::SelinuxLabeler::new(info.gid, home_prefix)? selinux_util::SelinuxLabeler::new(info.gid, &home_mount_prefix_path)?
} else { } else {
selinux_util::SelinuxLabeler::new_noop() selinux_util::SelinuxLabeler::new_noop()
}; };

View file

@ -310,7 +310,7 @@ impl<'a> DbTxn<'a> {
key: &K, key: &K,
) -> Result<(), CacheError> { ) -> Result<(), CacheError> {
let data = serde_json::to_vec(key).map_err(|e| { let data = serde_json::to_vec(key).map_err(|e| {
error!("insert_hsm_machine_key json error -> {:?}", e); error!("json error -> {:?}", e);
CacheError::SerdeJson CacheError::SerdeJson
})?; })?;
@ -451,7 +451,6 @@ impl<'a> DbTxn<'a> {
} }
pub fn commit(mut self) -> Result<(), CacheError> { pub fn commit(mut self) -> Result<(), CacheError> {
// debug!("Committing BE txn");
if self.committed { if self.committed {
error!("Invalid state, SQL transaction was already committed!"); error!("Invalid state, SQL transaction was already committed!");
return Err(CacheError::TransactionInvalidState); return Err(CacheError::TransactionInvalidState);
@ -1016,7 +1015,7 @@ impl<'a> Drop for DbTxn<'a> {
mod tests { mod tests {
use super::{Cache, Db}; use super::{Cache, Db};
use crate::idprovider::interface::{GroupToken, Id, UserToken}; use crate::idprovider::interface::{GroupToken, Id, ProviderOrigin, UserToken};
use kanidm_hsm_crypto::{AuthValue, Tpm}; use kanidm_hsm_crypto::{AuthValue, Tpm};
const TESTACCOUNT1_PASSWORD_A: &str = "password a for account1 test"; const TESTACCOUNT1_PASSWORD_A: &str = "password a for account1 test";
@ -1042,6 +1041,7 @@ mod tests {
assert!(dbtxn.migrate().is_ok()); assert!(dbtxn.migrate().is_ok());
let mut ut1 = UserToken { let mut ut1 = UserToken {
provider: ProviderOrigin::Files,
name: "testuser".to_string(), name: "testuser".to_string(),
spn: "testuser@example.com".to_string(), spn: "testuser@example.com".to_string(),
displayname: "Test User".to_string(), displayname: "Test User".to_string(),
@ -1126,6 +1126,7 @@ mod tests {
assert!(dbtxn.migrate().is_ok()); assert!(dbtxn.migrate().is_ok());
let mut gt1 = GroupToken { let mut gt1 = GroupToken {
provider: ProviderOrigin::Files,
name: "testgroup".to_string(), name: "testgroup".to_string(),
spn: "testgroup@example.com".to_string(), spn: "testgroup@example.com".to_string(),
gidnumber: 2000, gidnumber: 2000,
@ -1201,6 +1202,7 @@ mod tests {
assert!(dbtxn.migrate().is_ok()); assert!(dbtxn.migrate().is_ok());
let gt1 = GroupToken { let gt1 = GroupToken {
provider: ProviderOrigin::Files,
name: "testuser".to_string(), name: "testuser".to_string(),
spn: "testuser@example.com".to_string(), spn: "testuser@example.com".to_string(),
gidnumber: 2000, gidnumber: 2000,
@ -1208,6 +1210,7 @@ mod tests {
}; };
let gt2 = GroupToken { let gt2 = GroupToken {
provider: ProviderOrigin::Files,
name: "testgroup".to_string(), name: "testgroup".to_string(),
spn: "testgroup@example.com".to_string(), spn: "testgroup@example.com".to_string(),
gidnumber: 2001, gidnumber: 2001,
@ -1215,6 +1218,7 @@ mod tests {
}; };
let mut ut1 = UserToken { let mut ut1 = UserToken {
provider: ProviderOrigin::Files,
name: "testuser".to_string(), name: "testuser".to_string(),
spn: "testuser@example.com".to_string(), spn: "testuser@example.com".to_string(),
displayname: "Test User".to_string(), displayname: "Test User".to_string(),
@ -1284,6 +1288,7 @@ mod tests {
let uuid1 = uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16"); let uuid1 = uuid::uuid!("0302b99c-f0f6-41ab-9492-852692b0fd16");
let mut ut1 = UserToken { let mut ut1 = UserToken {
provider: ProviderOrigin::Files,
name: "testuser".to_string(), name: "testuser".to_string(),
spn: "testuser@example.com".to_string(), spn: "testuser@example.com".to_string(),
displayname: "Test User".to_string(), displayname: "Test User".to_string(),
@ -1353,6 +1358,7 @@ mod tests {
assert!(dbtxn.migrate().is_ok()); assert!(dbtxn.migrate().is_ok());
let mut gt1 = GroupToken { let mut gt1 = GroupToken {
provider: ProviderOrigin::Files,
name: "testgroup".to_string(), name: "testgroup".to_string(),
spn: "testgroup@example.com".to_string(), spn: "testgroup@example.com".to_string(),
gidnumber: 2000, gidnumber: 2000,
@ -1360,6 +1366,7 @@ mod tests {
}; };
let gt2 = GroupToken { let gt2 = GroupToken {
provider: ProviderOrigin::Files,
name: "testgroup".to_string(), name: "testgroup".to_string(),
spn: "testgroup@example.com".to_string(), spn: "testgroup@example.com".to_string(),
gidnumber: 2001, gidnumber: 2001,
@ -1408,6 +1415,7 @@ mod tests {
assert!(dbtxn.migrate().is_ok()); assert!(dbtxn.migrate().is_ok());
let mut ut1 = UserToken { let mut ut1 = UserToken {
provider: ProviderOrigin::Files,
name: "testuser".to_string(), name: "testuser".to_string(),
spn: "testuser@example.com".to_string(), spn: "testuser@example.com".to_string(),
displayname: "Test User".to_string(), displayname: "Test User".to_string(),
@ -1420,6 +1428,7 @@ mod tests {
}; };
let ut2 = UserToken { let ut2 = UserToken {
provider: ProviderOrigin::Files,
name: "testuser".to_string(), name: "testuser".to_string(),
spn: "testuser@example.com".to_string(), spn: "testuser@example.com".to_string(),
displayname: "Test User".to_string(), displayname: "Test User".to_string(),

View file

@ -39,8 +39,20 @@ pub enum Id {
Gid(u32), Gid(u32),
} }
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
pub enum ProviderOrigin {
// To allow transition, we have an ignored type that effectively
// causes these items to be nixed.
#[default]
Ignore,
Files,
Kanidm,
}
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GroupToken { pub struct GroupToken {
#[serde(default)]
pub provider: ProviderOrigin,
pub name: String, pub name: String,
pub spn: String, pub spn: String,
pub uuid: Uuid, pub uuid: Uuid,
@ -49,6 +61,8 @@ pub struct GroupToken {
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UserToken { pub struct UserToken {
#[serde(default)]
pub provider: ProviderOrigin,
pub name: String, pub name: String,
pub spn: String, pub spn: String,
pub uuid: Uuid, pub uuid: Uuid,

View file

@ -17,6 +17,7 @@ use super::interface::{
Id, Id,
IdProvider, IdProvider,
IdpError, IdpError,
ProviderOrigin,
UserToken, UserToken,
}; };
use kanidm_unix_common::unix_proto::PamAuthRequest; use kanidm_unix_common::unix_proto::PamAuthRequest;
@ -52,6 +53,7 @@ impl From<UnixUserToken> for UserToken {
let groups = groups.into_iter().map(GroupToken::from).collect(); let groups = groups.into_iter().map(GroupToken::from).collect();
UserToken { UserToken {
provider: ProviderOrigin::Kanidm,
name, name,
spn, spn,
uuid, uuid,
@ -75,6 +77,7 @@ impl From<UnixGroupToken> for GroupToken {
} = value; } = value;
GroupToken { GroupToken {
provider: ProviderOrigin::Kanidm,
name, name,
spn, spn,
uuid, uuid,

View file

@ -68,6 +68,7 @@ pub struct Resolver {
hmac_key: HmacKey, hmac_key: HmacKey,
// A local passwd/shadow resolver. // A local passwd/shadow resolver.
nxset: Mutex<HashSet<Id>>,
// A set of remote resolvers // A set of remote resolvers
client: Box<dyn IdProvider + Sync + Send>, client: Box<dyn IdProvider + Sync + Send>,
@ -83,7 +84,6 @@ pub struct Resolver {
uid_attr_map: UidAttr, uid_attr_map: UidAttr,
gid_attr_map: UidAttr, gid_attr_map: UidAttr,
allow_id_overrides: HashSet<Id>, allow_id_overrides: HashSet<Id>,
nxset: Mutex<HashSet<Id>>,
nxcache: Mutex<LruCache<Id, SystemTime>>, nxcache: Mutex<LruCache<Id, SystemTime>>,
} }
@ -209,6 +209,7 @@ impl Resolver {
self.set_cachestate(CacheState::Offline).await; self.set_cachestate(CacheState::Offline).await;
} }
#[instrument(level = "debug", skip_all)]
pub async fn clear_cache(&self) -> Result<(), ()> { pub async fn clear_cache(&self) -> Result<(), ()> {
let mut nxcache_txn = self.nxcache.lock().await; let mut nxcache_txn = self.nxcache.lock().await;
nxcache_txn.clear(); nxcache_txn.clear();
@ -216,6 +217,7 @@ impl Resolver {
dbtxn.clear().and_then(|_| dbtxn.commit()).map_err(|_| ()) dbtxn.clear().and_then(|_| dbtxn.commit()).map_err(|_| ())
} }
#[instrument(level = "debug", skip_all)]
pub async fn invalidate(&self) -> Result<(), ()> { pub async fn invalidate(&self) -> Result<(), ()> {
let mut nxcache_txn = self.nxcache.lock().await; let mut nxcache_txn = self.nxcache.lock().await;
nxcache_txn.clear(); nxcache_txn.clear();
@ -596,8 +598,8 @@ impl Resolver {
} }
} }
#[instrument(level = "debug", skip(self))]
async fn get_usertoken(&self, account_id: Id) -> Result<Option<UserToken>, ()> { async fn get_usertoken(&self, account_id: Id) -> Result<Option<UserToken>, ()> {
debug!("get_usertoken");
// get the item from the cache // get the item from the cache
let (expired, item) = self.get_cached_usertoken(&account_id).await.map_err(|e| { let (expired, item) = self.get_cached_usertoken(&account_id).await.map_err(|e| {
debug!("get_usertoken error -> {:?}", e); debug!("get_usertoken error -> {:?}", e);
@ -648,8 +650,8 @@ impl Resolver {
}) })
} }
#[instrument(level = "debug", skip(self))]
async fn get_grouptoken(&self, grp_id: Id) -> Result<Option<GroupToken>, ()> { async fn get_grouptoken(&self, grp_id: Id) -> Result<Option<GroupToken>, ()> {
debug!("get_grouptoken");
let (expired, item) = self.get_cached_grouptoken(&grp_id).await.map_err(|e| { let (expired, item) = self.get_cached_grouptoken(&grp_id).await.map_err(|e| {
debug!("get_grouptoken error -> {:?}", e); debug!("get_grouptoken error -> {:?}", e);
})?; })?;
@ -707,6 +709,7 @@ impl Resolver {
} }
// Get ssh keys for an account id // Get ssh keys for an account id
#[instrument(level = "debug", skip(self))]
pub async fn get_sshkeys(&self, account_id: &str) -> Result<Vec<String>, ()> { pub async fn get_sshkeys(&self, account_id: &str) -> Result<Vec<String>, ()> {
let token = self.get_usertoken(Id::Name(account_id.to_string())).await?; let token = self.get_usertoken(Id::Name(account_id.to_string())).await?;
Ok(token Ok(token
@ -763,6 +766,7 @@ impl Resolver {
.to_string() .to_string()
} }
#[instrument(level = "debug", skip_all)]
pub async fn get_nssaccounts(&self) -> Result<Vec<NssUser>, ()> { pub async fn get_nssaccounts(&self) -> Result<Vec<NssUser>, ()> {
self.get_cached_usertokens().await.map(|l| { self.get_cached_usertokens().await.map(|l| {
l.into_iter() l.into_iter()
@ -788,10 +792,12 @@ impl Resolver {
})) }))
} }
#[instrument(level = "debug", skip(self))]
pub async fn get_nssaccount_name(&self, account_id: &str) -> Result<Option<NssUser>, ()> { pub async fn get_nssaccount_name(&self, account_id: &str) -> Result<Option<NssUser>, ()> {
self.get_nssaccount(Id::Name(account_id.to_string())).await self.get_nssaccount(Id::Name(account_id.to_string())).await
} }
#[instrument(level = "debug", skip(self))]
pub async fn get_nssaccount_gid(&self, gid: u32) -> Result<Option<NssUser>, ()> { pub async fn get_nssaccount_gid(&self, gid: u32) -> Result<Option<NssUser>, ()> {
self.get_nssaccount(Id::Gid(gid)).await self.get_nssaccount(Id::Gid(gid)).await
} }
@ -805,6 +811,7 @@ impl Resolver {
.to_string() .to_string()
} }
#[instrument(level = "debug", skip_all)]
pub async fn get_nssgroups(&self) -> Result<Vec<NssGroup>, ()> { pub async fn get_nssgroups(&self) -> Result<Vec<NssGroup>, ()> {
let l = self.get_cached_grouptokens().await?; let l = self.get_cached_grouptokens().await?;
let mut r: Vec<_> = Vec::with_capacity(l.len()); let mut r: Vec<_> = Vec::with_capacity(l.len());
@ -835,14 +842,17 @@ impl Resolver {
} }
} }
#[instrument(level = "debug", skip(self))]
pub async fn get_nssgroup_name(&self, grp_id: &str) -> Result<Option<NssGroup>, ()> { pub async fn get_nssgroup_name(&self, grp_id: &str) -> Result<Option<NssGroup>, ()> {
self.get_nssgroup(Id::Name(grp_id.to_string())).await self.get_nssgroup(Id::Name(grp_id.to_string())).await
} }
#[instrument(level = "debug", skip(self))]
pub async fn get_nssgroup_gid(&self, gid: u32) -> Result<Option<NssGroup>, ()> { pub async fn get_nssgroup_gid(&self, gid: u32) -> Result<Option<NssGroup>, ()> {
self.get_nssgroup(Id::Gid(gid)).await self.get_nssgroup(Id::Gid(gid)).await
} }
#[instrument(level = "debug", skip(self))]
pub async fn pam_account_allowed(&self, account_id: &str) -> Result<Option<bool>, ()> { pub async fn pam_account_allowed(&self, account_id: &str) -> Result<Option<bool>, ()> {
let token = self.get_usertoken(Id::Name(account_id.to_string())).await?; let token = self.get_usertoken(Id::Name(account_id.to_string())).await?;
@ -871,6 +881,7 @@ impl Resolver {
} }
} }
#[instrument(level = "debug", skip(self, shutdown_rx))]
pub async fn pam_account_authenticate_init( pub async fn pam_account_authenticate_init(
&self, &self,
account_id: &str, account_id: &str,
@ -945,6 +956,7 @@ impl Resolver {
} }
} }
#[instrument(level = "debug", skip_all)]
pub async fn pam_account_authenticate_step( pub async fn pam_account_authenticate_step(
&self, &self,
auth_session: &mut AuthSession, auth_session: &mut AuthSession,
@ -1142,6 +1154,7 @@ impl Resolver {
} }
// Can this be cfg debug/test? // Can this be cfg debug/test?
#[instrument(level = "debug", skip(self, password))]
pub async fn pam_account_authenticate( pub async fn pam_account_authenticate(
&self, &self,
account_id: &str, account_id: &str,
@ -1209,6 +1222,7 @@ impl Resolver {
} }
} }
#[instrument(level = "debug", skip(self))]
pub async fn pam_account_beginsession( pub async fn pam_account_beginsession(
&self, &self,
account_id: &str, account_id: &str,
@ -1224,6 +1238,7 @@ impl Resolver {
})) }))
} }
#[instrument(level = "debug", skip_all)]
pub async fn test_connection(&self) -> bool { pub async fn test_connection(&self) -> bool {
let state = self.get_cachestate().await; let state = self.get_cachestate().await;
match state { match state {

View file

@ -1,6 +1,6 @@
use kanidm_utils_users::get_user_name_by_uid; use kanidm_utils_users::get_user_name_by_uid;
use std::ffi::CString; use std::ffi::{CString, OsStr};
use std::path::Path; use std::path::{Path, PathBuf};
use std::process::Command; use std::process::Command;
use selinux::{ use selinux::{
@ -20,14 +20,13 @@ pub fn supported() -> bool {
} }
} }
fn do_setfscreatecon_for_path(path_raw: &str, labeler: &Labeler<File>) -> Result<(), String> { fn do_setfscreatecon_for_path(path_raw: &Path, labeler: &Labeler<File>) -> Result<(), String> {
match labeler.look_up(&CString::new(path_raw.to_owned()).unwrap(), 0) { let path_c_string = CString::new(path_raw.as_os_str().as_encoded_bytes())
Ok(context) => { .map_err(|_| "Invalid Path String".to_string())?;
if context.set_for_new_file_system_objects(true).is_err() { match labeler.look_up(&path_c_string, 0) {
return Err("Failed setting creation context home directory path".to_string()); Ok(context) => context
} .set_for_new_file_system_objects(true)
Ok(()) .map_err(|_| "Failed setting creation context home directory path".to_string()),
}
Err(_) => { Err(_) => {
return Err("Failed looking up default context for home directory path".to_string()); return Err("Failed looking up default context for home directory path".to_string());
} }
@ -46,12 +45,12 @@ pub enum SelinuxLabeler {
None, None,
Enabled { Enabled {
labeler: Labeler<File>, labeler: Labeler<File>,
sel_lookup_path_raw: String, sel_lookup_path_raw: PathBuf,
}, },
} }
impl SelinuxLabeler { impl SelinuxLabeler {
pub fn new(gid: u32, home_prefix: &str) -> Result<Self, String> { pub fn new(gid: u32, home_prefix: &Path) -> Result<Self, String> {
let labeler = get_labeler()?; let labeler = get_labeler()?;
// Construct a path for SELinux context lookups. // Construct a path for SELinux context lookups.
@ -64,7 +63,7 @@ impl SelinuxLabeler {
#[cfg(all(target_family = "unix", feature = "selinux"))] #[cfg(all(target_family = "unix", feature = "selinux"))]
// Yes, gid, because we use the GID number for both the user's UID and primary GID // Yes, gid, because we use the GID number for both the user's UID and primary GID
let sel_lookup_path_raw = match get_user_name_by_uid(gid) { let sel_lookup_path_raw = match get_user_name_by_uid(gid) {
Some(v) => format!("{}{}", home_prefix, v.to_str().unwrap()), Some(v) => home_prefix.join(v),
None => { None => {
return Err("Failed looking up username by uid for SELinux relabeling".to_string()); return Err("Failed looking up username by uid for SELinux relabeling".to_string());
} }
@ -97,23 +96,32 @@ impl SelinuxLabeler {
labeler, labeler,
sel_lookup_path_raw, sel_lookup_path_raw,
} => { } => {
let sel_lookup_path = Path::new(&sel_lookup_path_raw).join(path.as_ref()); let sel_lookup_path = sel_lookup_path_raw.join(path.as_ref());
do_setfscreatecon_for_path(&sel_lookup_path.to_str().unwrap().to_string(), &labeler) do_setfscreatecon_for_path(&sel_lookup_path, &labeler)
} }
} }
} }
pub fn setup_equivalence_rule<P: AsRef<Path>>(&self, path: P) -> Result<(), String> { pub fn setup_equivalence_rule<P: AsRef<OsStr>>(&self, path: P) -> Result<(), String> {
match &self { match &self {
SelinuxLabeler::None => Ok(()), SelinuxLabeler::None => Ok(()),
SelinuxLabeler::Enabled { SelinuxLabeler::Enabled {
labeler: _, labeler: _,
sel_lookup_path_raw, sel_lookup_path_raw,
} => Command::new("semanage") } => {
.args(["fcontext", "-ae", sel_lookup_path_raw, path.as_ref()]) // Looks weird but needed to force the type to be os str
let arg1: &OsStr = "fcontext".as_ref();
Command::new("semanage")
.args([
arg1,
"-ae".as_ref(),
sel_lookup_path_raw.as_ref(),
path.as_ref(),
])
.spawn() .spawn()
.map(|_| ()) .map(|_| ())
.map_err(|_| "Failed creating SELinux policy equivalence rule".to_string()), .map_err(|_| "Failed creating SELinux policy equivalence rule".to_string())
}
} }
} }