2
2
//! server, which can be used in combination with an Axum [`Router`].
3
3
use std:: { net:: SocketAddr , sync:: Arc } ;
4
4
5
- use axum:: { Router , extract:: Request } ;
5
+ use axum:: {
6
+ Router ,
7
+ extract:: { ConnectInfo , Request } ,
8
+ middleware:: AddExtension ,
9
+ } ;
6
10
use cert_resolver:: { CertificateResolver , CertificateResolverError } ;
7
- use futures_util:: pin_mut;
8
11
use hyper:: { body:: Incoming , service:: service_fn} ;
9
12
use hyper_util:: rt:: { TokioExecutor , TokioIo } ;
10
13
use opentelemetry:: trace:: { FutureExt , SpanKind } ;
11
14
use snafu:: { ResultExt , Snafu } ;
12
15
use stackable_operator:: time:: Duration ;
13
- use tokio:: { net:: TcpListener , select, sync:: mpsc, time:: interval} ;
16
+ use tokio:: {
17
+ net:: { TcpListener , TcpStream } ,
18
+ sync:: mpsc,
19
+ } ;
14
20
use tokio_rustls:: {
15
21
TlsAcceptor ,
16
22
rustls:: {
@@ -48,9 +54,6 @@ pub enum TlsServerError {
48
54
49
55
#[ snafu( display( "failed to set safe TLS protocol versions" ) ) ]
50
56
SetSafeTlsProtocolVersions { source : tokio_rustls:: rustls:: Error } ,
51
-
52
- #[ snafu( display( "failed to run certificate rotation loop" ) ) ]
53
- RunCertificateRotationLoop { source : tokio:: task:: JoinError } ,
54
57
}
55
58
56
59
/// A server which terminates TLS connections and allows clients to communicate
@@ -109,8 +112,8 @@ impl TlsServer {
109
112
///
110
113
/// It also starts a background task to rotate the certificate as needed.
111
114
pub async fn run ( self ) -> Result < ( ) > {
112
- let certificate_rotation_loop =
113
- tokio:: spawn ( async { Self :: run_certificate_rotation_loop ( self . cert_resolver ) . await } ) ;
115
+ let start = tokio :: time :: Instant :: now ( ) + * WEBHOOK_CERTIFICATE_ROTATION_INTERVAL ;
116
+ let mut interval = tokio:: time :: interval_at ( start , * WEBHOOK_CERTIFICATE_ROTATION_INTERVAL ) ;
114
117
115
118
let tls_acceptor = TlsAcceptor :: from ( Arc :: new ( self . config ) ) ;
116
119
let tcp_listener =
@@ -135,123 +138,123 @@ impl TlsServer {
135
138
. router
136
139
. into_make_service_with_connect_info :: < SocketAddr > ( ) ;
137
140
138
- pin_mut ! ( certificate_rotation_loop) ;
139
141
loop {
140
142
let tls_acceptor = tls_acceptor. clone ( ) ;
141
143
142
- // Wait for either a new TCP connection or the certificate rotation loop to exit
143
- let tcp_stream = select ! {
144
- loop_result = & mut certificate_rotation_loop => {
145
- return loop_result. context( RunCertificateRotationLoopSnafu ) ?;
146
- }
147
- tcp_stream = tcp_listener. accept( ) => {
148
- tcp_stream
149
- }
150
- } ;
151
-
152
- let ( tcp_stream, remote_addr) = match tcp_stream {
153
- Ok ( ( stream, addr) ) => ( stream, addr) ,
154
- Err ( err) => {
155
- tracing:: trace!( %err, "failed to accept incoming TCP connection" ) ;
156
- continue ;
144
+ // Wait for either a new TCP connection or the certificate rotation interval tick
145
+ tokio:: select! {
146
+ // We opt for a biased execution of arms to make sure we always check if the
147
+ // certificate needs rotation based on the interval. This ensures, we always use
148
+ // a valid certificate for the TLS connection.
149
+ biased;
150
+
151
+ // This is cancellation-safe. If this branch is cancelled, the tick is NOT consumed.
152
+ // As such, we will not miss rotating the certificate.
153
+ _ = interval. tick( ) => {
154
+ self . cert_resolver
155
+ . rotate_certificate( )
156
+ . await
157
+ . context( RotateCertificateSnafu ) ?
157
158
}
158
- } ;
159
159
160
- // Here, the connect info is extracted by calling Tower's Service
161
- // trait function on `IntoMakeServiceWithConnectInfo`
162
- let tower_service = router. call ( remote_addr) . await . unwrap ( ) ;
163
-
164
- let span = tracing:: debug_span!( "accept tcp connection" ) ;
165
- tokio:: spawn (
166
- async move {
167
- let span = tracing:: trace_span!(
168
- "accept tls connection" ,
169
- "otel.kind" = ?SpanKind :: Server ,
170
- "otel.status_code" = Empty ,
171
- "otel.status_message" = Empty ,
172
- "client.address" = remote_addr. ip( ) . to_string( ) ,
173
- "client.port" = remote_addr. port( ) as i64 ,
174
- "server.address" = Empty ,
175
- "server.port" = Empty ,
176
- "network.peer.address" = remote_addr. ip( ) . to_string( ) ,
177
- "network.peer.port" = remote_addr. port( ) as i64 ,
178
- "network.local.address" = Empty ,
179
- "network.local.port" = Empty ,
180
- "network.transport" = "tcp" ,
181
- "network.type" = self . socket_addr. semantic_convention_network_type( ) ,
182
- ) ;
183
-
184
- if let Ok ( local_addr) = tcp_stream. local_addr ( ) {
185
- let addr = & local_addr. ip ( ) . to_string ( ) ;
186
- let port = local_addr. port ( ) ;
187
- span. record ( "server.address" , addr)
188
- . record ( "server.port" , port as i64 )
189
- . record ( "network.local.address" , addr)
190
- . record ( "network.local.port" , port as i64 ) ;
191
- }
192
-
193
- // Wait for tls handshake to happen
194
- let tls_stream = match tls_acceptor
195
- . accept ( tcp_stream)
196
- . instrument ( span. clone ( ) )
197
- . await
198
- {
199
- Ok ( tls_stream) => tls_stream,
160
+ // This is cancellation-safe. If cancelled, no new connections are accepted.
161
+ tcp_connection = tcp_listener. accept( ) => {
162
+ let ( tcp_stream, remote_addr) = match tcp_connection {
163
+ Ok ( ( stream, addr) ) => ( stream, addr) ,
200
164
Err ( err) => {
201
- span. record ( "otel.status_code" , "Error" )
202
- . record ( "otel.status_message" , err. to_string ( ) ) ;
203
- tracing:: trace!( %remote_addr, "error during tls handshake connection" ) ;
204
- return ;
165
+ tracing:: trace!( %err, "failed to accept incoming TCP connection" ) ;
166
+ continue ;
205
167
}
206
168
} ;
207
169
208
- // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
209
- // `TokioIo` converts between them.
210
- let tls_stream = TokioIo :: new ( tls_stream) ;
211
-
212
- // Hyper also has its own `Service` trait and doesn't use tower. We can use
213
- // `hyper::service::service_fn` to create a hyper `Service` that calls our app through
214
- // `tower::Service::call`.
215
- let hyper_service = service_fn ( move |request : Request < Incoming > | {
216
- // This carries the current context with the trace id so that the TraceLayer can use that as a parent
217
- let otel_context = Span :: current ( ) . context ( ) ;
218
- // We need to clone here, because oneshot consumes self
219
- tower_service
220
- . clone ( )
221
- . oneshot ( request)
222
- . with_context ( otel_context)
223
- } ) ;
224
-
225
- let span = tracing:: debug_span!( "serve connection" ) ;
226
- hyper_util:: server:: conn:: auto:: Builder :: new ( TokioExecutor :: new ( ) )
227
- . serve_connection_with_upgrades ( tls_stream, hyper_service)
228
- . instrument ( span. clone ( ) )
229
- . await
230
- . unwrap_or_else ( |err| {
231
- span. record ( "otel.status_code" , "Error" )
232
- . record ( "otel.status_message" , err. to_string ( ) ) ;
233
- tracing:: warn!( %err, %remote_addr, "failed to serve connection" ) ;
234
- } )
170
+ // Here, the connect info is extracted by calling Tower's Service
171
+ // trait function on `IntoMakeServiceWithConnectInfo`
172
+ let tower_service = router. call( remote_addr) . await . unwrap( ) ;
173
+
174
+ let span = tracing:: debug_span!( "accept tcp connection" ) ;
175
+ tokio:: spawn( async move {
176
+ Self :: handle_request( tcp_stream, remote_addr, tls_acceptor, tower_service, self . socket_addr)
177
+ } . instrument( span) ) ;
235
178
}
236
- . instrument ( span) ,
237
- ) ;
179
+ } ;
238
180
}
239
181
}
240
182
241
- async fn run_certificate_rotation_loop ( cert_resolver : Arc < CertificateResolver > ) -> Result < ( ) > {
242
- let mut interval = interval ( * WEBHOOK_CERTIFICATE_ROTATION_INTERVAL ) ;
243
- // Let the interval tick once, so that the first loop iteration does not start immediately,
244
- // thus generating a new cert.
245
- interval. tick ( ) . await ;
183
+ async fn handle_request (
184
+ tcp_stream : TcpStream ,
185
+ remote_addr : SocketAddr ,
186
+ tls_acceptor : TlsAcceptor ,
187
+ tower_service : AddExtension < Router , ConnectInfo < SocketAddr > > ,
188
+ socket_addr : SocketAddr ,
189
+ ) {
190
+ let span = tracing:: trace_span!(
191
+ "accept tls connection" ,
192
+ "otel.kind" = ?SpanKind :: Server ,
193
+ "otel.status_code" = Empty ,
194
+ "otel.status_message" = Empty ,
195
+ "client.address" = remote_addr. ip( ) . to_string( ) ,
196
+ "client.port" = remote_addr. port( ) as i64 ,
197
+ "server.address" = Empty ,
198
+ "server.port" = Empty ,
199
+ "network.peer.address" = remote_addr. ip( ) . to_string( ) ,
200
+ "network.peer.port" = remote_addr. port( ) as i64 ,
201
+ "network.local.address" = Empty ,
202
+ "network.local.port" = Empty ,
203
+ "network.transport" = "tcp" ,
204
+ "network.type" = socket_addr. semantic_convention_network_type( ) ,
205
+ ) ;
206
+
207
+ if let Ok ( local_addr) = tcp_stream. local_addr ( ) {
208
+ let addr = & local_addr. ip ( ) . to_string ( ) ;
209
+ let port = local_addr. port ( ) ;
210
+ span. record ( "server.address" , addr)
211
+ . record ( "server.port" , port as i64 )
212
+ . record ( "network.local.address" , addr)
213
+ . record ( "network.local.port" , port as i64 ) ;
214
+ }
246
215
247
- loop {
248
- interval. tick ( ) . await ;
216
+ // Wait for tls handshake to happen
217
+ let tls_stream = match tls_acceptor
218
+ . accept ( tcp_stream)
219
+ . instrument ( span. clone ( ) )
220
+ . await
221
+ {
222
+ Ok ( tls_stream) => tls_stream,
223
+ Err ( err) => {
224
+ span. record ( "otel.status_code" , "Error" )
225
+ . record ( "otel.status_message" , err. to_string ( ) ) ;
226
+ tracing:: trace!( %remote_addr, "error during tls handshake connection" ) ;
227
+ return ;
228
+ }
229
+ } ;
249
230
250
- cert_resolver
251
- . rotate_certificate ( )
252
- . await
253
- . context ( RotateCertificateSnafu ) ?;
254
- }
231
+ // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
232
+ // `TokioIo` converts between them.
233
+ let tls_stream = TokioIo :: new ( tls_stream) ;
234
+
235
+ // Hyper also has its own `Service` trait and doesn't use tower. We can use
236
+ // `hyper::service::service_fn` to create a hyper `Service` that calls our app through
237
+ // `tower::Service::call`.
238
+ let hyper_service = service_fn ( move |request : Request < Incoming > | {
239
+ // This carries the current context with the trace id so that the TraceLayer can use that as a parent
240
+ let otel_context = Span :: current ( ) . context ( ) ;
241
+ // We need to clone here, because oneshot consumes self
242
+ tower_service
243
+ . clone ( )
244
+ . oneshot ( request)
245
+ . with_context ( otel_context)
246
+ } ) ;
247
+
248
+ let span = tracing:: debug_span!( "serve connection" ) ;
249
+ hyper_util:: server:: conn:: auto:: Builder :: new ( TokioExecutor :: new ( ) )
250
+ . serve_connection_with_upgrades ( tls_stream, hyper_service)
251
+ . instrument ( span. clone ( ) )
252
+ . await
253
+ . unwrap_or_else ( |err| {
254
+ span. record ( "otel.status_code" , "Error" )
255
+ . record ( "otel.status_message" , err. to_string ( ) ) ;
256
+ tracing:: warn!( %err, %remote_addr, "failed to serve connection" ) ;
257
+ } )
255
258
}
256
259
}
257
260
0 commit comments