diff --git a/unix_integration/src/daemon.rs b/unix_integration/src/daemon.rs index e4bf558c6..3de69d052 100644 --- a/unix_integration/src/daemon.rs +++ b/unix_integration/src/daemon.rs @@ -35,6 +35,7 @@ use sketching::tracing_forest::traits::*; use sketching::tracing_forest::util::*; use sketching::tracing_forest::{self}; use tokio::net::{UnixListener, UnixStream}; +use tokio::sync::broadcast; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::oneshot; use tokio::time; @@ -683,36 +684,47 @@ async fn main() -> ExitCode { let task_channel_tx_cln = task_channel_tx.clone(); - tokio::spawn(async move { + // Start to build the worker tasks + let (broadcast_tx, mut broadcast_rx) = broadcast::channel(4); + let mut c_broadcast_rx = broadcast_tx.subscribe(); + + let task_b = tokio::spawn(async move { loop { - match task_listener.accept().await { - Ok((socket, _addr)) => { - // Did it come from root? - if let Ok(ucred) = socket.peer_cred() { - if ucred.uid() == 0 { - // all good! - } else { - // move along. - debug!("Task handler not running as root, ignoring ..."); - continue; - } - } else { - // move along. - debug!("Task handler not running as root, ignoring ..."); - continue; - }; - debug!("A task handler has connected."); - // It did? Great, now we can wait and spin on that one - // client. - if let Err(e) = - handle_task_client(socket, &task_channel_tx, &mut task_channel_rx).await - { - error!("Task client error occurred; error = {:?}", e); - } - // If they DC we go back to accept. + tokio::select! { + _ = c_broadcast_rx.recv() => { + break; } - Err(err) => { - error!("Task Accept error -> {:?}", err); + accept_res = task_listener.accept() => { + match accept_res { + Ok((socket, _addr)) => { + // Did it come from root? + if let Ok(ucred) = socket.peer_cred() { + if ucred.uid() == 0 { + // all good! + } else { + // move along. + debug!("Task handler not running as root, ignoring ..."); + continue; + } + } else { + // move along. + debug!("Task handler not running as root, ignoring ..."); + continue; + }; + debug!("A task handler has connected."); + // It did? Great, now we can wait and spin on that one + // client. + if let Err(e) = + handle_task_client(socket, &task_channel_tx, &mut task_channel_rx).await + { + error!("Task client error occurred; error = {:?}", e); + } + // If they DC we go back to accept. + } + Err(err) => { + error!("Task Accept error -> {:?}", err); + } + } } } // done @@ -720,30 +732,83 @@ async fn main() -> ExitCode { }); // TODO: Setup a task that handles pre-fetching here. - - let server = async move { + let task_a = tokio::spawn(async move { loop { let tc_tx = task_channel_tx_cln.clone(); - match listener.accept().await { - Ok((socket, _addr)) => { - let cachelayer_ref = cachelayer.clone(); - tokio::spawn(async move { - if let Err(e) = handle_client(socket, cachelayer_ref.clone(), &tc_tx).await - { - error!("handle_client error occurred; error = {:?}", e); - } - }); + + tokio::select! { + _ = broadcast_rx.recv() => { + break; } - Err(err) => { - error!("Error while handling connection -> {:?}", err); + accept_res = listener.accept() => { + match accept_res { + Ok((socket, _addr)) => { + let cachelayer_ref = cachelayer.clone(); + tokio::spawn(async move { + if let Err(e) = handle_client(socket, cachelayer_ref.clone(), &tc_tx).await + { + error!("handle_client error occurred; error = {:?}", e); + } + }); + } + Err(err) => { + error!("Error while handling connection -> {:?}", err); + } + } } } + } - }; + }); info!("Server started ..."); - server.await; + loop { + tokio::select! { + Ok(()) = tokio::signal::ctrl_c() => { + break + } + Some(()) = async move { + let sigterm = tokio::signal::unix::SignalKind::terminate(); + tokio::signal::unix::signal(sigterm).unwrap().recv().await + } => { + break + } + Some(()) = async move { + let sigterm = tokio::signal::unix::SignalKind::alarm(); + tokio::signal::unix::signal(sigterm).unwrap().recv().await + } => { + // Ignore + } + Some(()) = async move { + let sigterm = tokio::signal::unix::SignalKind::hangup(); + tokio::signal::unix::signal(sigterm).unwrap().recv().await + } => { + // Ignore + } + Some(()) = async move { + let sigterm = tokio::signal::unix::SignalKind::user_defined1(); + tokio::signal::unix::signal(sigterm).unwrap().recv().await + } => { + // Ignore + } + Some(()) = async move { + let sigterm = tokio::signal::unix::SignalKind::user_defined2(); + tokio::signal::unix::signal(sigterm).unwrap().recv().await + } => { + // Ignore + } + } + } + info!("Signal received, shutting down"); + // Send a broadcast that we are done. + if let Err(e) = broadcast_tx.send(true) { + error!("Unable to shutdown workers {:?}", e); + } + + let _ = task_a.await; + let _ = task_b.await; + ExitCode::SUCCESS }) .await diff --git a/unix_integration/src/tasks_daemon.rs b/unix_integration/src/tasks_daemon.rs index 8e46bfbd4..08485bca8 100644 --- a/unix_integration/src/tasks_daemon.rs +++ b/unix_integration/src/tasks_daemon.rs @@ -28,6 +28,7 @@ use sketching::tracing_forest::traits::*; use sketching::tracing_forest::util::*; use sketching::tracing_forest::{self}; use tokio::net::UnixStream; +use tokio::sync::broadcast; use tokio::time; use tokio_util::codec::{Decoder, Encoder, Framed}; use users::{get_effective_gid, get_effective_uid}; @@ -262,29 +263,82 @@ async fn main() -> ExitCode { let task_sock_path = cfg.task_sock_path.clone(); debug!("Attempting to use {} ...", task_sock_path); - let server = async move { + let (broadcast_tx, mut broadcast_rx) = broadcast::channel(4); + + let server = tokio::spawn(async move { loop { info!("Attempting to connect to kanidm_unixd ..."); - // Try to connect to the daemon. - match UnixStream::connect(&task_sock_path).await { - // Did we connect? - Ok(stream) => { - info!("Found kanidm_unixd, waiting for tasks ..."); - // Yep! Now let the main handler do it's job. - // If it returns (dc, etc, then we loop and try again). - handle_tasks(stream, &cfg).await; + + tokio::select! { + _ = broadcast_rx.recv() => { + break; } - Err(e) => { - error!("Unable to find kanidm_unixd, sleeping ..."); - debug!("\\---> {:?}", e); - // Back off. - time::sleep(Duration::from_millis(5000)).await; + connect_res = UnixStream::connect(&task_sock_path) => { + match connect_res { + Ok(stream) => { + info!("Found kanidm_unixd, waiting for tasks ..."); + // Yep! Now let the main handler do it's job. + // If it returns (dc, etc, then we loop and try again). + handle_tasks(stream, &cfg).await; + } + Err(e) => { + debug!("\\---> {:?}", e); + error!("Unable to find kanidm_unixd, sleeping ..."); + // Back off. + time::sleep(Duration::from_millis(5000)).await; + } + } } } } - }; + }); - server.await; + info!("Server started ..."); + + loop { + tokio::select! { + Ok(()) = tokio::signal::ctrl_c() => { + break + } + Some(()) = async move { + let sigterm = tokio::signal::unix::SignalKind::terminate(); + tokio::signal::unix::signal(sigterm).unwrap().recv().await + } => { + break + } + Some(()) = async move { + let sigterm = tokio::signal::unix::SignalKind::alarm(); + tokio::signal::unix::signal(sigterm).unwrap().recv().await + } => { + // Ignore + } + Some(()) = async move { + let sigterm = tokio::signal::unix::SignalKind::hangup(); + tokio::signal::unix::signal(sigterm).unwrap().recv().await + } => { + // Ignore + } + Some(()) = async move { + let sigterm = tokio::signal::unix::SignalKind::user_defined1(); + tokio::signal::unix::signal(sigterm).unwrap().recv().await + } => { + // Ignore + } + Some(()) = async move { + let sigterm = tokio::signal::unix::SignalKind::user_defined2(); + tokio::signal::unix::signal(sigterm).unwrap().recv().await + } => { + // Ignore + } + } + } + info!("Signal received, shutting down"); + // Send a broadcast that we are done. + if let Err(e) = broadcast_tx.send(true) { + error!("Unable to shutdown workers {:?}", e); + } + + let _ = server.await; ExitCode::SUCCESS }) .await