Skip to content

Commit bb7f7d2

Browse files
committed
fix: write proofs to sp1 stdin
1 parent c674b3b commit bb7f7d2

File tree

5 files changed

+76
-46
lines changed

5 files changed

+76
-46
lines changed

aggregation-mode/src/zk/aggregator.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
use crate::zk::backends::sp1::{self, SP1AggregatedProof};
2-
use serde::{Deserialize, Serialize};
1+
use crate::zk::backends::sp1::{self};
2+
3+
use super::backends::sp1::{SP1AggregationInput, SP1Proof};
34

4-
#[derive(Serialize, Deserialize)]
55
pub enum ProgramInput {
6-
SP1(sp1_aggregator::Input),
6+
SP1(SP1AggregationInput),
77
}
88

99
pub enum AggregatedProof {
10-
SP1(SP1AggregatedProof),
10+
SP1(SP1Proof),
1111
}
1212

1313
pub struct ProgramOutput {
@@ -29,6 +29,7 @@ impl ProgramOutput {
2929
pub enum ProofAggregationError {
3030
SP1Verification(sp1_sdk::SP1VerificationError),
3131
SP1Proving,
32+
UnsupportedProof,
3233
}
3334

3435
pub fn aggregate_proofs(input: ProgramInput) -> Result<ProgramOutput, ProofAggregationError> {

aggregation-mode/src/zk/backends/sp1.rs

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
use sp1_sdk::{Prover, ProverClient, SP1ProofWithPublicValues, SP1Stdin, SP1VerifyingKey};
1+
use alloy::primitives::Keccak256;
2+
use sp1_aggregator::{ProofInput, SP1ProofInput};
3+
use sp1_sdk::{
4+
HashableKey, Prover, ProverClient, SP1ProofWithPublicValues, SP1Stdin, SP1VerifyingKey,
5+
};
26

37
use crate::zk::aggregator::{AggregatedProof, ProgramOutput, ProofAggregationError};
48

@@ -7,28 +11,55 @@ const PROGRAM_ELF: &[u8] = include_bytes!("../../../zkvm/sp1/elf/sp1_aggregator_
711
// TODO lock prover
812

913
pub struct SP1Proof {
10-
pub elf: Vec<u8>,
14+
pub vk: SP1VerifyingKey,
1115
pub proof: SP1ProofWithPublicValues,
1216
}
1317

1418
impl SP1Proof {
15-
pub fn verifying_key(&self) -> SP1VerifyingKey {
16-
let client = ProverClient::from_env();
17-
let (_pk, vk) = client.setup(&self.elf);
18-
vk
19+
pub fn hash(&self) -> [u8; 32] {
20+
let mut hasher = Keccak256::new();
21+
for &word in &self.vk.hash_u32() {
22+
hasher.update(word.to_le_bytes());
23+
}
24+
hasher.update(self.proof.public_values.as_slice());
25+
hasher.finalize().into()
1926
}
2027
}
2128

22-
pub struct SP1AggregatedProof {
23-
pub proof: SP1ProofWithPublicValues,
24-
pub vk: SP1VerifyingKey,
29+
pub struct SP1AggregationInput {
30+
proofs: Vec<SP1Proof>,
31+
merkle_root: [u8; 32],
2532
}
2633

2734
pub(crate) fn aggregate_proofs(
28-
input: sp1_aggregator::Input,
35+
input: SP1AggregationInput,
2936
) -> Result<ProgramOutput, ProofAggregationError> {
3037
let mut stdin = SP1Stdin::new();
31-
stdin.write(&input);
38+
39+
let mut program_input = sp1_aggregator::Input {
40+
proofs: vec![],
41+
merkle_root: input.merkle_root,
42+
};
43+
44+
// write vk + public inputs
45+
for proof in input.proofs.iter() {
46+
program_input
47+
.proofs
48+
.push(ProofInput::SP1Compressed(SP1ProofInput {
49+
public_inputs: proof.proof.public_values.to_vec(),
50+
vk: proof.vk.hash_u32(),
51+
}));
52+
}
53+
stdin.write(&program_input);
54+
55+
// write proofs
56+
for SP1Proof { proof, vk } in input.proofs {
57+
// we only support sp1 Compressed proofs for now
58+
let sp1_sdk::SP1Proof::Compressed(proof) = proof.proof else {
59+
return Err(ProofAggregationError::UnsupportedProof);
60+
};
61+
stdin.write_proof(*proof, vk.vk);
62+
}
3263

3364
#[cfg(feature = "prove")]
3465
let client = ProverClient::from_env();
@@ -48,7 +79,7 @@ pub(crate) fn aggregate_proofs(
4879
.verify(&proof, &vk)
4980
.map_err(ProofAggregationError::SP1Verification)?;
5081

51-
let proof = SP1AggregatedProof { proof, vk };
82+
let proof = SP1Proof { proof, vk };
5283

5384
let output = ProgramOutput::new(AggregatedProof::SP1(proof));
5485

@@ -59,10 +90,10 @@ pub enum SP1VerificationError {
5990
Verification(sp1_sdk::SP1VerificationError),
6091
}
6192

62-
pub(crate) fn verify(proof: &SP1Proof) -> Result<(), SP1VerificationError> {
93+
pub(crate) fn verify(proof: &SP1Proof, elf: &[u8]) -> Result<(), SP1VerificationError> {
6394
let client = ProverClient::from_env();
6495

65-
let (_pk, vk) = client.setup(&proof.elf);
96+
let (_pk, vk) = client.setup(elf);
6697
client
6798
.verify(&proof.proof, &vk)
6899
.map_err(SP1VerificationError::Verification)

aggregation-mode/src/zk/mod.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,22 @@ pub enum Proof {
1111
SP1(SP1Proof),
1212
}
1313

14+
impl Proof {
15+
pub fn hash(&self) -> [u8; 32] {
16+
match self {
17+
Proof::SP1(proof) => proof.hash(),
18+
}
19+
}
20+
}
21+
1422
pub enum VerificationError {
1523
SP1(SP1VerificationError),
1624
}
1725

1826
impl Proof {
19-
pub fn verify(&self) -> Result<(), VerificationError> {
27+
pub fn verify(&self, elf: &[u8]) -> Result<(), VerificationError> {
2028
match self {
21-
Proof::SP1(proof) => sp1::verify(proof).map_err(VerificationError::SP1),
29+
Proof::SP1(proof) => sp1::verify(proof, elf).map_err(VerificationError::SP1),
2230
}
2331
}
2432
}

aggregation-mode/zkvm/sp1/src/lib.rs

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,37 @@ use serde::{Deserialize, Serialize};
22
use sha3::{Digest, Keccak256};
33

44
#[derive(Serialize, Deserialize)]
5-
pub struct SP1CompressedProof {
6-
pub vk: Vec<u8>,
5+
pub struct SP1ProofInput {
6+
pub vk: [u32; 8],
77
pub public_inputs: Vec<u8>,
88
}
99

10-
impl SP1CompressedProof {
11-
pub fn vk(&self) -> [u32; 8] {
12-
assert!(self.vk.len() >= 32, "vk must be at least 32 bytes long");
13-
14-
let mut bytes = [0_32; 8];
15-
16-
for (i, chunk) in self.vk.chunks_exact(4).enumerate() {
17-
bytes[i] = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
18-
}
19-
20-
bytes
21-
}
22-
10+
impl SP1ProofInput {
2311
pub fn hash(&self) -> [u8; 32] {
2412
let mut hasher = Keccak256::new();
25-
hasher.update(&self.vk);
13+
for &word in &self.vk {
14+
hasher.update(word.to_le_bytes());
15+
}
2616
hasher.update(&self.public_inputs);
2717
hasher.finalize().into()
2818
}
2919
}
3020

3121
#[derive(Serialize, Deserialize)]
32-
pub enum Proof {
33-
SP1Compressed(SP1CompressedProof),
22+
pub enum ProofInput {
23+
SP1Compressed(SP1ProofInput),
3424
}
3525

36-
impl Proof {
26+
impl ProofInput {
3727
pub fn hash(&self) -> [u8; 32] {
3828
match self {
39-
Proof::SP1Compressed(proof) => proof.hash(),
29+
ProofInput::SP1Compressed(proof) => proof.hash(),
4030
}
4131
}
4232
}
4333

4434
#[derive(Serialize, Deserialize)]
4535
pub struct Input {
46-
pub proofs: Vec<Proof>,
36+
pub proofs: Vec<ProofInput>,
4737
pub merkle_root: [u8; 32],
4838
}

aggregation-mode/zkvm/sp1/src/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ sp1_zkvm::entrypoint!(main);
33

44
use sha2::{Digest, Sha256};
55
use sha3::Keccak256;
6-
use sp1_aggregator::{Input, Proof};
6+
use sp1_aggregator::{Input, ProofInput};
77

88
fn combine_hashes(hash_a: &[u8; 32], hash_b: &[u8; 32]) -> [u8; 32] {
99
let mut hasher = Keccak256::new();
@@ -13,7 +13,7 @@ fn combine_hashes(hash_a: &[u8; 32], hash_b: &[u8; 32]) -> [u8; 32] {
1313
}
1414

1515
/// Computes the merkle root for the given proofs using the vk
16-
fn compute_merkle_root(proofs: &[Proof]) -> [u8; 32] {
16+
fn compute_merkle_root(proofs: &[ProofInput]) -> [u8; 32] {
1717
let mut leaves: Vec<[u8; 32]> = proofs
1818
.chunks(2)
1919
.map(|chunk| match chunk {
@@ -44,8 +44,8 @@ pub fn main() {
4444
// Verify the proofs.
4545
for proof in input.proofs.iter() {
4646
match proof {
47-
Proof::SP1Compressed(proof) => {
48-
let vkey = proof.vk();
47+
ProofInput::SP1Compressed(proof) => {
48+
let vkey = proof.vk;
4949
let public_values = &proof.public_inputs;
5050
let public_values_digest = Sha256::digest(public_values);
5151
sp1_zkvm::lib::verify::verify_sp1_proof(&vkey, &public_values_digest.into());

0 commit comments

Comments
 (0)