This commit is contained in:
William Brown 2025-04-03 15:41:24 +10:00
parent 9b3a4ad761
commit bd9cfda678
8 changed files with 161 additions and 58 deletions

46
Cargo.lock generated
View file

@ -188,7 +188,7 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acb1161c6b64d1c3d83108213c2a2533a342ac225aabd0bda218278c2ddb00c0" checksum = "acb1161c6b64d1c3d83108213c2a2533a342ac225aabd0bda218278c2ddb00c0"
dependencies = [ dependencies = [
"nom", "nom 7.1.3",
] ]
[[package]] [[package]]
@ -200,7 +200,7 @@ dependencies = [
"asn1-rs-derive", "asn1-rs-derive",
"asn1-rs-impl", "asn1-rs-impl",
"displaydoc", "displaydoc",
"nom", "nom 7.1.3",
"num-traits", "num-traits",
"rusticata-macros", "rusticata-macros",
"thiserror 1.0.69", "thiserror 1.0.69",
@ -675,7 +675,7 @@ version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [ dependencies = [
"nom", "nom 7.1.3",
] ]
[[package]] [[package]]
@ -1148,7 +1148,7 @@ checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553"
dependencies = [ dependencies = [
"asn1-rs", "asn1-rs",
"displaydoc", "displaydoc",
"nom", "nom 7.1.3",
"num-bigint", "num-bigint",
"num-traits", "num-traits",
"rusticata-macros", "rusticata-macros",
@ -1213,7 +1213,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9313f104b590510b46fc01c0a324fc76505c13871454d3c48490468d04c8d395" checksum = "9313f104b590510b46fc01c0a324fc76505c13871454d3c48490468d04c8d395"
dependencies = [ dependencies = [
"libc", "libc",
"nom", "nom 7.1.3",
] ]
[[package]] [[package]]
@ -2271,6 +2271,18 @@ version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
[[package]]
name = "haproxy-protocol"
version = "0.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f61fc527a2f089b57ebc09301b6371bbbff4ce7b547306c17dfa55766655bec6"
dependencies = [
"hex",
"nom 8.0.0",
"tokio",
"tracing",
]
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.12.3" version = "0.12.3"
@ -3140,6 +3152,7 @@ dependencies = [
"filetime", "filetime",
"futures", "futures",
"futures-util", "futures-util",
"haproxy-protocol",
"hyper 1.6.0", "hyper 1.6.0",
"hyper-util", "hyper-util",
"kanidm_build_profiles", "kanidm_build_profiles",
@ -3311,7 +3324,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2df7f9fd9f64cf8f59e1a4a0753fe7d575a5b38d3d7ac5758dcee9357d83ef0a" checksum = "2df7f9fd9f64cf8f59e1a4a0753fe7d575a5b38d3d7ac5758dcee9357d83ef0a"
dependencies = [ dependencies = [
"bytes", "bytes",
"nom", "nom 7.1.3",
] ]
[[package]] [[package]]
@ -3343,7 +3356,7 @@ dependencies = [
"base64 0.21.7", "base64 0.21.7",
"bytes", "bytes",
"lber", "lber",
"nom", "nom 7.1.3",
"peg", "peg",
"serde", "serde",
"thiserror 1.0.69", "thiserror 1.0.69",
@ -3675,6 +3688,15 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "nom"
version = "8.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "nonempty" name = "nonempty"
version = "0.8.1" version = "0.8.1"
@ -4873,7 +4895,7 @@ version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632"
dependencies = [ dependencies = [
"nom", "nom 7.1.3",
] ]
[[package]] [[package]]
@ -5364,7 +5386,7 @@ checksum = "34285eaade87ba166c4f17c0ae1e35d52659507db81888beae277e962b9e5a02"
dependencies = [ dependencies = [
"base64 0.21.7", "base64 0.21.7",
"base64urlsafedata", "base64urlsafedata",
"nom", "nom 7.1.3",
"openssl", "openssl",
"serde", "serde",
"serde_cbor_2", "serde_cbor_2",
@ -6341,7 +6363,7 @@ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
"futures", "futures",
"hex", "hex",
"nom", "nom 7.1.3",
"num-derive", "num-derive",
"num-traits", "num-traits",
"openssl", "openssl",
@ -6386,7 +6408,7 @@ dependencies = [
"compact_jwt", "compact_jwt",
"der-parser", "der-parser",
"hex", "hex",
"nom", "nom 7.1.3",
"openssl", "openssl",
"rand 0.8.5", "rand 0.8.5",
"rand_chacha 0.3.1", "rand_chacha 0.3.1",
@ -6888,7 +6910,7 @@ dependencies = [
"data-encoding", "data-encoding",
"der-parser", "der-parser",
"lazy_static", "lazy_static",
"nom", "nom 7.1.3",
"oid-registry", "oid-registry",
"rusticata-macros", "rusticata-macros",
"thiserror 1.0.69", "thiserror 1.0.69",

View file

@ -177,6 +177,7 @@ fs4 = "^0.12.0"
futures = "^0.3.31" futures = "^0.3.31"
futures-util = { version = "^0.3.30", features = ["sink"] } futures-util = { version = "^0.3.30", features = ["sink"] }
gix = { version = "0.64.0", default-features = false } gix = { version = "0.64.0", default-features = false }
haproxy-protocol = { version = "0.0.1" }
hashbrown = { version = "0.14.3", features = ["serde", "inline-more", "ahash"] } hashbrown = { version = "0.14.3", features = ["serde", "inline-more", "ahash"] }
hex = "^0.4.3" hex = "^0.4.3"
http = "1.2.0" http = "1.2.0"

View file

@ -34,6 +34,7 @@ cron = { workspace = true }
filetime = { workspace = true } filetime = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
futures-util = { workspace = true } futures-util = { workspace = true }
haproxy-protocol = { workspace = true, features = ["tokio"] }
hyper = { workspace = true } hyper = { workspace = true }
hyper-util = { workspace = true } hyper-util = { workspace = true }
kanidm_proto = { workspace = true } kanidm_proto = { workspace = true }

View file

@ -109,7 +109,7 @@ pub enum LdapAddressInfo {
} }
impl LdapAddressInfo { impl LdapAddressInfo {
pub fn proxy_v2(&self) -> bool { pub fn is_proxy_v2(&self) -> bool {
matches!(self, Self::ProxyV2) matches!(self, Self::ProxyV2)
} }
} }
@ -138,7 +138,7 @@ impl HttpAddressInfo {
matches!(self, Self::XForwardFor) matches!(self, Self::XForwardFor)
} }
pub fn proxy_v2(&self) -> bool { pub fn is_proxy_v2(&self) -> bool {
matches!(self, Self::ProxyV2) matches!(self, Self::ProxyV2)
} }
} }

View file

@ -5,7 +5,6 @@ use axum::{
http::{ http::{
header::HeaderName, header::AUTHORIZATION as AUTHORISATION, request::Parts, StatusCode, header::HeaderName, header::AUTHORIZATION as AUTHORISATION, request::Parts, StatusCode,
}, },
serve::IncomingStream,
RequestPartsExt, RequestPartsExt,
}; };
@ -40,7 +39,8 @@ impl FromRequestParts<ServerState> for TrustedClientIp {
state: &ServerState, state: &ServerState,
) -> Result<Self, Self::Rejection> { ) -> Result<Self, Self::Rejection> {
let ConnectInfo(ClientConnInfo { let ConnectInfo(ClientConnInfo {
addr, connection_addr: _,
client_addr,
client_cert: _, client_cert: _,
}) = parts }) = parts
.extract::<ConnectInfo<ClientConnInfo>>() .extract::<ConnectInfo<ClientConnInfo>>()
@ -75,10 +75,13 @@ impl FromRequestParts<ServerState> for TrustedClientIp {
) )
})? })?
} else { } else {
addr.ip() client_addr.ip()
} }
} else { } else {
addr.ip() // This can either be the client_addr == connection_addr if there are
// no ip address trust sources, or this is the value as reported by haproxy
// proxy header.
client_addr.ip()
}; };
Ok(TrustedClientIp(ip_addr)) Ok(TrustedClientIp(ip_addr))
@ -97,7 +100,11 @@ impl FromRequestParts<ServerState> for VerifiedClientInformation {
parts: &mut Parts, parts: &mut Parts,
state: &ServerState, state: &ServerState,
) -> Result<Self, Self::Rejection> { ) -> Result<Self, Self::Rejection> {
let ConnectInfo(ClientConnInfo { addr, client_cert }) = parts let ConnectInfo(ClientConnInfo {
connection_addr: _,
client_addr,
client_cert,
}) = parts
.extract::<ConnectInfo<ClientConnInfo>>() .extract::<ConnectInfo<ClientConnInfo>>()
.await .await
.map_err(|_| { .map_err(|_| {
@ -130,10 +137,10 @@ impl FromRequestParts<ServerState> for VerifiedClientInformation {
) )
})? })?
} else { } else {
addr.ip() client_addr.ip()
} }
} else { } else {
addr.ip() client_addr.ip()
}; };
let (basic_authz, bearer_token) = if let Some(header) = parts.headers.get(AUTHORISATION) { let (basic_authz, bearer_token) = if let Some(header) = parts.headers.get(AUTHORISATION) {
@ -201,30 +208,30 @@ impl FromRequestParts<ServerState> for DomainInfo {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ClientConnInfo { pub struct ClientConnInfo {
pub addr: SocketAddr, /// This is the address that is *connected* to kanidm right now
/// for this operation.
#[allow(dead_code)]
pub connection_addr: SocketAddr,
/// This is the client address as reported by a remote IP source
/// such as x-forward-for or proxy-hdr
pub client_addr: SocketAddr,
// Only set if the certificate is VALID // Only set if the certificate is VALID
pub client_cert: Option<ClientCertInfo>, pub client_cert: Option<ClientCertInfo>,
} }
// This is the normal way that our extractors get the ip info
impl Connected<ClientConnInfo> for ClientConnInfo { impl Connected<ClientConnInfo> for ClientConnInfo {
fn connect_info(target: ClientConnInfo) -> Self { fn connect_info(target: ClientConnInfo) -> Self {
target target
} }
} }
// This is only used for plaintext http - in other words, integration tests only.
impl Connected<SocketAddr> for ClientConnInfo { impl Connected<SocketAddr> for ClientConnInfo {
fn connect_info(addr: SocketAddr) -> Self { fn connect_info(connection_addr: SocketAddr) -> Self {
ClientConnInfo { ClientConnInfo {
addr, client_addr: connection_addr.clone(),
client_cert: None, connection_addr,
}
}
}
impl Connected<IncomingStream<'_>> for ClientConnInfo {
fn connect_info(target: IncomingStream<'_>) -> Self {
ClientConnInfo {
addr: target.remote_addr(),
client_cert: None, client_cert: None,
} }
} }

View file

@ -19,7 +19,6 @@ use self::javascript::*;
use crate::actors::{QueryServerReadV1, QueryServerWriteV1}; use crate::actors::{QueryServerReadV1, QueryServerWriteV1};
use crate::config::{Configuration, ServerRole}; use crate::config::{Configuration, ServerRole};
use crate::CoreAction; use crate::CoreAction;
use axum::{ use axum::{
body::Body, body::Body,
extract::connect_info::IntoMakeServiceWithConnectInfo, extract::connect_info::IntoMakeServiceWithConnectInfo,
@ -29,21 +28,23 @@ use axum::{
routing::*, routing::*,
Router, Router,
}; };
use axum_extra::extract::cookie::CookieJar; use axum_extra::extract::cookie::CookieJar;
use compact_jwt::{error::JwtError, JwsCompact, JwsHs256Signer, JwsVerifier}; use compact_jwt::{error::JwtError, JwsCompact, JwsHs256Signer, JwsVerifier};
use futures::pin_mut; use futures::pin_mut;
use haproxy_protocol::{ProxyHdrV2, RemoteAddress};
use hyper::body::Incoming; use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo}; use hyper_util::rt::{TokioExecutor, TokioIo};
use kanidm_lib_crypto::x509_cert::{der::Decode, x509_public_key_s256, Certificate};
use kanidm_proto::{constants::KSESSIONID, internal::COOKIE_AUTH_SESSION_ID}; use kanidm_proto::{constants::KSESSIONID, internal::COOKIE_AUTH_SESSION_ID};
use kanidmd_lib::{idm::ClientCertInfo, status::StatusActor}; use kanidmd_lib::{idm::ClientCertInfo, status::StatusActor};
use openssl::ssl::{Ssl, SslAcceptor}; use openssl::ssl::{Ssl, SslAcceptor};
use kanidm_lib_crypto::x509_cert::{der::Decode, x509_public_key_s256, Certificate};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use sketching::*; use sketching::*;
use std::fmt::Write; use std::fmt::Write;
use std::io::ErrorKind;
use std::path::PathBuf;
use std::pin::Pin;
use std::{net::SocketAddr, str::FromStr};
use tokio::{ use tokio::{
net::{TcpListener, TcpStream}, net::{TcpListener, TcpStream},
sync::broadcast, sync::broadcast,
@ -56,11 +57,6 @@ use tower_http::{services::ServeDir, trace::TraceLayer};
use url::Url; use url::Url;
use uuid::Uuid; use uuid::Uuid;
use std::io::ErrorKind;
use std::path::PathBuf;
use std::pin::Pin;
use std::{net::SocketAddr, str::FromStr};
#[derive(Clone)] #[derive(Clone)]
pub struct ServerState { pub struct ServerState {
pub(crate) status_ref: &'static StatusActor, pub(crate) status_ref: &'static StatusActor,
@ -212,6 +208,7 @@ pub async fn create_https_server(
})?; })?;
let trust_x_forward_for = config.http_client_address_info.is_x_forward_for(); let trust_x_forward_for = config.http_client_address_info.is_x_forward_for();
let enable_haproxy_hdr = config.http_client_address_info.is_proxy_v2();
let origin = Url::parse(&config.origin) let origin = Url::parse(&config.origin)
// Should be impossible! // Should be impossible!
@ -337,6 +334,7 @@ pub async fn create_https_server(
rx, rx,
server_message_tx, server_message_tx,
tls_acceptor_reload_rx, tls_acceptor_reload_rx,
enable_haproxy_hdr,
))) )))
} }
None => Ok(task::spawn(server_loop_plaintext(addr, app, rx))), None => Ok(task::spawn(server_loop_plaintext(addr, app, rx))),
@ -350,6 +348,7 @@ async fn server_loop(
mut rx: broadcast::Receiver<CoreAction>, mut rx: broadcast::Receiver<CoreAction>,
server_message_tx: broadcast::Sender<CoreAction>, server_message_tx: broadcast::Sender<CoreAction>,
mut tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>, mut tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>,
enable_haproxy_hdr: bool,
) { ) {
pin_mut!(listener); pin_mut!(listener);
@ -365,7 +364,7 @@ async fn server_loop(
Ok((stream, addr)) => { Ok((stream, addr)) => {
let tls_acceptor = tls_acceptor.clone(); let tls_acceptor = tls_acceptor.clone();
let app = app.clone(); let app = app.clone();
task::spawn(handle_conn(tls_acceptor, stream, app, addr)); task::spawn(handle_conn(tls_acceptor, stream, app, addr, enable_haproxy_hdr));
} }
Err(err) => { Err(err) => {
error!("Web server exited with {:?}", err); error!("Web server exited with {:?}", err);
@ -415,8 +414,36 @@ pub(crate) async fn handle_conn(
acceptor: SslAcceptor, acceptor: SslAcceptor,
stream: TcpStream, stream: TcpStream,
mut app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>, mut app: IntoMakeServiceWithConnectInfo<Router, ClientConnInfo>,
addr: SocketAddr, connection_addr: SocketAddr,
enable_haproxy_hdr: bool,
) -> Result<(), std::io::Error> { ) -> Result<(), std::io::Error> {
let (stream, client_addr) = if enable_haproxy_hdr {
match ProxyHdrV2::parse_from_read(stream).await {
Ok((stream, hdr)) => {
let remote_socket_addr = match hdr.to_remote_addr() {
RemoteAddress::Local => {
debug!("haproxy check - will not contain client data");
return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
}
RemoteAddress::TcpV4 { src, dst: _ } => SocketAddr::from(src),
RemoteAddress::TcpV6 { src, dst: _ } => SocketAddr::from(src),
remote_addr => {
error!(?remote_addr, "remote address in proxy header is invalid");
return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
}
};
(stream, remote_socket_addr)
}
Err(err) => {
error!(?err, "Unable to process proxy v2 header");
return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
}
}
} else {
(stream, connection_addr.clone())
};
let ssl = Ssl::new(acceptor.context()).map_err(|e| { let ssl = Ssl::new(acceptor.context()).map_err(|e| {
error!("Failed to create TLS context: {:?}", e); error!("Failed to create TLS context: {:?}", e);
std::io::Error::from(ErrorKind::ConnectionAborted) std::io::Error::from(ErrorKind::ConnectionAborted)
@ -459,7 +486,11 @@ pub(crate) async fn handle_conn(
None None
}; };
let client_conn_info = ClientConnInfo { addr, client_cert }; let client_conn_info = ClientConnInfo {
connection_addr,
client_addr,
client_cert,
};
debug!(?client_conn_info); debug!(?client_conn_info);

View file

@ -2,12 +2,13 @@ use crate::actors::QueryServerReadV1;
use crate::CoreAction; use crate::CoreAction;
use futures_util::sink::SinkExt; use futures_util::sink::SinkExt;
use futures_util::stream::StreamExt; use futures_util::stream::StreamExt;
use haproxy_protocol::{ProxyHdrV2, RemoteAddress};
use kanidmd_lib::idm::ldap::{LdapBoundToken, LdapResponseState}; use kanidmd_lib::idm::ldap::{LdapBoundToken, LdapResponseState};
use kanidmd_lib::prelude::*; use kanidmd_lib::prelude::*;
use ldap3_proto::proto::LdapMsg; use ldap3_proto::proto::LdapMsg;
use ldap3_proto::LdapCodec; use ldap3_proto::LdapCodec;
use openssl::ssl::{Ssl, SslAcceptor}; use openssl::ssl::{Ssl, SslAcceptor};
use std::net; use std::net::SocketAddr;
use std::pin::Pin; use std::pin::Pin;
use std::str::FromStr; use std::str::FromStr;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
@ -33,7 +34,7 @@ impl LdapSession {
#[instrument(name = "ldap-request", skip(client_address, qe_r_ref))] #[instrument(name = "ldap-request", skip(client_address, qe_r_ref))]
async fn client_process_msg( async fn client_process_msg(
uat: Option<LdapBoundToken>, uat: Option<LdapBoundToken>,
client_address: net::SocketAddr, client_address: SocketAddr,
protomsg: LdapMsg, protomsg: LdapMsg,
qe_r_ref: &'static QueryServerReadV1, qe_r_ref: &'static QueryServerReadV1,
) -> Option<LdapResponseState> { ) -> Option<LdapResponseState> {
@ -50,7 +51,8 @@ async fn client_process_msg(
async fn client_process<STREAM>( async fn client_process<STREAM>(
stream: STREAM, stream: STREAM,
client_address: net::SocketAddr, client_address: SocketAddr,
connection_address: SocketAddr,
qe_r_ref: &'static QueryServerReadV1, qe_r_ref: &'static QueryServerReadV1,
) where ) where
STREAM: AsyncRead + AsyncWrite, STREAM: AsyncRead + AsyncWrite,
@ -67,6 +69,8 @@ async fn client_process<STREAM>(
let uat = session.uat.clone(); let uat = session.uat.clone();
let caddr = client_address; let caddr = client_address;
debug!(?client_address, ?connection_address);
match client_process_msg(uat, caddr, protomsg, qe_r_ref).await { match client_process_msg(uat, caddr, protomsg, qe_r_ref).await {
// I'd really have liked to have put this near the [LdapResponseState::Bind] but due // I'd really have liked to have put this near the [LdapResponseState::Bind] but due
// to the handing of `audit` it isn't possible due to borrows, etc. // to the handing of `audit` it isn't possible due to borrows, etc.
@ -112,28 +116,61 @@ async fn client_process<STREAM>(
} }
async fn client_tls_accept( async fn client_tls_accept(
tcpstream: TcpStream, stream: TcpStream,
tls_acceptor: SslAcceptor, tls_acceptor: SslAcceptor,
client_socket_addr: net::SocketAddr, connection_addr: SocketAddr,
qe_r_ref: &'static QueryServerReadV1, qe_r_ref: &'static QueryServerReadV1,
enable_haproxy_hdr: bool,
) { ) {
let (stream, client_addr) = if enable_haproxy_hdr {
match ProxyHdrV2::parse_from_read(stream).await {
Ok((stream, hdr)) => {
let remote_socket_addr = match hdr.to_remote_addr() {
RemoteAddress::Local => {
debug!("haproxy check - will not contain client data");
return;
}
RemoteAddress::TcpV4 { src, dst: _ } => SocketAddr::from(src),
RemoteAddress::TcpV6 { src, dst: _ } => SocketAddr::from(src),
remote_addr => {
error!(?remote_addr, "remote address in proxy header is invalid");
return;
}
};
(stream, remote_socket_addr)
}
Err(err) => {
error!(?err, "Unable to process proxy v2 header");
return;
}
}
} else {
(stream, connection_addr.clone())
};
// Start the event // Start the event
// From the parameters we need to create an SslContext. // From the parameters we need to create an SslContext.
let mut tlsstream = match Ssl::new(tls_acceptor.context()) let mut tlsstream = match Ssl::new(tls_acceptor.context())
.and_then(|tls_obj| SslStream::new(tls_obj, tcpstream)) .and_then(|tls_obj| SslStream::new(tls_obj, stream))
{ {
Ok(ta) => ta, Ok(ta) => ta,
Err(err) => { Err(err) => {
error!(?err, %client_socket_addr, "LDAP TLS setup error"); error!(?err, %client_addr, %connection_addr, "LDAP TLS setup error");
return; return;
} }
}; };
if let Err(err) = SslStream::accept(Pin::new(&mut tlsstream)).await { if let Err(err) = SslStream::accept(Pin::new(&mut tlsstream)).await {
error!(?err, %client_socket_addr, "LDAP TLS accept error"); error!(?err, %client_addr, %connection_addr, "LDAP TLS accept error");
return; return;
}; };
tokio::spawn(client_process(tlsstream, client_socket_addr, qe_r_ref)); tokio::spawn(client_process(
tlsstream,
client_addr,
connection_addr,
qe_r_ref,
));
} }
/// TLS LDAP Listener, hands off to [client_tls_accept] /// TLS LDAP Listener, hands off to [client_tls_accept]
@ -143,6 +180,7 @@ async fn ldap_tls_acceptor(
qe_r_ref: &'static QueryServerReadV1, qe_r_ref: &'static QueryServerReadV1,
mut rx: broadcast::Receiver<CoreAction>, mut rx: broadcast::Receiver<CoreAction>,
mut tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>, mut tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>,
enable_haproxy_hdr: bool,
) { ) {
loop { loop {
tokio::select! { tokio::select! {
@ -155,7 +193,7 @@ async fn ldap_tls_acceptor(
match accept_result { match accept_result {
Ok((tcpstream, client_socket_addr)) => { Ok((tcpstream, client_socket_addr)) => {
let clone_tls_acceptor = tls_acceptor.clone(); let clone_tls_acceptor = tls_acceptor.clone();
tokio::spawn(client_tls_accept(tcpstream, clone_tls_acceptor, client_socket_addr, qe_r_ref)); tokio::spawn(client_tls_accept(tcpstream, clone_tls_acceptor, client_socket_addr, qe_r_ref, enable_haproxy_hdr));
} }
Err(err) => { Err(err) => {
warn!(?err, "LDAP acceptor error, continuing"); warn!(?err, "LDAP acceptor error, continuing");
@ -187,7 +225,7 @@ async fn ldap_plaintext_acceptor(
accept_result = listener.accept() => { accept_result = listener.accept() => {
match accept_result { match accept_result {
Ok((tcpstream, client_socket_addr)) => { Ok((tcpstream, client_socket_addr)) => {
tokio::spawn(client_process(tcpstream, client_socket_addr, qe_r_ref)); tokio::spawn(client_process(tcpstream, client_socket_addr.clone(), client_socket_addr, qe_r_ref));
} }
Err(e) => { Err(e) => {
error!("LDAP acceptor error, continuing -> {:?}", e); error!("LDAP acceptor error, continuing -> {:?}", e);
@ -205,6 +243,7 @@ pub(crate) async fn create_ldap_server(
qe_r_ref: &'static QueryServerReadV1, qe_r_ref: &'static QueryServerReadV1,
rx: broadcast::Receiver<CoreAction>, rx: broadcast::Receiver<CoreAction>,
tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>, tls_acceptor_reload_rx: mpsc::Receiver<SslAcceptor>,
enable_haproxy_hdr: bool,
) -> Result<tokio::task::JoinHandle<()>, ()> { ) -> Result<tokio::task::JoinHandle<()>, ()> {
if address.starts_with(":::") { if address.starts_with(":::") {
// takes :::xxxx to xxxx // takes :::xxxx to xxxx
@ -212,7 +251,7 @@ pub(crate) async fn create_ldap_server(
error!("Address '{}' looks like an attempt to wildcard bind with IPv6 on port {} - please try using ldapbindaddress = '[::]:{}'", address, port, port); error!("Address '{}' looks like an attempt to wildcard bind with IPv6 on port {} - please try using ldapbindaddress = '[::]:{}'", address, port, port);
}; };
let addr = net::SocketAddr::from_str(address).map_err(|e| { let addr = SocketAddr::from_str(address).map_err(|e| {
error!("Could not parse LDAP server address {} -> {:?}", address, e); error!("Could not parse LDAP server address {} -> {:?}", address, e);
})?; })?;
@ -233,6 +272,7 @@ pub(crate) async fn create_ldap_server(
qe_r_ref, qe_r_ref,
rx, rx,
tls_acceptor_reload_rx, tls_acceptor_reload_rx,
enable_haproxy_hdr,
)) ))
} }
None => tokio::spawn(ldap_plaintext_acceptor(listener, qe_r_ref, rx)), None => tokio::spawn(ldap_plaintext_acceptor(listener, qe_r_ref, rx)),

View file

@ -1087,6 +1087,7 @@ pub async fn create_server_core(
server_read_ref, server_read_ref,
broadcast_tx.subscribe(), broadcast_tx.subscribe(),
ldap_tls_acceptor_reload_rx, ldap_tls_acceptor_reload_rx,
config.ldap_client_address_info.is_proxy_v2(),
) )
.await?; .await?;
Some(h) Some(h)