1- use std:: path:: PathBuf ;
1+ use std:: { path:: PathBuf , sync :: Arc , time :: Duration } ;
22
33use alloy_primitives:: Address ;
44use env_logger:: Env ;
5- use futures_util:: { future, SinkExt , StreamExt , TryStreamExt } ;
6- use log:: { info} ;
7- use tokio_tungstenite:: connect_async;
5+ use futures_util:: {
6+ future,
7+ stream:: { SplitSink , SplitStream } ,
8+ SinkExt , StreamExt , TryStreamExt ,
9+ } ;
10+ use log:: info;
11+ use tokio:: { net:: TcpStream , sync:: Mutex } ;
12+ use tokio_tungstenite:: { connect_async, MaybeTlsStream , WebSocketStream } ;
813
9- use batcher:: types:: { parse_proving_system, VerificationData } ;
14+ use batcher:: types:: { parse_proving_system, BatchInclusionData , VerificationData } ;
1015
1116use clap:: Parser ;
17+ use tungstenite:: Message ;
1218
1319#[ derive( Parser , Debug ) ]
1420#[ command( version, about, long_about = None ) ]
@@ -48,15 +54,14 @@ struct Args {
4854 long = "repetitions" ,
4955 default_value = "1"
5056 ) ]
51- repetitions : u32 ,
57+ repetitions : usize ,
5258
5359 #[ arg(
5460 name = "Proof generator address" ,
5561 long = "proof_generator_addr" ,
5662 default_value = "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266"
5763 ) ] // defaults to anvil address 1
5864 proof_generator_addr : String ,
59-
6065}
6166
6267#[ tokio:: main]
@@ -99,7 +104,8 @@ async fn main() {
99104 info ! ( "No VM program code file provided, continuing without VM program code..." ) ;
100105 }
101106
102- let proof_generator_addr: Address = Address :: parse_checksummed ( & args. proof_generator_addr , None ) . unwrap ( ) ;
107+ let proof_generator_addr: Address =
108+ Address :: parse_checksummed ( & args. proof_generator_addr , None ) . unwrap ( ) ;
103109
104110 let verification_data = VerificationData {
105111 proving_system,
@@ -112,23 +118,41 @@ async fn main() {
112118
113119 let json_data = serde_json:: to_string ( & verification_data) . expect ( "Failed to serialize task" ) ;
114120 for _ in 0 ..args. repetitions {
121+ // NOTE(marian): This sleep is only for ease of testing interactions between client and batcher,
122+ // it can be removed.
123+ std:: thread:: sleep ( Duration :: from_millis ( 500 ) ) ;
115124 ws_write
116125 . send ( tungstenite:: Message :: Text ( json_data. to_string ( ) ) )
117126 . await
118127 . unwrap ( ) ;
128+ info ! ( "Message sent..." )
119129 }
120130
131+ let num_responses = Arc :: new ( Mutex :: new ( 0 ) ) ;
132+ let ws_write = Arc :: new ( Mutex :: new ( ws_write) ) ;
133+
134+ receive ( ws_read, ws_write, args. repetitions , num_responses) . await ;
135+ }
136+
137+ async fn receive (
138+ ws_read : SplitStream < WebSocketStream < MaybeTlsStream < TcpStream > > > ,
139+ ws_write : Arc < Mutex < SplitSink < WebSocketStream < MaybeTlsStream < TcpStream > > , Message > > > ,
140+ total_messages : usize ,
141+ num_responses : Arc < Mutex < usize > > ,
142+ ) {
121143 ws_read
122- . try_filter ( |msg| future:: ready ( msg. is_text ( ) ) )
123- . for_each ( |msg| async move {
124- let data = msg. unwrap ( ) . into_text ( ) . unwrap ( ) ;
125- info ! ( "Batch merkle root received: {}" , data) ;
144+ . try_filter ( |msg| future:: ready ( msg. is_text ( ) || msg. is_binary ( ) ) )
145+ . for_each ( |msg| async {
146+ let mut num_responses_lock = num_responses. lock ( ) . await ;
147+ * num_responses_lock += 1 ;
148+ let data = msg. unwrap ( ) . into_data ( ) ;
149+ let deserialized_data: BatchInclusionData = serde_json:: from_slice ( & data) . unwrap ( ) ;
150+ info ! ( "Batcher response received: {}" , deserialized_data) ;
151+
152+ if * num_responses_lock == total_messages {
153+ info ! ( "All messages responded. Closing connection..." ) ;
154+ ws_write. lock ( ) . await . close ( ) . await . unwrap ( ) ;
155+ }
126156 } )
127157 . await ;
128-
129- info ! ( "Closing connection..." ) ;
130- ws_write
131- . close ( )
132- . await
133- . expect ( "Failed to close WebSocket connection" ) ;
134158}
0 commit comments