Skip to content

Commit 4059b2c

Browse files
committed
feat: aggregate proofs in chunks of 512 proofs
1 parent 02987a4 commit 4059b2c

File tree

3 files changed

+87
-15
lines changed

3 files changed

+87
-15
lines changed

aggregation_mode/src/aggregators/mod.rs

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ use std::fmt::Display;
55

66
use risc0_aggregator::{
77
AlignedRisc0VerificationError, Risc0AggregationError, Risc0ProofReceiptAndImageId,
8+
Risc0ProofType,
89
};
910
use sp1_aggregator::{
10-
AlignedSP1VerificationError, SP1AggregationError, SP1ProofWithPubValuesAndElf,
11+
AlignedSP1VerificationError, SP1AggregationError, SP1ProofType, SP1ProofWithPubValuesAndElf,
1112
};
1213

1314
#[derive(Clone, Debug)]
@@ -59,16 +60,38 @@ impl ZKVMEngine {
5960
) -> Result<(AlignedProof, [u8; 32]), ProofAggregationError> {
6061
let res = match self {
6162
ZKVMEngine::SP1 => {
62-
let proofs = proofs
63+
let proofs: Vec<SP1ProofWithPubValuesAndElf> = proofs
6364
.into_iter()
6465
.filter_map(|proof| match proof {
6566
AlignedProof::SP1(proof) => Some(*proof),
6667
_ => None,
6768
})
6869
.collect();
6970

70-
let mut agg_proof = sp1_aggregator::aggregate_proofs(proofs)
71-
.map_err(ProofAggregationError::SP1Aggregation)?;
71+
// we run the aggregator in chunks of 512 proofs
72+
let chunks = proofs.chunks(512);
73+
let mut agg_proofs: Vec<SP1ProofWithPubValuesAndElf> = vec![];
74+
75+
let agg_chunks_type = if chunks.len() == 1 {
76+
SP1ProofType::Groth16
77+
} else {
78+
SP1ProofType::Compressed
79+
};
80+
81+
for chunk in chunks {
82+
let agg_proof =
83+
sp1_aggregator::aggregate_proofs(chunk, agg_chunks_type.clone())
84+
.map_err(ProofAggregationError::SP1Aggregation)?;
85+
86+
agg_proofs.push(agg_proof);
87+
}
88+
89+
let mut agg_proof = if agg_proofs.len() > 1 {
90+
sp1_aggregator::aggregate_proofs(&agg_proofs, SP1ProofType::Groth16)
91+
.map_err(ProofAggregationError::SP1Aggregation)?
92+
} else {
93+
agg_proofs.pop().unwrap()
94+
};
7295

7396
let merkle_root: [u8; 32] = agg_proof
7497
.proof_with_pub_values
@@ -78,16 +101,37 @@ impl ZKVMEngine {
78101
(AlignedProof::SP1(agg_proof.into()), merkle_root)
79102
}
80103
ZKVMEngine::RISC0 => {
81-
let proofs = proofs
104+
let proofs: Vec<Risc0ProofReceiptAndImageId> = proofs
82105
.into_iter()
83106
.filter_map(|proof| match proof {
84107
AlignedProof::Risc0(proof) => Some(*proof),
85108
_ => None,
86109
})
87110
.collect();
88111

89-
let agg_proof = risc0_aggregator::aggregate_proofs(proofs)
90-
.map_err(ProofAggregationError::Risc0Aggregation)?;
112+
let chunks = proofs.chunks(512);
113+
let mut agg_proofs: Vec<Risc0ProofReceiptAndImageId> = vec![];
114+
115+
let agg_chunks_type = if chunks.len() == 1 {
116+
Risc0ProofType::Groth16
117+
} else {
118+
Risc0ProofType::Composite
119+
};
120+
121+
for chunk in chunks {
122+
let agg_proof =
123+
risc0_aggregator::aggregate_proofs(chunk, agg_chunks_type.clone())
124+
.map_err(ProofAggregationError::Risc0Aggregation)?;
125+
126+
agg_proofs.push(agg_proof);
127+
}
128+
129+
let agg_proof = if agg_proofs.len() > 1 {
130+
risc0_aggregator::aggregate_proofs(&agg_proofs, Risc0ProofType::Groth16)
131+
.map_err(ProofAggregationError::Risc0Aggregation)?
132+
} else {
133+
agg_proofs.pop().unwrap()
134+
};
91135

92136
// Note: journal.decode() won't work here as risc0 deserializer works under u32 words
93137
let public_input_bytes = agg_proof.receipt.journal.as_ref();

aggregation_mode/src/aggregators/risc0_aggregator.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,16 @@ pub enum Risc0AggregationError {
4646
Verification(String),
4747
}
4848

49+
#[derive(Debug, Clone)]
50+
pub enum Risc0ProofType {
51+
Groth16,
52+
Composite,
53+
Succinct,
54+
}
55+
4956
pub(crate) fn aggregate_proofs(
50-
proofs: Vec<Risc0ProofReceiptAndImageId>,
57+
proofs: &[Risc0ProofReceiptAndImageId],
58+
to_proof_type: Risc0ProofType,
5159
) -> Result<Risc0ProofReceiptAndImageId, Risc0AggregationError> {
5260
let mut env_builder = ExecutorEnv::builder();
5361

@@ -58,7 +66,7 @@ pub(crate) fn aggregate_proofs(
5866
image_id: proof.image_id,
5967
public_inputs: proof.receipt.journal.bytes.clone(),
6068
});
61-
env_builder.add_assumption(proof.receipt);
69+
env_builder.add_assumption(proof.receipt.clone());
6270
}
6371

6472
// write input data
@@ -75,8 +83,14 @@ pub(crate) fn aggregate_proofs(
7583

7684
let prover = default_prover();
7785

86+
let opts = match to_proof_type {
87+
Risc0ProofType::Groth16 => ProverOpts::groth16(),
88+
Risc0ProofType::Composite => ProverOpts::composite(),
89+
Risc0ProofType::Succinct => ProverOpts::succinct(),
90+
};
91+
7892
let receipt = prover
79-
.prove_with_opts(env, RISC0_AGGREGATOR_PROGRAM_ELF, &ProverOpts::groth16())
93+
.prove_with_opts(env, RISC0_AGGREGATOR_PROGRAM_ELF, &opts)
8094
.map_err(|e| Risc0AggregationError::Prove(e.to_string()))?
8195
.receipt;
8296

aggregation_mode/src/aggregators/sp1_aggregator.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,16 @@ pub enum SP1AggregationError {
3838
UnsupportedProof,
3939
}
4040

41+
#[derive(Debug, Clone)]
42+
pub enum SP1ProofType {
43+
Groth16,
44+
Compressed,
45+
Core,
46+
}
47+
4148
pub(crate) fn aggregate_proofs(
42-
proofs: Vec<SP1ProofWithPubValuesAndElf>,
49+
proofs: &[SP1ProofWithPubValuesAndElf],
50+
to_proof_type: SP1ProofType,
4351
) -> Result<SP1ProofWithPubValuesAndElf, SP1AggregationError> {
4452
let mut stdin = SP1Stdin::new();
4553

@@ -62,7 +70,8 @@ pub(crate) fn aggregate_proofs(
6270
for input_proof in proofs {
6371
let vk = input_proof.vk().vk;
6472
// we only support sp1 Compressed proofs for now
65-
let sp1_sdk::SP1Proof::Compressed(proof) = input_proof.proof_with_pub_values.proof else {
73+
let sp1_sdk::SP1Proof::Compressed(proof) = input_proof.proof_with_pub_values.proof.clone()
74+
else {
6675
return Err(SP1AggregationError::UnsupportedProof);
6776
};
6877
stdin.write_proof(*proof, vk);
@@ -75,9 +84,14 @@ pub(crate) fn aggregate_proofs(
7584
let client = ProverClient::builder().mock().build();
7685

7786
let (pk, vk) = client.setup(PROGRAM_ELF);
78-
let proof = client
79-
.prove(&pk, &stdin)
80-
.groth16()
87+
let proof_builder = client.prove(&pk, &stdin);
88+
let proof_builder = match to_proof_type {
89+
SP1ProofType::Groth16 => proof_builder.groth16(),
90+
SP1ProofType::Compressed => proof_builder.compressed(),
91+
SP1ProofType::Core => proof_builder.core(),
92+
};
93+
94+
let proof = proof_builder
8195
.run()
8296
.map_err(|e| SP1AggregationError::Prove(e.to_string()))?;
8397

0 commit comments

Comments
 (0)