11use super :: * ;
22
3+ use std:: future:: Future ;
4+ use std:: ops:: Range ;
35use std:: sync:: atomic:: AtomicUsize ;
46use std:: sync:: atomic:: Ordering ;
57use tokio:: io:: AsyncBufReadExt ;
@@ -344,29 +346,23 @@ async fn resume_max_5_times() {
344346 let request_count = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
345347 let rc = request_count. clone ( ) ;
346348
347- let listener = TcpListener :: bind ( "localhost:0" ) . await . unwrap ( ) ;
348- let port = listener. local_addr ( ) . unwrap ( ) . port ( ) ;
349- let server_task = tokio:: spawn ( async move {
350- while let Ok ( ( mut stream, _addr) ) = listener. accept ( ) . await {
351- let response_task = async move {
352- let ( _, mut writer) = stream. split ( ) ;
353- // Always respond with only first chunk, never completing the response, triggering retries
354- let header = "\
349+ let handler = move |_, _| {
350+ let rc = rc. clone ( ) ;
351+ async move {
352+ rc. fetch_add ( 1 , Ordering :: SeqCst ) ;
353+ // Always respond with only first chunk, never completing the response, triggering retries
354+ let header = "\
355355 HTTP/1.1 200 OK\r \n \
356356 transfer-encoding: chunked\r \n \
357357 connection: close\r \n \
358358 content-type: application/octet-stream\r \n \
359359 accept-ranges: bytes\r \n ";
360360
361- let body = "AAAA" ;
362- let msg = format ! ( "{header}\r \n 4\r \n {body}\r \n " ) ;
363- writer. write_all ( msg. as_bytes ( ) ) . await . unwrap ( ) ;
364- writer. flush ( ) . await . unwrap ( ) ;
365- } ;
366- tokio:: spawn ( response_task) ;
367- rc. fetch_add ( 1 , Ordering :: SeqCst ) ;
361+ let body = "AAAA" ;
362+ format ! ( "{header}\r \n 4\r \n {body}\r \n " )
368363 }
369- } ) ;
364+ } ;
365+ let ( port, server_task) = spawn_server ( handler) . await ;
370366
371367 // Wait until task binds a listener on the TCP port
372368 tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 50 ) ) . await ;
@@ -395,38 +391,30 @@ async fn only_retry_until_success() {
395391 let request_count = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
396392 let rc = request_count. clone ( ) ;
397393
398- let listener = TcpListener :: bind ( "localhost:0" ) . await . unwrap ( ) ;
399- let port = listener. local_addr ( ) . unwrap ( ) . port ( ) ;
400- let server_task = tokio:: spawn ( async move {
401- while let Ok ( ( mut stream, _addr) ) = listener. accept ( ) . await {
402- let rc_num = rc. load ( Ordering :: SeqCst ) ;
403- let response_task = async move {
404- let ( _, mut writer) = stream. split ( ) ;
405- // Always respond with only first chunk, never completing the response, triggering retries
406- let header = "\
394+ let handler = move |_, request_count| {
395+ let rc = rc. clone ( ) ;
396+ async move {
397+ rc. fetch_add ( 1 , Ordering :: SeqCst ) ;
398+ // Always respond with only first chunk, never completing the response, triggering retries
399+ let header = "\
407400 HTTP/1.1 200 OK\r \n \
408401 transfer-encoding: chunked\r \n \
409402 connection: close\r \n \
410403 content-type: application/octet-stream\r \n \
411404 accept-ranges: bytes\r \n ";
412405
413- // On the 2nd request, send the full response, should trigger only 1 retry
414- let msg = if rc_num == 0 {
415- let body = "AAAA" ;
416- format ! ( "{header}\r \n 4\r \n {body}\r \n " )
417- } else {
418- let body = file;
419- let len = file. len ( ) ;
420- format ! ( "{header}\r \n {len:x}\r \n {body}\r \n 0\r \n \r \n " )
421- } ;
422-
423- writer. write_all ( msg. as_bytes ( ) ) . await . unwrap ( ) ;
424- writer. flush ( ) . await . unwrap ( ) ;
425- } ;
426- tokio:: spawn ( response_task) ;
427- rc. fetch_add ( 1 , Ordering :: SeqCst ) ;
406+ // On the 2nd request, send the full response, should trigger only 1 retry
407+ if request_count == 0 {
408+ let body = "AAAA" ;
409+ format ! ( "{header}\r \n 4\r \n {body}\r \n " )
410+ } else {
411+ let body = file;
412+ let len = file. len ( ) ;
413+ format ! ( "{header}\r \n {len:x}\r \n {body}\r \n 0\r \n \r \n " )
414+ }
428415 }
429- } ) ;
416+ } ;
417+ let ( port, server_task) = spawn_server ( handler) . await ;
430418
431419 // Wait until task binds a listener on the TCP port
432420 tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 50 ) ) . await ;
@@ -447,3 +435,52 @@ async fn only_retry_until_success() {
447435
448436 assert_eq ! ( request_count. load( Ordering :: SeqCst ) , 2 ) ;
449437}
438+
439+ async fn spawn_server < F , Fut > ( handle_request : F ) -> ( u16 , tokio:: task:: JoinHandle < ( ) > )
440+ where
441+ F : Fn ( Option < Range < usize > > , u32 ) -> Fut + ' static + Send + Sync + Clone ,
442+ Fut : Future < Output = String > + Send + Sync + ' static ,
443+ {
444+ let listener = TcpListener :: bind ( "localhost:0" ) . await . unwrap ( ) ;
445+ let port = listener. local_addr ( ) . unwrap ( ) . port ( ) ;
446+ let server_task = tokio:: spawn ( async move {
447+ let mut request_count = 0 ;
448+ while let Ok ( ( mut stream, _addr) ) = listener. accept ( ) . await {
449+ let handler = handle_request. clone ( ) ;
450+ tokio:: spawn ( async move {
451+ let ( reader, mut writer) = stream. split ( ) ;
452+ let request_range = parse_request_range ( reader) . await ;
453+
454+ let response: String = handler ( request_range, request_count) . await ;
455+
456+ writer. write_all ( response. as_bytes ( ) ) . await . unwrap ( ) ;
457+ writer. flush ( ) . await . unwrap ( ) ;
458+ } ) ;
459+ request_count += 1 ;
460+ }
461+ } ) ;
462+ ( port, server_task)
463+ }
464+
465+ /// Parse HTTP request from a BufReader, stopping at empty line
466+ async fn parse_request_range ( reader : tokio:: net:: tcp:: ReadHalf < ' _ > ) -> Option < Range < usize > > {
467+ let reader = BufReader :: new ( reader) ;
468+ let mut lines = reader. lines ( ) ;
469+ let mut range = None ;
470+
471+ while let Ok ( Some ( line) ) = lines. next_line ( ) . await {
472+ if line. to_ascii_lowercase ( ) . contains ( "range:" ) {
473+ let ( _, bytes) = line. split_once ( '=' ) . unwrap ( ) ;
474+ let ( start, end) = bytes. split_once ( '-' ) . unwrap ( ) ;
475+ let start = start. parse ( ) . unwrap_or ( 0 ) ;
476+ let end = end. parse ( ) . unwrap_or ( 0 ) ;
477+ range = Some ( start..end)
478+ }
479+ // On `\r\n\r\n` (empty line) stop reading the request
480+ if line. is_empty ( ) {
481+ break ;
482+ }
483+ }
484+
485+ range
486+ }
0 commit comments