1
1
use super :: * ;
2
2
use crate :: util:: Socket ;
3
3
use std:: pin:: Pin ;
4
- use std:: sync:: atomic:: { AtomicU32 , Ordering } ;
4
+ use std:: sync:: atomic:: { AtomicBool , AtomicU32 , Ordering } ;
5
5
use std:: sync:: Arc ;
6
6
use std:: task:: { Context , Poll } ;
7
7
use tempfile:: tempdir;
8
8
use tokio:: io:: { duplex, AsyncRead , AsyncWrite , DuplexStream } ;
9
9
use tower:: Service ;
10
+ use std:: time:: Duration ;
10
11
11
12
#[ tokio:: test]
12
13
async fn test_sync_context_push_frame ( ) {
@@ -131,6 +132,50 @@ async fn test_sync_context_corrupted_metadata() {
131
132
assert_eq ! ( sync_ctx. generation( ) , 1 ) ;
132
133
}
133
134
135
+ #[ tokio:: test]
136
+ async fn test_sync_context_retry_on_error ( ) {
137
+ // Pause time to control it manually
138
+ tokio:: time:: pause ( ) ;
139
+
140
+ let server = MockServer :: start ( ) ;
141
+ let temp_dir = tempdir ( ) . unwrap ( ) ;
142
+ let db_path = temp_dir. path ( ) . join ( "test.db" ) ;
143
+
144
+ let sync_ctx = SyncContext :: new (
145
+ server. connector ( ) ,
146
+ db_path. to_str ( ) . unwrap ( ) . to_string ( ) ,
147
+ server. url ( ) ,
148
+ None ,
149
+ )
150
+ . await
151
+ . unwrap ( ) ;
152
+
153
+ let mut sync_ctx = sync_ctx;
154
+ let frame = Bytes :: from ( "test frame data" ) ;
155
+
156
+ // Set server to return errors
157
+ server. return_error . store ( true , Ordering :: SeqCst ) ;
158
+
159
+ // First attempt should fail but retry
160
+ let result = sync_ctx. push_one_frame ( frame. clone ( ) , 1 , 0 ) . await ;
161
+ assert ! ( result. is_err( ) ) ;
162
+
163
+ // Advance time to trigger retries faster
164
+ tokio:: time:: advance ( Duration :: from_secs ( 2 ) ) . await ;
165
+
166
+ // Verify multiple requests were made (retries occurred)
167
+ assert ! ( server. request_count( ) > 1 ) ;
168
+
169
+ // Allow the server to succeed
170
+ server. return_error . store ( false , Ordering :: SeqCst ) ;
171
+
172
+ // Next attempt should succeed
173
+ let durable_frame = sync_ctx. push_one_frame ( frame, 1 , 0 ) . await . unwrap ( ) ;
174
+ sync_ctx. write_metadata ( ) . await . unwrap ( ) ;
175
+ assert_eq ! ( durable_frame, 1 ) ;
176
+ assert_eq ! ( server. frame_count( ) , 1 ) ;
177
+ }
178
+
134
179
#[ test]
135
180
fn test_hash_verification ( ) {
136
181
let mut metadata = MetadataJson {
@@ -212,11 +257,15 @@ struct MockServer {
212
257
url : String ,
213
258
frame_count : Arc < AtomicU32 > ,
214
259
connector : ConnectorService ,
260
+ return_error : Arc < AtomicBool > ,
261
+ request_count : Arc < AtomicU32 > ,
215
262
}
216
263
217
264
impl MockServer {
218
265
fn start ( ) -> Self {
219
266
let frame_count = Arc :: new ( AtomicU32 :: new ( 0 ) ) ;
267
+ let return_error = Arc :: new ( AtomicBool :: new ( false ) ) ;
268
+ let request_count = Arc :: new ( AtomicU32 :: new ( 0 ) ) ;
220
269
221
270
// Create the mock connector with Some(client_stream)
222
271
let ( tx, mut rx) = tokio:: sync:: mpsc:: channel ( 1 ) ;
@@ -227,23 +276,43 @@ impl MockServer {
227
276
url : "http://mock.server" . to_string ( ) ,
228
277
frame_count : frame_count. clone ( ) ,
229
278
connector,
279
+ return_error : return_error. clone ( ) ,
280
+ request_count : request_count. clone ( ) ,
230
281
} ;
231
282
232
283
// Spawn the server handler
233
284
let frame_count_clone = frame_count. clone ( ) ;
285
+ let return_error_clone = return_error. clone ( ) ;
286
+ let request_count_clone = request_count. clone ( ) ;
234
287
235
288
tokio:: spawn ( async move {
236
289
while let Some ( server_stream) = rx. recv ( ) . await {
237
290
let frame_count_clone = frame_count_clone. clone ( ) ;
291
+ let return_error_clone = return_error_clone. clone ( ) ;
292
+ let request_count_clone = request_count_clone. clone ( ) ;
238
293
239
294
tokio:: spawn ( async move {
240
295
use hyper:: server:: conn:: Http ;
241
296
use hyper:: service:: service_fn;
242
297
243
298
let frame_count_clone = frame_count_clone. clone ( ) ;
299
+ let return_error_clone = return_error_clone. clone ( ) ;
300
+ let request_count_clone = request_count_clone. clone ( ) ;
244
301
let service = service_fn ( move |req : http:: Request < Body > | {
245
302
let frame_count = frame_count_clone. clone ( ) ;
303
+ let return_error = return_error_clone. clone ( ) ;
304
+ let request_count = request_count_clone. clone ( ) ;
246
305
async move {
306
+ request_count. fetch_add ( 1 , Ordering :: SeqCst ) ;
307
+ if return_error. load ( Ordering :: SeqCst ) {
308
+ return Ok :: < _ , hyper:: Error > (
309
+ http:: Response :: builder ( )
310
+ . status ( 500 )
311
+ . body ( Body :: from ( "Internal Server Error" ) )
312
+ . unwrap ( ) ,
313
+ ) ;
314
+ }
315
+
247
316
let current_count = frame_count. fetch_add ( 1 , Ordering :: SeqCst ) ;
248
317
249
318
if req. uri ( ) . path ( ) . contains ( "/sync/" ) {
@@ -287,6 +356,10 @@ impl MockServer {
287
356
fn frame_count ( & self ) -> u32 {
288
357
self . frame_count . load ( Ordering :: SeqCst )
289
358
}
359
+
360
+ fn request_count ( & self ) -> u32 {
361
+ self . request_count . load ( Ordering :: SeqCst )
362
+ }
290
363
}
291
364
292
365
// Mock connection that implements the Socket trait
0 commit comments