Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ pub struct ProverConfig {
pub circuit_version: String,
#[serde(default = "default_n_workers")]
pub n_workers: usize,
/// Interval between polling the coordinator for new tasks.
#[serde(default = "default_poll_interval_sec")]
pub poll_interval_sec: u64,
/// Delay the timer by a randomly selected, evenly distributed amount of time between 0 and the
/// specified time value. Defaults to 0, indicating that no randomized delay shall be applied.
#[serde(default)]
pub randomized_delay_sec: u64,
#[serde(default)]
pub suppress_empty_task_error: bool,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct DbConfig {}
Expand All @@ -38,10 +47,14 @@ fn default_health_listener_addr() -> String {
"0.0.0.0:80".to_string()
}

fn default_n_workers() -> usize {
const fn default_n_workers() -> usize {
1
}

const fn default_poll_interval_sec() -> u64 {
20
}

impl Config {
pub fn from_reader<R>(reader: R) -> Result<Self>
where
Expand Down Expand Up @@ -110,6 +123,17 @@ impl Config {
self.db_path = Option::from(val);
}

if let Some(val) = Self::get_env_var("POLL_INTERVAL_SEC")? {
self.prover.poll_interval_sec = val.parse()?;
}
if let Some(val) = Self::get_env_var("RANDOMIZED_DELAY_SEC")? {
self.prover.randomized_delay_sec = val.parse()?;
}

if Self::get_env_var("SUPPRESS_EMPTY_TASK_ERR")?.is_some() {
self.prover.suppress_empty_task_error = true;
}

Ok(())
}

Expand Down
6 changes: 3 additions & 3 deletions src/coordinator_handler/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Api {
self.base_url.join(method).map_err(|e| eyre::eyre!(e))
}

#[instrument(target = "coordinator_client", skip(self, req, token), level = Level::DEBUG)]
#[instrument(skip(self, req, token), level = Level::DEBUG)]
async fn post_with_token<Req, Resp>(
&self,
method: &str,
Expand All @@ -53,7 +53,7 @@ impl Api {
let request_body = serde_json::to_string(req)?;
let size = request_body.len();

info!("sent request");
debug!("sent request");
trace!(token = %token, request_body = %request_body, size = %size);
let response = self
.client
Expand All @@ -74,7 +74,7 @@ impl Api {

let response_body = response.text().await?;

info!("received response");
debug!("received response");
trace!(response_body = %response_body);
serde_json::from_str(&response_body).map_err(|e| eyre::eyre!(e))
}
Expand Down
3 changes: 3 additions & 0 deletions src/prover/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ where
.as_ref()
.map(|path| Db::new(path.as_str()))
.transpose()?,
poll_interval_sec: self.cfg.prover.poll_interval_sec,
randomized_delay_sec: self.cfg.prover.randomized_delay_sec,
suppress_empty_task_error: self.cfg.prover.suppress_empty_task_error,
})
}
}
87 changes: 57 additions & 30 deletions src/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ pub mod builder;
pub mod proving_service;
pub mod types;

use std::future::IntoFuture;
use crate::{
coordinator_handler::{
CoordinatorClient, ErrorCode, GetTaskRequest, GetTaskResponseData, ProofFailureType,
Expand All @@ -12,18 +11,18 @@ use crate::{
};
use axum::{routing::get, Router};
use proving_service::{ProveRequest, QueryTaskRequest, TaskStatus};
use std::future::IntoFuture;
use std::net::SocketAddr;
use std::str::FromStr;
use rand::Rng;
use tokio::net::TcpListener;
use tokio::time::{sleep, Duration};
use tokio::{sync::RwLock, task::JoinSet};
use tokio::net::TcpListener;
use tracing::Level;
use tracing::{error, info, instrument};

pub use {builder::ProverBuilder, proving_service::ProvingService, types::*};

const WORKER_SLEEP_SEC: u64 = 20;

pub struct Prover<Backend: ProvingService + Send + Sync + 'static> {
proof_types: Vec<ProofType>,
circuit_version: String,
Expand All @@ -32,6 +31,9 @@ pub struct Prover<Backend: ProvingService + Send + Sync + 'static> {
n_workers: usize,
health_listener_addr: String,
db: Option<Db>,
poll_interval_sec: u64,
randomized_delay_sec: u64,
suppress_empty_task_error: bool,
}

impl<Backend> Prover<Backend>
Expand Down Expand Up @@ -59,7 +61,6 @@ where
provers.spawn(async move {
self_clone.working_loop(i).await;
});
tokio::time::sleep(Duration::from_secs(3)).await; // Sleep for 3 seconds to avoid overwhelming the l2geth/coordinator with requests.
}

