Skip to content

Commit 782a9e1

Browse files
committed
perf(zk): use wnaf for g2 msm in wasm
1 parent c51b376 commit 782a9e1

File tree

2 files changed

+172
-77
lines changed

2 files changed

+172
-77
lines changed

tfhe-zk-pok/src/curve_api.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,9 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G2 {
416416
#[cfg(target_family = "wasm")]
417417
{
418418
if wasm_par_mq::is_pool_initialized() {
419-
return msm::cross_origin::msm_g2_cross_origin(bases, scalars);
419+
return msm::cross_origin::msm_wnaf_g2_446_cross_origin(bases, scalars);
420420
}
421-
return Self::Affine::multi_mul_scalar(bases, scalars);
421+
msm::msm_wnaf_g2_446(bases, scalars)
422422
}
423423
#[cfg(not(target_family = "wasm"))]
424424
Self::Affine::multi_mul_scalar(bases, scalars)

tfhe-zk-pok/src/curve_api/msm.rs

Lines changed: 170 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
use ark_ec::short_weierstrass::Affine;
1+
use ark_ec::short_weierstrass::{Affine, Projective, SWCurveConfig};
22
use ark_ec::AffineRepr;
3-
use ark_ff::{AdditiveGroup, BigInteger, Field, Fp, PrimeField};
3+
use ark_ff::{AdditiveGroup, BigInteger, Field, PrimeField};
44
use rayon::prelude::*;
55

66
fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<Item = i64> + '_ {
@@ -45,19 +45,38 @@ fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<
4545
})
4646
}
4747

