Skip to content

Commit 5d78e56

Browse files
refactor(agg-mode): reduce types, abstractions and double merkle root in PI (#1889)
Co-authored-by: MauroFab <[email protected]>
1 parent 3ef2976 commit 5d78e56

File tree

11 files changed

+138
-150
lines changed

11 files changed

+138
-150
lines changed

aggregation_mode/aggregation_programs/risc0/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,4 @@ impl Risc0ImageIdAndPubInputs {
2424
#[derive(Serialize, Deserialize)]
2525
pub struct Input {
2626
pub proofs_image_id_and_pub_inputs: Vec<Risc0ImageIdAndPubInputs>,
27-
pub merkle_root: [u8; 32],
2827
}

aggregation_mode/aggregation_programs/risc0/src/main.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,5 @@ fn main() {
5454

5555
let merkle_root = compute_merkle_root(&input.proofs_image_id_and_pub_inputs);
5656

57-
assert_eq!(merkle_root, input.merkle_root);
58-
5957
env::commit_slice(&merkle_root);
6058
}

aggregation_mode/aggregation_programs/sp1/src/lib.rs

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,7 @@ impl SP1VkAndPubInputs {
1818
}
1919
}
2020

21-
#[derive(Serialize, Deserialize)]
22-
pub enum ProofVkAndPubInputs {
23-
SP1Compressed(SP1VkAndPubInputs),
24-
}
25-
26-
impl ProofVkAndPubInputs {
27-
pub fn hash(&self) -> [u8; 32] {
28-
match self {
29-
ProofVkAndPubInputs::SP1Compressed(proof_data) => proof_data.hash(),
30-
}
31-
}
32-
}
33-
3421
#[derive(Serialize, Deserialize)]
3522
pub struct Input {
36-
pub proofs_vk_and_pub_inputs: Vec<ProofVkAndPubInputs>,
37-
pub merkle_root: [u8; 32],
23+
pub proofs_vk_and_pub_inputs: Vec<SP1VkAndPubInputs>,
3824
}

aggregation_mode/aggregation_programs/sp1/src/main.rs

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ sp1_zkvm::entrypoint!(main);
33

44
use sha2::{Digest, Sha256};
55
use sha3::Keccak256;
6-
use sp1_aggregation_program::{Input, ProofVkAndPubInputs};
6+
use sp1_aggregation_program::{Input, SP1VkAndPubInputs};
77

88
fn combine_hashes(hash_a: &[u8; 32], hash_b: &[u8; 32]) -> [u8; 32] {
99
let mut hasher = Keccak256::new();
@@ -13,7 +13,7 @@ fn combine_hashes(hash_a: &[u8; 32], hash_b: &[u8; 32]) -> [u8; 32] {
1313
}
1414

1515
/// Computes the merkle root for the given proofs using the vk
16-
fn compute_merkle_root(proofs: &[ProofVkAndPubInputs]) -> [u8; 32] {
16+
fn compute_merkle_root(proofs: &[SP1VkAndPubInputs]) -> [u8; 32] {
1717
let mut leaves: Vec<[u8; 32]> = proofs
1818
.chunks(2)
1919
.map(|chunk| match chunk {
@@ -42,19 +42,13 @@ pub fn main() {
4242

4343
// Verify the proofs.
4444
for proof in input.proofs_vk_and_pub_inputs.iter() {
45-
match proof {
46-
ProofVkAndPubInputs::SP1Compressed(proof) => {
47-
let vkey = proof.vk;
48-
let public_values = &proof.public_inputs;
49-
let public_values_digest = Sha256::digest(public_values);
50-
sp1_zkvm::lib::verify::verify_sp1_proof(&vkey, &public_values_digest.into());
51-
}
52-
}
45+
let vkey = proof.vk;
46+
let public_values = &proof.public_inputs;
47+
let public_values_digest = Sha256::digest(public_values);
48+
sp1_zkvm::lib::verify::verify_sp1_proof(&vkey, &public_values_digest.into());
5349
}
5450

5551
let merkle_root = compute_merkle_root(&input.proofs_vk_and_pub_inputs);
5652

57-
assert_eq!(merkle_root, input.merkle_root);
58-
5953
sp1_zkvm::io::commit_slice(&merkle_root);
6054
}

aggregation_mode/src/aggregators/lib.rs

Lines changed: 0 additions & 32 deletions
This file was deleted.

aggregation_mode/src/aggregators/mod.rs

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
pub mod lib;
21
pub mod risc0_aggregator;
32
pub mod sp1_aggregator;
43

54
use std::fmt::Display;
65

7-
use risc0_aggregator::{AlignedRisc0VerificationError, Risc0ProofReceiptAndImageId};
8-
use sp1_aggregator::{AlignedSP1VerificationError, SP1ProofWithPubValuesAndElf};
6+
use risc0_aggregator::{
7+
AlignedRisc0VerificationError, Risc0AggregationError, Risc0ProofReceiptAndImageId,
8+
};
9+
use sp1_aggregator::{
10+
AlignedSP1VerificationError, SP1AggregationError, SP1ProofWithPubValuesAndElf,
11+
};
912

1013
#[derive(Clone, Debug)]
1114
pub enum ZKVMEngine {
@@ -22,6 +25,13 @@ impl Display for ZKVMEngine {
2225
}
2326
}
2427

28+
#[derive(Debug)]
29+
pub enum ProofAggregationError {
30+
SP1Aggregation(SP1AggregationError),
31+
Risc0Aggregation(Risc0AggregationError),
32+
PublicInputsDeserialization,
33+
}
34+
2535
impl ZKVMEngine {
2636
pub fn from_env() -> Option<Self> {
2737
let key = "AGGREGATOR";
@@ -34,6 +44,69 @@ impl ZKVMEngine {
3444

3545
Some(engine)
3646
}
47+
48+
/// Aggregates a list of [`AlignedProof`]s into a single [`AlignedProof`].
49+
///
50+
/// Returns a tuple containing:
51+
/// - The aggregated [`AlignedProof`], representing the combined proof
52+
/// - The Merkle root computed within the ZKVM, exposed as a public input
53+
///
54+
/// This function performs proof aggregation and ensures the resulting Merkle root
55+
/// can be independently verified by external systems.
56+
pub fn aggregate_proofs(
57+
&self,
58+
proofs: Vec<AlignedProof>,
59+
) -> Result<(AlignedProof, [u8; 32]), ProofAggregationError> {
60+
let res = match self {
61+
ZKVMEngine::SP1 => {
62+
let proofs = proofs
63+
.into_iter()
64+
// Fetcher already filtered for SP1
65+
// We do this for type casting, as to avoid using generics
66+
// or macros in this function
67+
.filter_map(|proof| match proof {
68+
AlignedProof::SP1(proof) => Some(*proof),
69+
_ => None,
70+
})
71+
.collect();
72+
73+
let mut agg_proof = sp1_aggregator::aggregate_proofs(proofs)
74+
.map_err(ProofAggregationError::SP1Aggregation)?;
75+
76+
let merkle_root: [u8; 32] = agg_proof
77+
.proof_with_pub_values
78+
.public_values
79+
.read::<[u8; 32]>();
80+
81+
(AlignedProof::SP1(agg_proof.into()), merkle_root)
82+
}
83+
ZKVMEngine::RISC0 => {
84+
let proofs = proofs
85+
.into_iter()
86+
// Fetcher already filtered for Risc0
87+
// We do this for type casting, as to avoid using generics
88+
// or macros in this function
89+
.filter_map(|proof| match proof {
90+
AlignedProof::Risc0(proof) => Some(*proof),
91+
_ => None,
92+
})
93+
.collect();
94+
95+
let agg_proof = risc0_aggregator::aggregate_proofs(proofs)
96+
.map_err(ProofAggregationError::Risc0Aggregation)?;
97+
98+
// Note: journal.decode() won't work here as risc0 deserializer works under u32 words
99+
let public_input_bytes = agg_proof.receipt.journal.as_ref();
100+
let merkle_root: [u8; 32] = public_input_bytes
101+
.try_into()
102+
.map_err(|_| ProofAggregationError::PublicInputsDeserialization)?;
103+
104+
(AlignedProof::Risc0(agg_proof.into()), merkle_root)
105+
}
106+
};
107+
108+
Ok(res)
109+
}
37110
}
38111

39112
pub enum AlignedProof {
@@ -42,7 +115,7 @@ pub enum AlignedProof {
42115
}
43116

44117
impl AlignedProof {
45-
pub fn hash(&self) -> [u8; 32] {
118+
pub fn commitment(&self) -> [u8; 32] {
46119
match self {
47120
AlignedProof::SP1(proof) => proof.hash_vk_and_pub_inputs(),
48121
AlignedProof::Risc0(proof) => proof.hash_image_id_and_public_inputs(),

aggregation_mode/src/aggregators/risc0_aggregator.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ include!(concat!(env!("OUT_DIR"), "/methods.rs"));
33
use risc0_zkvm::{default_prover, ExecutorEnv, ProverOpts, Receipt};
44
use sha3::{Digest, Keccak256};
55

6-
use super::lib::{AggregatedProof, ProgramOutput, ProofAggregationError};
7-
86
/// Byte representation of the aggregator image_id, converted from `[u32; 8]` to `[u8; 32]`.
97
const RISC0_AGGREGATOR_PROGRAM_ID_BYTES: [u8; 32] = {
108
let mut res = [0u8; 32];
@@ -40,19 +38,22 @@ impl Risc0ProofReceiptAndImageId {
4038
}
4139
}
4240

43-
pub struct Risc0AggregationInput {
44-
pub proofs: Vec<Risc0ProofReceiptAndImageId>,
45-
pub merkle_root: [u8; 32],
41+
#[derive(Debug)]
42+
pub enum Risc0AggregationError {
43+
WriteInput(String),
44+
BuildExecutor(String),
45+
Prove(String),
46+
Verification(String),
4647
}
4748

4849
pub(crate) fn aggregate_proofs(
49-
input: Risc0AggregationInput,
50-
) -> Result<ProgramOutput, ProofAggregationError> {
50+
proofs: Vec<Risc0ProofReceiptAndImageId>,
51+
) -> Result<Risc0ProofReceiptAndImageId, Risc0AggregationError> {
5152
let mut env_builder = ExecutorEnv::builder();
5253

5354
// write assumptions and proof image id + pub inputs
5455
let mut proofs_image_id_and_pub_inputs = vec![];
55-
for proof in input.proofs {
56+
for proof in proofs {
5657
proofs_image_id_and_pub_inputs.push(risc0_aggregation_program::Risc0ImageIdAndPubInputs {
5758
image_id: proof.image_id,
5859
public_inputs: proof.receipt.journal.bytes.clone(),
@@ -62,34 +63,33 @@ pub(crate) fn aggregate_proofs(
6263

6364
// write input data
6465
let input = risc0_aggregation_program::Input {
65-
merkle_root: input.merkle_root,
6666
proofs_image_id_and_pub_inputs,
6767
};
6868
env_builder
6969
.write(&input)
70-
.map_err(|e| ProofAggregationError::Risc0Proving(e.to_string()))?;
70+
.map_err(|e| Risc0AggregationError::WriteInput(e.to_string()))?;
7171

7272
let env = env_builder
7373
.build()
74-
.map_err(|e| ProofAggregationError::Risc0Proving(e.to_string()))?;
74+
.map_err(|e| Risc0AggregationError::BuildExecutor(e.to_string()))?;
7575

7676
let prover = default_prover();
7777

7878
let receipt = prover
7979
.prove_with_opts(env, RISC0_AGGREGATOR_PROGRAM_ELF, &ProverOpts::groth16())
80-
.map_err(|e| ProofAggregationError::Risc0Proving(e.to_string()))?
80+
.map_err(|e| Risc0AggregationError::Prove(e.to_string()))?
8181
.receipt;
8282

8383
receipt
8484
.verify(RISC0_AGGREGATOR_PROGRAM_ID)
85-
.map_err(|e| ProofAggregationError::Risc0Proving(e.to_string()))?;
85+
.map_err(|e| Risc0AggregationError::Verification(e.to_string()))?;
8686

87-
let output = Risc0ProofReceiptAndImageId {
87+
let proof = Risc0ProofReceiptAndImageId {
8888
image_id: RISC0_AGGREGATOR_PROGRAM_ID_BYTES,
8989
receipt,
9090
};
9191

92-
Ok(ProgramOutput::new(AggregatedProof::Risc0(output.into())))
92+
Ok(proof)
9393
}
9494

9595
#[derive(Debug)]

aggregation_mode/src/aggregators/sp1_aggregator.rs

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
use std::sync::LazyLock;
22

33
use alloy::primitives::Keccak256;
4-
use sp1_aggregation_program::{ProofVkAndPubInputs, SP1VkAndPubInputs};
4+
use sp1_aggregation_program::SP1VkAndPubInputs;
55
use sp1_sdk::{
66
EnvProver, HashableKey, Prover, ProverClient, SP1ProofWithPublicValues, SP1Stdin,
77
SP1VerifyingKey,
88
};
99

10-
use super::lib::{AggregatedProof, ProgramOutput, ProofAggregationError};
11-
1210
const PROGRAM_ELF: &[u8] =
1311
include_bytes!("../../aggregation_programs/sp1/elf/sp1_aggregator_program");
1412

@@ -33,38 +31,39 @@ impl SP1ProofWithPubValuesAndElf {
3331
}
3432
}
3533

36-
pub struct SP1AggregationInput {
37-
pub proofs: Vec<SP1ProofWithPubValuesAndElf>,
38-
pub merkle_root: [u8; 32],
34+
#[derive(Debug)]
35+
pub enum SP1AggregationError {
36+
Verification(sp1_sdk::SP1VerificationError),
37+
Prove(String),
38+
UnsupportedProof,
3939
}
4040

4141
pub(crate) fn aggregate_proofs(
42-
input: SP1AggregationInput,
43-
) -> Result<ProgramOutput, ProofAggregationError> {
42+
proofs: Vec<SP1ProofWithPubValuesAndElf>,
43+
) -> Result<SP1ProofWithPubValuesAndElf, SP1AggregationError> {
4444
let mut stdin = SP1Stdin::new();
4545

4646
let mut program_input = sp1_aggregation_program::Input {
4747
proofs_vk_and_pub_inputs: vec![],
48-
merkle_root: input.merkle_root,
4948
};
5049

5150
// write vk + public inputs
52-
for proof in input.proofs.iter() {
51+
for proof in proofs.iter() {
5352
program_input
5453
.proofs_vk_and_pub_inputs
55-
.push(ProofVkAndPubInputs::SP1Compressed(SP1VkAndPubInputs {
54+
.push(SP1VkAndPubInputs {
5655
public_inputs: proof.proof_with_pub_values.public_values.to_vec(),
5756
vk: proof.vk().hash_u32(),
58-
}));
57+
});
5958
}
6059
stdin.write(&program_input);
6160

6261
// write proofs
63-
for input_proof in input.proofs {
62+
for input_proof in proofs {
6463
let vk = input_proof.vk().vk;
6564
// we only support sp1 Compressed proofs for now
6665
let sp1_sdk::SP1Proof::Compressed(proof) = input_proof.proof_with_pub_values.proof else {
67-
return Err(ProofAggregationError::UnsupportedProof);
66+
return Err(SP1AggregationError::UnsupportedProof);
6867
};
6968
stdin.write_proof(*proof, vk);
7069
}
@@ -80,21 +79,19 @@ pub(crate) fn aggregate_proofs(
8079
.prove(&pk, &stdin)
8180
.groth16()
8281
.run()
83-
.map_err(|_| ProofAggregationError::SP1Proving)?;
82+
.map_err(|e| SP1AggregationError::Prove(e.to_string()))?;
8483

8584
// a sanity check, vm already performs it
8685
client
8786
.verify(&proof, &vk)
88-
.map_err(ProofAggregationError::SP1Verification)?;
87+
.map_err(SP1AggregationError::Verification)?;
8988

9089
let proof_and_elf = SP1ProofWithPubValuesAndElf {
9190
proof_with_pub_values: proof,
9291
elf: PROGRAM_ELF.to_vec(),
9392
};
9493

95-
let output = ProgramOutput::new(AggregatedProof::SP1(proof_and_elf.into()));
96-
97-
Ok(output)
94+
Ok(proof_and_elf)
9895
}
9996

10097
#[derive(Debug)]

0 commit comments

Comments
 (0)