diff --git a/src/config.rs b/src/config.rs index 6650e6a..8ac98c4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -30,6 +30,15 @@ pub struct ProverConfig { pub circuit_version: String, #[serde(default = "default_n_workers")] pub n_workers: usize, + /// Interval between polling the coordinator for new tasks. + #[serde(default = "default_poll_interval_sec")] + pub poll_interval_sec: u64, + /// Delay the timer by a randomly selected, evenly distributed amount of time between 0 and the + /// specified time value. Defaults to 0, indicating that no randomized delay shall be applied. + #[serde(default)] + pub randomized_delay_sec: u64, + #[serde(default)] + pub suppress_empty_task_error: bool, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct DbConfig {} @@ -38,10 +47,14 @@ fn default_health_listener_addr() -> String { "0.0.0.0:80".to_string() } -fn default_n_workers() -> usize { +const fn default_n_workers() -> usize { 1 } +const fn default_poll_interval_sec() -> u64 { + 20 +} + impl Config { pub fn from_reader(reader: R) -> Result where @@ -110,6 +123,17 @@ impl Config { self.db_path = Option::from(val); } + if let Some(val) = Self::get_env_var("POLL_INTERVAL_SEC")? { + self.prover.poll_interval_sec = val.parse()?; + } + if let Some(val) = Self::get_env_var("RANDOMIZED_DELAY_SEC")? { + self.prover.randomized_delay_sec = val.parse()?; + } + + if Self::get_env_var("SUPPRESS_EMPTY_TASK_ERR")?.is_some() { + self.prover.suppress_empty_task_error = true; + } + Ok(()) } diff --git a/src/coordinator_handler/api.rs b/src/coordinator_handler/api.rs index 4483f9e..5ca2335 100644 --- a/src/coordinator_handler/api.rs +++ b/src/coordinator_handler/api.rs @@ -38,7 +38,7 @@ impl Api { self.base_url.join(method).map_err(|e| eyre::eyre!(e)) } - #[instrument(target = "coordinator_client", skip(self, req, token), level = Level::DEBUG)] + #[instrument(skip(self, req, token), level = Level::DEBUG)] async fn post_with_token( &self, method: &str, @@ -53,7 +53,7 @@ impl Api { let request_body = serde_json::to_string(req)?; let size = request_body.len(); - info!("sent request"); + debug!("sent request"); trace!(token = %token, request_body = %request_body, size = %size); let response = self .client @@ -74,7 +74,7 @@ impl Api { let response_body = response.text().await?; - info!("received response"); + debug!("received response"); trace!(response_body = %response_body); serde_json::from_str(&response_body).map_err(|e| eyre::eyre!(e)) } diff --git a/src/prover/builder.rs b/src/prover/builder.rs index 9a1ecc4..1aa0048 100644 --- a/src/prover/builder.rs +++ b/src/prover/builder.rs @@ -97,6 +97,9 @@ where .as_ref() .map(|path| Db::new(path.as_str())) .transpose()?, + poll_interval_sec: self.cfg.prover.poll_interval_sec, + randomized_delay_sec: self.cfg.prover.randomized_delay_sec, + suppress_empty_task_error: self.cfg.prover.suppress_empty_task_error, }) } } diff --git a/src/prover/mod.rs b/src/prover/mod.rs index 8e23eef..5c35a6f 100644 --- a/src/prover/mod.rs +++ b/src/prover/mod.rs @@ -2,7 +2,6 @@ pub mod builder; pub mod proving_service; pub mod types; -use std::future::IntoFuture; use crate::{ coordinator_handler::{ CoordinatorClient, ErrorCode, GetTaskRequest, GetTaskResponseData, ProofFailureType, @@ -12,18 +11,18 @@ use crate::{ }; use axum::{routing::get, Router}; use proving_service::{ProveRequest, QueryTaskRequest, TaskStatus}; +use std::future::IntoFuture; use std::net::SocketAddr; use std::str::FromStr; +use rand::Rng; +use tokio::net::TcpListener; use tokio::time::{sleep, Duration}; use tokio::{sync::RwLock, task::JoinSet}; -use tokio::net::TcpListener; use tracing::Level; use tracing::{error, info, instrument}; pub use {builder::ProverBuilder, proving_service::ProvingService, types::*}; -const WORKER_SLEEP_SEC: u64 = 20; - pub struct Prover { proof_types: Vec, circuit_version: String, @@ -32,6 +31,9 @@ pub struct Prover { n_workers: usize, health_listener_addr: String, db: Option, + poll_interval_sec: u64, + randomized_delay_sec: u64, + suppress_empty_task_error: bool, } impl Prover @@ -59,7 +61,6 @@ where provers.spawn(async move { self_clone.working_loop(i).await; }); - tokio::time::sleep(Duration::from_secs(3)).await; // Sleep for 3 seconds to avoid overwhelming the l2geth/coordinator with requests. } tokio::select! { @@ -94,6 +95,9 @@ where let task_str = task.to_string(); let i = work_set.pop().expect("can not be empty"); provers.spawn(async move { + // Soft start delay to stagger the provers + sleep(self_clone.poll_delay()).await; + let coordinator_client = &self_clone.coordinator_clients[i]; let prover_name = &coordinator_client.prover_name; @@ -108,7 +112,6 @@ where } i }); - tokio::time::sleep(Duration::from_secs(3)).await; // Sleep for 3 seconds to avoid overwhelming the l2geth/coordinator with requests. } // wait until all tasks has been done @@ -131,20 +134,18 @@ where #[instrument(skip(self), level = Level::DEBUG)] async fn working_loop(&self, i: usize) { + // Soft start delay to stagger the provers + sleep(self.poll_delay()).await; loop { let coordinator_client = &self.coordinator_clients[i]; - let prover_name = &coordinator_client.prover_name; - - info!(?prover_name, "Getting task from coordinator"); - if let Err(e) = self.handle_task(coordinator_client, None).await { - error!(prover_name, error = e.to_string(), "Error handling task"); + error!(prover_name = %coordinator_client.prover_name, error = e.to_string(), "Error handling task"); } - - sleep(Duration::from_secs(WORKER_SLEEP_SEC)).await; + sleep(self.poll_delay()).await; } } + #[instrument(skip(self, coordinator_client), level = Level::DEBUG)] async fn handle_task( &self, coordinator_client: &CoordinatorClient, @@ -174,9 +175,13 @@ where get_task_request.task_types = vec![t]; get_task_request.task_id.replace(s.to_string()); } - let coordinator_task = self + let Some(coordinator_task) = self .get_coordinator_task(coordinator_client, &get_task_request) - .await?; + .await? + else { + return Ok(()); + }; + info!(prover_name = %coordinator_client.prover_name, "Got task from coordinator"); let proving_task = self .request_proving(coordinator_client, &coordinator_task) .await?; @@ -188,9 +193,15 @@ where &self, coordinator_client: &CoordinatorClient, request: &GetTaskRequest, - ) -> eyre::Result { + ) -> eyre::Result> { let coordinator_task = coordinator_client.get_task(request).await?; + if coordinator_task.errcode == ErrorCode::ErrCoordinatorEmptyProofData + && self.suppress_empty_task_error + { + return Ok(None); + } + if coordinator_task.errcode != ErrorCode::Success { eyre::bail!( "Failed to get task, errcode: {:?}, errmsg: {:?}", @@ -199,9 +210,7 @@ where ); } - coordinator_task - .data - .ok_or_else(|| eyre::eyre!("No task available")) + Ok(coordinator_task.data) } async fn request_proving( @@ -269,6 +278,9 @@ where let coordinator_task_uuid = &coordinator_task.uuid; let coordinator_task_id = &coordinator_task.task_id; + // Track last observed status to avoid spamming logs when status hasn't changed. + let mut last_status: Option = None; + loop { let task = self .proving_service @@ -279,17 +291,22 @@ where }) .await; - match task.status { + let current_status = task.status; + + match current_status { TaskStatus::Queued | TaskStatus::Proving => { - info!( - ?prover_name, - ?task_type, - ?coordinator_task_uuid, - ?coordinator_task_id, - ?proving_service_task_id, - status = ?task.status, - "Task status update" - ); + if last_status != Some(current_status) { + info!( + ?prover_name, + ?task_type, + ?coordinator_task_uuid, + ?coordinator_task_id, + ?proving_service_task_id, + status = ?current_status, + "Task status update" + ); + } + last_status.replace(current_status); if let Some(db) = &self.db { db.set_task( public_key.clone(), @@ -297,7 +314,7 @@ where proving_service_task_id.clone(), ); } - sleep(Duration::from_secs(WORKER_SLEEP_SEC)).await; + sleep(self.poll_delay()).await; } TaskStatus::Success => { info!( @@ -436,6 +453,16 @@ where input: task.task_data.clone(), }) } + + fn poll_delay(&self) -> Duration { + let base_delay = Duration::from_secs(self.poll_interval_sec); + if self.randomized_delay_sec == 0 { + return base_delay; + } + let mut rng = rand::rng(); + let random_delay = rng.random_range(0..self.randomized_delay_sec * 1000); + base_delay + Duration::from_millis(random_delay) + } } #[cfg(test)]