Skip to content

Commit b9a26e1

Browse files
authored
refactor: multi (#78)
1 parent 22ad34e commit b9a26e1

File tree

9 files changed

+115
-30
lines changed

9 files changed

+115
-30
lines changed

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ edition = "2021"
55

66
[dependencies]
77
anyhow = "1.0"
8-
log = "0.4"
98
serde = { version = "1", features = ["derive"] }
109
serde_json = "1.0"
1110
ethers-core = { git = "https://github.com/scroll-tech/ethers-rs.git", branch = "v2.0.7" }
@@ -21,8 +20,8 @@ tokio = { version = "1.37.0", features = ["full"] }
2120
async-trait = "0.1"
2221
http = "1.1.0"
2322
clap = { version = "4.5", features = ["derive"] }
24-
tracing = "0.1.40"
25-
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
23+
tracing = "0.1"
24+
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
2625
axum = "0.6.0"
2726
dotenv = "0.15"
2827
rocksdb = "0.23.0"

src/coordinator_handler/api.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use reqwest::{header::CONTENT_TYPE, Url};
88
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
99
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
1010
use serde::Serialize;
11+
use tracing::Level;
1112

1213
pub struct Api {
1314
pub base_url: Url,
@@ -37,6 +38,7 @@ impl Api {
3738
self.base_url.join(method).map_err(|e| anyhow::anyhow!(e))
3839
}
3940

41+
#[instrument(target = "coordinator_client", skip(self, req, token), level = Level::DEBUG)]
4042
async fn post_with_token<Req, Resp>(
4143
&self,
4244
method: &str,
@@ -51,8 +53,8 @@ impl Api {
5153
let request_body = serde_json::to_string(req)?;
5254
let size = request_body.len();
5355

54-
log::info!("[coordinator client], {method}, sent request");
55-
log::debug!("[coordinator client], {method}, request: {request_body}, token: {token}, request size: {size}");
56+
info!("sent request");
57+
trace!(token = %token, request_body = %request_body, size = %size);
5658
let response = self
5759
.client
5860
.post(url)
@@ -72,8 +74,8 @@ impl Api {
7274

7375
let response_body = response.text().await?;
7476

75-
log::info!("[coordinator client], {method}, received response");
76-
log::debug!("[coordinator client], {method}, response: {response_body}");
77+
info!("received response");
78+
trace!(response_body = %response_body);
7779
serde_json::from_str(&response_body).map_err(|e| anyhow::anyhow!(e))
7880
}
7981

@@ -109,7 +111,7 @@ impl Api {
109111
token: &String,
110112
) -> anyhow::Result<Response<GetTaskResponseData>> {
111113
let method = "/coordinator/v1/get_task";
112-
if self.send_timeout < core::time::Duration::from_secs(600) {
114+
if self.send_timeout < Duration::from_secs(600) {
113115
tracing::warn!(
114116
"get_task API is time-consuming, timeout setting is too low ({}), set it to more than 600s",
115117
self.send_timeout.as_secs(),

src/coordinator_handler/error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl ErrorCode {
3636
50000 => ErrorCode::ErrJWTCommonErr,
3737
50001 => ErrorCode::ErrJWTTokenExpired,
3838
_ => {
39-
log::error!("get unexpected error code from coordinator: {v}");
39+
error!("get unexpected error code from coordinator: {v}");
4040
ErrorCode::Undefined(v)
4141
}
4242
}

src/db.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::coordinator_handler::GetTaskResponseData;
22
use rocksdb::DB;
3+
use tracing::Level;
34

45
pub struct Db {
56
db: DB,
@@ -11,21 +12,21 @@ impl Db {
1112
Ok(Self { db })
1213
}
1314

15+
#[instrument(skip(self), level = Level::DEBUG)]
1416
pub fn get_task(&self, public_key: String) -> (Option<GetTaskResponseData>, Option<String>) {
15-
log::debug!("[db], get task, public_key: {public_key}");
1617
(
1718
self.get_coordinator_task_by_public_key(public_key.clone()),
1819
self.get_proving_task_id_by_public_key(public_key),
1920
)
2021
}
2122

23+
#[instrument(skip_all, fields(public_key = %public_key), level = Level::DEBUG)]
2224
pub fn set_task(
2325
&self,
2426
public_key: String,
2527
coordinator_task: &GetTaskResponseData,
2628
proving_task_id: String,
2729
) {
28-
log::debug!("[db], set task, public_key: {public_key}");
2930
self.set_coordinator_task_by_public_key(public_key.clone(), coordinator_task);
3031
self.set_proving_task_id_by_public_key(public_key, proving_task_id);
3132
}

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#[macro_use]
2+
extern crate tracing;
3+
14
pub mod config;
25
pub mod coordinator_handler;
36
pub mod db;

src/prover/mod.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ use axum::{routing::get, Router};
1212
use proving_service::{ProveRequest, QueryTaskRequest, TaskStatus};
1313
use std::net::SocketAddr;
1414
use std::str::FromStr;
15-
use std::thread;
1615
use tokio::time::{sleep, Duration};
1716
use tokio::{sync::RwLock, task::JoinSet};
17+
use tracing::Level;
1818
use tracing::{error, info, instrument};
1919

2020
pub use {builder::ProverBuilder, proving_service::ProvingService, types::*};
@@ -36,7 +36,7 @@ where
3636
Backend: ProvingService + Send + Sync + 'static,
3737
{
3838
pub async fn run(self) {
39-
assert!(self.n_workers == self.coordinator_clients.len());
39+
assert_eq!(self.n_workers, self.coordinator_clients.len());
4040

4141
self.test_coordinator_connection().await;
4242

@@ -53,7 +53,7 @@ where
5353
provers.spawn(async move {
5454
self_clone.working_loop(i).await;
5555
});
56-
thread::sleep(Duration::from_secs(3)); // Sleep for 3 seconds to avoid overwhelming the l2geth/coordinator with requests.
56+
tokio::time::sleep(Duration::from_secs(3)).await; // Sleep for 3 seconds to avoid overwhelming the l2geth/coordinator with requests.
5757
}
5858

5959
tokio::select! {
@@ -67,7 +67,7 @@ where
6767
tasks: &[String],
6868
task_type: ProofType,
6969
) -> bool {
70-
assert!(self.n_workers == self.coordinator_clients.len());
70+
assert_eq!(self.n_workers, self.coordinator_clients.len());
7171

7272
self.test_coordinator_connection().await;
7373

@@ -102,17 +102,16 @@ where
102102
}
103103
i
104104
});
105-
thread::sleep(Duration::from_secs(3)); // Sleep for 3 seconds to avoid overwhelming the l2geth/coordinator with requests.
105+
tokio::time::sleep(Duration::from_secs(3)).await; // Sleep for 3 seconds to avoid overwhelming the l2geth/coordinator with requests.
106106
}
107107

108108
// wait until all tasks has been done
109109
while let Some(r) = provers.join_next().await {
110-
if r.is_err() {
110+
let Ok(r) = r else {
111111
// quit since one task has failed
112112
return false;
113-
} else {
114-
log::info!("worker {} has completed", r.unwrap());
115-
}
113+
};
114+
info!("worker {r} has completed");
116115
}
117116
true
118117
}
@@ -124,7 +123,7 @@ where
124123
.expect("Failed to login to coordinator");
125124
}
126125

127-
#[instrument(skip(self))]
126+
#[instrument(skip(self), level = Level::DEBUG)]
128127
async fn working_loop(&self, i: usize) {
129128
loop {
130129
let coordinator_client = &self.coordinator_clients[i];
@@ -152,7 +151,7 @@ where
152151
.unwrap_or_default()
153152
{
154153
let task_id = coordinator_task.clone().task_id;
155-
log::debug!("got previous task from db, task_id: {task_id}");
154+
debug!(task_id = %task_id, "got previous task from db");
156155
if self.proving_service.read().await.is_local() {
157156
let proving_task = self
158157
.request_proving(coordinator_client, &coordinator_task)

src/prover/proving_service.rs

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use super::ProofType;
22
use async_trait::async_trait;
3+
use std::fmt;
34

45
#[async_trait]
56
pub trait ProvingService {
@@ -9,27 +10,38 @@ pub trait ProvingService {
910
async fn query_task(&mut self, req: QueryTaskRequest) -> QueryTaskResponse;
1011
}
1112

12-
#[derive(Default)]
13+
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
1314
pub struct GetVkRequest {
1415
pub proof_types: Vec<ProofType>,
1516
pub circuit_version: String,
1617
}
1718

18-
#[derive(Default)]
19+
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
1920
pub struct GetVkResponse {
2021
pub vks: Vec<String>,
2122
pub error: Option<String>,
2223
}
2324

24-
#[derive(Default, Clone)]
25+
#[derive(Default, Clone, PartialEq, Eq, Hash)]
2526
pub struct ProveRequest {
2627
pub proof_type: ProofType,
2728
pub circuit_version: String,
2829
pub hard_fork_name: String,
2930
pub input: String,
3031
}
3132

32-
#[derive(Default)]
33+
impl fmt::Debug for ProveRequest {
34+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35+
f.debug_struct("ProveRequest")
36+
.field("proof_type", &self.proof_type)
37+
.field("circuit_version", &self.circuit_version)
38+
.field("hard_fork_name", &self.hard_fork_name)
39+
.field("input", &"...")
40+
.finish()
41+
}
42+
}
43+
44+
#[derive(Default, Clone, PartialEq)]
3345
pub struct ProveResponse {
3446
pub task_id: String,
3547
pub proof_type: ProofType,
@@ -46,12 +58,46 @@ pub struct ProveResponse {
4658
pub error: Option<String>,
4759
}
4860

49-
#[derive(Default)]
61+
impl fmt::Debug for ProveResponse {
62+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63+
let mut fmt = f.debug_struct("ProveResponse");
64+
fmt.field("task_id", &self.task_id)
65+
.field("proof_type", &self.proof_type)
66+
.field("circuit_version", &self.circuit_version)
67+
.field("hard_fork_name", &self.hard_fork_name)
68+
.field("status", &self.status)
69+
.field("created_at", &self.created_at);
70+
if let Some(started_at) = &self.started_at {
71+
fmt.field("started_at", started_at);
72+
}
73+
if let Some(finished_at) = &self.finished_at {
74+
fmt.field("finished_at", finished_at);
75+
}
76+
if let Some(compute_time_sec) = &self.compute_time_sec {
77+
fmt.field("compute_time_sec", compute_time_sec);
78+
}
79+
if self.input.is_some() {
80+
fmt.field("input", &"..."); // Hide actual input for brevity
81+
}
82+
if self.proof.is_some() {
83+
fmt.field("proof", &"..."); // Hide actual proof for brevity
84+
}
85+
if let Some(vk) = &self.vk {
86+
fmt.field("vk", vk);
87+
}
88+
if let Some(error) = &self.error {
89+
fmt.field("error", error);
90+
}
91+
fmt.finish()
92+
}
93+
}
94+
95+
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
5096
pub struct QueryTaskRequest {
5197
pub task_id: String,
5298
}
5399

54-
#[derive(Default)]
100+
#[derive(Default, Clone, PartialEq)]
55101
pub struct QueryTaskResponse {
56102
pub task_id: String,
57103
pub proof_type: ProofType,
@@ -68,7 +114,41 @@ pub struct QueryTaskResponse {
68114
pub error: Option<String>,
69115
}
70116

71-
#[derive(Debug, PartialEq, Default)]
117+
impl fmt::Debug for QueryTaskResponse {
118+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119+
let mut fmt = f.debug_struct("QueryTaskResponse");
120+
fmt.field("task_id", &self.task_id)
121+
.field("proof_type", &self.proof_type)
122+
.field("circuit_version", &self.circuit_version)
123+
.field("hard_fork_name", &self.hard_fork_name)
124+
.field("status", &self.status)
125+
.field("created_at", &self.created_at);
126+
if let Some(started_at) = &self.started_at {
127+
fmt.field("started_at", started_at);
128+
}
129+
if let Some(finished_at) = &self.finished_at {
130+
fmt.field("finished_at", finished_at);
131+
}
132+
if let Some(compute_time_sec) = &self.compute_time_sec {
133+
fmt.field("compute_time_sec", compute_time_sec);
134+
}
135+
if self.input.is_some() {
136+
fmt.field("input", &"..."); // Hide actual input for brevity
137+
}
138+
if self.proof.is_some() {
139+
fmt.field("proof", &"..."); // Hide actual proof for brevity
140+
}
141+
if let Some(vk) = &self.vk {
142+
fmt.field("vk", vk);
143+
}
144+
if let Some(error) = &self.error {
145+
fmt.field("error", error);
146+
}
147+
fmt.finish()
148+
}
149+
}
150+
151+
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
72152
pub enum TaskStatus {
73153
#[default]
74154
Queued,

src/utils.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use tracing_subscriber::filter::{EnvFilter, LevelFilter};
2+
use tracing_subscriber::fmt::format::FmtSpan;
23

34
static DEFAULT_COMMIT: &str = "unknown";
45
static VERSION: std::sync::OnceLock<String> = std::sync::OnceLock::new();
@@ -27,6 +28,7 @@ pub fn init_tracing() {
2728
.with_ansi(false)
2829
.with_level(true)
2930
.with_target(true)
31+
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
3032
.try_init()
3133
.expect("Failed to initialize tracing subscriber");
3234
}

0 commit comments

Comments
 (0)