@@ -20,6 +20,7 @@ use s2_sdk::{
2020 } ,
2121} ;
2222use tokio:: time:: Instant ;
23+ use tokio:: sync:: mpsc;
2324use xxhash_rust:: xxh3:: Xxh3Default ;
2425
2526use crate :: error:: { CliError , OpKind } ;
@@ -493,8 +494,36 @@ pub async fn run(
493494 write_done_records. clone ( ) ,
494495 bench_start,
495496 ) ;
496- let mut write_stream = std:: pin:: pin!( write_stream) ;
497- let mut read_stream = std:: pin:: pin!( read_stream) ;
497+
498+ enum BenchEvent {
499+ Write ( Result < BenchWriteSample , CliError > ) ,
500+ Read ( Result < BenchReadSample , CliError > ) ,
501+ WriteDone ,
502+ ReadDone ,
503+ }
504+
505+ let ( tx, mut rx) = mpsc:: unbounded_channel ( ) ;
506+ let write_tx = tx. clone ( ) ;
507+ let write_handle = tokio:: spawn ( async move {
508+ let mut write_stream = std:: pin:: pin!( write_stream) ;
509+ while let Some ( sample) = write_stream. next ( ) . await {
510+ if write_tx. send ( BenchEvent :: Write ( sample) ) . is_err ( ) {
511+ return ;
512+ }
513+ }
514+ let _ = write_tx. send ( BenchEvent :: WriteDone ) ;
515+ } ) ;
516+ let read_tx = tx. clone ( ) ;
517+ let read_handle = tokio:: spawn ( async move {
518+ let mut read_stream = std:: pin:: pin!( read_stream) ;
519+ while let Some ( sample) = read_stream. next ( ) . await {
520+ if read_tx. send ( BenchEvent :: Read ( sample) ) . is_err ( ) {
521+ return ;
522+ }
523+ }
524+ let _ = read_tx. send ( BenchEvent :: ReadDone ) ;
525+ } ) ;
526+ drop ( tx) ;
498527
499528 let deadline = bench_start + duration;
500529 let mut write_done = false ;
@@ -511,47 +540,56 @@ pub async fn run(
511540 _ = tokio:: signal:: ctrl_c( ) , if !stop. load( Ordering :: Relaxed ) => {
512541 stop. store( true , Ordering :: Relaxed ) ;
513542 }
514- result = write_stream . next ( ) , if !write_done => {
515- match result {
516- Some ( Ok ( sample) ) => {
543+ event = rx . recv ( ) => {
544+ match event {
545+ Some ( BenchEvent :: Write ( Ok ( sample) ) ) => {
517546 update_bench_bar( & write_bar, & sample) ;
518547 all_ack_latencies. extend( sample. ack_latencies. iter( ) . copied( ) ) ;
519548 if let Some ( hash) = sample. run_hash {
520549 write_run_hash = Some ( hash) ;
521550 }
522551 write_sample = Some ( sample) ;
523552 }
524- Some ( Err ( e) ) => {
553+ Some ( BenchEvent :: Write ( Err ( e) ) ) => {
525554 write_bar. finish_and_clear( ) ;
526555 read_bar. finish_and_clear( ) ;
556+ stop. store( true , Ordering :: Relaxed ) ;
557+ write_handle. abort( ) ;
558+ read_handle. abort( ) ;
527559 return Err ( e) ;
528560 }
529- None => {
561+ Some ( BenchEvent :: WriteDone ) => {
530562 write_done = true ;
531563 }
532- }
533- }
534- result = read_stream. next( ) , if !read_done => {
535- match result {
536- Some ( Ok ( sample) ) => {
564+ Some ( BenchEvent :: Read ( Ok ( sample) ) ) => {
537565 update_bench_bar( & read_bar, & sample) ;
538566 all_e2e_latencies. extend( sample. e2e_latencies. iter( ) . copied( ) ) ;
539567 if let Some ( hash) = sample. run_hash {
540568 read_run_hash = Some ( hash) ;
541569 }
542570 read_sample = Some ( sample) ;
543571 }
544- Some ( Err ( e) ) => {
572+ Some ( BenchEvent :: Read ( Err ( e) ) ) => {
545573 write_bar. finish_and_clear( ) ;
546574 read_bar. finish_and_clear( ) ;
575+ stop. store( true , Ordering :: Relaxed ) ;
576+ write_handle. abort( ) ;
577+ read_handle. abort( ) ;
547578 return Err ( e) ;
548579 }
549- None => read_done = true ,
580+ Some ( BenchEvent :: ReadDone ) => read_done = true ,
581+ None => {
582+ write_done = true ;
583+ read_done = true ;
584+ }
550585 }
551586 }
552587 }
553588 }
554589
590+ let _ = write_handle. await ;
591+ let _ = read_handle. await ;
592+
555593 write_bar. finish_and_clear ( ) ;
556594 read_bar. finish_and_clear ( ) ;
557595
0 commit comments