Skip to content

Commit b8e458e

Browse files
Parallelize SHPLONK multi-open prover (privacy-ethereum#114)
* feat: parallelize (cpu) shplonk prover * shplonk: improve `construct_intermediate_sets` using `BTreeSet` and `BTreeMap` more aggressively * shplonk: add `Send` and `Sync` to `Query` trait for more parallelization * fix: ensure the order of the collection of rotation sets is independent of the values of the opening points Co-authored-by: Jonathan Wang <[email protected]>
1 parent 0af4611 commit b8e458e

File tree

4 files changed

+69
-72
lines changed

4 files changed

+69
-72
lines changed

halo2_proofs/src/poly/commitment.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ pub trait ParamsProver<'params, C: CurveAffine>: Params<'params, C> {
9999
pub trait ParamsVerifier<'params, C: CurveAffine>: Params<'params, C> {}
100100

101101
/// Multi scalar multiplication engine
102-
pub trait MSM<C: CurveAffine>: Clone + Debug {
102+
pub trait MSM<C: CurveAffine>: Clone + Debug + Send + Sync {
103103
/// Add arbitrary term (the scalar and the point)
104104
fn append_term(&mut self, scalar: C::Scalar, point: C::CurveExt);
105105

halo2_proofs/src/poly/kzg/multiopen/shplonk.rs

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ use crate::{
99
poly::{query::Query, Coeff, Polynomial},
1010
transcript::ChallengeScalar,
1111
};
12-
12+
use rayon::prelude::*;
1313
use std::{
14-
collections::{btree_map::Entry, BTreeMap, BTreeSet},
14+
collections::{btree_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet},
1515
marker::PhantomData,
16+
sync::Arc,
1617
};
1718

1819
#[derive(Clone, Copy, Debug)]
@@ -49,7 +50,7 @@ struct RotationSet<F: FieldExt, T: PartialEq + Clone> {
4950
#[derive(Debug, PartialEq)]
5051
struct IntermediateSets<F: FieldExt, Q: Query<F>> {
5152
rotation_sets: Vec<RotationSet<F, Q::Commitment>>,
52-
super_point_set: Vec<F>,
53+
super_point_set: BTreeSet<F>,
5354
}
5455

5556
fn construct_intermediate_sets<F: FieldExt, I, Q: Query<F, Eval = F>>(
@@ -69,18 +70,8 @@ where
6970
.get_eval()
7071
};
7172

72-
// Order points according to their rotation
73-
let mut rotation_point_map = BTreeMap::new();
74-
for query in queries.clone() {
75-
let point = rotation_point_map
76-
.entry(query.get_point())
77-
.or_insert_with(|| query.get_point());
78-
79-
// Assert rotation point matching consistency
80-
assert_eq!(*point, query.get_point());
81-
}
82-
// All points appear in queries
83-
let super_point_set: Vec<F> = rotation_point_map.values().cloned().collect();
73+
// All points that appear in queries
74+
let mut super_point_set = BTreeSet::new();
8475

8576
// Collect rotation sets for each commitment
8677
// Example elements in the vector:
@@ -89,19 +80,21 @@ where
8980
// (C_2, {r_2, r_3, r_4}),
9081
// (C_3, {r_2, r_3, r_4}),
9182
// ...
92-
let mut commitment_rotation_set_map: Vec<(Q::Commitment, Vec<F>)> = vec![];
93-
for query in queries.clone() {
83+
let mut commitment_rotation_set_map: Vec<(Q::Commitment, BTreeSet<F>)> = vec![];
84+
for query in queries.iter() {
9485
let rotation = query.get_point();
95-
if let Some(pos) = commitment_rotation_set_map
96-
.iter()
97-
.position(|(commitment, _)| *commitment == query.get_commitment())
86+
super_point_set.insert(rotation);
87+
if let Some(commitment_rotation_set) = commitment_rotation_set_map
88+
.iter_mut()
89+
.find(|(commitment, _)| *commitment == query.get_commitment())
9890
{
99-
let (_, rotation_set) = &mut commitment_rotation_set_map[pos];
100-
if !rotation_set.contains(&rotation) {
101-
rotation_set.push(rotation);
102-
}
91+
let (_, rotation_set) = commitment_rotation_set;
92+
rotation_set.insert(rotation);
10393
} else {
104-
commitment_rotation_set_map.push((query.get_commitment(), vec![rotation]));
94+
commitment_rotation_set_map.push((
95+
query.get_commitment(),
96+
BTreeSet::from_iter(std::iter::once(rotation)),
97+
));
10598
};
10699
}
107100

@@ -111,41 +104,38 @@ where
111104
// {r_1, r_2, r_3} : [C_1]
112105
// {r_2, r_3, r_4} : [C_2, C_3],
113106
// ...
114-
let mut rotation_set_commitment_map = Vec::<(Vec<_>, Vec<Q::Commitment>)>::new();
115-
for (commitment, rotation_set) in commitment_rotation_set_map.iter() {
116-
if let Some(pos) = rotation_set_commitment_map.iter().position(|(set, _)| {
117-
BTreeSet::<F>::from_iter(set.iter().cloned())
118-
== BTreeSet::<F>::from_iter(rotation_set.iter().cloned())
119-
}) {
120-
let (_, commitments) = &mut rotation_set_commitment_map[pos];
121-
if !commitments.contains(commitment) {
122-
commitments.push(*commitment);
123-
}
107+
// NOTE: we want to make the order of the collection of rotation sets independent of the opening points, to ease the verifier computation
108+
let mut rotation_set_commitment_map: Vec<(BTreeSet<F>, Vec<Q::Commitment>)> = vec![];
109+
for (commitment, rotation_set) in commitment_rotation_set_map.into_iter() {
110+
if let Some(rotation_set_commitment) = rotation_set_commitment_map
111+
.iter_mut()
112+
.find(|(set, _)| set == &rotation_set)
113+
{
114+
let (_, commitments) = rotation_set_commitment;
115+
commitments.push(commitment);
124116
} else {
125-
rotation_set_commitment_map.push((rotation_set.clone(), vec![*commitment]))
126-
}
117+
rotation_set_commitment_map.push((rotation_set, vec![commitment]));
118+
};
127119
}
128120

129121
let rotation_sets = rotation_set_commitment_map
130-
.into_iter()
122+
.into_par_iter()
131123
.map(|(rotations, commitments)| {
124+
let rotations_vec = rotations.iter().collect::<Vec<_>>();
132125
let commitments: Vec<Commitment<F, Q::Commitment>> = commitments
133-
.iter()
126+
.into_par_iter()
134127
.map(|commitment| {
135-
let evals: Vec<F> = rotations
136-
.iter()
137-
.map(|rotation| get_eval(*commitment, *rotation))
128+
let evals: Vec<F> = rotations_vec
129+
.par_iter()
130+
.map(|&&rotation| get_eval(commitment, rotation))
138131
.collect();
139-
Commitment((*commitment, evals))
132+
Commitment((commitment, evals))
140133
})
141134
.collect();
142135

143136
RotationSet {
144137
commitments,
145-
points: rotations
146-
.iter()
147-
.map(|rotation| *rotation_point_map.get(rotation).unwrap())
148-
.collect(),
138+
points: rotations.into_iter().collect(),
149139
}
150140
})
151141
.collect::<Vec<RotationSet<_, _>>>();

halo2_proofs/src/poly/kzg/multiopen/shplonk/prover.rs

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use ff::Field;
1717
use group::Curve;
1818
use halo2curves::pairing::Engine;
1919
use rand_core::RngCore;
20+
use rayon::prelude::*;
2021
use std::fmt::Debug;
2122
use std::io::{self, Write};
2223
use std::marker::PhantomData;
@@ -36,8 +37,8 @@ struct CommitmentExtension<'a, C: CurveAffine> {
3637
}
3738

3839
impl<'a, C: CurveAffine> Commitment<C::Scalar, PolynomialPointer<'a, C>> {
39-
fn extend(&self, points: Vec<C::Scalar>) -> CommitmentExtension<'a, C> {
40-
let poly = lagrange_interpolate(&points[..], &self.evals()[..]);
40+
fn extend(&self, points: &[C::Scalar]) -> CommitmentExtension<'a, C> {
41+
let poly = lagrange_interpolate(points, &self.evals()[..]);
4142

4243
let low_degree_equivalent = Polynomial {
4344
values: poly,
@@ -79,10 +80,10 @@ struct RotationSetExtension<'a, C: CurveAffine> {
7980
}
8081

8182
impl<'a, C: CurveAffine> RotationSet<C::Scalar, PolynomialPointer<'a, C>> {
82-
fn extend(&self, commitments: Vec<CommitmentExtension<'a, C>>) -> RotationSetExtension<'a, C> {
83+
fn extend(self, commitments: Vec<CommitmentExtension<'a, C>>) -> RotationSetExtension<'a, C> {
8384
RotationSetExtension {
8485
commitments,
85-
points: self.points.clone(),
86+
points: self.points,
8687
}
8788
}
8889
}
@@ -136,15 +137,17 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme<E>>
136137
// [P_i_0(X) - R_i_0(X), P_i_1(X) - R_i_1(X), ... ]
137138
let numerators = rotation_set
138139
.commitments
139-
.iter()
140-
.map(|commitment| commitment.quotient_contribution());
140+
.par_iter()
141+
.map(|commitment| commitment.quotient_contribution())
142+
.collect::<Vec<_>>();
141143

142144
// define numerator polynomial as
143145
// N_i_j(X) = (P_i_j(X) - R_i_j(X))
144146
// and combine polynomials with same evaluation point set
145147
// N_i(X) = linear_combinination(y, N_i_j(X))
146148
// where y is random scalar to combine numerator polynomials
147149
let n_x = numerators
150+
.into_iter()
148151
.zip(powers(*y))
149152
.map(|(numerator, power_of_y)| numerator * power_of_y)
150153
.reduce(|acc, numerator| acc + &numerator)
@@ -171,22 +174,26 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme<E>>
171174
);
172175

173176
let rotation_sets: Vec<RotationSetExtension<E::G1Affine>> = rotation_sets
174-
.iter()
177+
.into_par_iter()
175178
.map(|rotation_set| {
176179
let commitments: Vec<CommitmentExtension<E::G1Affine>> = rotation_set
177180
.commitments
178-
.iter()
179-
.map(|commitment_data| commitment_data.extend(rotation_set.points.clone()))
181+
.par_iter()
182+
.map(|commitment_data| commitment_data.extend(&rotation_set.points))
180183
.collect();
181184
rotation_set.extend(commitments)
182185
})
183186
.collect();
184187

185188
let v: ChallengeV<_> = transcript.squeeze_challenge_scalar();
186189

187-
let quotient_polynomials = rotation_sets.iter().map(quotient_contribution);
190+
let quotient_polynomials = rotation_sets
191+
.par_iter()
192+
.map(quotient_contribution)
193+
.collect::<Vec<_>>();
188194

189195
let h_x: Polynomial<E::Scalar, Coeff> = quotient_polynomials
196+
.into_iter()
190197
.zip(powers(*v))
191198
.map(|(poly, power_of_v)| poly * power_of_v)
192199
.reduce(|acc, poly| acc + &poly)
@@ -196,18 +203,15 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme<E>>
196203
transcript.write_point(h)?;
197204
let u: ChallengeU<_> = transcript.squeeze_challenge_scalar();
198205

199-
let zt_eval = evaluate_vanishing_polynomial(&super_point_set[..], *u);
200-
201206
let linearisation_contribution =
202207
|rotation_set: RotationSetExtension<E::G1Affine>| -> (Polynomial<E::Scalar, Coeff>, E::Scalar) {
203-
let diffs: Vec<E::Scalar> = super_point_set
204-
.iter()
205-
.filter(|point| !rotation_set.points.contains(point))
206-
.copied()
207-
.collect();
208+
let mut diffs = super_point_set.clone();
209+
for point in rotation_set.points.iter() {
210+
diffs.remove(point);
211+
}
212+
let diffs = diffs.into_iter().collect::<Vec<_>>();
208213

209214
// calculate difference vanishing polynomial evaluation
210-
211215
let z_i = evaluate_vanishing_polynomial(&diffs[..], *u);
212216

213217
// inner linearisation contibutions are
@@ -216,15 +220,15 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme<E>>
216220
// where u is random evaluation point
217221
let inner_contributions = rotation_set
218222
.commitments
219-
.iter()
220-
.map(|commitment| commitment.linearisation_contribution(*u));
223+
.par_iter()
224+
.map(|commitment| commitment.linearisation_contribution(*u)).collect::<Vec<_>>();
221225

222226
// define inner contributor polynomial as
223227
// L_i_j(X) = (P_i_j(X) - r_i_j)
224228
// and combine polynomials with same evaluation point set
225229
// L_i(X) = linear_combinination(y, L_i_j(X))
226230
// where y is random scalar to combine inner contibutors
227-
let l_x: Polynomial<E::Scalar, Coeff> = inner_contributions.zip(powers(*y)).map(|(poly, power_of_y)| poly * power_of_y).reduce(|acc, poly| acc + &poly).unwrap();
231+
let l_x: Polynomial<E::Scalar, Coeff> = inner_contributions.into_iter().zip(powers(*y)).map(|(poly, power_of_y)| poly * power_of_y).reduce(|acc, poly| acc + &poly).unwrap();
228232

229233
// finally scale l_x by difference vanishing polynomial evaluation z_i
230234
(l_x * z_i, z_i)
@@ -235,7 +239,7 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme<E>>
235239
Vec<Polynomial<E::Scalar, Coeff>>,
236240
Vec<E::Scalar>,
237241
) = rotation_sets
238-
.into_iter()
242+
.into_par_iter()
239243
.map(linearisation_contribution)
240244
.unzip();
241245

@@ -246,9 +250,12 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme<E>>
246250
.reduce(|acc, poly| acc + &poly)
247251
.unwrap();
248252

253+
let super_point_set = super_point_set.into_iter().collect::<Vec<_>>();
254+
let zt_eval = evaluate_vanishing_polynomial(&super_point_set[..], *u);
249255
let l_x = l_x - &(h_x * zt_eval);
250256

251257
// sanity check
258+
#[cfg(debug_assertions)]
252259
{
253260
let must_be_zero = eval_polynomial(&l_x.values[..], *u);
254261
assert_eq!(must_be_zero, E::Scalar::zero());

halo2_proofs/src/poly/query.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ use crate::{
88
use ff::Field;
99
use halo2curves::CurveAffine;
1010

11-
pub trait Query<F>: Sized + Clone {
12-
type Commitment: PartialEq + Copy;
11+
pub trait Query<F>: Sized + Clone + Send + Sync {
12+
type Commitment: PartialEq + Copy + Send + Sync;
1313
type Eval: Clone + Default + Debug;
1414

1515
fn get_point(&self) -> F;

0 commit comments

Comments
 (0)