@@ -10,7 +10,7 @@ use hyper_util::rt::{TokioExecutor, TokioIo};
1010use opentelemetry:: trace:: { FutureExt , SpanKind } ;
1111use snafu:: { ResultExt , Snafu } ;
1212use stackable_operator:: time:: Duration ;
13- use tokio:: { net:: TcpListener , sync:: mpsc, time:: interval} ;
13+ use tokio:: { net:: TcpListener , select , sync:: mpsc, time:: interval} ;
1414use tokio_rustls:: {
1515 TlsAcceptor ,
1616 rustls:: {
@@ -48,6 +48,9 @@ pub enum TlsServerError {
4848
4949 #[ snafu( display( "failed to set safe TLS protocol versions" ) ) ]
5050 SetSafeTlsProtocolVersions { source : tokio_rustls:: rustls:: Error } ,
51+
52+ #[ snafu( display( "failed to run certificate rotation loop" ) ) ]
53+ RunCertificateRotationLoop { source : tokio:: task:: JoinError } ,
5154}
5255
5356/// A server which terminates TLS connections and allows clients to communicate
@@ -98,7 +101,8 @@ impl TlsServer {
98101 /// TLS stream get handled by a Hyper service, which in turn is an Axum
99102 /// router.
100103 pub async fn run ( self ) -> Result < ( ) > {
101- tokio:: spawn ( async { Self :: run_certificate_rotation_loop ( self . cert_resolver ) . await } ) ;
104+ let certificate_rotation_loop =
105+ tokio:: spawn ( async { Self :: run_certificate_rotation_loop ( self . cert_resolver ) . await } ) ;
102106
103107 let tls_acceptor = TlsAcceptor :: from ( Arc :: new ( self . config ) ) ;
104108 let tcp_listener =
@@ -123,12 +127,21 @@ impl TlsServer {
123127 . router
124128 . into_make_service_with_connect_info :: < SocketAddr > ( ) ;
125129
126- pin_mut ! ( tcp_listener ) ;
130+ pin_mut ! ( certificate_rotation_loop ) ;
127131 loop {
128132 let tls_acceptor = tls_acceptor. clone ( ) ;
129133
130- // Wait for new tcp connection
131- let ( tcp_stream, remote_addr) = match tcp_listener. accept ( ) . await {
134+ // Wait for either a new TCP connection or the certificate rotation loop to exit
135+ let tcp_stream = select ! {
136+ loop_result = & mut certificate_rotation_loop => {
137+ return loop_result. context( RunCertificateRotationLoopSnafu ) ?;
138+ }
139+ tcp_stream = tcp_listener. accept( ) => {
140+ tcp_stream
141+ }
142+ } ;
143+
144+ let ( tcp_stream, remote_addr) = match tcp_stream {
132145 Ok ( ( stream, addr) ) => ( stream, addr) ,
133146 Err ( err) => {
134147 tracing:: trace!( %err, "failed to accept incoming TCP connection" ) ;
0 commit comments