Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions pingora-core/src/listeners/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,29 @@ pub use l4::{ServerAddress, TcpSocketOptions};
/// The APIs to customize things like certificate during TLS server side handshake
#[async_trait]
pub trait TlsAccept {
// TODO: return error?
/// This function is called in the middle of a TLS handshake. Structs who
/// implement this function should provide tls certificate and key to the
/// [TlsRef] via `ssl_use_certificate` and `ssl_use_private_key`.
/// Note. This is only supported for openssl and boringssl
async fn certificate_callback(&self, _ssl: &mut TlsRef) -> () {
async fn certificate_callback(&self, _ssl: &mut TlsRef) {
// does nothing by default
}

/// Preferred variant of [`Self::certificate_callback`] for implementations
/// that need to reject certificate selection with an explicit error.
///
/// Returning an error will abort the handshake with a diagnostic message
/// derived from the error. By default this preserves backwards compatibility
/// by delegating to [`Self::certificate_callback`].
///
/// If both methods are implemented, this method is authoritative. Call
/// [`Self::certificate_callback`] from this method if you want to reuse the
/// legacy mutation logic before returning a structured result.
async fn certificate_callback_result(&self, ssl: &mut TlsRef) -> Result<()> {
self.certificate_callback(ssl).await;
Ok(())
}

/// This function is called after the TLS handshake is complete.
///
/// Any value returned from this function (other than `None`) will be stored in the
Expand Down
115 changes: 113 additions & 2 deletions pingora-core/src/protocols/tls/boringssl_openssl/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ pub async fn handshake_with_callback<S: IO>(
if !done {
// safety: we do hold a mut ref of tls_stream
let ssl_mut = unsafe { ext::ssl_mut(tls_stream.ssl()) };
callbacks.certificate_callback(ssl_mut).await;
callbacks
.certificate_callback_result(ssl_mut)
.await
.explain_err(TLSHandshakeFailure, |e| {
format!("certificate callback failed: {e}")
})?;
Pin::new(&mut tls_stream)
.resume_accept()
.await
Expand Down Expand Up @@ -145,9 +150,11 @@ mod tests {
use crate::protocols::tls::TlsRef;
use crate::tls::ext;
use crate::tls::ssl;
use pingora_error::{Error, ErrorType};

use async_trait::async_trait;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::io::DuplexStream;

Expand All @@ -165,6 +172,17 @@ mod tests {
let _ = stream.read(&mut buf).await;
}

async fn best_effort_client_task(client: DuplexStream) {
let ssl_context = ssl::SslContext::builder(ssl::SslMethod::tls())
.unwrap()
.build();
let mut ssl = ssl::Ssl::new(&ssl_context).unwrap();
ssl.set_hostname("pingora.org").unwrap();
ssl.set_verify(ssl::SslVerifyMode::NONE); // we don have a valid cert
let mut stream = SslStream::new(ssl, client).unwrap();
let _ = Pin::new(&mut stream).connect().await;
}

#[tokio::test]
#[cfg(feature = "any_tls")]
async fn test_async_cert() {
Expand All @@ -175,7 +193,7 @@ mod tests {
struct Callback;
#[async_trait]
impl TlsAccept for Callback {
async fn certificate_callback(&self, ssl: &mut TlsRef) -> () {
async fn certificate_callback(&self, ssl: &mut TlsRef) {
assert_eq!(
ssl.servername(ssl::NameType::HOST_NAME).unwrap(),
"pingora.org"
Expand Down Expand Up @@ -204,6 +222,99 @@ mod tests {
.unwrap();
}

#[tokio::test]
#[cfg(feature = "any_tls")]
async fn test_async_cert_error() {
let acceptor = ssl::SslAcceptor::mozilla_intermediate_v5(ssl::SslMethod::tls())
.unwrap()
.build();

struct Callback;
#[async_trait]
impl TlsAccept for Callback {
async fn certificate_callback_result(
&self,
_ssl: &mut TlsRef,
) -> pingora_error::Result<()> {
Error::e_explain(
ErrorType::InternalError,
"dynamic cert rejected by callback",
)
}
}

let cb: TlsAcceptCallbacks = Box::new(Callback);

let (client, server) = tokio::io::duplex(1024);

tokio::spawn(best_effort_client_task(client));

let err = handshake_with_callback(&acceptor, server, &cb)
.await
.unwrap_err();
let err_str = err.to_string();
assert!(err_str.contains("certificate callback failed:"));
assert!(err_str.contains("dynamic cert rejected by callback"));
assert_eq!(err.etype(), &ErrorType::TLSHandshakeFailure);
}

#[tokio::test]
#[cfg(feature = "any_tls")]
async fn test_async_cert_result_is_authoritative() {
let acceptor = ssl::SslAcceptor::mozilla_intermediate_v5(ssl::SslMethod::tls())
.unwrap()
.build();

struct Callback {
legacy_called: Arc<AtomicBool>,
result_called: Arc<AtomicBool>,
}

#[async_trait]
impl TlsAccept for Callback {
async fn certificate_callback(&self, _ssl: &mut TlsRef) {
self.legacy_called.store(true, Ordering::SeqCst);
}

async fn certificate_callback_result(
&self,
ssl: &mut TlsRef,
) -> pingora_error::Result<()> {
self.result_called.store(true, Ordering::SeqCst);

let cert = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR"));
let key = format!("{}/tests/keys/key.pem", env!("CARGO_MANIFEST_DIR"));

let cert_bytes = std::fs::read(cert).unwrap();
let cert = crate::tls::x509::X509::from_pem(&cert_bytes).unwrap();

let key_bytes = std::fs::read(key).unwrap();
let key = crate::tls::pkey::PKey::private_key_from_pem(&key_bytes).unwrap();
ext::ssl_use_certificate(ssl, &cert).unwrap();
ext::ssl_use_private_key(ssl, &key).unwrap();
Ok(())
}
}

let legacy_called = Arc::new(AtomicBool::new(false));
let result_called = Arc::new(AtomicBool::new(false));
let cb: TlsAcceptCallbacks = Box::new(Callback {
legacy_called: legacy_called.clone(),
result_called: result_called.clone(),
});

let (client, server) = tokio::io::duplex(1024);

tokio::spawn(client_task(client));

handshake_with_callback(&acceptor, server, &cb)
.await
.unwrap();

assert!(result_called.load(Ordering::SeqCst));
assert!(!legacy_called.load(Ordering::SeqCst));
}

#[tokio::test]
#[cfg(feature = "openssl_derived")]
async fn test_handshake_complete_callback() {
Expand Down
Loading