diff --git a/tools/cli/Cargo.toml b/tools/cli/Cargo.toml index f5efca76f..1794db991 100644 --- a/tools/cli/Cargo.toml +++ b/tools/cli/Cargo.toml @@ -53,7 +53,7 @@ shellexpand = { workspace = true } time = { workspace = true, features = ["serde", "std"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] } -tokio = { workspace = true, features = ["rt", "macros", "fs"] } +tokio = { workspace = true, features = ["rt", "macros", "fs", "signal"] } url = { workspace = true, features = ["serde"] } uuid = { workspace = true } zxcvbn = { workspace = true } diff --git a/tools/cli/src/cli/main.rs b/tools/cli/src/cli/main.rs index 6cb8d75a2..76e47ad0f 100644 --- a/tools/cli/src/cli/main.rs +++ b/tools/cli/src/cli/main.rs @@ -13,13 +13,48 @@ use clap::Parser; use kanidm_cli::KanidmClientParser; +use std::process::ExitCode; use std::thread; use tokio::runtime; use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::prelude::*; use tracing_subscriber::{fmt, EnvFilter}; -fn main() { +#[cfg(target_family = "unix")] +use tokio::signal::unix::{signal, SignalKind}; + +#[cfg(target_family = "unix")] +async fn signal_handler(opt: KanidmClientParser) -> ExitCode { + // We need a signal handler to deal with a few things that can occur during runtime, especially + // sigpipe on linux. + + let mut signal_quit = signal(SignalKind::quit()).expect("Invalid Signal"); + let mut signal_term = signal(SignalKind::terminate()).expect("Invalid Signal"); + let mut signal_pipe = signal(SignalKind::pipe()).expect("Invalid Signal"); + + tokio::select! { + _ = opt.commands.exec() => { + ExitCode::SUCCESS + } + _ = signal_quit.recv() => { + ExitCode::SUCCESS + } + _ = signal_term.recv() => { + ExitCode::SUCCESS + } + _ = signal_pipe.recv() => { + ExitCode::SUCCESS + } + } +} + +#[cfg(target_family = "windows")] +async fn signal_handler(opt: KanidmClientParser) -> ExitCode { + opt.commands.exec().await; + ExitCode::SUCCESS +} + +fn main() -> ExitCode { let opt = KanidmClientParser::parse(); let fmt_layer = fmt::layer().with_writer(std::io::stderr); @@ -30,7 +65,7 @@ fn main() { Ok(f) => f, Err(e) => { eprintln!("ERROR! Unable to start tracing {:?}", e); - return; + return ExitCode::FAILURE; } } } else { @@ -61,5 +96,5 @@ fn main() { #[cfg(debug_assertions)] tracing::debug!("Using {} worker threads", par_count); - rt.block_on(async { opt.commands.exec().await }); + rt.block_on(signal_handler(opt)) }