mod apidocs;
pub(crate) mod cache_buster;
pub(crate) mod errors;
mod extractors;
mod generic;
mod javascript;
mod manifest;
pub(crate) mod middleware;
mod oauth2;
pub(crate) mod trace;
mod v1;
mod v1_domain;
mod v1_oauth2;
mod v1_scim;
mod views;

use self::extractors::ClientConnInfo;
use self::javascript::*;
use crate::actors::{QueryServerReadV1, QueryServerWriteV1};
use crate::config::{Configuration, ServerRole};
use crate::CoreAction;

use axum::{
    body::Body,
    extract::connect_info::IntoMakeServiceWithConnectInfo,
    http::{HeaderMap, HeaderValue, Request},
    middleware::{from_fn, from_fn_with_state},
    response::Redirect,
    routing::*,
    Router,
};

use axum_extra::extract::cookie::CookieJar;
use compact_jwt::{error::JwtError, JwsCompact, JwsHs256Signer, JwsVerifier};
use futures::pin_mut;
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
use kanidm_proto::{constants::KSESSIONID, internal::COOKIE_AUTH_SESSION_ID};
use kanidmd_lib::{idm::ClientCertInfo, status::StatusActor};
use openssl::ssl::{Ssl, SslAcceptor};

use kanidm_lib_crypto::x509_cert::{der::Decode, x509_public_key_s256, Certificate};

use serde::de::DeserializeOwned;
use sketching::*;
use std::fmt::Write;
use tokio::{
    net::{TcpListener, TcpStream},
    sync::broadcast,
    sync::mpsc,
    task,
};
use tokio_openssl::SslStream;
use tower::Service;
use tower_http::{services::ServeDir, trace::TraceLayer};
use url::Url;
use uuid::Uuid;

use std::io::ErrorKind;
use std::path::PathBuf;
use std::pin::Pin;
use std::{net::SocketAddr, str::FromStr};

#[derive(Clone)]
pub struct ServerState {
    pub(crate) status_ref: &'static StatusActor,
    pub(crate) qe_w_ref: &'static QueryServerWriteV1,
    pub(crate) qe_r_ref: &'static QueryServerReadV1,
    // Store the token management parts.
    pub(crate) jws_signer: JwsHs256Signer,
    pub(crate) trust_x_forward_for: bool,
    pub(crate) csp_header: HeaderValue,
    pub(crate) origin: Url,
    pub(crate) domain: String,
    // This is set to true by default, and is only false on integration tests.
    pub(crate) secure_cookies: bool,
}

impl ServerState {
    /// Deserialize some input string validating that it was signed by our instance's
    /// HMAC signer. This is used for short lived server-only sessions and context
    /// data. This has applications in both accessing cookie content and header content.
    fn deserialise_from_str<T: DeserializeOwned>(&self, input: &str) -> Option<T> {
        match JwsCompact::from_str(input) {
            Ok(val) => match self.jws_signer.verify(&val) {
                Ok(val) => val.from_json::<T>().ok(),
                Err(err) => {
                    error!(?err, "Failed to deserialise JWT from request");
                    if matches!(err, JwtError::InvalidSignature) {
                        // The server has an ephemeral in memory HMAC signer. This is important as
                        // auth (login) sessions on one node shouldn't validate on another. Sessions
                        // that are shared beween nodes use the internal ECDSA signer.
                        //
                        // But because of this if the server restarts it rolls the key. Additionally
                        // it can occur if the load balancer isn't sticking sessions to the correct
                        // node. That can cause this error. So we want to specifically call it out
                        // to admins so they can investigate that the fault is occurring *outside*
                        // of kanidm.
                        warn!("Invalid Signature errors can occur if your instance restarted recently, if a load balancer is not configured for sticky sessions, or a session was tampered with.");
                    }
                    None
                }
            },
            Err(_) => None,
        }
    }

    #[instrument(level = "trace", skip_all)]
    fn get_current_auth_session_id(&self, headers: &HeaderMap, jar: &CookieJar) -> Option<Uuid> {
        // We see if there is a signed header copy first.
        headers
            .get(KSESSIONID)
            .and_then(|hv| {
                trace!("trying header");
                // Get the first header value.
                hv.to_str().ok()
            })
            .or_else(|| {
                trace!("trying cookie");
                jar.get(COOKIE_AUTH_SESSION_ID).map(|c| c.value())
            })
            .and_then(|s| {
                trace!(id_jws = %s);
                self.deserialise_from_str::<Uuid>(s)
            })
    }
}

