Add ID to task calls allowing other messages to be included.

This commit is contained in:
William Brown 2025-03-13 15:16:13 +10:00
parent 39adf992b3
commit 67a20ad697
3 changed files with 121 additions and 79 deletions
unix_integration

View file

@ -200,6 +200,12 @@ pub struct HomeDirectoryInfo {
pub aliases: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TaskRequestFrame {
pub id: u64,
pub req: TaskRequest,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum TaskRequest {
HomeDirectory(HomeDirectoryInfo),
@ -207,7 +213,7 @@ pub enum TaskRequest {
#[derive(Serialize, Deserialize, Debug)]
pub enum TaskResponse {
Success,
Success(u64),
Error(String),
}

View file

@ -10,6 +10,32 @@
#![deny(clippy::needless_pass_by_value)]
#![deny(clippy::trivially_copy_pass_by_ref)]
use bytes::{BufMut, BytesMut};
use clap::{Arg, ArgAction, Command};
use futures::{SinkExt, StreamExt};
use kanidm_client::KanidmClientBuilder;
use kanidm_hsm_crypto::{soft::SoftTpm, AuthValue, BoxedDynTpm, Tpm};
use kanidm_proto::constants::DEFAULT_CLIENT_CONFIG_PATH;
use kanidm_proto::internal::OperationError;
use kanidm_unix_common::constants::DEFAULT_CONFIG_PATH;
use kanidm_unix_common::unix_passwd::{parse_etc_group, parse_etc_passwd, parse_etc_shadow};
use kanidm_unix_common::unix_proto::{
ClientRequest, ClientResponse, TaskRequest, TaskRequestFrame, TaskResponse,
};
use kanidm_unix_resolver::db::{Cache, Db};
use kanidm_unix_resolver::idprovider::interface::IdProvider;
use kanidm_unix_resolver::idprovider::kanidm::KanidmProvider;
use kanidm_unix_resolver::idprovider::system::SystemProvider;
use kanidm_unix_resolver::resolver::Resolver;
use kanidm_unix_resolver::unix_config::{HsmType, UnixdConfig};
use kanidm_utils_users::{get_current_gid, get_current_uid, get_effective_gid, get_effective_uid};
use libc::umask;
use notify_debouncer_full::{new_debouncer, notify::RecursiveMode, DebouncedEvent};
use sketching::tracing::span;
use sketching::tracing_forest::traits::*;
use sketching::tracing_forest::util::*;
use sketching::tracing_forest::{self};
use std::collections::BTreeMap;
use std::error::Error;
use std::fs::metadata;
use std::io;
@ -20,29 +46,6 @@ use std::process::ExitCode;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use bytes::{BufMut, BytesMut};
use clap::{Arg, ArgAction, Command};
use futures::{SinkExt, StreamExt};
use kanidm_client::KanidmClientBuilder;
use kanidm_proto::constants::DEFAULT_CLIENT_CONFIG_PATH;
use kanidm_proto::internal::OperationError;
use kanidm_unix_common::constants::DEFAULT_CONFIG_PATH;
use kanidm_unix_common::unix_passwd::{parse_etc_group, parse_etc_passwd, parse_etc_shadow};
use kanidm_unix_common::unix_proto::{ClientRequest, ClientResponse, TaskRequest, TaskResponse};
use kanidm_unix_resolver::db::{Cache, Db};
use kanidm_unix_resolver::idprovider::interface::IdProvider;
use kanidm_unix_resolver::idprovider::kanidm::KanidmProvider;
use kanidm_unix_resolver::idprovider::system::SystemProvider;
use kanidm_unix_resolver::resolver::Resolver;
use kanidm_unix_resolver::unix_config::{HsmType, UnixdConfig};
use kanidm_utils_users::{get_current_gid, get_current_uid, get_effective_gid, get_effective_uid};
use libc::umask;
use sketching::tracing::span;
use sketching::tracing_forest::traits::*;
use sketching::tracing_forest::util::*;
use sketching::tracing_forest::{self};
use time::OffsetDateTime;
use tokio::fs::File;
use tokio::io::AsyncReadExt; // for read_to_end()
@ -52,17 +55,16 @@ use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::oneshot;
use tokio_util::codec::{Decoder, Encoder, Framed};
use kanidm_hsm_crypto::{soft::SoftTpm, AuthValue, BoxedDynTpm, Tpm};
use notify_debouncer_full::{new_debouncer, notify::RecursiveMode, DebouncedEvent};
#[cfg(not(target_os = "illumos"))]
#[global_allocator]
static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
//=== the codec
type AsyncTaskRequest = (TaskRequest, oneshot::Sender<()>);
struct AsyncTaskRequest {
task_req: TaskRequest,
task_chan: oneshot::Sender<()>,
}
#[derive(Default)]
struct ClientCodec;
@ -117,11 +119,11 @@ impl Decoder for TaskCodec {
}
}
impl Encoder<TaskRequest> for TaskCodec {
impl Encoder<TaskRequestFrame> for TaskCodec {
type Error = io::Error;
fn encode(&mut self, msg: TaskRequest, dst: &mut BytesMut) -> Result<(), Self::Error> {
debug!("Attempting to send request -> {:?} ...", msg);
fn encode(&mut self, msg: TaskRequestFrame, dst: &mut BytesMut) -> Result<(), Self::Error> {
debug!("Attempting to send request -> {:?} ...", msg.id);
let data = serde_json::to_vec(&msg).map_err(|e| {
error!("socket encoding error -> {:?}", e);
io::Error::new(io::ErrorKind::Other, "JSON encode error")
@ -148,46 +150,76 @@ fn rm_if_exist(p: &str) {
async fn handle_task_client(
stream: UnixStream,
task_channel_tx: &Sender<AsyncTaskRequest>,
_task_channel_tx: &Sender<AsyncTaskRequest>,
task_channel_rx: &mut Receiver<AsyncTaskRequest>,
broadcast_rx: &mut broadcast::Receiver<bool>,
) -> Result<(), Box<dyn Error>> {
// setup the codec
let mut reqs = Framed::new(stream, TaskCodec);
// setup the codec, this is to the unix socket which the task daemon
// connected to us with.
let mut last_task_id: u64 = 0;
let mut task_handles = BTreeMap::new();
let mut framed_stream = Framed::new(stream, TaskCodec);
loop {
// TODO wait on the channel OR the task handler, so we know
// when it closes.
let v = match task_channel_rx.recv().await {
Some(v) => v,
None => return Ok(()),
};
debug!("Sending Task -> {:?}", v.0);
// Write the req to the socket.
if let Err(_e) = reqs.send(v.0.clone()).await {
// re-queue the event if not timed out.
// This is indicated by the one shot being dropped.
if !v.1.is_closed() {
let _ = task_channel_tx
.send_timeout(v, Duration::from_millis(100))
.await;
tokio::select! {
// We have been commanded to stop operation.
_ = broadcast_rx.recv() => {
return Ok(())
}
// now return the error.
return Err(Box::new(IoError::new(ErrorKind::Other, "oh no!")));
}
task_request = task_channel_rx.recv() => {
let Some(AsyncTaskRequest {
task_req,
task_chan
}) = task_request else {
// Task channel has died, cease operation.
return Ok(())
};
match reqs.next().await {
Some(Ok(TaskResponse::Success)) => {
debug!("Task was acknowledged and completed.");
// Send a result back via the one-shot
// Ignore if it fails.
let _ = v.1.send(());
debug!("Sending Task -> {:?}", task_req);
last_task_id += 1;
let task_id = last_task_id;
// Setup the task handle so we know who to get back to.
task_handles.insert(task_id, task_chan);
let task_frame = TaskRequestFrame {
id: task_id,
req: task_req,
};
if let Err(err) = framed_stream.send(task_frame).await {
warn!("Unable to queue task for completion");
return Err(Box::new(err));
}
// Task sent
}
other => {
error!("Error -> {:?}", other);
return Err(Box::new(IoError::new(ErrorKind::Other, "oh no!")));
response = framed_stream.next() => {
// Process incoming messages. They may be out of order.
match response {
Some(Ok(TaskResponse::Success(task_id))) => {
debug!("Task was acknowledged and completed.");
if let Some(handle) = task_handles.remove(&task_id) {
// Send a result back via the one-shot
// Ignore if it fails.
let _ = handle.send(());
}
// If the ID was unregistered, ignore.
}
// Other things ....
// Some(Ok(TaskResponse::
other => {
error!("Error -> {:?}", other);
return Err(Box::new(IoError::new(ErrorKind::Other, "oh no!")));
}
}
}
}
}
}
@ -341,7 +373,10 @@ async fn handle_client(
match task_channel_tx
.send_timeout(
(TaskRequest::HomeDirectory(info), tx),
AsyncTaskRequest {
task_req: TaskRequest::HomeDirectory(info),
task_chan: tx,
},
Duration::from_millis(100),
)
.await
@ -1040,6 +1075,7 @@ async fn main() -> ExitCode {
let task_b = tokio::spawn(async move {
loop {
tokio::select! {
// Wait on the broadcast to see if we need to close down.
_ = c_broadcast_rx.recv() => {
break;
}
@ -1062,16 +1098,11 @@ async fn main() -> ExitCode {
// It did? Great, now we can wait and spin on that one
// client.
tokio::select! {
_ = d_broadcast_rx.recv() => {
break;
}
// We have to check for signals here else this tasks waits forever.
Err(e) = handle_task_client(socket, &task_channel_tx, &mut task_channel_rx) => {
error!("Task client error occurred; error = {:?}", e);
}
// We have to check for signals here else this tasks waits forever.
if let Err(err) = handle_task_client(socket, &task_channel_tx, &mut task_channel_rx, &mut d_broadcast_rx).await {
error!(?err, "Task client error occurred");
}
// If they DC we go back to accept.
// If they disconnect we go back to accept.
}
Err(err) => {
error!("Task Accept error -> {:?}", err);

View file

@ -21,7 +21,9 @@ use std::{fs, io};
use bytes::{BufMut, BytesMut};
use futures::{SinkExt, StreamExt};
use kanidm_unix_common::constants::DEFAULT_CONFIG_PATH;
use kanidm_unix_common::unix_proto::{HomeDirectoryInfo, TaskRequest, TaskResponse};
use kanidm_unix_common::unix_proto::{
HomeDirectoryInfo, TaskRequest, TaskRequestFrame, TaskResponse,
};
use kanidm_unix_resolver::unix_config::UnixdConfig;
use kanidm_utils_users::{get_effective_gid, get_effective_uid};
use libc::{lchown, umask};
@ -41,10 +43,10 @@ struct TaskCodec;
impl Decoder for TaskCodec {
type Error = io::Error;
type Item = TaskRequest;
type Item = TaskRequestFrame;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match serde_json::from_slice::<TaskRequest>(src) {
match serde_json::from_slice::<TaskRequestFrame>(src) {
Ok(msg) => {
// Clear the buffer for the next message.
src.clear();
@ -274,7 +276,10 @@ async fn handle_tasks(stream: UnixStream, cfg: &UnixdConfig) {
loop {
match reqs.next().await {
Some(Ok(TaskRequest::HomeDirectory(info))) => {
Some(Ok(TaskRequestFrame {
id,
req: TaskRequest::HomeDirectory(info),
})) => {
debug!("Received task -> HomeDirectory({:?})", info);
let resp = match create_home_directory(
@ -284,7 +289,7 @@ async fn handle_tasks(stream: UnixStream, cfg: &UnixdConfig) {
cfg.use_etc_skel,
cfg.selinux,
) {
Ok(()) => TaskResponse::Success,
Ok(()) => TaskResponse::Success(id),
Err(msg) => TaskResponse::Error(msg),
};