48-
pub fn compute_window(
48+
trait MsmAffine: Copy + Send + Sync {
49+
type Config: SWCurveConfig;
50+
fn to_ark_affine(&self) -> &Affine<Self::Config>;
51+
}
52+
53+
impl MsmAffine for super::bls12_446::G1Affine {
54+
type Config = crate::curve_446::g1::Config;
55+
#[inline]
56+
fn to_ark_affine(&self) -> &Affine<Self::Config> {
57+
&self.inner
58+
}
59+
}
60+
61+
impl MsmAffine for super::bls12_446::G2Affine {
62+
type Config = crate::curve_446::g2::Config;
63+
#[inline]
64+
fn to_ark_affine(&self) -> &Affine<Self::Config> {
65+
&self.inner
66+
}
67+
}
68+
69+
#[track_caller]
70+
fn compute_window<Aff: MsmAffine>(
4971
i: usize,
50-
bases: &[super::bls12_446::G1Affine],
72+
bases: &[Aff],
5173
scalar_digits: &[i64],
5274
digits_count: usize,
53-
) -> super::bls12_446::G1 {
54-
use super::bls12_446::*;
75+
) -> Projective<Aff::Config> {
76+
type BaseField<Aff> = <<Aff as MsmAffine>::Config as ark_ec::CurveConfig>::BaseField;
5577

56-
type BaseField = Fp<ark_ff::MontBackend<crate::curve_446::FqConfig, 7>, 7>;
78+
let zero = Affine::<Aff::Config>::zero();
5779

58-
let zero = G1Affine {
59-
inner: Affine::zero(),
60-
};
6180
let size = bases.len();
6281

6382
let c = if size < 32 {
@@ -69,12 +88,11 @@ pub fn compute_window(
6988

7089
let n = 1 << c;
7190
let mut indices = vec![vec![]; n];
72-
let mut d = vec![BaseField::ZERO; n + 1];
73-
let mut e = vec![BaseField::ZERO; n + 1];
91+
let mut d = vec![BaseField::<Aff>::ZERO; n + 1];
92+
let mut e = vec![BaseField::<Aff>::ZERO; n + 1];
7493

7594
for (idx, digits) in scalar_digits.chunks(digits_count).enumerate() {
7695
use core::cmp::Ordering;
77-
// digits is the digits thing of the first scalar?
7896
let scalar = digits[i];
7997
match 0.cmp(&scalar) {
8098
Ordering::Less => indices[(scalar - 1) as usize].push(idx),
@@ -83,26 +101,33 @@ pub fn compute_window(
83101
}
84102
}
85103

104+
let get_base = |idx: usize| -> Affine<Aff::Config> {
105+
if idx >> (usize::BITS - 1) == 1 {
106+
let base = bases[!idx].to_ark_affine();
107+
Affine::<Aff::Config> {
108+
x: base.x,
109+
y: -base.y,
110+
infinity: base.infinity,
111+
}
112+
} else {
113+
*bases[idx].to_ark_affine()
114+
}
115+
};
116+
86117
let mut buckets = vec![zero; 1 << c];
87118

88119
loop {
89-
d[0] = BaseField::ONE;
120+
d[0] = BaseField::<Aff>::ONE;
90121
for (k, (bucket, idx)) in core::iter::zip(&mut buckets, &mut indices).enumerate() {
91-
if let Some(idx) = idx.last().copied() {
92-
let value = if idx >> (usize::BITS - 1) == 1 {
93-
let mut val = bases[!idx];
94-
val.inner.y = -val.inner.y;
95-
val
96-
} else {
97-
bases[idx]
98-
};
122+
if let Some(&idx) = idx.last() {
123+
let value = get_base(idx);
99124

100-
if !bucket.inner.infinity {
101-
let a = value.inner.x - bucket.inner.x;
102-
if a != BaseField::ZERO {
125+
if !bucket.infinity {
126+
let a = value.x - bucket.x;
127+
if a != BaseField::<Aff>::ZERO {
103128
d[k + 1] = d[k] * a;
104-
} else if value.inner.y == bucket.inner.y {
105-
d[k + 1] = d[k] * value.inner.y.double();
129+
} else if value.y == bucket.y {
130+
d[k + 1] = d[k] * value.y.double();
106131
} else {
107132
d[k + 1] = d[k];
108133
}
@@ -117,21 +142,15 @@ pub fn compute_window(
117142
.enumerate()
118143
.rev()
119144
{
120-
if let Some(idx) = idx.last().copied() {
121-
let value = if idx >> (usize::BITS - 1) == 1 {
122-
let mut val = bases[!idx];
123-
val.inner.y = -val.inner.y;
124-
val
125-
} else {
126-
bases[idx]
127-
};
145+
if let Some(&idx) = idx.last() {
146+
let value = get_base(idx);
128147

129-
if !bucket.inner.infinity {
130-
let a = value.inner.x - bucket.inner.x;
131-
if a != BaseField::ZERO {
148+
if !bucket.infinity {
149+
let a = value.x - bucket.x;
150+
if a != BaseField::<Aff>::ZERO {
132151
e[k] = e[k + 1] * a;
133-
} else if value.inner.y == bucket.inner.y {
134-
e[k] = e[k + 1] * value.inner.y.double();
152+
} else if value.y == bucket.y {
153+
e[k] = e[k + 1] * value.y.double();
135154
} else {
136155
e[k] = e[k + 1];
137156
}
@@ -151,24 +170,18 @@ pub fn compute_window(
151170
) {
152171
empty &= idx.len() <= 1;
153172
if let Some(idx) = idx.pop() {
154-
let value = if idx >> (usize::BITS - 1) == 1 {
155-
let mut val = bases[!idx];
156-
val.inner.y = -val.inner.y;
157-
val
158-
} else {
159-
bases[idx]
160-
};
173+
let value = get_base(idx);
161174

162-
if !bucket.inner.infinity {
163-
let x1: BaseField = bucket.inner.x;
164-
let x2 = value.inner.x;
165-
let y1 = bucket.inner.y;
166-
let y2 = value.inner.y;
175+
if !bucket.infinity {
176+
let x1 = bucket.x;
177+
let x2 = value.x;
178+
let y1 = bucket.y;
179+
let y2 = value.y;
167180

168181
let eq_x = x1 == x2;
169182

170183
if eq_x && y1 != y2 {
171-
bucket.inner.infinity = true;
184+
bucket.infinity = true;
172185
} else {
173186
let r = d * e;
174187
let m = if eq_x {
@@ -181,8 +194,8 @@ pub fn compute_window(
181194

182195
let x3 = m.square() - x1 - x2;
183196
let y3 = m * (x1 - x3) - y1;
184-
bucket.inner.x = x3;
185-
bucket.inner.y = y3;
197+
bucket.x = x3;
198+
bucket.y = y3;
186199
}
187200
} else {
188201
*bucket = value;
@@ -195,22 +208,17 @@ pub fn compute_window(
195208
}
196209
}
197210

198-
let mut running_sum = G1::ZERO;
199-
let mut res = G1::ZERO;
211+
let mut running_sum = Projective::<Aff::Config>::ZERO;
212+
let mut res = Projective::<Aff::Config>::ZERO;
200213
buckets.into_iter().rev().for_each(|b| {
201-
running_sum.inner += b.inner;
214+
running_sum += b;
202215
res += running_sum;
203216
});
204217
res
205218
}
206219

207-
// Compute msm using windowed non-adjacent form
208-
#[track_caller]
209-
pub fn msm_wnaf_g1_446(
210-
bases: &[super::bls12_446::G1Affine],
211-
scalars: &[super::bls12_446::Zp],
212-
) -> super::bls12_446::G1 {
213-
use super::bls12_446::*;
220+
fn msm_wnaf<A: MsmAffine>(bases: &[A], scalars: &[super::bls12_446::Zp]) -> Projective<A::Config> {
221+
// size of the scalars, the modulus used for FrConfig in curve446::mod.rs is 299 bits
214222
let num_bits = 299usize;
215223

216224
assert_eq!(bases.len(), scalars.len());
@@ -224,7 +232,6 @@ pub fn msm_wnaf_g1_446(
224232
let c = if size < 32 {
225233
3
226234
} else {
227-
// natural log approx
228235
(size.ilog2() as usize * 69 / 100) + 2
229236
};
230237

@@ -236,26 +243,44 @@ pub fn msm_wnaf_g1_446(
236243

237244
let window_sums: Vec<_> = (0..digits_count)
238245
.into_par_iter()
239-
.map(|i| compute_window(i, bases, &scalar_digits, digits_count))
246+
.map(|i| compute_window::<A>(i, bases, &scalar_digits, digits_count))
240247
.collect();
241248

242-
// We store the sum for the lowest window.
243249
let lowest = *window_sums.first().unwrap();
244250

245-
// We're traversing windows from high to low.
246251
lowest
247252
+ window_sums[1..]
248253
.iter()
249254
.rev()
250-
.fold(G1::ZERO, |mut total, &sum_i| {
255+
.fold(Projective::<A::Config>::ZERO, |mut total, &sum_i| {
251256
total += sum_i;
252257
for _ in 0..c {
253-
total = total.double();
258+
total.double_in_place();
254259
}
255260
total
256261
})
257262
}
258263

264+
#[track_caller]
265+
pub fn msm_wnaf_g1_446(
266+
bases: &[super::bls12_446::G1Affine],
267+
scalars: &[super::bls12_446::Zp],
268+
) -> super::bls12_446::G1 {
269+
super::bls12_446::G1 {
270+
inner: msm_wnaf(bases, scalars),
271+
}
272+
}
273+
274+
#[track_caller]
275+
pub fn msm_wnaf_g2_446(
276+
bases: &[super::bls12_446::G2Affine],
277+
scalars: &[super::bls12_446::Zp],
278+
) -> super::bls12_446::G2 {
279+
super::bls12_446::G2 {
280+
inner: msm_wnaf(bases, scalars),
281+
}
282+
}
283+
259284
#[cfg(target_family = "wasm")]
260285
pub mod cross_origin {
261286
use crate::serialization::{
@@ -265,10 +290,9 @@ pub mod cross_origin {
265290
use serde::{Deserialize, Serialize};
266291
use wasm_par_mq::{par_fn, register_fn, IntoParallelIterator, ParallelIterator};
267292

293+
use super::{msm_wnaf_g1_446, msm_wnaf_g2_446};
268294
use crate::curve_api::bls12_446::{G1Affine, G2Affine, Zp, G1, G2};
269295

270-
use super::msm_wnaf_g1_446;
271-
272296
/// Input for parallel MSM: a chunk of bases and scalars
273297
/// Each worker computes a partial MSM on its subset, then results are summed.
274298
#[derive(Serialize, Deserialize)]
@@ -320,7 +344,7 @@ pub mod cross_origin {
320344
.map(|s| s.try_into().unwrap())
321345
.collect();
322346

323-
let result = G2Affine::multi_mul_scalar(&bases, &scalars);
347+
let result = msm_wnaf_g2_446(&bases, &scalars);
324348
result.inner.into()
325349
}
326350
register_fn!(
@@ -363,7 +387,7 @@ pub mod cross_origin {
363387
.expect("worker returned invalid projective point")
364388
}
365389

366-
pub fn msm_g2_cross_origin(bases: &[G2Affine], scalars: &[Zp]) -> G2 {
390+
pub fn msm_wnaf_g2_446_cross_origin(bases: &[G2Affine], scalars: &[Zp]) -> G2 {
367391
assert_eq!(bases.len(), scalars.len());
368392

369393
let num_workers = wasm_par_mq::num_workers().max(1);
@@ -397,3 +421,74 @@ pub mod cross_origin {
397421
.expect("worker returned invalid projective point")
398422
}
399423
}
424+
425+
#[cfg(test)]
426+
mod tests {
427+
use ark_ec::CurveGroup;
428+
use rand::rngs::StdRng;
429+
use rand::{thread_rng, Rng, SeedableRng};
430+
431+
use super::*;
432+
use crate::curve_api::bls12_446::{G1Affine, G2Affine, Zp, G1, G2};
433+
434+
const MSM_SIZES: [usize; 7] = [1, 2, 7, 32, 64, 256, 1024];
435+
const TEST_COUNT: usize = 10;
436+
437+
fn random_g1_affine_points(rng: &mut dyn rand::RngCore, n: usize) -> Vec<G1Affine> {
438+
(0..n)
439+
.map(|_| G1Affine {
440+
inner: G1::GENERATOR.mul_scalar(Zp::rand(rng)).inner.into_affine(),
441+
})
442+
.collect()
443+
}
444+
445+
fn random_g2_affine_points(rng: &mut dyn rand::RngCore, n: usize) -> Vec<G2Affine> {
446+
(0..n)
447+
.map(|_| G2Affine {
448+
inner: G2::GENERATOR.mul_scalar(Zp::rand(rng)).inner.into_affine(),
449+
})
450+
.collect()
451+
}
452+
453+
fn random_scalars(rng: &mut dyn rand::RngCore, n: usize) -> Vec<Zp> {
454+
(0..n).map(|_| Zp::rand(rng)).collect()
455+
}
456+
457+
#[test]
458+
fn test_wnaf_msm_g1_matches_arkworks() {
459+
let seed = thread_rng().gen();
460+
println!("test_wnaf_msm_g1_matches_arkworks seed: {seed:x}");
461+
let rng = &mut StdRng::seed_from_u64(seed);
462+
463+
for _ in 0..TEST_COUNT {
464+
for n in MSM_SIZES {
465+
let bases = random_g1_affine_points(rng, n);
466+
let scalars = random_scalars(rng, n);
467+
468+
let wnaf_result = msm_wnaf_g1_446(&bases, &scalars);
469+
let ark_result = G1Affine::multi_mul_scalar(&bases, &scalars);
470+
471+
assert_eq!(wnaf_result, ark_result, "G1 MSM mismatch for n={n}");
472+
}
473+
}
474+
}
475+
476+
#[test]
477+
fn test_wnaf_msm_g2_matches_arkworks() {
478+
let seed = thread_rng().gen();
479+
println!("test_wnaf_msm_g2_matches_arkworks seed: {seed:x}");
480+
let rng = &mut StdRng::seed_from_u64(seed);
481+
482+
for _ in 0..TEST_COUNT {
483+
for n in MSM_SIZES {
484+
let bases = random_g2_affine_points(rng, n);
485+
let scalars = random_scalars(rng, n);
486+
487+
let wnaf_result = msm_wnaf_g2_446(&bases, &scalars);
488+
let ark_result = G2Affine::multi_mul_scalar(&bases, &scalars);
489+
490+
assert_eq!(wnaf_result, ark_result, "G2 MSM mismatch for n={n}");
491+
}
492+
}
493+
}
494+
}

0 commit comments

Comments
 (0)