pub(crate) fn get_js_files(role: ServerRole) -> Result<Vec<JavaScriptFile>, ()> {
    let mut all_pages: Vec<JavaScriptFile> = Vec::new();

    if !matches!(role, ServerRole::WriteReplicaNoUI) {
        // let's set up the list of js module hashes
        let pkg_path = env!("KANIDM_SERVER_UI_PKG_PATH").to_owned();

        let filelist = [
            "external/bootstrap.bundle.min.js",
            "external/htmx.min.1.9.12.js",
            "external/confetti.js",
            "external/base64.js",
            "modules/cred_update.mjs",
            "pkhtml.js",
            "style.js",
        ];

        for filepath in filelist {
            match generate_integrity_hash(format!("{}/{}", pkg_path, filepath,)) {
                Ok(hash) => {
                    debug!("Integrity hash for {}: {}", filepath, hash);
                    let js = JavaScriptFile { hash };
                    all_pages.push(js)
                }
                Err(err) => {
                    admin_error!(
                        ?err,
                        "Failed to generate integrity hash for {} - cancelling startup!",
                        filepath
                    );
                    return Err(());
                }
            }
        }
    }
    Ok(all_pages)
}

pub async fn create_https_server(
    config: Configuration,
    jws_signer: JwsHs256Signer,
    status_ref: &'static StatusActor,
    qe_w_ref: &'static QueryServerWriteV1,
    qe_r_ref: &'static QueryServerReadV1,
    server_message_tx: broadcast::Sender<CoreAction>,
    maybe_tls_acceptor: Option<SslAcceptor>,
    tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>,
) -> Result<task::JoinHandle<()>, ()> {
    let rx = server_message_tx.subscribe();

    let all_js_files = get_js_files(config.role)?;
    // set up the CSP headers
    // script-src 'self'
    //      'sha384-Zao7ExRXVZOJobzS/uMp0P1jtJz3TTqJU4nYXkdmsjpiVD+/wcwCyX7FGqRIqvIz'
    //      'sha384-MrcW6ZMFYlzcLA8Nl+NtUVF0sA7MsXsP1UyJoMp4YLEuNSfAP+JcXn/tWtIaxVXM';

    let js_directives = all_js_files
        .into_iter()
        .map(|f| f.hash)
        .collect::<Vec<String>>();

    let js_checksums: String = js_directives
        .iter()
        .fold(String::new(), |mut output, value| {
            let _ = write!(output, " 'sha384-{}'", value);
            output
        });

    let csp_header = format!(
        concat!(
            "default-src 'self'; ",
            "base-uri 'self' https:; ",
            "form-action 'self' https:;",
            "frame-ancestors 'none'; ",
            "img-src 'self' data:; ",
            "worker-src 'none'; ",
            "script-src 'self' 'unsafe-eval'{};",
        ),
        js_checksums
    );

    let csp_header = HeaderValue::from_str(&csp_header).map_err(|err| {
        error!(?err, "Unable to generate content security policy");
    })?;

    let trust_x_forward_for = config.trust_x_forward_for;

    let origin = Url::parse(&config.origin)
        // Should be impossible!
        .map_err(|err| {
            error!(?err, "Unable to parse origin URL - refusing to start. You must correct the value for origin. {:?}", config.origin);
        })?;

    let state = ServerState {
        status_ref,
        qe_w_ref,
        qe_r_ref,
        jws_signer,
        trust_x_forward_for,
        csp_header,
        origin,
        domain: config.domain.clone(),
        secure_cookies: config.integration_test_config.is_none(),
    };

    let static_routes = match config.role {
        ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
            Router::new()
                .route("/ui/images/oauth2/:rs_name", get(oauth2::oauth2_image_get))
                .route("/ui/images/domain", get(v1_domain::image_get))
                .route("/manifest.webmanifest", get(manifest::manifest)) // skip_route_check
                // Layers only apply to routes that are *already* added, not the ones
                // added after.
                .layer(middleware::compression::new())
                .layer(from_fn(middleware::caching::cache_me_short))
                .route("/", get(|| async { Redirect::to("/ui") }))
                .nest("/ui", views::view_router())
            // Can't compress on anything that changes
        }
        ServerRole::WriteReplicaNoUI => Router::new(),
    };
    let app = Router::new()
        .merge(oauth2::route_setup(state.clone()))
        .merge(v1_scim::route_setup())
        .merge(v1::route_setup(state.clone()))
        .route("/robots.txt", get(generic::robots_txt));

    let app = match config.role {
        ServerRole::WriteReplicaNoUI => app,
        ServerRole::WriteReplica | ServerRole::ReadOnlyReplica => {
            let pkg_path = PathBuf::from(env!("KANIDM_SERVER_UI_PKG_PATH"));
            if !pkg_path.exists() {
                eprintln!(
                    "Couldn't find htmx UI package path: ({}), quitting.",
                    env!("KANIDM_SERVER_UI_PKG_PATH")
                );
                std::process::exit(1);
            }
            let pkg_router = Router::new()
                .nest_service("/pkg", ServeDir::new(pkg_path))
                // TODO: Add in the br precompress
                .layer(from_fn(middleware::caching::cache_me_short));

            app.merge(pkg_router)
        }
    };

    // this sets up the default span which logs the URL etc.
    let trace_layer = TraceLayer::new_for_http()
        .make_span_with(trace::DefaultMakeSpanKanidmd::new())
        // setting these to trace because all they do is print "started processing request", and we are already doing that enough!
        .on_response(trace::DefaultOnResponseKanidmd::new());

    let app = app
        .merge(static_routes)
        .layer(from_fn_with_state(
            state.clone(),
            middleware::security_headers::security_headers_layer,
        ))
        .layer(from_fn(middleware::version_middleware))
        .layer(from_fn(
            middleware::hsts_header::strict_transport_security_layer,
        ));

    // layer which checks the responses have a content-type of JSON when we're in debug mode
    #[cfg(any(test, debug_assertions))]
    let app = app.layer(from_fn(middleware::are_we_json_yet));

    let app = app
        .route("/status", get(generic::status))
        // This must be the LAST middleware.
        // This is because the last middleware here is the first to be entered and the last
        // to be exited, and this middleware sets up ids' and other bits for for logging
        // coherence to be maintained.
        .layer(from_fn(middleware::kopid_middleware))
        .merge(apidocs::router())
        // this MUST be the last layer before with_state else the span never starts and everything breaks.
        .layer(trace_layer)
        .with_state(state)
        // the connect_info bit here lets us pick up the remote address of the client
        .into_make_service_with_connect_info::<ClientConnInfo>();

    let addr = SocketAddr::from_str(&config.address).map_err(|err| {
        error!(
            "Failed to parse address ({:?}) from config: {:?}",
            config.address, err
        );
    })?;

    info!("Starting the web server...");

    match maybe_tls_acceptor {
        Some(tls_acceptor) => {
            let listener = match TcpListener::bind(addr).await {
                Ok(l) => l,
                Err(err) => {
                    error!(?err, "Failed to bind tcp listener");
                    return Err(());
                }
            };
            Ok(task::spawn(server_loop(
                tls_acceptor,
                listener,
                app,
                rx,
                server_message_tx,
                tls_acceptor_reload_rx,
            )))
        }
        None => Ok(task::spawn(server_loop_plaintext(addr, app, rx))),
    }
}

