Skip to content

Commit d2b2f53

Browse files
authored
fix (cli): error handling (#337)
1 parent 37814c7 commit d2b2f53

File tree

4 files changed

+125
-68
lines changed

4 files changed

+125
-68
lines changed

batcher/client/Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

batcher/client/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,5 @@ batcher = { path = "../" }
1616
bytes = "1.6.0"
1717
log = "0.4.21"
1818
env_logger = "0.11.3"
19-
anyhow = "1.0.83"
2019
alloy-primitives = "0.7.4"
2120
clap = { version = "4.5.4", features = ["derive"] }

batcher/client/src/errors.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use core::fmt;
2+
use std::io;
3+
use std::path::PathBuf;
4+
5+
pub enum BatcherClientError {
6+
MissingParameter(String),
7+
InvalidUrl(url::ParseError, String),
8+
InvalidProvingSystem(String),
9+
ConnectionError(tokio_tungstenite::tungstenite::Error),
10+
IoError(PathBuf, io::Error),
11+
SerdeError(serde_json::Error),
12+
}
13+
14+
impl From<tokio_tungstenite::tungstenite::Error> for BatcherClientError {
15+
fn from(e: tokio_tungstenite::tungstenite::Error) -> Self {
16+
BatcherClientError::ConnectionError(e)
17+
}
18+
}
19+
20+
impl From<serde_json::Error> for BatcherClientError {
21+
fn from(e: serde_json::Error) -> Self {
22+
BatcherClientError::SerdeError(e)
23+
}
24+
}
25+
26+
impl fmt::Debug for BatcherClientError {
27+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
28+
match self {
29+
BatcherClientError::MissingParameter(param) =>
30+
write!(f, "Missing parameter: {} required for this proving system", param),
31+
BatcherClientError::InvalidUrl(err, url) =>
32+
write!(f, "Invalid URL \"{}\", {}", url, err),
33+
BatcherClientError::InvalidProvingSystem(proving_system) =>
34+
write!(f, "Invalid proving system: {}", proving_system),
35+
BatcherClientError::ConnectionError(e) =>
36+
write!(f, "Web Socket Connection error: {}", e),
37+
BatcherClientError::IoError(path, e) =>
38+
write!(f, "IO error for file: \"{}\", {}", path.display(), e),
39+
BatcherClientError::SerdeError(e) =>
40+
write!(f, "Serialization error: {}", e),
41+
}
42+
}
43+
}

batcher/client/src/main.rs

Lines changed: 82 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
use std::{path::PathBuf, sync::Arc, time::Duration};
1+
extern crate core;
2+
3+
mod errors;
4+
5+
use std::{path::PathBuf, sync::Arc};
26

37
use alloy_primitives::{Address, hex};
48
use env_logger::Env;
@@ -11,10 +15,11 @@ use log::info;
1115
use tokio::{net::TcpStream, sync::Mutex};
1216
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
1317

14-
use batcher::types::{parse_proving_system, BatchInclusionData, VerificationData};
18+
use batcher::types::{parse_proving_system, BatchInclusionData, VerificationData, ProvingSystemId};
1519

1620
use clap::Parser;
1721
use tungstenite::Message;
22+
use crate::errors::BatcherClientError;
1823

1924
#[derive(Parser, Debug)]
2025
#[command(version, about, long_about = None)]
@@ -35,19 +40,17 @@ struct Args {
3540
#[arg(
3641
name = "Public input file name",
3742
long = "public_input",
38-
default_value = "."
3943
)]
40-
pub_input_file_name: PathBuf,
44+
pub_input_file_name: Option<PathBuf>,
4145

42-
#[arg(name = "Verification key file name", long = "vk", default_value = ".")]
43-
verification_key_file_name: PathBuf,
46+
#[arg(name = "Verification key file name", long = "vk")]
47+
verification_key_file_name: Option<PathBuf>,
4448

4549
#[arg(
4650
name = "VM prgram code file name",
4751
long = "vm_program",
48-
default_value = "."
4952
)]
50-
vm_program_code_file_name: PathBuf,
53+
vm_program_code_file_name: Option<PathBuf>,
5154

5255
#[arg(
5356
name = "Number of repetitions",
@@ -65,87 +68,49 @@ struct Args {
6568
}
6669

