@@ -7,7 +7,7 @@ use std::{
77 atomic:: { AtomicU64 , Ordering } ,
88 } ,
99 task:: { Context , Poll } ,
10- time:: { Duration , SystemTime , UNIX_EPOCH } ,
10+ time:: { Duration , Instant } ,
1111} ;
1212
1313use anyhow:: { Context as _, Result } ;
@@ -110,19 +110,39 @@ async fn copy_with_idle_timeout(
110110) -> Result < ( ) > {
111111 let tracker = Arc :: new ( SharedIdleTracker :: new ( ) ) ;
112112
113- let mut timeout_client = IdleTimeoutStream :: new ( client, tracker. clone ( ) , IDLE_TIMEOUT ) ;
114- let mut timeout_proxy = IdleTimeoutStream :: new ( proxy, tracker, IDLE_TIMEOUT ) ;
113+ let mut active_client = ActiveStream :: new ( client, tracker. clone ( ) ) ;
114+ let mut active_proxy = ActiveStream :: new ( proxy, tracker. clone ( ) ) ;
115115
116- match copy_bidirectional ( & mut timeout_client, & mut timeout_proxy) . await {
117- Ok ( ( up, down) ) => {
118- stats:: update_metrics ( runtime, Protocol :: Tcp , proxy_name, target_addr, up, down) ;
119- Ok ( ( ) )
116+ let copy_task = copy_bidirectional ( & mut active_client, & mut active_proxy) ;
117+
118+ let monitor_task = async {
119+ loop {
120+ let elapsed = tracker. elapsed ( ) ;
121+ if elapsed >= IDLE_TIMEOUT {
122+ break ;
123+ }
124+ tokio:: time:: sleep ( IDLE_TIMEOUT - elapsed) . await ;
125+ }
126+ } ;
127+
128+ let ( up, down) = tokio:: select! {
129+ res = copy_task => {
130+ match res {
131+ Ok ( ( up, down) ) => ( up, down) ,
132+ Err ( e) => {
133+ debug!( "TCP relay error or closed: {}" , e) ;
134+ ( active_proxy. written_bytes, active_client. written_bytes)
135+ }
136+ }
120137 }
121- Err ( e ) => {
122- debug ! ( "TCP relay error: {}" , e ) ;
123- Ok ( ( ) )
138+ _ = monitor_task => {
139+ debug!( "TCP relay idle timeout for {}" , target_addr ) ;
140+ ( active_proxy . written_bytes , active_client . written_bytes )
124141 }
125- }
142+ } ;
143+
144+ stats:: update_metrics ( runtime, Protocol :: Tcp , proxy_name, target_addr, up, down) ;
145+ Ok ( ( ) )
126146}
127147
128148async fn find_session_target (
@@ -141,131 +161,86 @@ async fn find_session_target(
141161}
142162
143163struct SharedIdleTracker {
164+ base_instant : Instant ,
144165 last_activity : Arc < AtomicU64 > ,
145166}
146167
147168impl SharedIdleTracker {
148169 fn new ( ) -> Self {
149- let now_millis = SystemTime :: now ( )
150- . duration_since ( UNIX_EPOCH )
151- . unwrap ( )
152- . as_millis ( ) as u64 ;
153-
154170 Self {
155- last_activity : Arc :: new ( AtomicU64 :: new ( now_millis) ) ,
171+ base_instant : Instant :: now ( ) ,
172+ last_activity : Arc :: new ( AtomicU64 :: new ( 0 ) ) ,
156173 }
157174 }
158175
159176 fn update_activity ( & self ) {
160- let now_millis = SystemTime :: now ( )
161- . duration_since ( UNIX_EPOCH )
162- . unwrap ( )
163- . as_millis ( ) as u64 ;
164- self . last_activity . store ( now_millis, Ordering :: Relaxed ) ;
177+ let elapsed = self . base_instant . elapsed ( ) . as_millis ( ) as u64 ;
178+ self . last_activity . store ( elapsed, Ordering :: Relaxed ) ;
165179 }
166180
167181 fn elapsed ( & self ) -> Duration {
168182 let last_millis = self . last_activity . load ( Ordering :: Relaxed ) ;
169- let now_millis = SystemTime :: now ( )
170- . duration_since ( UNIX_EPOCH )
171- . unwrap ( )
172- . as_millis ( ) as u64 ;
183+ let now_millis = self . base_instant . elapsed ( ) . as_millis ( ) as u64 ;
173184 Duration :: from_millis ( now_millis. saturating_sub ( last_millis) )
174185 }
175-
176- fn is_idle ( & self , timeout : Duration ) -> bool {
177- self . elapsed ( ) > timeout
178- }
179186}
180187
181- struct IdleTimeoutStream < T > {
188+ struct ActiveStream < T > {
182189 inner : T ,
183190 tracker : Arc < SharedIdleTracker > ,
184- timeout : Duration ,
191+ written_bytes : u64 ,
185192}
186193
187- impl < T > IdleTimeoutStream < T > {
188- fn new ( inner : T , tracker : Arc < SharedIdleTracker > , timeout : Duration ) -> Self {
194+ impl < T > ActiveStream < T > {
195+ fn new ( inner : T , tracker : Arc < SharedIdleTracker > ) -> Self {
189196 Self {
190197 inner,
191198 tracker,
192- timeout ,
199+ written_bytes : 0 ,
193200 }
194201 }
195202
196203 fn update_activity ( & self ) {
197204 self . tracker . update_activity ( ) ;
198205 }
199-
200- fn check_idle ( & self ) -> tokio:: io:: Result < ( ) > {
201- if self . tracker . is_idle ( self . timeout ) {
202- return Err ( tokio:: io:: Error :: new (
203- tokio:: io:: ErrorKind :: TimedOut ,
204- "idle timeout - no activity on either side" ,
205- ) ) ;
206- }
207- Ok ( ( ) )
208- }
209-
210- fn is_normal_close ( e : & std:: io:: Error ) -> bool {
211- matches ! (
212- e. kind( ) ,
213- std:: io:: ErrorKind :: BrokenPipe
214- | std:: io:: ErrorKind :: ConnectionReset
215- | std:: io:: ErrorKind :: ConnectionAborted
216- | std:: io:: ErrorKind :: UnexpectedEof
217- )
218- }
219206}
220207
221- impl < T : AsyncRead + Unpin > AsyncRead for IdleTimeoutStream < T > {
208+ impl < T : AsyncRead + Unpin > AsyncRead for ActiveStream < T > {
222209 fn poll_read (
223210 mut self : Pin < & mut Self > ,
224211 cx : & mut Context < ' _ > ,
225212 buf : & mut tokio:: io:: ReadBuf < ' _ > ,
226213 ) -> Poll < std:: io:: Result < ( ) > > {
227- self . check_idle ( ) ?;
228-
229214 let initial_filled = buf. filled ( ) . len ( ) ;
230215 let poll = Pin :: new ( & mut self . inner ) . poll_read ( cx, buf) ;
231216
232- match & poll {
233- Poll :: Ready ( Ok ( ( ) ) ) if buf. filled ( ) . len ( ) > initial_filled => {
217+ if let Poll :: Ready ( Ok ( ( ) ) ) = & poll {
218+ let n = buf. filled ( ) . len ( ) - initial_filled;
219+ if n > 0 {
234220 self . update_activity ( ) ;
235221 }
236- Poll :: Ready ( Err ( e) ) if Self :: is_normal_close ( e) => {
237- return Poll :: Ready ( Ok ( ( ) ) ) ;
238- }
239- _ => { }
240222 }
241223
242224 poll
243225 }
244226}
245227
246- impl < T : AsyncWrite + Unpin > AsyncWrite for IdleTimeoutStream < T > {
228+ impl < T : AsyncWrite + Unpin > AsyncWrite for ActiveStream < T > {
247229 fn poll_write (
248230 mut self : Pin < & mut Self > ,
249231 cx : & mut Context < ' _ > ,
250232 buf : & [ u8 ] ,
251233 ) -> Poll < Result < usize , Error > > {
252- self . check_idle ( ) ?;
253-
254234 let poll = Pin :: new ( & mut self . inner ) . poll_write ( cx, buf) ;
255235
256- match poll {
257- Poll :: Ready ( Ok ( n) ) => {
258- if n > 0 {
259- self . update_activity ( ) ;
260- }
261- Poll :: Ready ( Ok ( n) )
262- }
263- Poll :: Ready ( Err ( e) ) if Self :: is_normal_close ( & e) => {
264- // Treat normal close as successful write of all bytes
265- Poll :: Ready ( Ok ( buf. len ( ) ) )
266- }
267- _ => poll,
236+ if let Poll :: Ready ( Ok ( n) ) = & poll
237+ && * n > 0
238+ {
239+ self . update_activity ( ) ;
240+ self . written_bytes += * n as u64 ;
268241 }
242+
243+ poll
269244 }
270245
271246 fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Error > > {
0 commit comments