tokio::select! {
Expand Down Expand Up @@ -94,6 +95,9 @@ where
let task_str = task.to_string();
let i = work_set.pop().expect("can not be empty");
provers.spawn(async move {
// Soft start delay to stagger the provers
sleep(self_clone.poll_delay()).await;

let coordinator_client = &self_clone.coordinator_clients[i];
let prover_name = &coordinator_client.prover_name;

Expand All @@ -108,7 +112,6 @@ where
}
i
});
tokio::time::sleep(Duration::from_secs(3)).await; // Sleep for 3 seconds to avoid overwhelming the l2geth/coordinator with requests.
}

// wait until all tasks has been done
Expand All @@ -131,20 +134,18 @@ where

#[instrument(skip(self), level = Level::DEBUG)]
async fn working_loop(&self, i: usize) {
// Soft start delay to stagger the provers
sleep(self.poll_delay()).await;
loop {
let coordinator_client = &self.coordinator_clients[i];
let prover_name = &coordinator_client.prover_name;

info!(?prover_name, "Getting task from coordinator");

if let Err(e) = self.handle_task(coordinator_client, None).await {
error!(prover_name, error = e.to_string(), "Error handling task");
error!(prover_name = %coordinator_client.prover_name, error = e.to_string(), "Error handling task");
}

sleep(Duration::from_secs(WORKER_SLEEP_SEC)).await;
sleep(self.poll_delay()).await;
}
}

#[instrument(skip(self, coordinator_client), level = Level::DEBUG)]
async fn handle_task(
&self,
coordinator_client: &CoordinatorClient,
Expand Down Expand Up @@ -174,9 +175,13 @@ where
get_task_request.task_types = vec![t];
get_task_request.task_id.replace(s.to_string());
}
let coordinator_task = self
let Some(coordinator_task) = self
.get_coordinator_task(coordinator_client, &get_task_request)
.await?;
.await?
else {
return Ok(());
};
info!(prover_name = %coordinator_client.prover_name, "Got task from coordinator");
let proving_task = self
.request_proving(coordinator_client, &coordinator_task)
.await?;
Expand All @@ -188,9 +193,15 @@ where
&self,
coordinator_client: &CoordinatorClient,
request: &GetTaskRequest,
) -> eyre::Result<GetTaskResponseData> {
) -> eyre::Result<Option<GetTaskResponseData>> {
let coordinator_task = coordinator_client.get_task(request).await?;

if coordinator_task.errcode == ErrorCode::ErrCoordinatorEmptyProofData
&& self.suppress_empty_task_error
{
return Ok(None);
}

if coordinator_task.errcode != ErrorCode::Success {
eyre::bail!(
"Failed to get task, errcode: {:?}, errmsg: {:?}",
Expand All @@ -199,9 +210,7 @@ where
);
}

coordinator_task
.data
.ok_or_else(|| eyre::eyre!("No task available"))
Ok(coordinator_task.data)
}

async fn request_proving(
Expand Down Expand Up @@ -269,6 +278,9 @@ where
let coordinator_task_uuid = &coordinator_task.uuid;
let coordinator_task_id = &coordinator_task.task_id;

// Track last observed status to avoid spamming logs when status hasn't changed.
let mut last_status: Option<TaskStatus> = None;

loop {
let task = self
.proving_service
Expand All @@ -279,25 +291,30 @@ where
})
.await;

match task.status {
let current_status = task.status;

match current_status {
TaskStatus::Queued | TaskStatus::Proving => {
info!(
?prover_name,
?task_type,
?coordinator_task_uuid,
?coordinator_task_id,
?proving_service_task_id,
status = ?task.status,
"Task status update"
);
if last_status != Some(current_status) {
info!(
?prover_name,
?task_type,
?coordinator_task_uuid,
?coordinator_task_id,
?proving_service_task_id,
status = ?current_status,
"Task status update"
);
}
last_status.replace(current_status);
if let Some(db) = &self.db {
db.set_task(
public_key.clone(),
coordinator_task,
proving_service_task_id.clone(),
);
}
sleep(Duration::from_secs(WORKER_SLEEP_SEC)).await;
sleep(self.poll_delay()).await;
}
TaskStatus::Success => {
info!(
Expand Down Expand Up @@ -436,6 +453,16 @@ where
input: task.task_data.clone(),
})
}

fn poll_delay(&self) -> Duration {
let base_delay = Duration::from_secs(self.poll_interval_sec);
if self.randomized_delay_sec == 0 {
return base_delay;
}
let mut rng = rand::rng();
let random_delay = rng.random_range(0..self.randomized_delay_sec * 1000);
base_delay + Duration::from_millis(random_delay)
}
}

#[cfg(test)]
Expand Down