6770
#[tokio::main]
68-
async fn main() {
71+
async fn main() -> Result<(), errors::BatcherClientError> {
6972
let args = Args::parse();
7073

7174
env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();
7275

73-
let url = url::Url::parse(&args.connect_addr).unwrap();
74-
let (ws_stream, _) = connect_async(url).await.expect("Failed to connect");
75-
info!("WebSocket handshake has been successfully completed");
76+
let url = url::Url::parse(&args.connect_addr)
77+
.map_err(|e| errors::BatcherClientError::InvalidUrl(e, args.connect_addr.clone()))?;
7678

77-
let (mut ws_write, ws_read) = ws_stream.split();
78-
79-
let proving_system = parse_proving_system(&args.proving_system_flag).unwrap();
80-
81-
// Read proof file
82-
let proof = std::fs::read(&args.proof_file_name)
83-
.unwrap_or_else(|_| panic!("Failed to read .proof file: {:?}", args.proof_file_name));
79+
let (ws_stream, _) = connect_async(url).await?;
8480

85-
// Read public input file
86-
let mut pub_input: Option<Vec<u8>> = None;
87-
if let Ok(data) = std::fs::read(args.pub_input_file_name) {
88-
pub_input = Some(data);
89-
} else {
90-
info!("No public input file provided, continuing without public input...");
91-
}
92-
93-
let mut verification_key: Option<Vec<u8>> = None;
94-
if let Ok(data) = std::fs::read(args.verification_key_file_name) {
95-
verification_key = Some(data);
96-
} else {
97-
info!("No verification key file provided, continuing without verification key...");
98-
}
81+
info!("WebSocket handshake has been successfully completed");
9982

100-
let mut vm_program_code: Option<Vec<u8>> = None;
101-
if let Ok(data) = std::fs::read(args.vm_program_code_file_name) {
102-
vm_program_code = Some(data);
103-
} else {
104-
info!("No VM program code file provided, continuing without VM program code...");
105-
}
83+
let (mut ws_write, ws_read) = ws_stream.split();
10684

107-
let proof_generator_addr: Address =
108-
Address::parse_checksummed(&args.proof_generator_addr, None).unwrap();
85+
let repetitions = args.repetitions;
86+
let verification_data = verification_data_from_args(args)?;
10987

110-
let verification_data = VerificationData {
111-
proving_system,
112-
proof,
113-
pub_input,
114-
verification_key,
115-
vm_program_code,
116-
proof_generator_addr,
117-
};
118-
119-
let json_data = serde_json::to_string(&verification_data).expect("Failed to serialize task");
120-
for _ in 0..args.repetitions {
121-
// NOTE(marian): This sleep is only for ease of testing interactions between client and batcher,
122-
// it can be removed.
123-
std::thread::sleep(Duration::from_millis(500));
124-
ws_write
125-
.send(tungstenite::Message::Text(json_data.to_string()))
126-
.await
127-
.unwrap();
88+
let json_data = serde_json::to_string(&verification_data)?;
89+
for _ in 0..repetitions {
90+
ws_write.send(Message::Text(json_data.to_string())).await?;
12891
info!("Message sent...")
12992
}
13093

13194
let num_responses = Arc::new(Mutex::new(0));
13295
let ws_write = Arc::new(Mutex::new(ws_write));
13396

134-
receive(ws_read, ws_write, args.repetitions, num_responses).await;
97+
receive(ws_read, ws_write, repetitions, num_responses).await?;
98+
99+
Ok(())
135100
}
136101

137102
async fn receive(
138103
ws_read: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
139104
ws_write: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
140105
total_messages: usize,
141106
num_responses: Arc<Mutex<usize>>,
142-
) {
107+
) -> Result<(), BatcherClientError> {
143108
ws_read
144109
.try_filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
145-
.for_each(|msg| async {
110+
.try_for_each(|msg| async {
146111
let mut num_responses_lock = num_responses.lock().await;
147112
*num_responses_lock += 1;
148-
let data = msg.unwrap().into_data();
113+
let data = msg.into_data();
149114
let deserialized_data: BatchInclusionData = serde_json::from_slice(&data).unwrap();
150115
info!("Batcher response received: {}", deserialized_data);
151116

@@ -157,8 +122,59 @@ async fn receive(
157122

158123
if *num_responses_lock == total_messages {
159124
info!("All messages responded. Closing connection...");
160-
ws_write.lock().await.close().await.unwrap();
125+
ws_write.lock().await.close().await?;
161126
}
162-
})
163-
.await;
127+
128+
Ok(())
129+
}).await?;
130+
131+
Ok(())
132+
}
133+
134+
fn verification_data_from_args(args: Args) -> Result<VerificationData, BatcherClientError> {
135+
let proving_system = parse_proving_system(&args.proving_system_flag)
136+
.map_err(|_| errors::BatcherClientError::InvalidProvingSystem(args.proving_system_flag))?;
137+
138+
// Read proof file
139+
let proof = read_file(args.proof_file_name)?;
140+
141+
let mut pub_input: Option<Vec<u8>> = None;
142+
let mut verification_key: Option<Vec<u8>> = None;
143+
let mut vm_program_code: Option<Vec<u8>> = None;
144+
145+
match proving_system {
146+
ProvingSystemId::SP1 => {
147+
vm_program_code = Some(read_file_option("--vm_program", args.vm_program_code_file_name)?);
148+
}
149+
ProvingSystemId::Halo2KZG
150+
| ProvingSystemId::Halo2IPA
151+
| ProvingSystemId::GnarkPlonkBls12_381
152+
| ProvingSystemId::GnarkPlonkBn254
153+
| ProvingSystemId::Groth16Bn254 => {
154+
verification_key = Some(read_file_option("--vk", args.verification_key_file_name)?);
155+
pub_input = Some(read_file_option("--public_input", args.pub_input_file_name)?);
156+
}
157+
}
158+
159+
let proof_generator_addr: Address =
160+
Address::parse_checksummed(&args.proof_generator_addr, None).unwrap();
161+
162+
Ok(VerificationData {
163+
proving_system,
164+
proof,
165+
pub_input,
166+
verification_key,
167+
vm_program_code,
168+
proof_generator_addr,
169+
})
170+
}
171+
172+
fn read_file(file_name: PathBuf) -> Result<Vec<u8>, BatcherClientError> {
173+
std::fs::read(&file_name)
174+
.map_err(|e| BatcherClientError::IoError(file_name, e))
175+
}
176+
177+
fn read_file_option(param_name: &str, file_name: Option<PathBuf>) -> Result<Vec<u8>, BatcherClientError> {
178+
let file_name = file_name.ok_or(BatcherClientError::MissingParameter(param_name.to_string()))?;
179+
read_file(file_name)
164180
}

0 commit comments

Comments
 (0)