Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ edition = "2021"

[dependencies]
anyhow = "1.0"
log = "0.4"
serde = { version = "1", features = ["derive"] }
serde_json = "1.0"
ethers-core = { git = "https://github.com/scroll-tech/ethers-rs.git", branch = "v2.0.7" }
Expand All @@ -21,8 +20,8 @@ tokio = { version = "1.37.0", features = ["full"] }
async-trait = "0.1"
http = "1.1.0"
clap = { version = "4.5", features = ["derive"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
axum = "0.6.0"
dotenv = "0.15"
rocksdb = "0.23.0"
Expand Down
12 changes: 7 additions & 5 deletions src/coordinator_handler/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use reqwest::{header::CONTENT_TYPE, Url};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use serde::Serialize;
use tracing::Level;

pub struct Api {
pub base_url: Url,
Expand Down Expand Up @@ -37,6 +38,7 @@ impl Api {
self.base_url.join(method).map_err(|e| anyhow::anyhow!(e))
}

#[instrument(target = "coordinator_client", skip(self, req, token), level = Level::DEBUG)]
async fn post_with_token<Req, Resp>(
&self,
method: &str,
Expand All @@ -51,8 +53,8 @@ impl Api {
let request_body = serde_json::to_string(req)?;
let size = request_body.len();

log::info!("[coordinator client], {method}, sent request");
log::debug!("[coordinator client], {method}, request: {request_body}, token: {token}, request size: {size}");
info!("sent request");
trace!(token = %token, request_body = %request_body, size = %size);
let response = self
.client
.post(url)
Expand All @@ -72,8 +74,8 @@ impl Api {

let response_body = response.text().await?;

log::info!("[coordinator client], {method}, received response");
log::debug!("[coordinator client], {method}, response: {response_body}");
info!("received response");
trace!(response_body = %response_body);
serde_json::from_str(&response_body).map_err(|e| anyhow::anyhow!(e))
}

Expand Down Expand Up @@ -109,7 +111,7 @@ impl Api {
token: &String,
) -> anyhow::Result<Response<GetTaskResponseData>> {
let method = "/coordinator/v1/get_task";
if self.send_timeout < core::time::Duration::from_secs(600) {
if self.send_timeout < Duration::from_secs(600) {
tracing::warn!(
"get_task API is time-consuming, timeout setting is too low ({}), set it to more than 600s",
self.send_timeout.as_secs(),
Expand Down
2 changes: 1 addition & 1 deletion src/coordinator_handler/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl ErrorCode {
50000 => ErrorCode::ErrJWTCommonErr,
50001 => ErrorCode::ErrJWTTokenExpired,
_ => {
log::error!("get unexpected error code from coordinator: {v}");
error!("get unexpected error code from coordinator: {v}");
ErrorCode::Undefined(v)
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/db.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::coordinator_handler::GetTaskResponseData;
use rocksdb::DB;
use tracing::Level;

pub struct Db {
db: DB,
Expand All @@ -11,21 +12,21 @@ impl Db {
Ok(Self { db })
}

#[instrument(skip(self), level = Level::DEBUG)]
pub fn get_task(&self, public_key: String) -> (Option<GetTaskResponseData>, Option<String>) {
log::debug!("[db], get task, public_key: {public_key}");
(
self.get_coordinator_task_by_public_key(public_key.clone()),
self.get_proving_task_id_by_public_key(public_key),
)
}

#[instrument(skip_all, fields(public_key = %public_key), level = Level::DEBUG)]
pub fn set_task(
&self,
public_key: String,
coordinator_task: &GetTaskResponseData,
proving_task_id: String,
) {
log::debug!("[db], set task, public_key: {public_key}");
self.set_coordinator_task_by_public_key(public_key.clone(), coordinator_task);
self.set_proving_task_id_by_public_key(public_key, proving_task_id);
}
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#[macro_use]
extern crate tracing;

pub mod config;
pub mod coordinator_handler;
pub mod db;
Expand Down
26 changes: 16 additions & 10 deletions src/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ use axum::{routing::get, Router};
use proving_service::{ProveRequest, QueryTaskRequest, TaskStatus};
use std::net::SocketAddr;
use std::str::FromStr;
use std::thread;
use tokio::time::{sleep, Duration};
use tokio::{sync::RwLock, task::JoinSet};
use tracing::Level;
use tracing::{error, info, instrument};

pub use {builder::ProverBuilder, proving_service::ProvingService, types::*};
Expand All @@ -36,7 +36,7 @@ where
Backend: ProvingService + Send + Sync + 'static,
{
pub async fn run(self) {
assert!(self.n_workers == self.coordinator_clients.len());
assert_eq!(self.n_workers, self.coordinator_clients.len());

self.test_coordinator_connection().await;

Expand All @@ -53,7 +53,7 @@ where
provers.spawn(async move {
self_clone.working_loop(i).await;
});
thread::sleep(Duration::from_secs(3)); // Sleep for 3 seconds to avoid overwhelming the l2geth/coordinator with requests.
tokio::time::sleep(Duration::from_secs(3)).await; // Sleep for 3 seconds to avoid overwhelming the l2geth/coordinator with requests.
}

tokio::select! {
Expand All @@ -67,7 +67,7 @@ where
tasks: &[String],
task_type: ProofType,
) -> bool {
assert!(self.n_workers == self.coordinator_clients.len());
assert_eq!(self.n_workers, self.coordinator_clients.len());

self.test_coordinator_connection().await;

Expand Down Expand Up @@ -97,16 +97,22 @@ where
.handle_task(coordinator_client, Some((task_type, task_str.as_str())))
.await
{
error!(?prover_name, ?e, "Error handling task");
error!(prover_name, error = e.to_string(), "Error handling task");
panic!("task fail");
}
i
});
thread::sleep(Duration::from_secs(3)); // Sleep for 3 seconds to avoid overwhelming the l2geth/coordinator with requests.
tokio::time::sleep(Duration::from_secs(3)).await; // Sleep for 3 seconds to avoid overwhelming the l2geth/coordinator with requests.
}

// wait until all tasks has been done
while provers.join_next().await.is_some() {}
while let Some(r) = provers.join_next().await {
let Ok(r) = r else {
// quit since one task has failed
return false;
};
info!("worker {r} has completed");
}
true
}

Expand All @@ -117,7 +123,7 @@ where
.expect("Failed to login to coordinator");
}

#[instrument(skip(self))]
#[instrument(skip(self), level = Level::DEBUG)]
async fn working_loop(&self, i: usize) {
loop {
let coordinator_client = &self.coordinator_clients[i];
Expand All @@ -126,7 +132,7 @@ where
info!(?prover_name, "Getting task from coordinator");

if let Err(e) = self.handle_task(coordinator_client, None).await {
error!(?prover_name, ?e, "Error handling task");
error!(prover_name, error = e.to_string(), "Error handling task");
}

sleep(Duration::from_secs(WORKER_SLEEP_SEC)).await;
Expand All @@ -145,7 +151,7 @@ where
.unwrap_or_default()
{
let task_id = coordinator_task.clone().task_id;
log::debug!("got previous task from db, task_id: {task_id}");
debug!(task_id = %task_id, "got previous task from db");
if self.proving_service.read().await.is_local() {
let proving_task = self
.request_proving(coordinator_client, &coordinator_task)
Expand Down
94 changes: 87 additions & 7 deletions src/prover/proving_service.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::ProofType;
use async_trait::async_trait;
use std::fmt;

#[async_trait]
pub trait ProvingService {
Expand All @@ -9,27 +10,38 @@ pub trait ProvingService {
async fn query_task(&mut self, req: QueryTaskRequest) -> QueryTaskResponse;
}

#[derive(Default)]
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
pub struct GetVkRequest {
pub proof_types: Vec<ProofType>,
pub circuit_version: String,
}

#[derive(Default)]
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
pub struct GetVkResponse {
pub vks: Vec<String>,
pub error: Option<String>,
}

#[derive(Default, Clone)]
#[derive(Default, Clone, PartialEq, Eq, Hash)]
pub struct ProveRequest {
pub proof_type: ProofType,
pub circuit_version: String,
pub hard_fork_name: String,
pub input: String,
}

#[derive(Default)]
impl fmt::Debug for ProveRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProveRequest")
.field("proof_type", &self.proof_type)
.field("circuit_version", &self.circuit_version)
.field("hard_fork_name", &self.hard_fork_name)
.field("input", &"...")
.finish()
}
}

#[derive(Default, Clone, PartialEq)]
pub struct ProveResponse {
pub task_id: String,
pub proof_type: ProofType,
Expand All @@ -46,12 +58,46 @@ pub struct ProveResponse {
pub error: Option<String>,
}

#[derive(Default)]
impl fmt::Debug for ProveResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut fmt = f.debug_struct("ProveResponse");
fmt.field("task_id", &self.task_id)
.field("proof_type", &self.proof_type)
.field("circuit_version", &self.circuit_version)
.field("hard_fork_name", &self.hard_fork_name)
.field("status", &self.status)
.field("created_at", &self.created_at);
if let Some(started_at) = &self.started_at {
fmt.field("started_at", started_at);
}
if let Some(finished_at) = &self.finished_at {
fmt.field("finished_at", finished_at);
}
if let Some(compute_time_sec) = &self.compute_time_sec {
fmt.field("compute_time_sec", compute_time_sec);
}
if self.input.is_some() {
fmt.field("input", &"..."); // Hide actual input for brevity
}
if self.proof.is_some() {
fmt.field("proof", &"..."); // Hide actual proof for brevity
}
if let Some(vk) = &self.vk {
fmt.field("vk", vk);
}
if let Some(error) = &self.error {
fmt.field("error", error);
}
fmt.finish()
}
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
pub struct QueryTaskRequest {
pub task_id: String,
}

#[derive(Default)]
#[derive(Default, Clone, PartialEq)]
pub struct QueryTaskResponse {
pub task_id: String,
pub proof_type: ProofType,
Expand All @@ -68,7 +114,41 @@ pub struct QueryTaskResponse {
pub error: Option<String>,
}

#[derive(Debug, PartialEq, Default)]
impl fmt::Debug for QueryTaskResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut fmt = f.debug_struct("QueryTaskResponse");
fmt.field("task_id", &self.task_id)
.field("proof_type", &self.proof_type)
.field("circuit_version", &self.circuit_version)
.field("hard_fork_name", &self.hard_fork_name)
.field("status", &self.status)
.field("created_at", &self.created_at);
if let Some(started_at) = &self.started_at {
fmt.field("started_at", started_at);
}
if let Some(finished_at) = &self.finished_at {
fmt.field("finished_at", finished_at);
}
if let Some(compute_time_sec) = &self.compute_time_sec {
fmt.field("compute_time_sec", compute_time_sec);
}
if self.input.is_some() {
fmt.field("input", &"..."); // Hide actual input for brevity
}
if self.proof.is_some() {
fmt.field("proof", &"..."); // Hide actual proof for brevity
}
if let Some(vk) = &self.vk {
fmt.field("vk", vk);
}
if let Some(error) = &self.error {
fmt.field("error", error);
}
fmt.finish()
}
}

#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum TaskStatus {
#[default]
Queued,
Expand Down
2 changes: 2 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use tracing_subscriber::filter::{EnvFilter, LevelFilter};
use tracing_subscriber::fmt::format::FmtSpan;

static DEFAULT_COMMIT: &str = "unknown";
static VERSION: std::sync::OnceLock<String> = std::sync::OnceLock::new();
Expand Down Expand Up @@ -27,6 +28,7 @@ pub fn init_tracing() {
.with_ansi(false)
.with_level(true)
.with_target(true)
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
.try_init()
.expect("Failed to initialize tracing subscriber");
}
Expand Down