Skip to content

Commit 9e70085

Browse files
refactor: agg mode (#1871)
1 parent 104109d commit 9e70085

File tree

5 files changed

+42
-35
lines changed

5 files changed

+42
-35
lines changed

aggregation_mode/aggregation_programs/sp1/src/lib.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,37 @@ use serde::{Deserialize, Serialize};
22
use sha3::{Digest, Keccak256};
33

44
#[derive(Serialize, Deserialize)]
5-
pub struct SP1ProofInput {
5+
pub struct SP1VkAndPubInputs {
66
pub vk: [u32; 8],
77
pub public_inputs: Vec<u8>,
88
}
99

10-
impl SP1ProofInput {
10+
impl SP1VkAndPubInputs {
1111
pub fn hash(&self) -> [u8; 32] {
1212
let mut hasher = Keccak256::new();
1313
for &word in &self.vk {
14-
hasher.update(word.to_le_bytes());
14+
hasher.update(word.to_be_bytes());
1515
}
1616
hasher.update(&self.public_inputs);
1717
hasher.finalize().into()
1818
}
1919
}
2020

2121
#[derive(Serialize, Deserialize)]
22-
pub enum ProofInput {
23-
SP1Compressed(SP1ProofInput),
22+
pub enum ProofVkAndPubInputs {
23+
SP1Compressed(SP1VkAndPubInputs),
2424
}
2525

26-
impl ProofInput {
26+
impl ProofVkAndPubInputs {
2727
pub fn hash(&self) -> [u8; 32] {
2828
match self {
29-
ProofInput::SP1Compressed(proof) => proof.hash(),
29+
ProofVkAndPubInputs::SP1Compressed(proof_data) => proof_data.hash(),
3030
}
3131
}
3232
}
3333

3434
#[derive(Serialize, Deserialize)]
3535
pub struct Input {
36-
pub proofs: Vec<ProofInput>,
36+
pub proofs_vk_and_pub_inputs: Vec<ProofVkAndPubInputs>,
3737
pub merkle_root: [u8; 32],
3838
}

aggregation_mode/aggregation_programs/sp1/src/main.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ sp1_zkvm::entrypoint!(main);
33

44
use sha2::{Digest, Sha256};
55
use sha3::Keccak256;
6-
use sp1_aggregation_program::{Input, ProofInput};
6+
use sp1_aggregation_program::{Input, ProofVkAndPubInputs};
77

88
fn combine_hashes(hash_a: &[u8; 32], hash_b: &[u8; 32]) -> [u8; 32] {
99
let mut hasher = Keccak256::new();
@@ -13,7 +13,7 @@ fn combine_hashes(hash_a: &[u8; 32], hash_b: &[u8; 32]) -> [u8; 32] {
1313
}
1414

1515
/// Computes the merkle root for the given proofs using the vk
16-
fn compute_merkle_root(proofs: &[ProofInput]) -> [u8; 32] {
16+
fn compute_merkle_root(proofs: &[ProofVkAndPubInputs]) -> [u8; 32] {
1717
let mut leaves: Vec<[u8; 32]> = proofs
1818
.chunks(2)
1919
.map(|chunk| match chunk {
@@ -41,9 +41,9 @@ pub fn main() {
4141
let input = sp1_zkvm::io::read::<Input>();
4242

4343
// Verify the proofs.
44-
for proof in input.proofs.iter() {
44+
for proof in input.proofs_vk_and_pub_inputs.iter() {
4545
match proof {
46-
ProofInput::SP1Compressed(proof) => {
46+
ProofVkAndPubInputs::SP1Compressed(proof) => {
4747
let vkey = proof.vk;
4848
let public_values = &proof.public_inputs;
4949
let public_values_digest = Sha256::digest(public_values);
@@ -52,7 +52,7 @@ pub fn main() {
5252
}
5353
}
5454

55-
let merkle_root = compute_merkle_root(&input.proofs);
55+
let merkle_root = compute_merkle_root(&input.proofs_vk_and_pub_inputs);
5656

5757
assert_eq!(merkle_root, input.merkle_root);
5858

aggregation_mode/src/aggregators/sp1_aggregator.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::sync::LazyLock;
22

33
use alloy::primitives::Keccak256;
4-
use sp1_aggregation_program::{ProofInput, SP1ProofInput};
4+
use sp1_aggregation_program::{ProofVkAndPubInputs, SP1VkAndPubInputs};
55
use sp1_sdk::{
66
EnvProver, HashableKey, Prover, ProverClient, SP1ProofWithPublicValues, SP1Stdin,
77
SP1VerifyingKey,
@@ -22,9 +22,8 @@ pub struct SP1ProofWithPubValuesAndElf {
2222
impl SP1ProofWithPubValuesAndElf {
2323
pub fn hash_vk_and_pub_inputs(&self) -> [u8; 32] {
2424
let mut hasher = Keccak256::new();
25-
for &word in &self.vk().hash_u32() {
26-
hasher.update(word.to_le_bytes());
27-
}
25+
let vk_bytes = &self.vk().hash_bytes();
26+
hasher.update(vk_bytes);
2827
hasher.update(self.proof_with_pub_values.public_values.as_slice());
2928
hasher.finalize().into()
3029
}
@@ -45,15 +44,15 @@ pub(crate) fn aggregate_proofs(
4544
let mut stdin = SP1Stdin::new();
4645

4746
let mut program_input = sp1_aggregation_program::Input {
48-
proofs: vec![],
47+
proofs_vk_and_pub_inputs: vec![],
4948
merkle_root: input.merkle_root,
5049
};
5150

5251
// write vk + public inputs
5352
for proof in input.proofs.iter() {
5453
program_input
55-
.proofs
56-
.push(ProofInput::SP1Compressed(SP1ProofInput {
54+
.proofs_vk_and_pub_inputs
55+
.push(ProofVkAndPubInputs::SP1Compressed(SP1VkAndPubInputs {
5756
public_inputs: proof.proof_with_pub_values.public_values.to_vec(),
5857
vk: proof.vk().hash_u32(),
5958
}));

aggregation_mode/src/backend/fetcher.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ use tracing::{error, info};
1717

1818
#[derive(Debug)]
1919
pub enum ProofsFetcherError {
20-
QueryingLogs,
21-
BlockNumber,
20+
GetLogs(String),
21+
GetBlockNumber(String),
2222
}
2323

2424
pub struct ProofsFetcher {
@@ -59,7 +59,7 @@ impl ProofsFetcher {
5959
.from_block(from_block)
6060
.query()
6161
.await
62-
.map_err(|_| ProofsFetcherError::QueryingLogs)?;
62+
.map_err(|e| ProofsFetcherError::GetLogs(e.to_string()))?;
6363

6464
info!("Logs collected {}", logs.len());
6565

@@ -124,7 +124,7 @@ impl ProofsFetcher {
124124
.rpc_provider
125125
.get_block_number()
126126
.await
127-
.map_err(|_| ProofsFetcherError::BlockNumber)?;
127+
.map_err(|e| ProofsFetcherError::GetBlockNumber(e.to_string()))?;
128128

129129
let number_of_blocks_in_the_past = self.fetch_from_secs_ago / self.block_time_secs;
130130

aggregation_mode/src/backend/s3.rs

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use aligned_sdk::core::types::VerificationData;
22

33
#[derive(Debug)]
4+
#[allow(dead_code)]
45
pub enum GetBatchProofsError {
5-
Fetching,
6-
Deserialization,
7-
EmptyBody,
8-
StatusFailed,
9-
ReqwestClientFailed,
6+
FetchingS3Batch(String),
7+
Deserialization(String),
8+
EmptyBody(String),
9+
StatusFailed((u16, String)),
10+
ReqwestClientFailed(String),
1011
}
1112

1213
// needed to make S3 bucket work
@@ -18,25 +19,32 @@ pub async fn get_aligned_batch_from_s3(
1819
let client = reqwest::Client::builder()
1920
.user_agent(DEFAULT_USER_AGENT)
2021
.build()
21-
.map_err(|_| GetBatchProofsError::ReqwestClientFailed)?;
22+
.map_err(|e| GetBatchProofsError::ReqwestClientFailed(e.to_string()))?;
2223

2324
let response = client
2425
.get(url)
2526
.send()
2627
.await
27-
.map_err(|_| GetBatchProofsError::Fetching)?;
28+
.map_err(|e| GetBatchProofsError::FetchingS3Batch(e.to_string()))?;
2829
if !response.status().is_success() {
29-
return Err(GetBatchProofsError::StatusFailed);
30+
return Err(GetBatchProofsError::StatusFailed((
31+
response.status().as_u16(),
32+
response
33+
.status()
34+
.canonical_reason()
35+
.unwrap_or("")
36+
.to_string(),
37+
)));
3038
}
3139

3240
let bytes = response
3341
.bytes()
3442
.await
35-
.map_err(|_| GetBatchProofsError::EmptyBody)?;
43+
.map_err(|e| GetBatchProofsError::EmptyBody(e.to_string()))?;
3644
let bytes: &[u8] = bytes.iter().as_slice();
3745

38-
let data: Vec<VerificationData> =
39-
ciborium::from_reader(bytes).map_err(|_| GetBatchProofsError::Deserialization)?;
46+
let data: Vec<VerificationData> = ciborium::from_reader(bytes)
47+
.map_err(|e| GetBatchProofsError::Deserialization(e.to_string()))?;
4048

4149
Ok(data)
4250
}

0 commit comments

Comments
 (0)