diff --git a/Cargo.lock b/Cargo.lock index 840c348..feb4996 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,28 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloy-rlp" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f70d83b765fdc080dbcd4f4db70d8d23fe4761f2f02ebfa9146b833900634b4" +dependencies = [ + "alloy-rlp-derive", + "arrayvec", + "bytes", +] + +[[package]] +name = "alloy-rlp-derive" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64b728d511962dda67c1bc7ea7c03736ec275ed2cf4c35d9585298ac9ccf3b73" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "anstream" version = "0.6.21" @@ -663,10 +685,10 @@ dependencies = [ "num_enum", "open-fastrlp", "rand 0.8.5", - "rlp 0.5.2", + "rlp", "serde", "serde_json", - "strum", + "strum 0.24.1", "tempfile", "thiserror", "tiny-keccak", @@ -1200,7 +1222,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f28220f89297a075ddc7245cd538076ee98b01f2a9c23a53a4f1105d5a322808" dependencies = [ - "rlp 0.5.2", + "rlp", ] [[package]] @@ -2005,16 +2027,6 @@ dependencies = [ "rustc-hex", ] -[[package]] -name = "rlp" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa24e92bb2a83198bb76d661a71df9f7076b8c420b8696e4d3d97d50d94479e3" -dependencies = [ - "bytes", - "rustc-hex", -] - [[package]] name = "rlp-derive" version = "0.1.0" @@ -2147,24 +2159,27 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "scroll-proving-sdk" -version = "0.1.0" +version = "0.2.0" dependencies = [ + "alloy-rlp", "async-trait", "axum", "clap", "dotenvy", "ethers-core", "eyre", + "futures", "hex", "http", "rand 0.9.2", "reqwest", "reqwest-middleware", "reqwest-retry", - "rlp 0.6.1", "rocksdb", "serde", "serde_json", + "serde_repr", + "strum 0.27.2", "tiny-keccak", "tokio", "tracing", @@ -2251,6 +2266,17 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -2371,7 +2397,16 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" dependencies = [ - "strum_macros", + "strum_macros 0.24.3", +] + +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +dependencies = [ + "strum_macros 0.27.2", ] [[package]] @@ -2387,6 +2422,18 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "strum_macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "subtle" version = "2.6.1" diff --git a/Cargo.toml b/Cargo.toml index 718aa90..1fef419 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,20 +1,21 @@ [package] name = "scroll-proving-sdk" -version = "0.1.0" -edition = "2021" +version = "0.2.0" +edition = "2024" [dependencies] eyre = "0.6" serde = { version = "1", features = ["derive"] } serde_json = "1.0" ethers-core = { git = "https://github.com/scroll-tech/ethers-rs.git", branch = "v2.0.7" } -reqwest = { version = "0.12", features = ["gzip"] } +futures = "0.3" +reqwest = { version = "0.12", features = ["gzip", "json"] } reqwest-middleware = "0.4" reqwest-retry = "0.7" hex = "0.4" tiny-keccak = { version = "2.0", features = ["sha3", "keccak"] } rand = "0.9" -rlp = "0.6" +alloy-rlp = { version = "0.3", features = ["derive"] } tokio = { version = "1.48", features = ["net", "sync"] } async-trait = "0.1" http = "1.4" @@ -24,4 +25,5 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } axum = { version = "0.8", default-features = false, features = ["tokio", "http1"] } dotenvy = "0.15" rocksdb = "0.24" - +strum = { version = "0.27", features = ["derive"] } +serde_repr = "0.1" diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..6d63c40 --- /dev/null +++ b/build.rs @@ -0,0 +1,20 @@ +use std::env; + +const DEFAULT_COMMIT: &str = "unknown"; +const DEFAULT_ZK_VERSION: &str = "000000-000000"; +const DEFAULT_TAG: &str = "v0.0.0"; + +fn main() { + println!( + "cargo:rustc-env=GIT_REV={}", + env::var("GIT_REV").unwrap_or_else(|_| DEFAULT_COMMIT.to_string()), + ); + println!( + "cargo:rustc-env=GO_TAG={}", + env::var("GO_TAG").unwrap_or_else(|_| DEFAULT_TAG.to_string()), + ); + println!( + "cargo:rustc-env=ZK_VERSION={}", + env::var("ZK_VERSION").unwrap_or_else(|_| DEFAULT_ZK_VERSION.to_string()), + ); +} diff --git a/examples/cloud.rs b/examples/cloud.rs index 9ac64f0..1c0c49d 100644 --- a/examples/cloud.rs +++ b/examples/cloud.rs @@ -1,18 +1,18 @@ #![allow(dead_code)] -use eyre::{eyre, Result}; use async_trait::async_trait; use clap::Parser; +use eyre::{Result, eyre}; use reqwest::Url; use std::fs::File; use scroll_proving_sdk::{ config::Config as SdkConfig, prover::{ + ProverBuilder, ProvingService, proving_service::{ GetVkRequest, GetVkResponse, ProveRequest, ProveResponse, QueryTaskRequest, QueryTaskResponse, }, - ProverBuilder, ProvingService, }, utils::init_tracing, }; diff --git a/examples/local.rs b/examples/local.rs index 977c2fe..b7254d6 100644 --- a/examples/local.rs +++ b/examples/local.rs @@ -1,14 +1,14 @@ -use eyre::{eyre, Result}; use async_trait::async_trait; use clap::Parser; +use eyre::{Result, eyre}; use scroll_proving_sdk::{ config::Config as SdkConfig, prover::{ + ProverBuilder, ProvingService, proving_service::{ GetVkRequest, GetVkResponse, ProveRequest, ProveResponse, QueryTaskRequest, QueryTaskResponse, }, - ProverBuilder, ProvingService, }, utils::init_tracing, }; diff --git a/rust-toolchain.toml b/rust-toolchain.toml index d3e25b1..a43d994 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2025-02-14" +channel = "nightly-2025-08-18" diff --git a/src/config.rs b/src/config.rs index 8ac98c4..27011f5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,6 @@ use crate::{coordinator_handler::ProverType, prover::ProofType}; -use eyre::{eyre, Result}; use dotenvy::dotenv; +use eyre::{Result, eyre}; use serde::{Deserialize, Serialize}; use serde_json; use std::fs::File; @@ -22,6 +22,8 @@ pub struct CoordinatorConfig { pub retry_count: u32, pub retry_wait_time_sec: u64, pub connection_timeout_sec: u64, + #[serde(default)] + pub suppress_empty_task_error: bool, } #[derive(Debug, Serialize, Deserialize, Clone)] @@ -37,8 +39,6 @@ pub struct ProverConfig { /// specified time value. Defaults to 0, indicating that no randomized delay shall be applied. #[serde(default)] pub randomized_delay_sec: u64, - #[serde(default)] - pub suppress_empty_task_error: bool, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct DbConfig {} @@ -106,11 +106,12 @@ impl Config { self.prover.supported_proof_types = values_vec .iter() - .map(|value| match value.parse::() { - Ok(num) => ProofType::from_u8(num), - Err(e) => { - panic!("Failed to parse circuit type: {}", e); - } + .map(|value| { + value + .parse::() + .ok() + .and_then(ProofType::from_repr) + .expect("failed to parse circuit type") }) .collect::>(); } @@ -120,7 +121,7 @@ impl Config { } if let Some(val) = Self::get_env_var("DB_PATH")? { - self.db_path = Option::from(val); + self.db_path = Some(val); } if let Some(val) = Self::get_env_var("POLL_INTERVAL_SEC")? { @@ -131,7 +132,7 @@ impl Config { } if Self::get_env_var("SUPPRESS_EMPTY_TASK_ERR")?.is_some() { - self.prover.suppress_empty_task_error = true; + self.coordinator.suppress_empty_task_error = true; } Ok(()) diff --git a/src/coordinator_handler/api.rs b/src/coordinator_handler/api.rs index 5ca2335..bef362c 100644 --- a/src/coordinator_handler/api.rs +++ b/src/coordinator_handler/api.rs @@ -1,13 +1,15 @@ use super::{ - ChallengeResponseData, GetTaskRequest, GetTaskResponseData, LoginRequest, LoginResponseData, - Response, SubmitProofRequest, SubmitProofResponseData, + ChallengeResponse, GetTaskRequest, GetTaskResponse, LoginRequest, LoginResponse, Response, + SubmitProofRequest, }; use crate::config::CoordinatorConfig; use core::time::Duration; -use reqwest::{header::CONTENT_TYPE, Url}; +use eyre::Context; +use http::{Method, StatusCode}; +use reqwest::{Url, header::CONTENT_TYPE}; use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; -use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; -use serde::Serialize; +use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff}; +use serde::{Deserialize, Serialize}; use tracing::Level; pub struct Api { @@ -28,89 +30,34 @@ impl Api { .build(); Ok(Self { - base_url: Url::parse(&cfg.base_url)?, - send_timeout: core::time::Duration::from_secs(cfg.connection_timeout_sec), + base_url: Url::parse(cfg.base_url.trim_end_matches('/'))?, + send_timeout: Duration::from_secs(cfg.connection_timeout_sec), client, }) } - fn build_url(&self, method: &str) -> eyre::Result { - self.base_url.join(method).map_err(|e| eyre::eyre!(e)) + pub async fn challenge(&self) -> eyre::Result { + const PATH: &str = "/coordinator/v1/challenge"; + let response: Response = + self.request(Method::GET, PATH, None::<&()>, None).await?; + response.into_result().context("challenge request failed") } - #[instrument(skip(self, req, token), level = Level::DEBUG)] - async fn post_with_token( - &self, - method: &str, - req: &Req, - token: &String, - ) -> eyre::Result - where - Req: ?Sized + Serialize, - Resp: serde::de::DeserializeOwned, - { - let url = self.build_url(method)?; - let request_body = serde_json::to_string(req)?; - let size = request_body.len(); - - debug!("sent request"); - trace!(token = %token, request_body = %request_body, size = %size); - let response = self - .client - .post(url) - .header(CONTENT_TYPE, "application/json") - .bearer_auth(token) - .body(request_body) - .timeout(self.send_timeout) - .send() + pub async fn login(&self, req: &LoginRequest<'_>, token: &str) -> eyre::Result { + const PATH: &str = "/coordinator/v1/login"; + let response: Response = self + .request(Method::POST, PATH, Some(req), Some(token)) .await?; - - if response.status() != http::status::StatusCode::OK { - eyre::bail!( - "[coordinator client], {method}, status not ok: {}", - response.status() - ) - } - - let response_body = response.text().await?; - - debug!("received response"); - trace!(response_body = %response_body); - serde_json::from_str(&response_body).map_err(|e| eyre::eyre!(e)) - } - - pub async fn challenge(&self) -> eyre::Result> { - let method = "/coordinator/v1/challenge"; - let url = self.build_url(method)?; - - let response = self - .client - .get(url) - .header(CONTENT_TYPE, "application/json") - .timeout(self.send_timeout) - .send() - .await?; - - let response_body = response.text().await?; - - serde_json::from_str(&response_body).map_err(|e| eyre::eyre!(e)) - } - - pub async fn login( - &self, - req: &LoginRequest, - token: &String, - ) -> eyre::Result> { - let method = "/coordinator/v1/login"; - self.post_with_token(method, req, token).await + response.into_result().context("login failed") } pub async fn get_task( &self, - req: &GetTaskRequest, - token: &String, - ) -> eyre::Result> { - let method = "/coordinator/v1/get_task"; + req: &GetTaskRequest<'_>, + token: &str, + ) -> eyre::Result> { + const PATH: &str = "/coordinator/v1/get_task"; + if self.send_timeout < Duration::from_secs(600) { tracing::warn!( "get_task API is time-consuming, timeout setting is too low ({}), set it to more than 600s", @@ -118,15 +65,63 @@ impl Api { ); } - self.post_with_token(method, req, token).await + self.request(Method::POST, PATH, Some(req), Some(token)) + .await } pub async fn submit_proof( &self, - req: &SubmitProofRequest, - token: &String, - ) -> eyre::Result> { - let method = "/coordinator/v1/submit_proof"; - self.post_with_token(method, req, token).await + req: &SubmitProofRequest<'_>, + token: &str, + ) -> eyre::Result> { + const PATH: &str = "/coordinator/v1/submit_proof"; + self.request(Method::POST, PATH, Some(req), Some(token)) + .await + } + + #[instrument(skip(self, body, token), level = Level::DEBUG)] + async fn request( + &self, + method: Method, + path: &str, + body: Option<&Req>, + token: Option<&str>, + ) -> eyre::Result> + where + Req: ?Sized + Serialize, + T: for<'de> Deserialize<'de>, + { + let url = self.base_url.join(path)?; + + let mut builder = self + .client + .request(method, url) + .header(CONTENT_TYPE, "application/json") + .timeout(self.send_timeout); + + if let Some(token) = token { + trace!(token = %token); + builder = builder.bearer_auth(token); + } + + if let Some(body) = body { + let request_body = serde_json::to_string(body)?; + let size = request_body.len(); + trace!(request_body = %request_body, size = %size); + builder = builder.body(request_body); + } + + debug!("sending request"); + let response = builder.send().await?; + + if response.status() != StatusCode::OK { + eyre::bail!("{path}, status not ok: {}", response.status()) + } + + let response = response.text().await?; + debug!("received response"); + trace!(response_body = %response); + let response = serde_json::from_str(&response)?; + Ok(response) } } diff --git a/src/coordinator_handler/coordinator_client.rs b/src/coordinator_handler/coordinator_client.rs index 768f0cc..ca547dd 100644 --- a/src/coordinator_handler/coordinator_client.rs +++ b/src/coordinator_handler/coordinator_client.rs @@ -1,24 +1,29 @@ use super::{ - api::Api, error::ErrorCode, GetTaskRequest, GetTaskResponseData, KeySigner, LoginMessage, - LoginRequest, ProverType, Response, SubmitProofRequest, SubmitProofResponseData, + GetTaskRequest, GetTaskResponse, KeySigner, LoginMessage, LoginRequest, ProverType, + SubmitProofRequest, api::Api, }; -use crate::{config::CoordinatorConfig, prover::ProverProviderType, utils::get_version}; -use tokio::sync::{Mutex, MutexGuard}; +use crate::{config::CoordinatorConfig, prover::ProverProviderType, utils::VERSION}; +use std::borrow::Cow; +use std::sync::Arc; +use std::sync::Mutex; +use tokio::sync::OnceCell; pub struct CoordinatorClient { - prover_types: Vec, - vks: Vec, pub prover_name: String, pub prover_provider_type: ProverProviderType, pub key_signer: KeySigner, + prover_types: Vec, + vks: Vec, api: Api, - token: Mutex>, + suppress_empty_task_errors: bool, + token: OnceCell>>, } impl CoordinatorClient { pub fn new( cfg: CoordinatorConfig, prover_types: Vec, + suppress_empty_task_errors: bool, vks: Vec, prover_name: String, prover_provider_type: ProverProviderType, @@ -32,121 +37,119 @@ impl CoordinatorClient { prover_provider_type, key_signer, api, - token: Mutex::new(None), + suppress_empty_task_errors, + token: OnceCell::new(), }; Ok(client) } + pub async fn token(&self) -> eyre::Result> { + let mutex = self.get_lazy_init_mutex().await?; + let guard = mutex.lock().unwrap_or_else(|e| e.into_inner()); // ignore poisoned lock + Ok(guard.clone()) + } pub async fn get_task( &self, - req: &GetTaskRequest, - ) -> eyre::Result> { - let token = self.get_token(false).await?; - let response = self.api.get_task(req, &token).await?; + req: &GetTaskRequest<'_>, + ) -> eyre::Result> { + let response = self.api.get_task(req, &self.token().await?).await?; - if response.errcode == ErrorCode::ErrJWTTokenExpired { - let token = self.get_token(true).await?; - self.api.get_task(req, &token).await + let response = if response.is_jwt_token_expired() { + let token = self.refresh_token().await?; + self.api.get_task(req, &token).await? } else { - Ok(response) + response + }; + + if response.is_empty_task_error() && self.suppress_empty_task_errors { + return Ok(None); } + Ok(Some(response.into_result()?)) } - pub async fn submit_proof( - &self, - req: &SubmitProofRequest, - ) -> eyre::Result> { - let token = self.get_token(false).await?; - let response = self.api.submit_proof(req, &token).await?; + pub async fn submit_proof(&self, req: &SubmitProofRequest<'_>) -> eyre::Result<()> { + let response = self.api.submit_proof(req, &self.token().await?).await?; - if response.errcode == ErrorCode::ErrJWTTokenExpired { - let token = self.get_token(true).await?; - self.api.submit_proof(req, &token).await + if response.is_jwt_token_expired() { + let token = self.refresh_token().await?; + self.api.submit_proof(req, &token).await?.into_result()?; } else { - Ok(response) + response.into_result()?; } + Ok(()) } - /// Retrieves a token for authentication, optionally forcing a re-login. - /// - /// This function attempts to get the stored token if `force_relogin` is set to `false`. - /// - /// If the token is expired, `force_relogin` is set to `true`, or a login was never performed - /// before, it will authenticate and fetch a new token. - pub async fn get_token(&self, force_relogin: bool) -> eyre::Result { - let token_guard = self.token.lock().await; - - match *token_guard { - Some(ref token) if !force_relogin => return Ok(token.to_string()), - _ => (), - } + /// Refresh the jwt token for authentication. + pub async fn refresh_token(&self) -> eyre::Result> { + let mutex = self.get_lazy_init_mutex().await?; + + let token = refresh_token( + &self.api, + &self.prover_name, + self.prover_provider_type, + &self.prover_types, + &self.vks, + &self.key_signer, + ) + .await?; + let mut guard = mutex.lock().unwrap_or_else(|e| e.into_inner()); // ignore poisoned lock + *guard = token.clone(); - self.login(token_guard).await + Ok(token) } - async fn login( - &self, - mut token_guard: MutexGuard<'_, Option>, - ) -> eyre::Result { - let challenge_response = self - .api - .challenge() - .await - .map_err(|e| eyre::eyre!("Failed to request a challenge: {e}"))?; - - if challenge_response.errcode != ErrorCode::Success { - eyre::bail!( - "Challenge request failed with {:?} {}", - challenge_response.errcode, - challenge_response.errmsg - ); - } - - let login_response_data = challenge_response - .data - .as_ref() - .ok_or_else(|| eyre::eyre!("Missing challenge token"))?; - - let login_message = LoginMessage { - challenge: login_response_data.token.clone(), - prover_version: get_version().to_string(), - prover_name: self.prover_name.clone(), - prover_provider_type: self.prover_provider_type, - prover_types: self.prover_types.clone(), - vks: self.vks.clone(), - }; - - let buffer = rlp::encode(&login_message); - let signature = self - .key_signer - .sign_buffer(&buffer) - .map_err(|e| eyre::eyre!("Failed to sign the login message: {e}"))?; - - let login_request = LoginRequest { - message: login_message, - public_key: self.key_signer.get_public_key(), - signature, - }; - let login_response = self - .api - .login(&login_request, &login_response_data.token) + async fn get_lazy_init_mutex(&self) -> eyre::Result<&Mutex>> { + self.token + .get_or_try_init::(|| async { + let token = refresh_token( + &self.api, + &self.prover_name, + self.prover_provider_type, + &self.prover_types, + &self.vks, + &self.key_signer, + ) + .await?; + Ok(Mutex::new(token)) + }) .await - .map_err(|e| eyre::eyre!("Failed to login: {e}"))?; - - if login_response.errcode != ErrorCode::Success { - eyre::bail!( - "Login request failed with {:?} {}", - login_response.errcode, - login_response.errmsg - ); - } - let token = login_response - .data - .map(|r| r.token) - .ok_or_else(|| eyre::eyre!("Empty data in response, lack of login"))?; - - *token_guard = Some(token.clone()); - - Ok(token) } } + +async fn refresh_token( + api: &Api, + prover_name: &str, + prover_provider_type: ProverProviderType, + prover_types: &[ProverType], + vks: &[String], + key_signer: &KeySigner, +) -> eyre::Result> { + // login and get new token + let challenge = api.challenge().await?; + + let login_message = LoginMessage { + challenge: Cow::Borrowed(&challenge.token), + prover_version: VERSION.into(), + prover_name: Cow::Borrowed(prover_name), + prover_provider_type, + prover_types: prover_types.into(), + vks: vks.to_vec(), + }; + let buffer = alloy_rlp::encode(&login_message); + let signature = key_signer + .sign_buffer(&buffer) + .map_err(|e| eyre::eyre!("Failed to sign the login message: {e}"))?; + + let response = api + .login( + &LoginRequest { + message: login_message, + public_key: key_signer.get_public_key().into(), + signature: signature.into(), + }, + &challenge.token, + ) + .await?; + + Ok(Arc::from(response.token)) +} diff --git a/src/coordinator_handler/error.rs b/src/coordinator_handler/error.rs index c09a6d9..00b4ff2 100644 --- a/src/coordinator_handler/error.rs +++ b/src/coordinator_handler/error.rs @@ -1,22 +1,22 @@ use serde::{Deserialize, Deserializer}; -use std::fmt; +use strum::EnumIs; -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, EnumIs)] pub enum ErrorCode { Success, InternalServerError, - ErrProverStatsAPIParameterInvalidNo, - ErrProverStatsAPIProverTaskFailure, - ErrProverStatsAPIProverTotalRewardFailure, + ProverStatsAPIParameterInvalidNo, + ProverStatsAPIProverTaskFailure, + ProverStatsAPIProverTotalRewardFailure, - ErrCoordinatorParameterInvalidNo, - ErrCoordinatorGetTaskFailure, - ErrCoordinatorHandleZkProofFailure, - ErrCoordinatorEmptyProofData, + CoordinatorParameterInvalidNo, + CoordinatorGetTaskFailure, + CoordinatorHandleZkProofFailure, + CoordinatorEmptyProofData, - ErrJWTCommonErr, - ErrJWTTokenExpired, + JWTCommonErr, + JWTTokenExpired, Undefined(i32), } @@ -26,15 +26,15 @@ impl ErrorCode { match v { 0 => ErrorCode::Success, 500 => ErrorCode::InternalServerError, - 10001 => ErrorCode::ErrProverStatsAPIParameterInvalidNo, - 10002 => ErrorCode::ErrProverStatsAPIProverTaskFailure, - 10003 => ErrorCode::ErrProverStatsAPIProverTotalRewardFailure, - 20001 => ErrorCode::ErrCoordinatorParameterInvalidNo, - 20002 => ErrorCode::ErrCoordinatorGetTaskFailure, - 20003 => ErrorCode::ErrCoordinatorHandleZkProofFailure, - 20004 => ErrorCode::ErrCoordinatorEmptyProofData, - 50000 => ErrorCode::ErrJWTCommonErr, - 50001 => ErrorCode::ErrJWTTokenExpired, + 10001 => ErrorCode::ProverStatsAPIParameterInvalidNo, + 10002 => ErrorCode::ProverStatsAPIProverTaskFailure, + 10003 => ErrorCode::ProverStatsAPIProverTotalRewardFailure, + 20001 => ErrorCode::CoordinatorParameterInvalidNo, + 20002 => ErrorCode::CoordinatorGetTaskFailure, + 20003 => ErrorCode::CoordinatorHandleZkProofFailure, + 20004 => ErrorCode::CoordinatorEmptyProofData, + 50000 => ErrorCode::JWTCommonErr, + 50001 => ErrorCode::JWTTokenExpired, _ => { error!("get unexpected error code from coordinator: {v}"); ErrorCode::Undefined(v) @@ -52,14 +52,3 @@ impl<'de> Deserialize<'de> for ErrorCode { Ok(ErrorCode::from_i32(v)) } } - -// ==================================================== - -#[derive(Debug, Clone)] -pub struct ProofStatusNotOKError; - -impl fmt::Display for ProofStatusNotOKError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "proof status not ok") - } -} diff --git a/src/coordinator_handler/key_signer.rs b/src/coordinator_handler/key_signer.rs index e954e9a..6ac712e 100644 --- a/src/coordinator_handler/key_signer.rs +++ b/src/coordinator_handler/key_signer.rs @@ -1,8 +1,8 @@ use ethers_core::{ k256::{ - ecdsa::{signature::hazmat::PrehashSigner, RecoveryId, Signature, SigningKey}, - elliptic_curve::{sec1::ToEncodedPoint, FieldBytes}, PublicKey, Secp256k1, SecretKey, + ecdsa::{RecoveryId, Signature, SigningKey, signature::hazmat::PrehashSigner}, + elliptic_curve::{FieldBytes, sec1::ToEncodedPoint}, }, types::Signature as EthSignature, types::{H256, U256}, diff --git a/src/coordinator_handler/types.rs b/src/coordinator_handler/types.rs index 3b4eb2e..68d6019 100644 --- a/src/coordinator_handler/types.rs +++ b/src/coordinator_handler/types.rs @@ -1,120 +1,217 @@ use super::error::ErrorCode; use crate::prover::{ProofType, ProverProviderType}; -use rlp::{Encodable, RlpStream}; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use alloy_rlp::{ + BufMut, Decodable, EMPTY_STRING_CODE, Encodable, Header, RlpEncodable, length_of_length, +}; +use serde::{Deserialize, Deserializer, Serialize}; +use serde_repr::{Deserialize_repr as DeserializeRepr, Serialize_repr as SerializeRepr}; +use std::borrow::Cow; +use std::fmt; +use strum::{EnumIs, FromRepr}; -#[derive(Deserialize)] -pub struct Response { - pub errcode: ErrorCode, - pub errmsg: String, - pub data: Option, +#[derive(Debug, Clone)] +pub enum Response { + Ok(T), + Err(RpcError), } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum ProverType { - Undefined, - Chunk, - Batch, - OpenVM, -} +impl Response { + pub const fn is_ok(&self) -> bool { + matches!(self, Response::Ok(_)) + } + + pub const fn is_error(&self) -> bool { + matches!(self, Response::Err(_)) + } -impl ProverType { - pub fn from_u8(v: u8) -> Self { - match v { - 1 => ProverType::Chunk, - 2 => ProverType::Batch, - 3 => ProverType::OpenVM, - _ => ProverType::Undefined, + pub const fn is_jwt_token_expired(&self) -> bool { + match self { + Response::Err(err) => err.code.is_jwt_token_expired(), + _ => false, } } - pub fn to_u8(&self) -> u8 { + pub const fn is_empty_task_error(&self) -> bool { match self { - ProverType::Undefined => 0, - ProverType::Chunk => 1, - ProverType::Batch => 2, - ProverType::OpenVM => 3, + Response::Err(err) => err.code.is_coordinator_empty_proof_data(), + _ => false, } } -} -impl Serialize for ProverType { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_u8(self.to_u8()) + pub fn into_result(self) -> Result { + match self { + Response::Ok(data) => Ok(data), + Response::Err(err) => Err(err), + } } } -impl<'de> Deserialize<'de> for ProverType { +#[derive(Deserialize)] +struct ResponseHelper { + errcode: ErrorCode, + errmsg: String, + data: Option, +} + +impl<'de, T: Deserialize<'de>> Deserialize<'de> for Response { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { - let v: u8 = u8::deserialize(deserializer)?; - Ok(ProverType::from_u8(v)) + let helper = ResponseHelper::::deserialize(deserializer)?; + if helper.errcode == ErrorCode::Success { + if let Some(data) = helper.data { + Ok(Response::Ok(data)) + } else if size_of::() == 0 { + // Special handling for zero-sized types, e.g. Reposnse<()> + let zst: T = unsafe { + // SAFETY: It's always safe to synthesizing ZST + std::mem::zeroed() + }; + return Ok(Response::Ok(zst)); + } else { + Err(serde::de::Error::custom( + "Expected data field for successful response", + )) + } + } else { + Ok(Response::Err(RpcError { + code: helper.errcode, + msg: helper.errmsg, + })) + } } } -#[derive(Serialize, Deserialize)] -pub struct LoginMessage { - pub challenge: String, - pub prover_version: String, - pub prover_name: String, - pub prover_provider_type: ProverProviderType, - pub prover_types: Vec, - pub vks: Vec, +#[derive(Debug, Clone, Deserialize)] +pub struct RpcError { + code: ErrorCode, + msg: String, } -impl Encodable for LoginMessage { - fn rlp_append(&self, s: &mut RlpStream) { - let num_fields = 6; - s.begin_list(num_fields); - s.append(&self.challenge); - s.append(&self.prover_version); - s.append(&self.prover_name); - s.append(&(self.prover_provider_type as u8)); - // The ProverType in go side is an type alias of uint8 - // A uint8 slice is treated as a string when doing the rlp encoding - let prover_types = self - .prover_types - .iter() - .map(|prover_type| prover_type.to_u8()) - .collect::>(); - s.append(&prover_types); - s.begin_list(self.vks.len()); - for vk in &self.vks { - s.append(vk); +impl RpcError { + pub fn code(&self) -> &ErrorCode { + &self.code + } + + pub fn msg(&self) -> &str { + &self.msg + } +} + +impl fmt::Display for RpcError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}: {}", self.code, self.msg) + } +} + +impl std::error::Error for RpcError {} + +#[derive( + Debug, + Clone, + Copy, + PartialEq, + Eq, + Hash, + SerializeRepr, + DeserializeRepr, + EnumIs, + strum::Display, + FromRepr, +)] +#[repr(u8)] +pub enum ProverType { + Undefined = 0, + Chunk = 1, + Batch = 2, + OpenVM = 3, +} + +impl Encodable for ProverType { + fn encode(&self, s: &mut dyn BufMut) { + (*self as u8).encode(s); + } +} + +impl Decodable for ProverType { + fn decode(buf: &mut &[u8]) -> alloy_rlp::Result { + let v = u8::decode(buf)?; + ProverType::from_repr(v).ok_or(alloy_rlp::Error::Custom("Invalid ProverType value")) + } +} + +/// The ProverType in go side is a type alias of uint8 +/// A uint8 slice is treated as a string when doing the rlp encoding +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProverTypes<'a>(Cow<'a, [ProverType]>); + +impl<'a, T: Into>> From for ProverTypes<'a> { + fn from(value: T) -> Self { + ProverTypes(value.into()) + } +} + +/// See implementation of [`Encodable`] for [u8] +impl Encodable for ProverTypes<'_> { + #[inline] + fn encode(&self, out: &mut dyn BufMut) { + if self.0.len() != 1 || self.0[0] as u8 >= EMPTY_STRING_CODE { + Header { + list: false, + payload_length: self.0.len(), + } + .encode(out); + } + for prover_type in self.0.iter() { + out.put_u8(*prover_type as u8); } } + #[inline] + fn length(&self) -> usize { + let mut len = self.0.len(); + if len != 1 || self.0[0] as u8 >= EMPTY_STRING_CODE { + len += length_of_length(len); + } + len + } } -#[derive(Serialize, Deserialize)] -pub struct LoginRequest { - pub message: LoginMessage, - pub public_key: String, - pub signature: String, +#[derive(Debug, Clone, Serialize, RlpEncodable)] +pub struct LoginMessage<'a> { + pub challenge: Cow<'a, str>, + pub prover_version: Cow<'a, str>, + pub prover_name: Cow<'a, str>, + pub prover_provider_type: ProverProviderType, + pub prover_types: ProverTypes<'a>, + pub vks: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub struct LoginRequest<'a> { + pub message: LoginMessage<'a>, + pub public_key: Cow<'a, str>, + pub signature: Cow<'a, str>, } #[derive(Serialize, Deserialize)] -pub struct LoginResponseData { +pub struct LoginResponse { pub time: String, pub token: String, } -pub type ChallengeResponseData = LoginResponseData; +pub type ChallengeResponse = LoginResponse; -#[derive(Default, Serialize, Deserialize)] -pub struct GetTaskRequest { +#[derive(Default, Debug, Clone, Serialize)] +pub struct GetTaskRequest<'a> { pub task_types: Vec, pub prover_height: Option, pub universal: bool, - pub task_id: Option, + pub task_id: Option>, } -#[derive(Serialize, Deserialize, Clone, Default)] -pub struct GetTaskResponseData { +#[derive(Default, Clone, Serialize, Deserialize)] +pub struct GetTaskResponse { pub uuid: String, pub task_id: String, pub task_type: ProofType, @@ -122,102 +219,57 @@ pub struct GetTaskResponseData { pub hard_fork_name: String, } -#[derive(Serialize, Deserialize)] // TODO: Default? -pub struct SubmitProofRequest { - pub uuid: String, - pub task_id: String, +#[derive(Debug, Clone, Serialize)] // TODO: Default? +pub struct SubmitProofRequest<'a> { + pub uuid: Cow<'a, str>, + pub task_id: Cow<'a, str>, pub task_type: ProofType, pub status: ProofStatus, - pub proof: String, + pub proof: Cow<'a, str>, pub failure_type: Option, - pub failure_msg: Option, + pub failure_msg: Option>, pub universal: bool, } -#[derive(Serialize, Deserialize)] -pub struct SubmitProofResponseData {} - -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive( + Default, + Debug, + Clone, + Copy, + PartialEq, + Eq, + Hash, + SerializeRepr, + DeserializeRepr, + EnumIs, + strum::Display, + FromRepr, +)] +#[repr(u8)] pub enum ProofFailureType { - Undefined, - Panic, - NoPanic, -} - -impl ProofFailureType { - fn from_u8(v: u8) -> Self { - match v { - 1 => ProofFailureType::Panic, - 2 => ProofFailureType::NoPanic, - _ => ProofFailureType::Undefined, - } - } + #[default] + Undefined = 0, + Panic = 1, + NoPanic = 2, } -impl Serialize for ProofFailureType { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match *self { - ProofFailureType::Undefined => serializer.serialize_u8(0), - ProofFailureType::Panic => serializer.serialize_u8(1), - ProofFailureType::NoPanic => serializer.serialize_u8(2), - } - } -} - -impl<'de> Deserialize<'de> for ProofFailureType { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let v: u8 = u8::deserialize(deserializer)?; - Ok(ProofFailureType::from_u8(v)) - } -} - -impl Default for ProofFailureType { - fn default() -> Self { - Self::Undefined - } -} - -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive( + Debug, + Clone, + Copy, + PartialEq, + Eq, + Hash, + SerializeRepr, + DeserializeRepr, + EnumIs, + strum::Display, + FromRepr, +)] +#[repr(u8)] pub enum ProofStatus { - Ok, - Error, -} - -impl ProofStatus { - fn from_u8(v: u8) -> Self { - match v { - 0 => ProofStatus::Ok, - _ => ProofStatus::Error, - } - } -} - -impl Serialize for ProofStatus { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match *self { - ProofStatus::Ok => serializer.serialize_u8(0), - ProofStatus::Error => serializer.serialize_u8(1), - } - } -} - -impl<'de> Deserialize<'de> for ProofStatus { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let v: u8 = u8::deserialize(deserializer)?; - Ok(ProofStatus::from_u8(v)) - } + Ok = 0, + Error = 1, } #[cfg(test)] @@ -241,15 +293,15 @@ mod tests { let key_signer = KeySigner::new_from_secret_key(private_key_hex).unwrap(); let login_message = LoginMessage { - challenge: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MjQ4Mzg0ODUsIm9yaWdfaWF0IjoxNzI0ODM0ODg1LCJyYW5kb20iOiJ6QmdNZGstNGc4UzNUNTFrVEFsYk1RTXg2TGJ4SUs4czY3ejM2SlNuSFlJPSJ9.x9PvihhNx2w4_OX5uCrv8QJCNYVQkIi-K2k8XFXYmik".to_string(), - prover_version: "v4.4.45-37af5ef5-38a68e2-1c5093c".to_string(), - prover_name: "test".to_string(), + challenge: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MjQ4Mzg0ODUsIm9yaWdfaWF0IjoxNzI0ODM0ODg1LCJyYW5kb20iOiJ6QmdNZGstNGc4UzNUNTFrVEFsYk1RTXg2TGJ4SUs4czY3ejM2SlNuSFlJPSJ9.x9PvihhNx2w4_OX5uCrv8QJCNYVQkIi-K2k8XFXYmik".into(), + prover_version: "v4.4.45-37af5ef5-38a68e2-1c5093c".into(), + prover_name: "test".into(), prover_provider_type: ProverProviderType::Internal, - prover_types: vec![ProverType::Chunk], - vks: vec!["mock_vk".to_string()], + prover_types: (&[ProverType::Chunk]).into(), + vks: vec!["mock_vk".into()], }; - let buffer = rlp::encode(&login_message); + let buffer = alloy_rlp::encode(&login_message); let signature = key_signer .sign_buffer(&buffer) .map_err(|e| eyre::eyre!("Failed to sign the login message: {e}")) diff --git a/src/db.rs b/src/db.rs index 3a382ac..019b921 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,4 +1,4 @@ -use crate::coordinator_handler::GetTaskResponseData; +use crate::coordinator_handler::GetTaskResponse; use rocksdb::DB; use tracing::Level; @@ -13,7 +13,7 @@ impl Db { } #[instrument(skip(self), level = Level::DEBUG)] - pub fn get_task(&self, public_key: String) -> (Option, Option) { + pub fn get_task(&self, public_key: String) -> (Option, Option) { ( self.get_coordinator_task_by_public_key(public_key.clone()), self.get_proving_task_id_by_public_key(public_key), @@ -24,7 +24,7 @@ impl Db { pub fn set_task( &self, public_key: String, - coordinator_task: &GetTaskResponseData, + coordinator_task: &GetTaskResponse, proving_task_id: String, ) { self.set_coordinator_task_by_public_key(public_key.clone(), coordinator_task); @@ -36,10 +36,7 @@ impl Db { self.delete_proving_task_id_by_public_key(public_key); } - fn get_coordinator_task_by_public_key( - &self, - public_key: String, - ) -> Option { + fn get_coordinator_task_by_public_key(&self, public_key: String) -> Option { self.db .get(fmt_coordinator_task_key(public_key)) .ok()? @@ -57,7 +54,7 @@ impl Db { fn set_coordinator_task_by_public_key( &self, public_key: String, - coordinator_task: &GetTaskResponseData, + coordinator_task: &GetTaskResponse, ) { let _ = serde_json::to_vec(coordinator_task) .map(|bytes| self.db.put(fmt_coordinator_task_key(public_key), bytes)); diff --git a/src/prover/builder.rs b/src/prover/builder.rs index 1aa0048..6de106d 100644 --- a/src/prover/builder.rs +++ b/src/prover/builder.rs @@ -6,8 +6,8 @@ use crate::{ coordinator_handler::{CoordinatorClient, KeySigner}, db::Db, prover::{ - proving_service::{GetVkRequest, ProvingService}, Prover, + proving_service::{GetVkRequest, ProvingService}, }, utils::format_cloud_prover_name, }; @@ -75,6 +75,7 @@ where CoordinatorClient::new( self.cfg.coordinator.clone(), self.cfg.coordinator_prover_type(), + self.cfg.coordinator.suppress_empty_task_error, get_vk_response.vks.clone(), prover_name, prover_provider_type, @@ -99,7 +100,6 @@ where .transpose()?, poll_interval_sec: self.cfg.prover.poll_interval_sec, randomized_delay_sec: self.cfg.prover.randomized_delay_sec, - suppress_empty_task_error: self.cfg.prover.suppress_empty_task_error, }) } } diff --git a/src/prover/mod.rs b/src/prover/mod.rs index 5c35a6f..87ea54a 100644 --- a/src/prover/mod.rs +++ b/src/prover/mod.rs @@ -4,19 +4,21 @@ pub mod types; use crate::{ coordinator_handler::{ - CoordinatorClient, ErrorCode, GetTaskRequest, GetTaskResponseData, ProofFailureType, - ProofStatus, SubmitProofRequest, + CoordinatorClient, GetTaskRequest, GetTaskResponse, ProofFailureType, ProofStatus, + SubmitProofRequest, }, db::Db, }; -use axum::{routing::get, Router}; +use axum::{Router, routing::get}; +use eyre::Context; use proving_service::{ProveRequest, QueryTaskRequest, TaskStatus}; +use rand::Rng; +use std::borrow::Cow; use std::future::IntoFuture; use std::net::SocketAddr; use std::str::FromStr; -use rand::Rng; use tokio::net::TcpListener; -use tokio::time::{sleep, Duration}; +use tokio::time::{Duration, sleep}; use tokio::{sync::RwLock, task::JoinSet}; use tracing::Level; use tracing::{error, info, instrument}; @@ -33,7 +35,6 @@ pub struct Prover { db: Option, poll_interval_sec: u64, randomized_delay_sec: u64, - suppress_empty_task_error: bool, } impl Prover @@ -127,7 +128,7 @@ where async fn test_coordinator_connection(&self) { self.coordinator_clients[0] - .get_token(true) + .refresh_token() .await .expect("Failed to login to coordinator"); } @@ -170,14 +171,20 @@ where .await; } - let mut get_task_request = self.build_get_task_request(None)?; + let mut get_task_request = GetTaskRequest { + task_types: self.proof_types.clone(), + prover_height: None, + universal: true, + task_id: None, + }; if let Some((t, s)) = task_spec { get_task_request.task_types = vec![t]; - get_task_request.task_id.replace(s.to_string()); + get_task_request.task_id.replace(s.into()); } - let Some(coordinator_task) = self - .get_coordinator_task(coordinator_client, &get_task_request) - .await? + let Some(coordinator_task) = coordinator_client + .get_task(&get_task_request) + .await + .context("failed to get task")? else { return Ok(()); }; @@ -189,34 +196,10 @@ where .await } - async fn get_coordinator_task( - &self, - coordinator_client: &CoordinatorClient, - request: &GetTaskRequest, - ) -> eyre::Result> { - let coordinator_task = coordinator_client.get_task(request).await?; - - if coordinator_task.errcode == ErrorCode::ErrCoordinatorEmptyProofData - && self.suppress_empty_task_error - { - return Ok(None); - } - - if coordinator_task.errcode != ErrorCode::Success { - eyre::bail!( - "Failed to get task, errcode: {:?}, errmsg: {:?}", - coordinator_task.errcode, - coordinator_task.errmsg - ); - } - - Ok(coordinator_task.data) - } - async fn request_proving( &self, coordinator_client: &CoordinatorClient, - coordinator_task: &GetTaskResponseData, + coordinator_task: &GetTaskResponse, ) -> eyre::Result { let proving_input = match self.get_proving_input(coordinator_task) { Ok(result) => result, @@ -269,7 +252,7 @@ where async fn handle_proving_progress( &self, coordinator_client: &CoordinatorClient, - coordinator_task: &GetTaskResponseData, + coordinator_task: &GetTaskResponse, proving_service_task_id: String, ) -> eyre::Result<()> { let prover_name = &coordinator_client.prover_name; @@ -371,72 +354,50 @@ where async fn submit_proof( &self, coordinator_client: &CoordinatorClient, - coordinator_task: &GetTaskResponseData, + coordinator_task: &GetTaskResponse, task: proving_service::QueryTaskResponse, status: ProofStatus, failure_msg: Option, ) -> eyre::Result<()> { let submit_proof_req = SubmitProofRequest { universal: true, - uuid: coordinator_task.uuid.clone(), - task_id: coordinator_task.task_id.clone(), + uuid: coordinator_task.uuid.as_str().into(), + task_id: coordinator_task.task_id.as_str().into(), task_type: coordinator_task.task_type, status, - proof: task.proof.unwrap_or_default(), + proof: task.proof.map(Into::into).unwrap_or(Cow::Borrowed("")), failure_type: failure_msg.as_ref().map(|_| ProofFailureType::Panic), // TODO: handle ProofFailureType::NoPanic - failure_msg, + failure_msg: failure_msg.map(Into::into), }; - let submit_proof_result = match coordinator_client.submit_proof(&submit_proof_req).await { - Ok(result) => result, - Err(e) => { + match coordinator_client.submit_proof(&submit_proof_req).await { + Ok(()) => { info!( prover_name = ?coordinator_client.prover_name, ?coordinator_task.task_type, ?coordinator_task.uuid, ?coordinator_task.task_id, ?task.task_id, - error = ?e, - "Failed to submit proof due to a http error" + "Proof submitted successfully" + ); + } + Err(e) => { + error!( + prover_name = ?coordinator_client.prover_name, + ?coordinator_task.task_type, + ?coordinator_task.uuid, + ?coordinator_task.task_id, + ?task.task_id, + error = %e, + "Failed to submit proof due to coordinator error" ); return Ok(()); } }; - - if submit_proof_result.errcode != ErrorCode::Success { - info!( - prover_name = ?coordinator_client.prover_name, - ?coordinator_task.task_type, - ?coordinator_task.uuid, - ?coordinator_task.task_id, - ?task.task_id, - errcode = ?submit_proof_result.errcode, - errmsg = ?submit_proof_result.errmsg, - "Failed to submit proof due to coordinator error" - ); - } else { - info!( - prover_name = ?coordinator_client.prover_name, - ?coordinator_task.task_type, - ?coordinator_task.uuid, - ?coordinator_task.task_id, - ?task.task_id, - "Proof submitted successfully" - ); - } Ok(()) } - fn build_get_task_request(&self, prover_height: Option) -> eyre::Result { - Ok(GetTaskRequest { - task_types: self.proof_types.clone(), - prover_height, - universal: true, - task_id: None, - }) - } - - fn get_proving_input(&self, task: &GetTaskResponseData) -> eyre::Result { + fn get_proving_input(&self, task: &GetTaskResponse) -> eyre::Result { eyre::ensure!( self.proof_types.contains(&task.task_type), "unsupported task type. self: {:?}, task: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}", @@ -464,54 +425,3 @@ where base_delay + Duration::from_millis(random_delay) } } - -#[cfg(test)] -mod tests { - use crate::config::Config; - use crate::prover::{ - proving_service::{ - GetVkRequest, GetVkResponse, ProveRequest, ProveResponse, QueryTaskRequest, - QueryTaskResponse, - }, - ProverBuilder, ProvingService, - }; - use async_trait::async_trait; - use tokio; - - struct MockProver {} - - #[async_trait] - impl ProvingService for MockProver { - fn is_local(&self) -> bool { - true - } - async fn get_vks(&self, _: GetVkRequest) -> GetVkResponse { - GetVkResponse { - ..Default::default() - } - } - async fn prove(&mut self, _: ProveRequest) -> ProveResponse { - ProveResponse { - ..Default::default() - } - } - async fn query_task(&mut self, _: QueryTaskRequest) -> QueryTaskResponse { - QueryTaskResponse { - ..Default::default() - } - } - } - - #[tokio::test] - async fn test_build_get_task_request() { - let cfg = Config::from_file("conf/config.json".to_string()).unwrap(); - let prover_service = MockProver {}; - let prover = ProverBuilder::new(cfg, prover_service) - .build() - .await - .unwrap(); - - let get_task_request = prover.build_get_task_request(None); - assert!(get_task_request.is_ok()) - } -} diff --git a/src/prover/types.rs b/src/prover/types.rs index c23fad1..3e4da7a 100644 --- a/src/prover/types.rs +++ b/src/prover/types.rs @@ -1,91 +1,74 @@ -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use alloy_rlp::{BufMut, Decodable, Encodable}; +use serde_repr::{Deserialize_repr as DeserializeRepr, Serialize_repr as SerializeRepr}; +use strum::{Display, EnumIs, FromRepr}; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +#[derive( + Default, + Debug, + Clone, + Copy, + PartialEq, + Eq, + Hash, + SerializeRepr, + DeserializeRepr, + EnumIs, + Display, + FromRepr, +)] #[repr(u8)] pub enum ProverProviderType { #[default] - Undefined, - Internal, - External, + Undefined = 0, + Internal = 1, + External = 2, } -impl ProverProviderType { - pub fn from_u8(v: u8) -> Self { - match v { - 1 => ProverProviderType::Internal, - 2 => ProverProviderType::External, - _ => ProverProviderType::Undefined, - } +impl Encodable for ProverProviderType { + fn encode(&self, out: &mut dyn BufMut) { + (*self as u8).encode(out); } } -impl Serialize for ProverProviderType { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match *self { - ProverProviderType::Undefined => serializer.serialize_u8(0), - ProverProviderType::Internal => serializer.serialize_u8(1), - ProverProviderType::External => serializer.serialize_u8(2), - } +impl Decodable for ProverProviderType { + fn decode(buf: &mut &[u8]) -> alloy_rlp::Result { + let v = u8::decode(buf)?; + Ok(ProverProviderType::from_repr(v).unwrap_or(ProverProviderType::Undefined)) } } -impl<'de> Deserialize<'de> for ProverProviderType { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let v: u8 = u8::deserialize(deserializer)?; - Ok(ProverProviderType::from_u8(v)) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +#[derive( + Default, + Debug, + Clone, + Copy, + PartialEq, + Eq, + Hash, + SerializeRepr, + DeserializeRepr, + EnumIs, + Display, + FromRepr, +)] +#[repr(u8)] pub enum ProofType { #[default] - Undefined, - Chunk, - Batch, - Bundle, -} - -impl ProofType { - pub fn from_u8(v: u8) -> Self { - match v { - 1 => ProofType::Chunk, - 2 => ProofType::Batch, - 3 => ProofType::Bundle, - _ => ProofType::Undefined, - } - } - - pub fn to_u8(self) -> u8 { - match self { - ProofType::Undefined => 0, - ProofType::Chunk => 1, - ProofType::Batch => 2, - ProofType::Bundle => 3, - } - } + Undefined = 0, + Chunk = 1, + Batch = 2, + Bundle = 3, } -impl Serialize for ProofType { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_u8(self.to_u8()) +impl Encodable for ProofType { + fn encode(&self, out: &mut dyn BufMut) { + (*self as u8).encode(out); } } -impl<'de> Deserialize<'de> for ProofType { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let v: u8 = u8::deserialize(deserializer)?; - Ok(ProofType::from_u8(v)) +impl Decodable for ProofType { + fn decode(buf: &mut &[u8]) -> alloy_rlp::Result { + let v = u8::decode(buf)?; + Ok(ProofType::from_repr(v).unwrap_or(ProofType::Undefined)) } } diff --git a/src/utils.rs b/src/utils.rs index d2cccda..defd998 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,22 +1,13 @@ use tracing_subscriber::filter::{EnvFilter, LevelFilter}; use tracing_subscriber::fmt::format::FmtSpan; -static DEFAULT_COMMIT: &str = "unknown"; -static VERSION: std::sync::OnceLock = std::sync::OnceLock::new(); - -pub const TAG: &str = "v0.0.0"; -pub const DEFAULT_ZK_VERSION: &str = "000000-000000"; - -fn init_version() -> String { - let commit = option_env!("GIT_REV").unwrap_or(DEFAULT_COMMIT); - let tag = option_env!("GO_TAG").unwrap_or(TAG); - let zk_version = option_env!("ZK_VERSION").unwrap_or(DEFAULT_ZK_VERSION); - format!("{tag}-{commit}-{zk_version}") -} - -pub fn get_version() -> String { - VERSION.get_or_init(init_version).clone() -} +pub static VERSION: &str = concat!( + env!("GO_TAG", "semver from `./common/version.go` is required"), + "-", + env!("GIT_REV", "git rev of `scroll` is required"), + "-", + env!("ZK_VERSION", "`zkvm-prover` version and commit is required"), +); pub fn init_tracing() { tracing_subscriber::fmt()