async fn server_loop(
    mut tls_acceptor: SslAcceptor,
    listener: TcpListener,
    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
    mut rx: broadcast::Receiver<CoreAction>,
    server_message_tx: broadcast::Sender<CoreAction>,
    mut tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>,
) {
    pin_mut!(listener);

    loop {
        tokio::select! {
            Ok(action) = rx.recv() => {
                match action {
                    CoreAction::Shutdown => break,
                }
            }
            accept = listener.accept() => {
                match accept {
                    Ok((stream, addr)) => {
                        let tls_acceptor = tls_acceptor.clone();
                        let app = app.clone();
                        task::spawn(handle_conn(tls_acceptor, stream, app, addr));
                    }
                    Err(err) => {
                        error!("Web server exited with {:?}", err);
                        if let Err(err) = server_message_tx.send(CoreAction::Shutdown) {
                            error!("Web server failed to send shutdown message! {:?}", err)
                        };
                        break;
                    }
                }
            }
            Some(mut new_tls_acceptor) = tls_acceptor_reload_rx.recv() => {
                std::mem::swap(&mut tls_acceptor, &mut new_tls_acceptor);
                info!("Reloaded http tls acceptor");
            }
        }
    }

    info!("Stopped {}", super::TaskName::HttpsServer);
}

async fn server_loop_plaintext(
    addr: SocketAddr,
    app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
    mut rx: broadcast::Receiver<CoreAction>,
) {
    let listener = axum_server::bind(addr).serve(app);

    pin_mut!(listener);

    loop {
        tokio::select! {
            Ok(action) = rx.recv() => {
                match action {
                    CoreAction::Shutdown =>
                        break,
                }
            }
            _ = &mut listener => {}
        }
    }

    info!("Stopped {}", super::TaskName::HttpsServer);
}

