diff --git a/Cargo.lock b/Cargo.lock index fe306e47e..743d6a71b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -188,7 +188,7 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acb1161c6b64d1c3d83108213c2a2533a342ac225aabd0bda218278c2ddb00c0" dependencies = [ - "nom", + "nom 7.1.3", ] [[package]] @@ -200,7 +200,7 @@ dependencies = [ "asn1-rs-derive", "asn1-rs-impl", "displaydoc", - "nom", + "nom 7.1.3", "num-traits", "rusticata-macros", "thiserror 1.0.69", @@ -675,7 +675,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" dependencies = [ - "nom", + "nom 7.1.3", ] [[package]] @@ -1148,7 +1148,7 @@ checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" dependencies = [ "asn1-rs", "displaydoc", - "nom", + "nom 7.1.3", "num-bigint", "num-traits", "rusticata-macros", @@ -1213,7 +1213,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9313f104b590510b46fc01c0a324fc76505c13871454d3c48490468d04c8d395" dependencies = [ "libc", - "nom", + "nom 7.1.3", ] [[package]] @@ -2271,6 +2271,18 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" +[[package]] +name = "haproxy-protocol" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f61fc527a2f089b57ebc09301b6371bbbff4ce7b547306c17dfa55766655bec6" +dependencies = [ + "hex", + "nom 8.0.0", + "tokio", + "tracing", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -3140,6 +3152,8 @@ dependencies = [ "filetime", "futures", "futures-util", + "haproxy-protocol", + "hashbrown 0.14.5", "hyper 1.6.0", "hyper-util", "kanidm_build_profiles", @@ -3247,6 +3261,10 @@ dependencies = [ "escargot", "fantoccini", "futures", + "hex", + "http-body-util", + "hyper 1.6.0", + "hyper-util", "jsonschema", "kanidm_build_profiles", "kanidm_client", @@ -3311,7 +3329,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2df7f9fd9f64cf8f59e1a4a0753fe7d575a5b38d3d7ac5758dcee9357d83ef0a" dependencies = [ "bytes", - "nom", + "nom 7.1.3", ] [[package]] @@ -3343,7 +3361,7 @@ dependencies = [ "base64 0.21.7", "bytes", "lber", - "nom", + "nom 7.1.3", "peg", "serde", "thiserror 1.0.69", @@ -3675,6 +3693,15 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nom" +version = "8.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] + [[package]] name = "nonempty" version = "0.8.1" @@ -4873,7 +4900,7 @@ version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" dependencies = [ - "nom", + "nom 7.1.3", ] [[package]] @@ -5364,7 +5391,7 @@ checksum = "34285eaade87ba166c4f17c0ae1e35d52659507db81888beae277e962b9e5a02" dependencies = [ "base64 0.21.7", "base64urlsafedata", - "nom", + "nom 7.1.3", "openssl", "serde", "serde_cbor_2", @@ -6341,7 +6368,7 @@ dependencies = [ "bitflags 1.3.2", "futures", "hex", - "nom", + "nom 7.1.3", "num-derive", "num-traits", "openssl", @@ -6386,7 +6413,7 @@ dependencies = [ "compact_jwt", "der-parser", "hex", - "nom", + "nom 7.1.3", "openssl", "rand 0.8.5", "rand_chacha 0.3.1", @@ -6888,7 +6915,7 @@ dependencies = [ "data-encoding", "der-parser", "lazy_static", - "nom", + "nom 7.1.3", "oid-registry", "rusticata-macros", "thiserror 1.0.69", diff --git a/Cargo.toml b/Cargo.toml index 400184969..863bfa9b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -177,9 +177,11 @@ fs4 = "^0.12.0" futures = "^0.3.31" futures-util = { version = "^0.3.30", features = ["sink"] } gix = { version = "0.64.0", default-features = false } +haproxy-protocol = { version = "0.0.1" } hashbrown = { version = "0.14.3", features = ["serde", "inline-more", "ahash"] } hex = "^0.4.3" http = "1.2.0" +http-body-util = "0.1" hyper = { version = "1.5.1", features = [ "full", ] } # hyper full includes client/server/http2 diff --git a/examples/server.toml b/examples/server.toml index 9a41738c5..5a69fbbcc 100644 --- a/examples/server.toml +++ b/examples/server.toml @@ -13,16 +13,6 @@ bindaddress = "[::]:443" # Defaults to "" (disabled) # ldapbindaddress = "[::]:636" # -# HTTPS requests can be reverse proxied by a loadbalancer. -# To preserve the original IP of the caller, these systems -# will often add a header such as "Forwarded" or -# "X-Forwarded-For". If set to true, then this header is -# respected as the "authoritative" source of the IP of the -# connected client. If you are not using a load balancer -# then you should leave this value as default. -# Defaults to false -# trust_x_forward_for = false -# # The path to the kanidm database. db_path = "/var/lib/private/kanidm/kanidm.db" # @@ -86,6 +76,32 @@ domain = "idm.example.com" # origin = "https://idm.example.com" origin = "https://idm.example.com:8443" # + +# HTTPS requests can be reverse proxied by a loadbalancer. +# To preserve the original IP of the caller, these systems +# will often add a header such as "Forwarded" or +# "X-Forwarded-For". Some other proxies can use the PROXY +# protocol v2 header. +# This setting allows configuration of the range of trusted +# IPs which can supply this header information, and which +# format the information is provided in. +# Defaults to "none" (no trusted sources) +# Only one option can be used at a time. +# [http_client_address_info] +# proxy-v2 = ["127.0.0.1"] +# # OR +# x-forward-for = ["127.0.0.1"] + +# LDAPS requests can be reverse proxied by a loadbalancer. +# To preserve the original IP of the caller, these systems +# can add a header such as the PROXY protocol v2 header. +# This setting allows configuration of the range of trusted +# IPs which can supply this header information, and which +# format the information is provided in. +# Defaults to "none" (no trusted sources) +# [ldap_client_address_info] +# proxy-v2 = ["127.0.0.1"] + [online_backup] # The path to the output folder for online backups path = "/var/lib/private/kanidm/backups/" diff --git a/examples/server_container.toml b/examples/server_container.toml index f57923a40..2d706b77d 100644 --- a/examples/server_container.toml +++ b/examples/server_container.toml @@ -13,16 +13,6 @@ bindaddress = "[::]:8443" # Defaults to "" (disabled) # ldapbindaddress = "[::]:3636" # -# HTTPS requests can be reverse proxied by a loadbalancer. -# To preserve the original IP of the caller, these systems -# will often add a header such as "Forwarded" or -# "X-Forwarded-For". If set to true, then this header is -# respected as the "authoritative" source of the IP of the -# connected client. If you are not using a load balancer -# then you should leave this value as default. -# Defaults to false -# trust_x_forward_for = false -# # The path to the kanidm database. db_path = "/data/kanidm.db" # @@ -85,7 +75,32 @@ domain = "idm.example.com" # not consistent, the server WILL refuse to start! # origin = "https://idm.example.com" origin = "https://idm.example.com:8443" -# + +# HTTPS requests can be reverse proxied by a loadbalancer. +# To preserve the original IP of the caller, these systems +# will often add a header such as "Forwarded" or +# "X-Forwarded-For". Some other proxies can use the PROXY +# protocol v2 header. +# This setting allows configuration of the range of trusted +# IPs which can supply this header information, and which +# format the information is provided in. +# Defaults to "none" (no trusted sources) +# Only one option can be used at a time. +# [http_client_address_info] +# proxy-v2 = ["127.0.0.1"] +# # OR +# x-forward-for = ["127.0.0.1"] + +# LDAPS requests can be reverse proxied by a loadbalancer. +# To preserve the original IP of the caller, these systems +# can add a header such as the PROXY protocol v2 header. +# This setting allows configuration of the range of trusted +# IPs which can supply this header information, and which +# format the information is provided in. +# Defaults to "none" (no trusted sources) +# [ldap_client_address_info] +# proxy-v2 = ["127.0.0.1"] + [online_backup] # The path to the output folder for online backups path = "/data/kanidm/backups/" diff --git a/server/core/Cargo.toml b/server/core/Cargo.toml index 1c753cb44..89cf3c59a 100644 --- a/server/core/Cargo.toml +++ b/server/core/Cargo.toml @@ -34,6 +34,8 @@ cron = { workspace = true } filetime = { workspace = true } futures = { workspace = true } futures-util = { workspace = true } +haproxy-protocol = { workspace = true, features = ["tokio"] } +hashbrown = { workspace = true } hyper = { workspace = true } hyper-util = { workspace = true } kanidm_proto = { workspace = true } diff --git a/server/core/src/config.rs b/server/core/src/config.rs index ad3d3bd9c..01bf005fd 100644 --- a/server/core/src/config.rs +++ b/server/core/src/config.rs @@ -4,18 +4,18 @@ //! These components should be "per server". Any "per domain" config should be in the system //! or domain entries that are able to be replicated. -use std::fmt::{self, Display}; -use std::fs::File; -use std::io::Read; -use std::path::{Path, PathBuf}; -use std::str::FromStr; - +use hashbrown::HashSet; use kanidm_proto::constants::DEFAULT_SERVER_ADDRESS; use kanidm_proto::internal::FsType; use kanidm_proto::messages::ConsoleOutputMode; - use serde::Deserialize; use sketching::LogLevel; +use std::fmt::{self, Display}; +use std::fs::File; +use std::io::Read; +use std::net::IpAddr; +use std::path::{Path, PathBuf}; +use std::str::FromStr; use url::Url; use crate::repl::config::ReplicationConfiguration; @@ -100,6 +100,111 @@ pub struct TlsConfiguration { pub client_ca: Option<PathBuf>, } +#[derive(Deserialize, Debug, Clone, Default)] +pub enum LdapAddressInfo { + #[default] + None, + #[serde(rename = "proxy-v2")] + ProxyV2(HashSet<IpAddr>), +} + +impl LdapAddressInfo { + pub fn trusted_proxy_v2(&self) -> Option<HashSet<IpAddr>> { + if let Self::ProxyV2(trusted) = self { + Some(trusted.clone()) + } else { + None + } + } +} + +impl Display for LdapAddressInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::None => f.write_str("none"), + Self::ProxyV2(trusted) => { + f.write_str("proxy-v2 [ ")?; + for ip in trusted { + write!(f, "{} ", ip)?; + } + f.write_str("]") + } + } + } +} + +pub(crate) enum AddressSet { + NonContiguousIpSet(HashSet<IpAddr>), + All, +} + +impl AddressSet { + pub(crate) fn contains(&self, ip_addr: &IpAddr) -> bool { + match self { + Self::All => true, + Self::NonContiguousIpSet(range) => range.contains(ip_addr), + } + } +} + +#[derive(Deserialize, Debug, Clone, Default)] +pub enum HttpAddressInfo { + #[default] + None, + #[serde(rename = "x-forward-for")] + XForwardFor(HashSet<IpAddr>), + // IMPORTANT: This is undocumented, and only exists for backwards compat + // with config v1 which has a boolean toggle for this option. + #[serde(rename = "x-forward-for-all-source-trusted")] + XForwardForAllSourcesTrusted, + #[serde(rename = "proxy-v2")] + ProxyV2(HashSet<IpAddr>), +} + +impl HttpAddressInfo { + pub(crate) fn trusted_x_forward_for(&self) -> Option<AddressSet> { + match self { + Self::XForwardForAllSourcesTrusted => Some(AddressSet::All), + Self::XForwardFor(trusted) => Some(AddressSet::NonContiguousIpSet(trusted.clone())), + _ => None, + } + } + + pub(crate) fn trusted_proxy_v2(&self) -> Option<HashSet<IpAddr>> { + if let Self::ProxyV2(trusted) = self { + Some(trusted.clone()) + } else { + None + } + } +} + +impl Display for HttpAddressInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::None => f.write_str("none"), + + Self::XForwardFor(trusted) => { + f.write_str("x-forward-for [ ")?; + for ip in trusted { + write!(f, "{} ", ip)?; + } + f.write_str("]") + } + Self::XForwardForAllSourcesTrusted => { + f.write_str("x-forward-for [ ALL SOURCES TRUSTED ]") + } + Self::ProxyV2(trusted) => { + f.write_str("proxy-v2 [ ")?; + for ip in trusted { + write!(f, "{} ", ip)?; + } + f.write_str("]") + } + } + } +} + /// This is the Server Configuration as read from `server.toml` or environment variables. /// /// Fields noted as "REQUIRED" are required for the server to start, even if they show as optional due to how file parsing works. @@ -217,7 +322,10 @@ pub struct ServerConfigV2 { role: Option<ServerRole>, log_level: Option<LogLevel>, online_backup: Option<OnlineBackup>, - trust_x_forward_for: Option<bool>, + + http_client_address_info: Option<HttpAddressInfo>, + ldap_client_address_info: Option<LdapAddressInfo>, + adminbindpath: Option<String>, thread_count: Option<usize>, maximum_request_size_bytes: Option<usize>, @@ -490,7 +598,10 @@ pub struct Configuration { pub db_fs_type: Option<FsType>, pub db_arc_size: Option<usize>, pub maximum_request: usize, - pub trust_x_forward_for: bool, + + pub http_client_address_info: HttpAddressInfo, + pub ldap_client_address_info: LdapAddressInfo, + pub tls_config: Option<TlsConfiguration>, pub integration_test_config: Option<Box<IntegrationTestConfig>>, pub online_backup: Option<OnlineBackup>, @@ -522,7 +633,8 @@ impl Configuration { db_fs_type: None, db_arc_size: None, maximum_request: 256 * 1024, // 256k - trust_x_forward_for: None, + http_client_address_info: HttpAddressInfo::default(), + ldap_client_address_info: LdapAddressInfo::default(), tls_key: None, tls_chain: None, tls_client_ca: None, @@ -547,7 +659,8 @@ impl Configuration { db_fs_type: None, db_arc_size: None, maximum_request: 256 * 1024, // 256k - trust_x_forward_for: false, + http_client_address_info: HttpAddressInfo::default(), + ldap_client_address_info: LdapAddressInfo::default(), tls_config: None, integration_test_config: None, online_backup: None, @@ -587,7 +700,17 @@ impl fmt::Display for Configuration { None => write!(f, "arcsize: AUTO, "), }?; write!(f, "max request size: {}b, ", self.maximum_request)?; - write!(f, "trust X-Forwarded-For: {}, ", self.trust_x_forward_for)?; + write!( + f, + "http client address info: {}, ", + self.http_client_address_info + )?; + write!( + f, + "ldap client address info: {}, ", + self.ldap_client_address_info + )?; + write!(f, "with TLS: {}, ", self.tls_config.is_some())?; match &self.online_backup { Some(bck) => write!( @@ -642,7 +765,8 @@ pub struct ConfigurationBuilder { db_fs_type: Option<FsType>, db_arc_size: Option<usize>, maximum_request: usize, - trust_x_forward_for: Option<bool>, + http_client_address_info: HttpAddressInfo, + ldap_client_address_info: LdapAddressInfo, tls_key: Option<PathBuf>, tls_chain: Option<PathBuf>, tls_client_ca: Option<PathBuf>, @@ -691,8 +815,8 @@ impl ConfigurationBuilder { self.db_arc_size = env_config.db_arc_size; } - if env_config.trust_x_forward_for.is_some() { - self.trust_x_forward_for = env_config.trust_x_forward_for; + if env_config.trust_x_forward_for == Some(true) { + self.http_client_address_info = HttpAddressInfo::XForwardForAllSourcesTrusted; } if env_config.tls_key.is_some() { @@ -813,8 +937,8 @@ impl ConfigurationBuilder { self.db_arc_size = config.db_arc_size; } - if config.trust_x_forward_for.is_some() { - self.trust_x_forward_for = config.trust_x_forward_for; + if config.trust_x_forward_for == Some(true) { + self.http_client_address_info = HttpAddressInfo::XForwardForAllSourcesTrusted; } if config.online_backup.is_some() { @@ -893,8 +1017,12 @@ impl ConfigurationBuilder { self.db_arc_size = config.db_arc_size; } - if config.trust_x_forward_for.is_some() { - self.trust_x_forward_for = config.trust_x_forward_for; + if let Some(http_client_address_info) = config.http_client_address_info { + self.http_client_address_info = http_client_address_info + } + + if let Some(ldap_client_address_info) = config.ldap_client_address_info { + self.ldap_client_address_info = ldap_client_address_info } if config.online_backup.is_some() { @@ -930,7 +1058,8 @@ impl ConfigurationBuilder { db_fs_type, db_arc_size, maximum_request, - trust_x_forward_for, + http_client_address_info, + ldap_client_address_info, tls_key, tls_chain, tls_client_ca, @@ -986,7 +1115,6 @@ impl ConfigurationBuilder { let adminbindpath = adminbindpath.unwrap_or(env!("KANIDM_SERVER_ADMIN_BIND_PATH").to_string()); let address = bindaddress.unwrap_or(DEFAULT_SERVER_ADDRESS.to_string()); - let trust_x_forward_for = trust_x_forward_for.unwrap_or_default(); let output_mode = output_mode.unwrap_or_default(); let role = role.unwrap_or(ServerRole::WriteReplica); let log_level = log_level.unwrap_or_default(); @@ -1000,7 +1128,8 @@ impl ConfigurationBuilder { db_fs_type, db_arc_size, maximum_request, - trust_x_forward_for, + http_client_address_info, + ldap_client_address_info, tls_config, online_backup, domain, diff --git a/server/core/src/https/extractors/mod.rs b/server/core/src/https/extractors/mod.rs index 4d3fd686f..105b4c680 100644 --- a/server/core/src/https/extractors/mod.rs +++ b/server/core/src/https/extractors/mod.rs @@ -5,7 +5,6 @@ use axum::{ http::{ header::HeaderName, header::AUTHORIZATION as AUTHORISATION, request::Parts, StatusCode, }, - serve::IncomingStream, RequestPartsExt, }; @@ -40,7 +39,8 @@ impl FromRequestParts<ServerState> for TrustedClientIp { state: &ServerState, ) -> Result<Self, Self::Rejection> { let ConnectInfo(ClientConnInfo { - addr, + connection_addr, + client_addr, client_cert: _, }) = parts .extract::<ConnectInfo<ClientConnInfo>>() @@ -53,7 +53,13 @@ impl FromRequestParts<ServerState> for TrustedClientIp { ) })?; - let ip_addr = if state.trust_x_forward_for { + let trust_x_forward_for = state + .trust_x_forward_for_ips + .as_ref() + .map(|range| range.contains(&connection_addr.ip())) + .unwrap_or_default(); + + let ip_addr = if trust_x_forward_for { if let Some(x_forward_for) = parts.headers.get(X_FORWARDED_FOR_HEADER) { // X forward for may be comma separated. let first = x_forward_for @@ -75,10 +81,14 @@ impl FromRequestParts<ServerState> for TrustedClientIp { ) })? } else { - addr.ip() + client_addr.ip() } } else { - addr.ip() + // This can either be the client_addr == connection_addr if there are + // no ip address trust sources, or this is the value as reported by + // proxy protocol header. If the proxy protocol header is used, then + // trust_x_forward_for can never have been true so we catch here. + client_addr.ip() }; Ok(TrustedClientIp(ip_addr)) @@ -97,7 +107,11 @@ impl FromRequestParts<ServerState> for VerifiedClientInformation { parts: &mut Parts, state: &ServerState, ) -> Result<Self, Self::Rejection> { - let ConnectInfo(ClientConnInfo { addr, client_cert }) = parts + let ConnectInfo(ClientConnInfo { + connection_addr, + client_addr, + client_cert, + }) = parts .extract::<ConnectInfo<ClientConnInfo>>() .await .map_err(|_| { @@ -108,7 +122,13 @@ impl FromRequestParts<ServerState> for VerifiedClientInformation { ) })?; - let ip_addr = if state.trust_x_forward_for { + let trust_x_forward_for = state + .trust_x_forward_for_ips + .as_ref() + .map(|range| range.contains(&connection_addr.ip())) + .unwrap_or_default(); + + let ip_addr = if trust_x_forward_for { if let Some(x_forward_for) = parts.headers.get(X_FORWARDED_FOR_HEADER) { // X forward for may be comma separated. let first = x_forward_for @@ -130,10 +150,10 @@ impl FromRequestParts<ServerState> for VerifiedClientInformation { ) })? } else { - addr.ip() + client_addr.ip() } } else { - addr.ip() + client_addr.ip() }; let (basic_authz, bearer_token) = if let Some(header) = parts.headers.get(AUTHORISATION) { @@ -201,30 +221,30 @@ impl FromRequestParts<ServerState> for DomainInfo { #[derive(Debug, Clone)] pub struct ClientConnInfo { - pub addr: SocketAddr, + /// This is the address that is *connected* to Kanidm right now + /// for this operation. + #[allow(dead_code)] + pub connection_addr: SocketAddr, + /// This is the client address as reported by a remote IP source + /// such as x-forward-for or the PROXY protocol header + pub client_addr: SocketAddr, // Only set if the certificate is VALID pub client_cert: Option<ClientCertInfo>, } +// This is the normal way that our extractors get the ip info impl Connected<ClientConnInfo> for ClientConnInfo { fn connect_info(target: ClientConnInfo) -> Self { target } } +// This is only used for plaintext http - in other words, integration tests only. impl Connected<SocketAddr> for ClientConnInfo { - fn connect_info(addr: SocketAddr) -> Self { + fn connect_info(connection_addr: SocketAddr) -> Self { ClientConnInfo { - addr, - client_cert: None, - } - } -} - -impl Connected<IncomingStream<'_>> for ClientConnInfo { - fn connect_info(target: IncomingStream<'_>) -> Self { - ClientConnInfo { - addr: target.remote_addr(), + client_addr: connection_addr, + connection_addr, client_cert: None, } } diff --git a/server/core/src/https/mod.rs b/server/core/src/https/mod.rs index 645f35202..1af317b03 100644 --- a/server/core/src/https/mod.rs +++ b/server/core/src/https/mod.rs @@ -17,9 +17,8 @@ mod views; use self::extractors::ClientConnInfo; use self::javascript::*; use crate::actors::{QueryServerReadV1, QueryServerWriteV1}; -use crate::config::{Configuration, ServerRole}; +use crate::config::{AddressSet, Configuration, ServerRole}; use crate::CoreAction; - use axum::{ body::Body, extract::connect_info::IntoMakeServiceWithConnectInfo, @@ -29,22 +28,28 @@ use axum::{ routing::*, Router, }; - use axum_extra::extract::cookie::CookieJar; use compact_jwt::{error::JwtError, JwsCompact, JwsHs256Signer, JwsVerifier}; use futures::pin_mut; +use haproxy_protocol::{ProxyHdrV2, RemoteAddress}; +use hashbrown::HashSet; use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; +use kanidm_lib_crypto::x509_cert::{der::Decode, x509_public_key_s256, Certificate}; use kanidm_proto::{constants::KSESSIONID, internal::COOKIE_AUTH_SESSION_ID}; use kanidmd_lib::{idm::ClientCertInfo, status::StatusActor}; use openssl::ssl::{Ssl, SslAcceptor}; - -use kanidm_lib_crypto::x509_cert::{der::Decode, x509_public_key_s256, Certificate}; - use serde::de::DeserializeOwned; use sketching::*; use std::fmt::Write; +use std::io::ErrorKind; +use std::net::IpAddr; +use std::path::PathBuf; +use std::pin::Pin; +use std::sync::Arc; +use std::{net::SocketAddr, str::FromStr}; use tokio::{ + io::{AsyncRead, AsyncWrite}, net::{TcpListener, TcpStream}, sync::broadcast, sync::mpsc, @@ -56,11 +61,6 @@ use tower_http::{services::ServeDir, trace::TraceLayer}; use url::Url; use uuid::Uuid; -use std::io::ErrorKind; -use std::path::PathBuf; -use std::pin::Pin; -use std::{net::SocketAddr, str::FromStr}; - #[derive(Clone)] pub struct ServerState { pub(crate) status_ref: &'static StatusActor, @@ -68,7 +68,7 @@ pub struct ServerState { pub(crate) qe_r_ref: &'static QueryServerReadV1, // Store the token management parts. pub(crate) jws_signer: JwsHs256Signer, - pub(crate) trust_x_forward_for: bool, + pub(crate) trust_x_forward_for_ips: Option<Arc<AddressSet>>, pub(crate) csp_header: HeaderValue, pub(crate) origin: Url, pub(crate) domain: String, @@ -211,7 +211,15 @@ pub async fn create_https_server( error!(?err, "Unable to generate content security policy"); })?; - let trust_x_forward_for = config.trust_x_forward_for; + let trust_x_forward_for_ips = config + .http_client_address_info + .trusted_x_forward_for() + .map(Arc::new); + + let trusted_proxy_v2_ips = config + .http_client_address_info + .trusted_proxy_v2() + .map(Arc::new); let origin = Url::parse(&config.origin) // Should be impossible! @@ -224,7 +232,7 @@ pub async fn create_https_server( qe_w_ref, qe_r_ref, jws_signer, - trust_x_forward_for, + trust_x_forward_for_ips, csp_header, origin, domain: config.domain.clone(), @@ -321,35 +329,41 @@ pub async fn create_https_server( info!("Starting the web server..."); - match maybe_tls_acceptor { - Some(tls_acceptor) => { - let listener = match TcpListener::bind(addr).await { - Ok(l) => l, - Err(err) => { - error!(?err, "Failed to bind tcp listener"); - return Err(()); - } - }; - Ok(task::spawn(server_loop( - tls_acceptor, - listener, - app, - rx, - server_message_tx, - tls_acceptor_reload_rx, - ))) + let listener = match TcpListener::bind(addr).await { + Ok(l) => l, + Err(err) => { + error!(?err, "Failed to bind tcp listener"); + return Err(()); } - None => Ok(task::spawn(server_loop_plaintext(addr, app, rx))), + }; + + match maybe_tls_acceptor { + Some(tls_acceptor) => Ok(task::spawn(server_tls_loop( + tls_acceptor, + listener, + app, + rx, + server_message_tx, + tls_acceptor_reload_rx, + trusted_proxy_v2_ips, + ))), + None => Ok(task::spawn(server_plaintext_loop( + listener, + app, + rx, + trusted_proxy_v2_ips, + ))), } } -async fn server_loop( +async fn server_tls_loop( mut tls_acceptor: SslAcceptor, listener: TcpListener, app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>, mut rx: broadcast::Receiver<CoreAction>, server_message_tx: broadcast::Sender<CoreAction>, mut tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>, + trusted_proxy_v2_ips: Option<Arc<HashSet<IpAddr>>>, ) { pin_mut!(listener); @@ -365,7 +379,7 @@ async fn server_loop( Ok((stream, addr)) => { let tls_acceptor = tls_acceptor.clone(); let app = app.clone(); - task::spawn(handle_conn(tls_acceptor, stream, app, addr)); + task::spawn(handle_tls_conn(tls_acceptor, stream, app, addr, trusted_proxy_v2_ips.clone())); } Err(err) => { error!("Web server exited with {:?}", err); @@ -386,24 +400,33 @@ async fn server_loop( info!("Stopped {}", super::TaskName::HttpsServer); } -async fn server_loop_plaintext( - addr: SocketAddr, +async fn server_plaintext_loop( + listener: TcpListener, app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>, mut rx: broadcast::Receiver<CoreAction>, + trusted_proxy_v2_ips: Option<Arc<HashSet<IpAddr>>>, ) { - let listener = axum_server::bind(addr).serve(app); - pin_mut!(listener); loop { tokio::select! { Ok(action) = rx.recv() => { match action { - CoreAction::Shutdown => - break, + CoreAction::Shutdown => break, + } + } + accept = listener.accept() => { + match accept { + Ok((stream, addr)) => { + let app = app.clone(); + task::spawn(handle_conn(stream, app, addr, trusted_proxy_v2_ips.clone())); + } + Err(err) => { + error!("Web server exited with {:?}", err); + break; + } } } - _ = &mut listener => {} } } @@ -412,11 +435,38 @@ async fn server_loop_plaintext( /// This handles an individual connection. pub(crate) async fn handle_conn( + stream: TcpStream, + app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>, + connection_addr: SocketAddr, + trusted_proxy_v2_ips: Option<Arc<HashSet<IpAddr>>>, +) -> Result<(), std::io::Error> { + let (stream, client_addr) = + process_client_addr(stream, connection_addr, trusted_proxy_v2_ips).await?; + + let client_conn_info = ClientConnInfo { + connection_addr, + client_addr, + client_cert: None, + }; + + // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. + // `TokioIo` converts between them. + let stream = TokioIo::new(stream); + + process_client_hyper(stream, app, client_conn_info).await +} + +/// This handles an individual connection. +pub(crate) async fn handle_tls_conn( acceptor: SslAcceptor, stream: TcpStream, - mut app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>, - addr: SocketAddr, + app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>, + connection_addr: SocketAddr, + trusted_proxy_v2_ips: Option<Arc<HashSet<IpAddr>>>, ) -> Result<(), std::io::Error> { + let (stream, client_addr) = + process_client_addr(stream, connection_addr, trusted_proxy_v2_ips).await?; + let ssl = Ssl::new(acceptor.context()).map_err(|e| { error!("Failed to create TLS context: {:?}", e); std::io::Error::from(ErrorKind::ConnectionAborted) @@ -459,42 +509,17 @@ pub(crate) async fn handle_conn( None }; - let client_conn_info = ClientConnInfo { addr, client_cert }; - - debug!(?client_conn_info); - - let svc = axum_server::service::MakeService::<ClientConnInfo, hyper::Request<Body>>::make_service( - &mut app, - client_conn_info, - ); - - let svc = svc.await.map_err(|e| { - error!("Failed to build HTTP response: {:?}", e); - std::io::Error::from(ErrorKind::Other) - })?; + let client_conn_info = ClientConnInfo { + connection_addr, + client_addr, + client_cert, + }; // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. // `TokioIo` converts between them. let stream = TokioIo::new(tls_stream); - // Hyper also has its own `Service` trait and doesn't use tower. We can use - // `hyper::service::service_fn` to create a hyper `Service` that calls our app through - // `tower::Service::call`. - let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| { - // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas - // tower's `Service` requires `&mut self`. - // - // We don't need to call `poll_ready` since `Router` is always ready. - svc.clone().call(request) - }); - - hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) - .serve_connection_with_upgrades(stream, hyper_service) - .await - .map_err(|e| { - debug!("Failed to complete connection: {:?}", e); - std::io::Error::from(ErrorKind::ConnectionAborted) - }) + process_client_hyper(stream, app, client_conn_info).await } Err(error) => { trace!("Failed to handle connection: {:?}", error); @@ -502,3 +527,83 @@ pub(crate) async fn handle_conn( } } } + +async fn process_client_addr( + stream: TcpStream, + connection_addr: SocketAddr, + trusted_proxy_v2_ips: Option<Arc<HashSet<IpAddr>>>, +) -> Result<(TcpStream, SocketAddr), std::io::Error> { + let enable_proxy_v2_hdr = trusted_proxy_v2_ips + .map(|trusted| trusted.contains(&connection_addr.ip())) + .unwrap_or_default(); + + let (stream, client_addr) = if enable_proxy_v2_hdr { + match ProxyHdrV2::parse_from_read(stream).await { + Ok((stream, hdr)) => { + let remote_socket_addr = match hdr.to_remote_addr() { + RemoteAddress::Local => { + debug!("PROXY protocol liveness check - will not contain client data"); + return Err(std::io::Error::from(ErrorKind::ConnectionAborted)); + } + RemoteAddress::TcpV4 { src, dst: _ } => SocketAddr::from(src), + RemoteAddress::TcpV6 { src, dst: _ } => SocketAddr::from(src), + remote_addr => { + error!(?remote_addr, "remote address in proxy header is invalid"); + return Err(std::io::Error::from(ErrorKind::ConnectionAborted)); + } + }; + + (stream, remote_socket_addr) + } + Err(err) => { + error!(?connection_addr, ?err, "Unable to process proxy v2 header"); + return Err(std::io::Error::from(ErrorKind::ConnectionAborted)); + } + } + } else { + (stream, connection_addr) + }; + + Ok((stream, client_addr)) +} + +async fn process_client_hyper<T>( + stream: TokioIo<T>, + mut app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>, + client_conn_info: ClientConnInfo, +) -> Result<(), std::io::Error> +where + T: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static, +{ + debug!(?client_conn_info); + + let svc = + axum_server::service::MakeService::<ClientConnInfo, hyper::Request<Body>>::make_service( + &mut app, + client_conn_info, + ); + + let svc = svc.await.map_err(|e| { + error!("Failed to build HTTP response: {:?}", e); + std::io::Error::from(ErrorKind::Other) + })?; + + // Hyper also has its own `Service` trait and doesn't use tower. We can use + // `hyper::service::service_fn` to create a hyper `Service` that calls our app through + // `tower::Service::call`. + let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| { + // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas + // tower's `Service` requires `&mut self`. + // + // We don't need to call `poll_ready` since `Router` is always ready. + svc.clone().call(request) + }); + + hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(stream, hyper_service) + .await + .map_err(|e| { + debug!("Failed to complete connection: {:?}", e); + std::io::Error::from(ErrorKind::ConnectionAborted) + }) +} diff --git a/server/core/src/ldaps.rs b/server/core/src/ldaps.rs index ca57a7e1b..9ce9f01b7 100644 --- a/server/core/src/ldaps.rs +++ b/server/core/src/ldaps.rs @@ -2,14 +2,17 @@ use crate::actors::QueryServerReadV1; use crate::CoreAction; use futures_util::sink::SinkExt; use futures_util::stream::StreamExt; +use haproxy_protocol::{ProxyHdrV2, RemoteAddress}; +use hashbrown::HashSet; use kanidmd_lib::idm::ldap::{LdapBoundToken, LdapResponseState}; use kanidmd_lib::prelude::*; use ldap3_proto::proto::LdapMsg; use ldap3_proto::LdapCodec; use openssl::ssl::{Ssl, SslAcceptor}; -use std::net; +use std::net::{IpAddr, SocketAddr}; use std::pin::Pin; use std::str::FromStr; +use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::broadcast; @@ -33,7 +36,7 @@ impl LdapSession { #[instrument(name = "ldap-request", skip(client_address, qe_r_ref))] async fn client_process_msg( uat: Option<LdapBoundToken>, - client_address: net::SocketAddr, + client_address: SocketAddr, protomsg: LdapMsg, qe_r_ref: &'static QueryServerReadV1, ) -> Option<LdapResponseState> { @@ -50,7 +53,8 @@ async fn client_process_msg( async fn client_process<STREAM>( stream: STREAM, - client_address: net::SocketAddr, + client_address: SocketAddr, + connection_address: SocketAddr, qe_r_ref: &'static QueryServerReadV1, ) where STREAM: AsyncRead + AsyncWrite, @@ -67,6 +71,8 @@ async fn client_process<STREAM>( let uat = session.uat.clone(); let caddr = client_address; + debug!(?client_address, ?connection_address); + match client_process_msg(uat, caddr, protomsg, qe_r_ref).await { // I'd really have liked to have put this near the [LdapResponseState::Bind] but due // to the handing of `audit` it isn't possible due to borrows, etc. @@ -112,28 +118,65 @@ async fn client_process<STREAM>( } async fn client_tls_accept( - tcpstream: TcpStream, + stream: TcpStream, tls_acceptor: SslAcceptor, - client_socket_addr: net::SocketAddr, + connection_addr: SocketAddr, qe_r_ref: &'static QueryServerReadV1, + trusted_proxy_v2_ips: Option<Arc<HashSet<IpAddr>>>, ) { + let enable_proxy_v2_hdr = trusted_proxy_v2_ips + .map(|trusted| trusted.contains(&connection_addr.ip())) + .unwrap_or_default(); + + let (stream, client_addr) = if enable_proxy_v2_hdr { + match ProxyHdrV2::parse_from_read(stream).await { + Ok((stream, hdr)) => { + let remote_socket_addr = match hdr.to_remote_addr() { + RemoteAddress::Local => { + debug!("PROXY protocol liveness check - will not contain client data"); + return; + } + RemoteAddress::TcpV4 { src, dst: _ } => SocketAddr::from(src), + RemoteAddress::TcpV6 { src, dst: _ } => SocketAddr::from(src), + remote_addr => { + error!(?remote_addr, "remote address in proxy header is invalid"); + return; + } + }; + + (stream, remote_socket_addr) + } + Err(err) => { + error!(?connection_addr, ?err, "Unable to process proxy v2 header"); + return; + } + } + } else { + (stream, connection_addr) + }; + // Start the event // From the parameters we need to create an SslContext. let mut tlsstream = match Ssl::new(tls_acceptor.context()) - .and_then(|tls_obj| SslStream::new(tls_obj, tcpstream)) + .and_then(|tls_obj| SslStream::new(tls_obj, stream)) { Ok(ta) => ta, Err(err) => { - error!(?err, %client_socket_addr, "LDAP TLS setup error"); + error!(?err, %client_addr, %connection_addr, "LDAP TLS setup error"); return; } }; if let Err(err) = SslStream::accept(Pin::new(&mut tlsstream)).await { - error!(?err, %client_socket_addr, "LDAP TLS accept error"); + error!(?err, %client_addr, %connection_addr, "LDAP TLS accept error"); return; }; - tokio::spawn(client_process(tlsstream, client_socket_addr, qe_r_ref)); + tokio::spawn(client_process( + tlsstream, + client_addr, + connection_addr, + qe_r_ref, + )); } /// TLS LDAP Listener, hands off to [client_tls_accept] @@ -143,6 +186,7 @@ async fn ldap_tls_acceptor( qe_r_ref: &'static QueryServerReadV1, mut rx: broadcast::Receiver<CoreAction>, mut tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>, + trusted_proxy_v2_ips: Option<Arc<HashSet<IpAddr>>>, ) { loop { tokio::select! { @@ -155,7 +199,7 @@ async fn ldap_tls_acceptor( match accept_result { Ok((tcpstream, client_socket_addr)) => { let clone_tls_acceptor = tls_acceptor.clone(); - tokio::spawn(client_tls_accept(tcpstream, clone_tls_acceptor, client_socket_addr, qe_r_ref)); + tokio::spawn(client_tls_accept(tcpstream, clone_tls_acceptor, client_socket_addr, qe_r_ref, trusted_proxy_v2_ips.clone())); } Err(err) => { warn!(?err, "LDAP acceptor error, continuing"); @@ -187,7 +231,7 @@ async fn ldap_plaintext_acceptor( accept_result = listener.accept() => { match accept_result { Ok((tcpstream, client_socket_addr)) => { - tokio::spawn(client_process(tcpstream, client_socket_addr, qe_r_ref)); + tokio::spawn(client_process(tcpstream, client_socket_addr, client_socket_addr, qe_r_ref)); } Err(e) => { error!("LDAP acceptor error, continuing -> {:?}", e); @@ -205,6 +249,7 @@ pub(crate) async fn create_ldap_server( qe_r_ref: &'static QueryServerReadV1, rx: broadcast::Receiver<CoreAction>, tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>, + trusted_proxy_v2_ips: Option<HashSet<IpAddr>>, ) -> Result<tokio::task::JoinHandle<()>, ()> { if address.starts_with(":::") { // takes :::xxxx to xxxx @@ -212,7 +257,7 @@ pub(crate) async fn create_ldap_server( error!("Address '{}' looks like an attempt to wildcard bind with IPv6 on port {} - please try using ldapbindaddress = '[::]:{}'", address, port, port); }; - let addr = net::SocketAddr::from_str(address).map_err(|e| { + let addr = SocketAddr::from_str(address).map_err(|e| { error!("Could not parse LDAP server address {} -> {:?}", address, e); })?; @@ -223,6 +268,8 @@ pub(crate) async fn create_ldap_server( ); })?; + let trusted_proxy_v2_ips = trusted_proxy_v2_ips.map(Arc::new); + let ldap_acceptor_handle = match opt_ssl_acceptor { Some(ssl_acceptor) => { info!("Starting LDAPS interface ldaps://{} ...", address); @@ -233,6 +280,7 @@ pub(crate) async fn create_ldap_server( qe_r_ref, rx, tls_acceptor_reload_rx, + trusted_proxy_v2_ips, )) } None => tokio::spawn(ldap_plaintext_acceptor(listener, qe_r_ref, rx)), diff --git a/server/core/src/lib.rs b/server/core/src/lib.rs index 1117f446a..392668ba8 100644 --- a/server/core/src/lib.rs +++ b/server/core/src/lib.rs @@ -1087,6 +1087,7 @@ pub async fn create_server_core( server_read_ref, broadcast_tx.subscribe(), ldap_tls_acceptor_reload_rx, + config.ldap_client_address_info.trusted_proxy_v2(), ) .await?; Some(h) diff --git a/server/testkit-macros/src/entry.rs b/server/testkit-macros/src/entry.rs index 81e1ef701..6fa1b9a48 100644 --- a/server/testkit-macros/src/entry.rs +++ b/server/testkit-macros/src/entry.rs @@ -10,16 +10,17 @@ const ALLOWED_ATTRIBUTES: &[&str] = &[ "threads", "db_path", "maximum_request", - "trust_x_forward_for", + "http_client_address_info", "role", "output_mode", "log_level", "ldap", + "with_test_env", ]; #[derive(Default)] struct Flags { - ldap: bool, + target_wants_test_env: bool, } fn parse_attributes( @@ -60,8 +61,11 @@ fn parse_attributes( .unwrap_or_default() .as_str() { + "with_test_env" => { + flags.target_wants_test_env = true; + } "ldap" => { - flags.ldap = true; + flags.target_wants_test_env = true; field_modifications.extend(quote! { ldapbindaddress: Some("on".to_string()),}) } @@ -134,7 +138,7 @@ pub(crate) fn test(args: TokenStream, item: TokenStream) -> TokenStream { #[::core::prelude::v1::test] }; - let test_fn_args = if flags.ldap { + let test_fn_args = if flags.target_wants_test_env { quote! { &test_env } diff --git a/server/testkit/Cargo.toml b/server/testkit/Cargo.toml index 6689649a2..83f87bf50 100644 --- a/server/testkit/Cargo.toml +++ b/server/testkit/Cargo.toml @@ -53,6 +53,10 @@ escargot = "0.5.13" # used for webdriver testing fantoccini = { version = "0.21.5" } futures = { workspace = true } +hex = { workspace = true } +hyper = { workspace = true } +http-body-util = { workspace = true } +hyper-util = { workspace = true } ldap3_client = { workspace = true } oauth2_ext = { workspace = true, default-features = false, features = [ "reqwest", diff --git a/server/testkit/src/lib.rs b/server/testkit/src/lib.rs index 7eef97a25..ec35a2199 100644 --- a/server/testkit/src/lib.rs +++ b/server/testkit/src/lib.rs @@ -15,7 +15,7 @@ use kanidm_proto::internal::{Filter, Modify, ModifyList}; use kanidmd_core::config::{Configuration, IntegrationTestConfig}; use kanidmd_core::{create_server_core, CoreHandle}; use kanidmd_lib::prelude::{Attribute, NAME_SYSTEM_ADMINS}; -use std::net::TcpStream; +use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}; use std::sync::atomic::{AtomicU16, Ordering}; use tokio::task; use tracing::error; @@ -64,6 +64,7 @@ fn port_loop() -> u16 { pub struct AsyncTestEnvironment { pub rsclient: KanidmClient, + pub http_sock_addr: SocketAddr, pub core_handle: CoreHandle, pub ldap_url: Option<Url>, } @@ -86,8 +87,9 @@ pub async fn setup_async_test(mut config: Configuration) -> AsyncTestEnvironment let ldap_url = if config.ldapbindaddress.is_some() { let ldapport = port_loop(); - config.ldapbindaddress = Some(format!("127.0.0.1:{}", ldapport)); - Url::parse(&format!("ldap://127.0.0.1:{}", ldapport)) + let ldap_sock_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), ldapport); + config.ldapbindaddress = Some(ldap_sock_addr.to_string()); + Url::parse(&format!("ldap://{}", ldap_sock_addr)) .inspect_err(|err| error!(?err, "ldap address setup")) .ok() } else { @@ -95,7 +97,9 @@ pub async fn setup_async_test(mut config: Configuration) -> AsyncTestEnvironment }; // Setup the address and origin.. - config.address = format!("127.0.0.1:{}", port); + let http_sock_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port); + + config.address = http_sock_addr.to_string(); config.integration_test_config = Some(int_config); config.domain = "localhost".to_string(); config.origin.clone_from(&addr); @@ -123,6 +127,7 @@ pub async fn setup_async_test(mut config: Configuration) -> AsyncTestEnvironment AsyncTestEnvironment { rsclient, + http_sock_addr, core_handle, ldap_url, } diff --git a/server/testkit/tests/testkit/https_extractors.rs b/server/testkit/tests/testkit/https_extractors.rs deleted file mode 100644 index b664517cb..000000000 --- a/server/testkit/tests/testkit/https_extractors.rs +++ /dev/null @@ -1,193 +0,0 @@ -use std::{ - net::{IpAddr, Ipv4Addr}, - str::FromStr, -}; - -use kanidm_client::KanidmClient; -use kanidm_proto::constants::X_FORWARDED_FOR; - -const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); - -// *test where we don't trust the x-forwarded-for header - -#[kanidmd_testkit::test(trust_x_forward_for = false)] -async fn dont_trust_xff_send_header(rsclient: &KanidmClient) { - let client = rsclient.client(); - - let res = client - .get(rsclient.make_url("/v1/debug/ipinfo")) - .header( - X_FORWARDED_FOR, - "An invalid header that will get through!!!", - ) - .send() - .await - .unwrap(); - let ip_res: IpAddr = res - .json() - .await - .expect("Failed to parse response as IpAddr"); - - assert_eq!(ip_res, DEFAULT_IP_ADDRESS); -} - -#[kanidmd_testkit::test(trust_x_forward_for = false)] -async fn dont_trust_xff_dont_send_header(rsclient: &KanidmClient) { - let client = rsclient.client(); - - let res = client - .get(rsclient.make_url("/v1/debug/ipinfo")) - .header( - X_FORWARDED_FOR, - "An invalid header that will get through!!!", - ) - .send() - .await - .unwrap(); - let body = res.bytes().await.unwrap(); - let ip_res: IpAddr = serde_json::from_slice(&body).unwrap_or_else(|op| { - panic!( - "Failed to parse response as IpAddr: {:?} body: {:?}", - op, body, - ) - }); - eprintln!("Body: {:?}", body); - assert_eq!(ip_res, DEFAULT_IP_ADDRESS); -} - -// *test where we trust the x-forwarded-for header - -#[kanidmd_testkit::test(trust_x_forward_for = true)] -async fn trust_xff_send_invalid_header_single_value(rsclient: &KanidmClient) { - let client = rsclient.client(); - - let res = client - .get(rsclient.make_url("/v1/debug/ipinfo")) - .header( - X_FORWARDED_FOR, - "An invalid header that will get through!!!", - ) - .send() - .await - .unwrap(); - - assert_eq!(res.status(), 400); -} - -// TODO: Right now we reject the request only if the leftmost address is invalid. In the future that could change so we could also have a test -// with a valid leftmost address and an invalid address later in the list. Right now it wouldn't work. -// -#[kanidmd_testkit::test(trust_x_forward_for = true)] -async fn trust_xff_send_invalid_header_multiple_values(rsclient: &KanidmClient) { - let client = rsclient.client(); - - let res = client - .get(rsclient.make_url("/v1/debug/ipinfo")) - .header( - X_FORWARDED_FOR, - "203.0.113.195_noooo_my_ip_address, 2001:db8:85a3:8d3:1319:8a2e:370:7348", - ) - .send() - .await - .unwrap(); - - assert_eq!(res.status(), 400); -} - -#[kanidmd_testkit::test(trust_x_forward_for = true)] -async fn trust_xff_send_valid_header_single_ipv4_address(rsclient: &KanidmClient) { - let ip_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7348"; - - let client = rsclient.client(); - - let res = client - .get(rsclient.make_url("/v1/debug/ipinfo")) - .header(X_FORWARDED_FOR, ip_addr) - .send() - .await - .unwrap(); - let ip_res: IpAddr = res - .json() - .await - .expect("Failed to parse response as Vec<IpAddr>"); - - assert_eq!(ip_res, IpAddr::from_str(ip_addr).unwrap()); -} - -#[kanidmd_testkit::test(trust_x_forward_for = true)] -async fn trust_xff_send_valid_header_single_ipv6_address(rsclient: &KanidmClient) { - let ip_addr = "203.0.113.195"; - - let client = rsclient.client(); - - let res = client - .get(rsclient.make_url("/v1/debug/ipinfo")) - .header(X_FORWARDED_FOR, ip_addr) - .send() - .await - .unwrap(); - let ip_res: IpAddr = res - .json() - .await - .expect("Failed to parse response as Vec<IpAddr>"); - - assert_eq!(ip_res, IpAddr::from_str(ip_addr).unwrap()); -} - -#[kanidmd_testkit::test(trust_x_forward_for = true)] -async fn trust_xff_send_valid_header_multiple_address(rsclient: &KanidmClient) { - let first_ip_addr = "203.0.113.195, 2001:db8:85a3:8d3:1319:8a2e:370:7348"; - - let client = rsclient.client(); - - let res = client - .get(rsclient.make_url("/v1/debug/ipinfo")) - .header(X_FORWARDED_FOR, first_ip_addr) - .send() - .await - .unwrap(); - let ip_res: IpAddr = res - .json() - .await - .expect("Failed to parse response as Vec<IpAddr>"); - - assert_eq!( - ip_res, - IpAddr::from_str(first_ip_addr.split(",").collect::<Vec<&str>>()[0]).unwrap() - ); - - let second_ip_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7348, 198.51.100.178, 203.0.113.195"; - - let res = client - .get(rsclient.make_url("/v1/debug/ipinfo")) - .header(X_FORWARDED_FOR, second_ip_addr) - .send() - .await - .unwrap(); - let ip_res: IpAddr = res - .json() - .await - .expect("Failed to parse response as Vec<IpAddr>"); - - assert_eq!( - ip_res, - IpAddr::from_str(second_ip_addr.split(",").collect::<Vec<&str>>()[0]).unwrap() - ); -} - -#[kanidmd_testkit::test(trust_x_forward_for = true)] -async fn trust_xff_dont_send_header(rsclient: &KanidmClient) { - let client = rsclient.client(); - - let res = client - .get(rsclient.make_url("/v1/debug/ipinfo")) - .send() - .await - .unwrap(); - let ip_res: IpAddr = res - .json() - .await - .expect("Failed to parse response as Vec<IpAddr>"); - - assert_eq!(ip_res, DEFAULT_IP_ADDRESS); -} diff --git a/server/testkit/tests/testkit/ip_addr_extractors.rs b/server/testkit/tests/testkit/ip_addr_extractors.rs new file mode 100644 index 000000000..0d7642a43 --- /dev/null +++ b/server/testkit/tests/testkit/ip_addr_extractors.rs @@ -0,0 +1,324 @@ +use kanidm_client::KanidmClient; +use kanidm_proto::constants::X_FORWARDED_FOR; +use kanidmd_core::config::HttpAddressInfo; +use kanidmd_testkit::AsyncTestEnvironment; +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + str::FromStr, +}; +use tracing::error; + +const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); + +// ===================================================== +// *test where we don't trust the x-forwarded-for header + +#[kanidmd_testkit::test(http_client_address_info = HttpAddressInfo::None)] +async fn dont_trust_xff_send_header(rsclient: &KanidmClient) { + let client = rsclient.client(); + + // Send an invalid header to x forwdr for + let res = client + .get(rsclient.make_url("/v1/debug/ipinfo")) + .header(X_FORWARDED_FOR, "a.b.c.d") + .send() + .await + .unwrap(); + + let ip_res: IpAddr = res + .json() + .await + .expect("Failed to parse response as IpAddr"); + + assert_eq!(ip_res, DEFAULT_IP_ADDRESS); + + // Send a valid header for xforward for, but we don't trust it. + let res = client + .get(rsclient.make_url("/v1/debug/ipinfo")) + .header(X_FORWARDED_FOR, "203.0.113.195") + .send() + .await + .unwrap(); + + let ip_res: IpAddr = res + .json() + .await + .expect("Failed to parse response as IpAddr"); + + assert_eq!(ip_res, DEFAULT_IP_ADDRESS); +} + +// ===================================================== +// *test where we do trust the x-forwarded-for header + +#[kanidmd_testkit::test(http_client_address_info = HttpAddressInfo::XForwardFor ( [DEFAULT_IP_ADDRESS].into() ))] +async fn trust_xff_address_set(rsclient: &KanidmClient) { + inner_test_trust_xff(rsclient).await; +} + +#[kanidmd_testkit::test(http_client_address_info = HttpAddressInfo::XForwardForAllSourcesTrusted)] +async fn trust_xff_all_addresses_trusted(rsclient: &KanidmClient) { + inner_test_trust_xff(rsclient).await; +} + +async fn inner_test_trust_xff(rsclient: &KanidmClient) { + let client = rsclient.client(); + + // An invalid address. + let res = client + .get(rsclient.make_url("/v1/debug/ipinfo")) + .header(X_FORWARDED_FOR, "a.b.c.d") + .send() + .await + .unwrap(); + + // Header was invalid + assert_eq!(res.status(), 400); + + // An invalid address - what follows doesn't matter, even if it was valid. We only + // care about the left most address anyway. + let res = client + .get(rsclient.make_url("/v1/debug/ipinfo")) + .header( + X_FORWARDED_FOR, + "203.0.113.195_noooo_my_ip_address, 2001:db8:85a3:8d3:1319:8a2e:370:7348", + ) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), 400); + + // A valid ipv6 address was provided. + let ip_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7348"; + + let res = client + .get(rsclient.make_url("/v1/debug/ipinfo")) + .header(X_FORWARDED_FOR, ip_addr) + .send() + .await + .unwrap(); + let ip_res: IpAddr = res + .json() + .await + .expect("Failed to parse response as Vec<IpAddr>"); + + assert_eq!(ip_res, IpAddr::from_str(ip_addr).unwrap()); + + // A valid ipv4 address was provided. + let ip_addr = "203.0.113.195"; + + let client = rsclient.client(); + + let res = client + .get(rsclient.make_url("/v1/debug/ipinfo")) + .header(X_FORWARDED_FOR, ip_addr) + .send() + .await + .unwrap(); + let ip_res: IpAddr = res + .json() + .await + .expect("Failed to parse response as Vec<IpAddr>"); + + assert_eq!(ip_res, IpAddr::from_str(ip_addr).unwrap()); + + // A valid ipv4 address in the leftmost field. + let first_ip_addr = "203.0.113.195, 2001:db8:85a3:8d3:1319:8a2e:370:7348"; + + let res = client + .get(rsclient.make_url("/v1/debug/ipinfo")) + .header(X_FORWARDED_FOR, first_ip_addr) + .send() + .await + .unwrap(); + let ip_res: IpAddr = res + .json() + .await + .expect("Failed to parse response as Vec<IpAddr>"); + + assert_eq!( + ip_res, + IpAddr::from_str(first_ip_addr.split(",").collect::<Vec<&str>>()[0]).unwrap() + ); + + // A valid ipv6 address in the left most field. + let second_ip_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7348, 198.51.100.178, 203.0.113.195"; + + let res = client + .get(rsclient.make_url("/v1/debug/ipinfo")) + .header(X_FORWARDED_FOR, second_ip_addr) + .send() + .await + .unwrap(); + let ip_res: IpAddr = res + .json() + .await + .expect("Failed to parse response as Vec<IpAddr>"); + + assert_eq!( + ip_res, + IpAddr::from_str(second_ip_addr.split(",").collect::<Vec<&str>>()[0]).unwrap() + ); + + // If no header is sent, then the connection IP is used. + let res = client + .get(rsclient.make_url("/v1/debug/ipinfo")) + .send() + .await + .unwrap(); + let ip_res: IpAddr = res + .json() + .await + .expect("Failed to parse response as Vec<IpAddr>"); + + assert_eq!(ip_res, DEFAULT_IP_ADDRESS); +} + +// ===================================================== +// *test where we do trust the PROXY protocol header +// +// NOTE: This is MUCH HARDER TO TEST because we can't just stuff this address +// in front of a reqwest call. We have to open raw connections and write the +// requests to them. +// +// As a result, we are pretty much forced to manually dump binary headers and then +// manually craft get reqs, followed by parsing them. + +#[derive(Debug, PartialEq)] +enum ProxyV2Error { + TcpStream, + TcpWrite, + TornWrite, + HttpHandshake, + HttpRequestBuild, + HttpRequest, + HttpBadRequest, +} + +async fn proxy_v2_make_request( + http_sock_addr: SocketAddr, + hdr: &[u8], +) -> Result<IpAddr, ProxyV2Error> { + use http_body_util::BodyExt; + use http_body_util::Empty; + use hyper::body::Bytes; + use hyper::Request; + use hyper_util::rt::TokioIo; + use tokio::io::AsyncWriteExt as _; + use tokio::net::TcpStream; + + let url = format!("http://{}/v1/debug/ipinfo", http_sock_addr) + .as_str() + .parse::<hyper::Uri>() + .unwrap(); + + let mut stream = TcpStream::connect(http_sock_addr).await.map_err(|err| { + error!(?err); + ProxyV2Error::TcpStream + })?; + + // Write the proxyv2 header + let nbytes = stream.write(hdr).await.map_err(|err| { + error!(?err); + ProxyV2Error::TcpWrite + })?; + + if nbytes != hdr.len() { + return Err(ProxyV2Error::TornWrite); + } + + let io = TokioIo::new(stream); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(io) + .await + .map_err(|err| { + error!(?err); + ProxyV2Error::HttpHandshake + })?; + + // Spawn a task to poll the connection, driving the HTTP state + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection failed: {:?}", err); + } + }); + + let authority = url.authority().unwrap().clone(); + + // Create an HTTP request with an empty body and a HOST header + let req = Request::builder() + .uri(url) + .header(hyper::header::HOST, authority.as_str()) + .body(Empty::<Bytes>::new()) + .map_err(|err| { + error!(?err); + ProxyV2Error::HttpRequestBuild + })?; + + // Await the response... + let mut res = sender.send_request(req).await.map_err(|err| { + error!(?err); + ProxyV2Error::HttpRequest + })?; + + println!("Response status: {}", res.status()); + + if res.status() != 200 { + return Err(ProxyV2Error::HttpBadRequest); + } + + let mut data: Vec<u8> = Vec::new(); + + while let Some(next) = res.frame().await { + let frame = next.unwrap(); + if let Some(chunk) = frame.data_ref() { + data.write_all(chunk).await.unwrap(); + } + } + + tracing::info!(?data); + let ip_res: IpAddr = serde_json::from_slice(&data).unwrap(); + tracing::info!(?ip_res); + + Ok(ip_res) +} + +#[kanidmd_testkit::test(with_test_env = true, http_client_address_info = HttpAddressInfo::ProxyV2 ( [DEFAULT_IP_ADDRESS].into() ))] +async fn trust_proxy_v2_address_set(test_env: &AsyncTestEnvironment) { + // Send with no header - with proxy v2, a header is ALWAYS required + let proxy_hdr: [u8; 0] = []; + + let res = proxy_v2_make_request(test_env.http_sock_addr, &proxy_hdr) + .await + .unwrap_err(); + + // Can't send http request because proxy wasn't sent. + assert_eq!(res, ProxyV2Error::HttpRequest); + + // Send with a valid header + let proxy_hdr = + hex::decode("0d0a0d0a000d0a515549540a2111000cac180c76ac180b8fcdcb027d").unwrap(); + + let res = proxy_v2_make_request(test_env.http_sock_addr, &proxy_hdr) + .await + .unwrap(); + + // The header was valid + assert_eq!(res, IpAddr::V4(Ipv4Addr::new(172, 24, 12, 118))); +} + +#[kanidmd_testkit::test(with_test_env = true, http_client_address_info = HttpAddressInfo::ProxyV2 ( [ IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)) ].into() ))] +async fn trust_proxy_v2_untrusted(test_env: &AsyncTestEnvironment) { + // Send with a valid header, but we aren't a trusted source. + let proxy_hdr = + hex::decode("0d0a0d0a000d0a515549540a2111000cac180c76ac180b8fcdcb027d").unwrap(); + + let res = proxy_v2_make_request(test_env.http_sock_addr, &proxy_hdr) + .await + .unwrap_err(); + + // Can't send http request because we aren't trusted to send it, so this + // ends up falling into a http request that is REJECTED. + assert_eq!(res, ProxyV2Error::HttpBadRequest); +} diff --git a/server/testkit/tests/testkit/mod.rs b/server/testkit/tests/testkit/mod.rs index 2784a85ed..766a5fdb5 100644 --- a/server/testkit/tests/testkit/mod.rs +++ b/server/testkit/tests/testkit/mod.rs @@ -2,10 +2,10 @@ mod apidocs; mod domain; mod group; mod http_manifest; -mod https_extractors; mod https_middleware; mod identity_verification_tests; mod integration; +mod ip_addr_extractors; mod ldap_basic; mod mtls_test; mod oauth2_test;