@@ -7,7 +7,7 @@ use std::{
77 atomic:: { AtomicU64 , Ordering } ,
88 } ,
99 task:: { Context , Poll } ,
10- time:: { Duration , Instant } ,
10+ time:: Duration ,
1111} ;
1212
1313use anyhow:: { Context as _, Result } ;
@@ -109,19 +109,23 @@ async fn copy_with_idle_timeout(
109109 target_addr : & str ,
110110) -> Result < ( ) > {
111111 let tracker = Arc :: new ( SharedIdleTracker :: new ( ) ) ;
112+ let client_written = Arc :: new ( AtomicU64 :: new ( 0 ) ) ;
113+ let proxy_written = Arc :: new ( AtomicU64 :: new ( 0 ) ) ;
112114
113- let mut active_client = ActiveStream :: new ( client, tracker. clone ( ) ) ;
114- let mut active_proxy = ActiveStream :: new ( proxy, tracker. clone ( ) ) ;
115+ let mut active_client = ActiveStream :: new ( client, tracker. clone ( ) , client_written . clone ( ) ) ;
116+ let mut active_proxy = ActiveStream :: new ( proxy, tracker. clone ( ) , proxy_written . clone ( ) ) ;
115117
116118 let copy_task = copy_bidirectional ( & mut active_client, & mut active_proxy) ;
117119
118120 let monitor_task = async {
119121 loop {
120- let elapsed = tracker. elapsed ( ) ;
121- if elapsed >= IDLE_TIMEOUT {
122+ let last = tracker. last_activity_instant ( ) ;
123+ let deadline = last + IDLE_TIMEOUT ;
124+ tokio:: time:: sleep_until ( deadline) . await ;
125+ if tracker. last_activity_instant ( ) <= last {
126+ // No activity since we woke up (or last activity is older) -> idle timeout
122127 break ;
123128 }
124- tokio:: time:: sleep ( IDLE_TIMEOUT - elapsed) . await ;
125129 }
126130 } ;
127131
@@ -131,13 +135,13 @@ async fn copy_with_idle_timeout(
131135 Ok ( ( up, down) ) => ( up, down) ,
132136 Err ( e) => {
133137 debug!( "TCP relay error or closed: {}" , e) ;
134- ( active_proxy . written_bytes , active_client . written_bytes )
138+ ( proxy_written . load ( Ordering :: Relaxed ) , client_written . load ( Ordering :: Relaxed ) )
135139 }
136140 }
137141 }
138142 _ = monitor_task => {
139143 debug!( "TCP relay idle timeout for {}" , target_addr) ;
140- ( active_proxy . written_bytes , active_client . written_bytes )
144+ ( proxy_written . load ( Ordering :: Relaxed ) , client_written . load ( Ordering :: Relaxed ) )
141145 }
142146 } ;
143147
@@ -161,42 +165,41 @@ async fn find_session_target(
161165}
162166
163167struct SharedIdleTracker {
164- base_instant : Instant ,
165- last_activity : Arc < AtomicU64 > ,
168+ base_instant : tokio :: time :: Instant ,
169+ last_activity_millis : Arc < AtomicU64 > ,
166170}
167171
168172impl SharedIdleTracker {
169173 fn new ( ) -> Self {
170174 Self {
171- base_instant : Instant :: now ( ) ,
172- last_activity : Arc :: new ( AtomicU64 :: new ( 0 ) ) ,
175+ base_instant : tokio :: time :: Instant :: now ( ) ,
176+ last_activity_millis : Arc :: new ( AtomicU64 :: new ( 0 ) ) ,
173177 }
174178 }
175179
176180 fn update_activity ( & self ) {
177181 let elapsed = self . base_instant . elapsed ( ) . as_millis ( ) as u64 ;
178- self . last_activity . store ( elapsed, Ordering :: Relaxed ) ;
182+ self . last_activity_millis . store ( elapsed, Ordering :: Relaxed ) ;
179183 }
180184
181- fn elapsed ( & self ) -> Duration {
182- let last_millis = self . last_activity . load ( Ordering :: Relaxed ) ;
183- let now_millis = self . base_instant . elapsed ( ) . as_millis ( ) as u64 ;
184- Duration :: from_millis ( now_millis. saturating_sub ( last_millis) )
185+ fn last_activity_instant ( & self ) -> tokio:: time:: Instant {
186+ let millis = self . last_activity_millis . load ( Ordering :: Relaxed ) ;
187+ self . base_instant + Duration :: from_millis ( millis)
185188 }
186189}
187190
188191struct ActiveStream < T > {
189192 inner : T ,
190193 tracker : Arc < SharedIdleTracker > ,
191- written_bytes : u64 ,
194+ written_bytes : Arc < AtomicU64 > ,
192195}
193196
194197impl < T > ActiveStream < T > {
195- fn new ( inner : T , tracker : Arc < SharedIdleTracker > ) -> Self {
198+ fn new ( inner : T , tracker : Arc < SharedIdleTracker > , written_bytes : Arc < AtomicU64 > ) -> Self {
196199 Self {
197200 inner,
198201 tracker,
199- written_bytes : 0 ,
202+ written_bytes,
200203 }
201204 }
202205
@@ -233,11 +236,11 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for ActiveStream<T> {
233236 ) -> Poll < Result < usize , Error > > {
234237 let poll = Pin :: new ( & mut self . inner ) . poll_write ( cx, buf) ;
235238
236- if let Poll :: Ready ( Ok ( n) ) = & poll
237- && * n > 0
238- {
239- self . update_activity ( ) ;
240- self . written_bytes += * n as u64 ;
239+ if let Poll :: Ready ( Ok ( n) ) = & poll {
240+ if * n > 0 {
241+ self . update_activity ( ) ;
242+ self . written_bytes . fetch_add ( * n as u64 , Ordering :: Relaxed ) ;
243+ }
241244 }
242245
243246 poll
0 commit comments