/// This handles an individual connection.
pub(crate) async fn handle_conn(
    acceptor: SslAcceptor,
    stream: TcpStream,
    mut app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
    addr: SocketAddr,
) -> Result<(), std::io::Error> {
    let ssl = Ssl::new(acceptor.context()).map_err(|e| {
        error!("Failed to create TLS context: {:?}", e);
        std::io::Error::from(ErrorKind::ConnectionAborted)
    })?;

    let mut tls_stream = SslStream::new(ssl, stream).map_err(|err| {
        error!(?err, "Failed to create TLS stream");
        std::io::Error::from(ErrorKind::ConnectionAborted)
    })?;

    match SslStream::accept(Pin::new(&mut tls_stream)).await {
        Ok(_) => {
            // Process the client cert (if any)
            let client_cert = if let Some(peer_cert) = tls_stream.ssl().peer_certificate() {
                // TODO: This is where we should be checking the CRL!!!

                // Extract the cert from openssl to x509-cert which is a better
                // parser to handle the various extensions.

                let cert_der = peer_cert.to_der().map_err(|ossl_err| {
                    error!(?ossl_err, "unable to process x509 certificate as DER");
                    std::io::Error::from(ErrorKind::ConnectionAborted)
                })?;

                let certificate = Certificate::from_der(&cert_der).map_err(|ossl_err| {
                    error!(?ossl_err, "unable to process DER certificate to x509");
                    std::io::Error::from(ErrorKind::ConnectionAborted)
                })?;

                let public_key_s256 = x509_public_key_s256(&certificate).ok_or_else(|| {
                    error!("subject public key bitstring is not octet aligned");
                    std::io::Error::from(ErrorKind::ConnectionAborted)
                })?;

                Some(ClientCertInfo {
                    public_key_s256,
                    certificate,
                })
            } else {
                None
            };

            let client_conn_info = ClientConnInfo { addr, client_cert };

            debug!(?client_conn_info);

            let svc = axum_server::service::MakeService::<ClientConnInfo, hyper::Request<Body>>::make_service(
                &mut app,
                client_conn_info,
            );

            let svc = svc.await.map_err(|e| {
                error!("Failed to build HTTP response: {:?}", e);
                std::io::Error::from(ErrorKind::Other)
            })?;

            // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
            // `TokioIo` converts between them.
            let stream = TokioIo::new(tls_stream);

            // Hyper also has its own `Service` trait and doesn't use tower. We can use
            // `hyper::service::service_fn` to create a hyper `Service` that calls our app through
            // `tower::Service::call`.
            let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
                // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
                // tower's `Service` requires `&mut self`.
                //
                // We don't need to call `poll_ready` since `Router` is always ready.
                svc.clone().call(request)
            });

            hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
                .serve_connection_with_upgrades(stream, hyper_service)
                .await
                .map_err(|e| {
                    debug!("Failed to complete connection: {:?}", e);
                    std::io::Error::from(ErrorKind::ConnectionAborted)
                })
        }
        Err(error) => {
            trace!("Failed to handle connection: {:?}", error);
            Ok(())
        }
    }
}