1- use std:: { path:: PathBuf , sync:: Arc , time:: Duration } ;
1+ extern crate core;
2+
3+ mod errors;
4+
5+ use std:: { path:: PathBuf , sync:: Arc } ;
26
37use alloy_primitives:: { Address , hex} ;
48use env_logger:: Env ;
@@ -11,10 +15,11 @@ use log::info;
1115use tokio:: { net:: TcpStream , sync:: Mutex } ;
1216use tokio_tungstenite:: { connect_async, MaybeTlsStream , WebSocketStream } ;
1317
14- use batcher:: types:: { parse_proving_system, BatchInclusionData , VerificationData } ;
18+ use batcher:: types:: { parse_proving_system, BatchInclusionData , VerificationData , ProvingSystemId } ;
1519
1620use clap:: Parser ;
1721use tungstenite:: Message ;
22+ use crate :: errors:: BatcherClientError ;
1823
1924#[ derive( Parser , Debug ) ]
2025#[ command( version, about, long_about = None ) ]
@@ -35,19 +40,17 @@ struct Args {
3540 #[ arg(
3641 name = "Public input file name" ,
3742 long = "public_input" ,
38- default_value = "."
3943 ) ]
40- pub_input_file_name : PathBuf ,
44+ pub_input_file_name : Option < PathBuf > ,
4145
42- #[ arg( name = "Verification key file name" , long = "vk" , default_value = "." ) ]
43- verification_key_file_name : PathBuf ,
46+ #[ arg( name = "Verification key file name" , long = "vk" ) ]
47+ verification_key_file_name : Option < PathBuf > ,
4448
4549 #[ arg(
4650 name = "VM prgram code file name" ,
4751 long = "vm_program" ,
48- default_value = "."
4952 ) ]
50- vm_program_code_file_name : PathBuf ,
53+ vm_program_code_file_name : Option < PathBuf > ,
5154
5255 #[ arg(
5356 name = "Number of repetitions" ,
@@ -65,87 +68,49 @@ struct Args {
6568}
6669
6770#[ tokio:: main]
68- async fn main ( ) {
71+ async fn main ( ) -> Result < ( ) , errors :: BatcherClientError > {
6972 let args = Args :: parse ( ) ;
7073
7174 env_logger:: Builder :: from_env ( Env :: default ( ) . default_filter_or ( "info" ) ) . init ( ) ;
7275
73- let url = url:: Url :: parse ( & args. connect_addr ) . unwrap ( ) ;
74- let ( ws_stream, _) = connect_async ( url) . await . expect ( "Failed to connect" ) ;
75- info ! ( "WebSocket handshake has been successfully completed" ) ;
76+ let url = url:: Url :: parse ( & args. connect_addr )
77+ . map_err ( |e| errors:: BatcherClientError :: InvalidUrl ( e, args. connect_addr . clone ( ) ) ) ?;
7678
77- let ( mut ws_write, ws_read) = ws_stream. split ( ) ;
78-
79- let proving_system = parse_proving_system ( & args. proving_system_flag ) . unwrap ( ) ;
80-
81- // Read proof file
82- let proof = std:: fs:: read ( & args. proof_file_name )
83- . unwrap_or_else ( |_| panic ! ( "Failed to read .proof file: {:?}" , args. proof_file_name) ) ;
79+ let ( ws_stream, _) = connect_async ( url) . await ?;
8480
85- // Read public input file
86- let mut pub_input: Option < Vec < u8 > > = None ;
87- if let Ok ( data) = std:: fs:: read ( args. pub_input_file_name ) {
88- pub_input = Some ( data) ;
89- } else {
90- info ! ( "No public input file provided, continuing without public input..." ) ;
91- }
92-
93- let mut verification_key: Option < Vec < u8 > > = None ;
94- if let Ok ( data) = std:: fs:: read ( args. verification_key_file_name ) {
95- verification_key = Some ( data) ;
96- } else {
97- info ! ( "No verification key file provided, continuing without verification key..." ) ;
98- }
81+ info ! ( "WebSocket handshake has been successfully completed" ) ;
9982
100- let mut vm_program_code: Option < Vec < u8 > > = None ;
101- if let Ok ( data) = std:: fs:: read ( args. vm_program_code_file_name ) {
102- vm_program_code = Some ( data) ;
103- } else {
104- info ! ( "No VM program code file provided, continuing without VM program code..." ) ;
105- }
83+ let ( mut ws_write, ws_read) = ws_stream. split ( ) ;
10684
107- let proof_generator_addr : Address =
108- Address :: parse_checksummed ( & args. proof_generator_addr , None ) . unwrap ( ) ;
85+ let repetitions = args . repetitions ;
86+ let verification_data = verification_data_from_args ( args) ? ;
10987
110- let verification_data = VerificationData {
111- proving_system,
112- proof,
113- pub_input,
114- verification_key,
115- vm_program_code,
116- proof_generator_addr,
117- } ;
118-
119- let json_data = serde_json:: to_string ( & verification_data) . expect ( "Failed to serialize task" ) ;
120- for _ in 0 ..args. repetitions {
121- // NOTE(marian): This sleep is only for ease of testing interactions between client and batcher,
122- // it can be removed.
123- std:: thread:: sleep ( Duration :: from_millis ( 500 ) ) ;
124- ws_write
125- . send ( tungstenite:: Message :: Text ( json_data. to_string ( ) ) )
126- . await
127- . unwrap ( ) ;
88+ let json_data = serde_json:: to_string ( & verification_data) ?;
89+ for _ in 0 ..repetitions {
90+ ws_write. send ( Message :: Text ( json_data. to_string ( ) ) ) . await ?;
12891 info ! ( "Message sent..." )
12992 }
13093
13194 let num_responses = Arc :: new ( Mutex :: new ( 0 ) ) ;
13295 let ws_write = Arc :: new ( Mutex :: new ( ws_write) ) ;
13396
134- receive ( ws_read, ws_write, args. repetitions , num_responses) . await ;
97+ receive ( ws_read, ws_write, repetitions, num_responses) . await ?;
98+
99+ Ok ( ( ) )
135100}
136101
137102async fn receive (
138103 ws_read : SplitStream < WebSocketStream < MaybeTlsStream < TcpStream > > > ,
139104 ws_write : Arc < Mutex < SplitSink < WebSocketStream < MaybeTlsStream < TcpStream > > , Message > > > ,
140105 total_messages : usize ,
141106 num_responses : Arc < Mutex < usize > > ,
142- ) {
107+ ) -> Result < ( ) , BatcherClientError > {
143108 ws_read
144109 . try_filter ( |msg| future:: ready ( msg. is_text ( ) || msg. is_binary ( ) ) )
145- . for_each ( |msg| async {
110+ . try_for_each ( |msg| async {
146111 let mut num_responses_lock = num_responses. lock ( ) . await ;
147112 * num_responses_lock += 1 ;
148- let data = msg. unwrap ( ) . into_data ( ) ;
113+ let data = msg. into_data ( ) ;
149114 let deserialized_data: BatchInclusionData = serde_json:: from_slice ( & data) . unwrap ( ) ;
150115 info ! ( "Batcher response received: {}" , deserialized_data) ;
151116
@@ -157,8 +122,59 @@ async fn receive(
157122
158123 if * num_responses_lock == total_messages {
159124 info ! ( "All messages responded. Closing connection..." ) ;
160- ws_write. lock ( ) . await . close ( ) . await . unwrap ( ) ;
125+ ws_write. lock ( ) . await . close ( ) . await ? ;
161126 }
162- } )
163- . await ;
127+
128+ Ok ( ( ) )
129+ } ) . await ?;
130+
131+ Ok ( ( ) )
132+ }
133+
134+ fn verification_data_from_args ( args : Args ) -> Result < VerificationData , BatcherClientError > {
135+ let proving_system = parse_proving_system ( & args. proving_system_flag )
136+ . map_err ( |_| errors:: BatcherClientError :: InvalidProvingSystem ( args. proving_system_flag ) ) ?;
137+
138+ // Read proof file
139+ let proof = read_file ( args. proof_file_name ) ?;
140+
141+ let mut pub_input: Option < Vec < u8 > > = None ;
142+ let mut verification_key: Option < Vec < u8 > > = None ;
143+ let mut vm_program_code: Option < Vec < u8 > > = None ;
144+
145+ match proving_system {
146+ ProvingSystemId :: SP1 => {
147+ vm_program_code = Some ( read_file_option ( "--vm_program" , args. vm_program_code_file_name ) ?) ;
148+ }
149+ ProvingSystemId :: Halo2KZG
150+ | ProvingSystemId :: Halo2IPA
151+ | ProvingSystemId :: GnarkPlonkBls12_381
152+ | ProvingSystemId :: GnarkPlonkBn254
153+ | ProvingSystemId :: Groth16Bn254 => {
154+ verification_key = Some ( read_file_option ( "--vk" , args. verification_key_file_name ) ?) ;
155+ pub_input = Some ( read_file_option ( "--public_input" , args. pub_input_file_name ) ?) ;
156+ }
157+ }
158+
159+ let proof_generator_addr: Address =
160+ Address :: parse_checksummed ( & args. proof_generator_addr , None ) . unwrap ( ) ;
161+
162+ Ok ( VerificationData {
163+ proving_system,
164+ proof,
165+ pub_input,
166+ verification_key,
167+ vm_program_code,
168+ proof_generator_addr,
169+ } )
170+ }
171+
172+ fn read_file ( file_name : PathBuf ) -> Result < Vec < u8 > , BatcherClientError > {
173+ std:: fs:: read ( & file_name)
174+ . map_err ( |e| BatcherClientError :: IoError ( file_name, e) )
175+ }
176+
177+ fn read_file_option ( param_name : & str , file_name : Option < PathBuf > ) -> Result < Vec < u8 > , BatcherClientError > {
178+ let file_name = file_name. ok_or ( BatcherClientError :: MissingParameter ( param_name. to_string ( ) ) ) ?;
179+ read_file ( file_name)
164180}
0 commit comments