Skip to content

Commit 51321f0

Browse files
sbernauerTechassi
andcommitted
Rewrite certificate rotation logic
Co-authored-by: Techassi <[email protected]>
1 parent ac29c0e commit 51321f0

File tree

1 file changed

+113
-110
lines changed
  • crates/stackable-webhook/src/tls

1 file changed

+113
-110
lines changed

crates/stackable-webhook/src/tls/mod.rs

Lines changed: 113 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,21 @@
22
//! server, which can be used in combination with an Axum [`Router`].
33
use std::{net::SocketAddr, sync::Arc};
44

5-
use axum::{Router, extract::Request};
5+
use axum::{
6+
Router,
7+
extract::{ConnectInfo, Request},
8+
middleware::AddExtension,
9+
};
610
use cert_resolver::{CertificateResolver, CertificateResolverError};
7-
use futures_util::pin_mut;
811
use hyper::{body::Incoming, service::service_fn};
912
use hyper_util::rt::{TokioExecutor, TokioIo};
1013
use opentelemetry::trace::{FutureExt, SpanKind};
1114
use snafu::{ResultExt, Snafu};
1215
use stackable_operator::time::Duration;
13-
use tokio::{net::TcpListener, select, sync::mpsc, time::interval};
16+
use tokio::{
17+
net::{TcpListener, TcpStream},
18+
sync::mpsc,
19+
};
1420
use tokio_rustls::{
1521
TlsAcceptor,
1622
rustls::{
@@ -48,9 +54,6 @@ pub enum TlsServerError {
4854

4955
#[snafu(display("failed to set safe TLS protocol versions"))]
5056
SetSafeTlsProtocolVersions { source: tokio_rustls::rustls::Error },
51-
52-
#[snafu(display("failed to run certificate rotation loop"))]
53-
RunCertificateRotationLoop { source: tokio::task::JoinError },
5457
}
5558

5659
/// A server which terminates TLS connections and allows clients to communicate
@@ -109,8 +112,8 @@ impl TlsServer {
109112
///
110113
/// It also starts a background task to rotate the certificate as needed.
111114
pub async fn run(self) -> Result<()> {
112-
let certificate_rotation_loop =
113-
tokio::spawn(async { Self::run_certificate_rotation_loop(self.cert_resolver).await });
115+
let start = tokio::time::Instant::now() + *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL;
116+
let mut interval = tokio::time::interval_at(start, *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL);
114117

115118
let tls_acceptor = TlsAcceptor::from(Arc::new(self.config));
116119
let tcp_listener =
@@ -135,123 +138,123 @@ impl TlsServer {
135138
.router
136139
.into_make_service_with_connect_info::<SocketAddr>();
137140

138-
pin_mut!(certificate_rotation_loop);
139141
loop {
140142
let tls_acceptor = tls_acceptor.clone();
141143

142-
// Wait for either a new TCP connection or the certificate rotation loop to exit
143-
let tcp_stream = select! {
144-
loop_result = &mut certificate_rotation_loop => {
145-
return loop_result.context(RunCertificateRotationLoopSnafu)?;
146-
}
147-
tcp_stream = tcp_listener.accept() => {
148-
tcp_stream
149-
}
150-
};
151-
152-
let (tcp_stream, remote_addr) = match tcp_stream {
153-
Ok((stream, addr)) => (stream, addr),
154-
Err(err) => {
155-
tracing::trace!(%err, "failed to accept incoming TCP connection");
156-
continue;
144+
// Wait for either a new TCP connection or the certificate rotation interval tick
145+
tokio::select! {
146+
// We opt for a biased execution of arms to make sure we always check if the
147+
// certificate needs rotation based on the interval. This ensures, we always use
148+
// a valid certificate for the TLS connection.
149+
biased;
150+
151+
// This is cancellation-safe. If this branch is cancelled, the tick is NOT consumed.
152+
// As such, we will not miss rotating the certificate.
153+
_ = interval.tick() => {
154+
self.cert_resolver
155+
.rotate_certificate()
156+
.await
157+
.context(RotateCertificateSnafu)?
157158
}
158-
};
159159

160-
// Here, the connect info is extracted by calling Tower's Service
161-
// trait function on `IntoMakeServiceWithConnectInfo`
162-
let tower_service = router.call(remote_addr).await.unwrap();
163-
164-
let span = tracing::debug_span!("accept tcp connection");
165-
tokio::spawn(
166-
async move {
167-
let span = tracing::trace_span!(
168-
"accept tls connection",
169-
"otel.kind" = ?SpanKind::Server,
170-
"otel.status_code" = Empty,
171-
"otel.status_message" = Empty,
172-
"client.address" = remote_addr.ip().to_string(),
173-
"client.port" = remote_addr.port() as i64,
174-
"server.address" = Empty,
175-
"server.port" = Empty,
176-
"network.peer.address" = remote_addr.ip().to_string(),
177-
"network.peer.port" = remote_addr.port() as i64,
178-
"network.local.address" = Empty,
179-
"network.local.port" = Empty,
180-
"network.transport" = "tcp",
181-
"network.type" = self.socket_addr.semantic_convention_network_type(),
182-
);
183-
184-
if let Ok(local_addr) = tcp_stream.local_addr() {
185-
let addr = &local_addr.ip().to_string();
186-
let port = local_addr.port();
187-
span.record("server.address", addr)
188-
.record("server.port", port as i64)
189-
.record("network.local.address", addr)
190-
.record("network.local.port", port as i64);
191-
}
192-
193-
// Wait for tls handshake to happen
194-
let tls_stream = match tls_acceptor
195-
.accept(tcp_stream)
196-
.instrument(span.clone())
197-
.await
198-
{
199-
Ok(tls_stream) => tls_stream,
160+
// This is cancellation-safe. If cancelled, no new connections are accepted.
161+
tcp_connection = tcp_listener.accept() => {
162+
let (tcp_stream, remote_addr) = match tcp_connection {
163+
Ok((stream, addr)) => (stream, addr),
200164
Err(err) => {
201-
span.record("otel.status_code", "Error")
202-
.record("otel.status_message", err.to_string());
203-
tracing::trace!(%remote_addr, "error during tls handshake connection");
204-
return;
165+
tracing::trace!(%err, "failed to accept incoming TCP connection");
166+
continue;
205167
}
206168
};
207169

208-
// Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
209-
// `TokioIo` converts between them.
210-
let tls_stream = TokioIo::new(tls_stream);
211-
212-
// Hyper also has its own `Service` trait and doesn't use tower. We can use
213-
// `hyper::service::service_fn` to create a hyper `Service` that calls our app through
214-
// `tower::Service::call`.
215-
let hyper_service = service_fn(move |request: Request<Incoming>| {
216-
// This carries the current context with the trace id so that the TraceLayer can use that as a parent
217-
let otel_context = Span::current().context();
218-
// We need to clone here, because oneshot consumes self
219-
tower_service
220-
.clone()
221-
.oneshot(request)
222-
.with_context(otel_context)
223-
});
224-
225-
let span = tracing::debug_span!("serve connection");
226-
hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
227-
.serve_connection_with_upgrades(tls_stream, hyper_service)
228-
.instrument(span.clone())
229-
.await
230-
.unwrap_or_else(|err| {
231-
span.record("otel.status_code", "Error")
232-
.record("otel.status_message", err.to_string());
233-
tracing::warn!(%err, %remote_addr, "failed to serve connection");
234-
})
170+
// Here, the connect info is extracted by calling Tower's Service
171+
// trait function on `IntoMakeServiceWithConnectInfo`
172+
let tower_service = router.call(remote_addr).await.unwrap();
173+
174+
let span = tracing::debug_span!("accept tcp connection");
175+
tokio::spawn(async move {
176+
Self::handle_request(tcp_stream, remote_addr, tls_acceptor, tower_service, self.socket_addr)
177+
}.instrument(span));
235178
}
236-
.instrument(span),
237-
);
179+
};
238180
}
239181
}
240182

241-
async fn run_certificate_rotation_loop(cert_resolver: Arc<CertificateResolver>) -> Result<()> {
242-
let mut interval = interval(*WEBHOOK_CERTIFICATE_ROTATION_INTERVAL);
243-
// Let the interval tick once, so that the first loop iteration does not start immediately,
244-
// thus generating a new cert.
245-
interval.tick().await;
183+
async fn handle_request(
184+
tcp_stream: TcpStream,
185+
remote_addr: SocketAddr,
186+
tls_acceptor: TlsAcceptor,
187+
tower_service: AddExtension<Router, ConnectInfo<SocketAddr>>,
188+
socket_addr: SocketAddr,
189+
) {
190+
let span = tracing::trace_span!(
191+
"accept tls connection",
192+
"otel.kind" = ?SpanKind::Server,
193+
"otel.status_code" = Empty,
194+
"otel.status_message" = Empty,
195+
"client.address" = remote_addr.ip().to_string(),
196+
"client.port" = remote_addr.port() as i64,
197+
"server.address" = Empty,
198+
"server.port" = Empty,
199+
"network.peer.address" = remote_addr.ip().to_string(),
200+
"network.peer.port" = remote_addr.port() as i64,
201+
"network.local.address" = Empty,
202+
"network.local.port" = Empty,
203+
"network.transport" = "tcp",
204+
"network.type" = socket_addr.semantic_convention_network_type(),
205+
);
206+
207+
if let Ok(local_addr) = tcp_stream.local_addr() {
208+
let addr = &local_addr.ip().to_string();
209+
let port = local_addr.port();
210+
span.record("server.address", addr)
211+
.record("server.port", port as i64)
212+
.record("network.local.address", addr)
213+
.record("network.local.port", port as i64);
214+
}
246215

247-
loop {
248-
interval.tick().await;
216+
// Wait for tls handshake to happen
217+
let tls_stream = match tls_acceptor
218+
.accept(tcp_stream)
219+
.instrument(span.clone())
220+
.await
221+
{
222+
Ok(tls_stream) => tls_stream,
223+
Err(err) => {
224+
span.record("otel.status_code", "Error")
225+
.record("otel.status_message", err.to_string());
226+
tracing::trace!(%remote_addr, "error during tls handshake connection");
227+
return;
228+
}
229+
};
249230

250-
cert_resolver
251-
.rotate_certificate()
252-
.await
253-
.context(RotateCertificateSnafu)?;
254-
}
231+
// Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
232+
// `TokioIo` converts between them.
233+
let tls_stream = TokioIo::new(tls_stream);
234+
235+
// Hyper also has its own `Service` trait and doesn't use tower. We can use
236+
// `hyper::service::service_fn` to create a hyper `Service` that calls our app through
237+
// `tower::Service::call`.
238+
let hyper_service = service_fn(move |request: Request<Incoming>| {
239+
// This carries the current context with the trace id so that the TraceLayer can use that as a parent
240+
let otel_context = Span::current().context();
241+
// We need to clone here, because oneshot consumes self
242+
tower_service
243+
.clone()
244+
.oneshot(request)
245+
.with_context(otel_context)
246+
});
247+
248+
let span = tracing::debug_span!("serve connection");
249+
hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
250+
.serve_connection_with_upgrades(tls_stream, hyper_service)
251+
.instrument(span.clone())
252+
.await
253+
.unwrap_or_else(|err| {
254+
span.record("otel.status_code", "Error")
255+
.record("otel.status_message", err.to_string());
256+
tracing::warn!(%err, %remote_addr, "failed to serve connection");
257+
})
255258
}
256259
}
257260

0 commit comments

Comments
 (0)