1- use ark_ec:: short_weierstrass:: Affine ;
1+ use ark_ec:: short_weierstrass:: { Affine , Projective , SWCurveConfig } ;
22use ark_ec:: AffineRepr ;
3- use ark_ff:: { AdditiveGroup , BigInteger , Field , Fp , PrimeField } ;
3+ use ark_ff:: { AdditiveGroup , BigInteger , Field , PrimeField } ;
44use rayon:: prelude:: * ;
55
66fn make_digits ( a : & impl BigInteger , w : usize , num_bits : usize ) -> impl Iterator < Item = i64 > + ' _ {
@@ -45,19 +45,38 @@ fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<
4545 } )
4646}
4747
48- pub fn compute_window (
48+ trait MsmAffine : Copy + Send + Sync {
49+ type Config : SWCurveConfig ;
50+ fn to_ark_affine ( & self ) -> & Affine < Self :: Config > ;
51+ }
52+
53+ impl MsmAffine for super :: bls12_446:: G1Affine {
54+ type Config = crate :: curve_446:: g1:: Config ;
55+ #[ inline]
56+ fn to_ark_affine ( & self ) -> & Affine < Self :: Config > {
57+ & self . inner
58+ }
59+ }
60+
61+ impl MsmAffine for super :: bls12_446:: G2Affine {
62+ type Config = crate :: curve_446:: g2:: Config ;
63+ #[ inline]
64+ fn to_ark_affine ( & self ) -> & Affine < Self :: Config > {
65+ & self . inner
66+ }
67+ }
68+
69+ #[ track_caller]
70+ fn compute_window < Aff : MsmAffine > (
4971 i : usize ,
50- bases : & [ super :: bls12_446 :: G1Affine ] ,
72+ bases : & [ Aff ] ,
5173 scalar_digits : & [ i64 ] ,
5274 digits_count : usize ,
53- ) -> super :: bls12_446 :: G1 {
54- use super :: bls12_446 :: * ;
75+ ) -> Projective < Aff :: Config > {
76+ type BaseField < Aff > = << Aff as MsmAffine > :: Config as ark_ec :: CurveConfig > :: BaseField ;
5577
56- type BaseField = Fp < ark_ff :: MontBackend < crate :: curve_446 :: FqConfig , 7 > , 7 > ;
78+ let zero = Affine :: < Aff :: Config > :: zero ( ) ;
5779
58- let zero = G1Affine {
59- inner : Affine :: zero ( ) ,
60- } ;
6180 let size = bases. len ( ) ;
6281
6382 let c = if size < 32 {
@@ -69,12 +88,11 @@ pub fn compute_window(
6988
7089 let n = 1 << c;
7190 let mut indices = vec ! [ vec![ ] ; n] ;
72- let mut d = vec ! [ BaseField :: ZERO ; n + 1 ] ;
73- let mut e = vec ! [ BaseField :: ZERO ; n + 1 ] ;
91+ let mut d = vec ! [ BaseField :: < Aff > :: ZERO ; n + 1 ] ;
92+ let mut e = vec ! [ BaseField :: < Aff > :: ZERO ; n + 1 ] ;
7493
7594 for ( idx, digits) in scalar_digits. chunks ( digits_count) . enumerate ( ) {
7695 use core:: cmp:: Ordering ;
77- // digits is the digits thing of the first scalar?
7896 let scalar = digits[ i] ;
7997 match 0 . cmp ( & scalar) {
8098 Ordering :: Less => indices[ ( scalar - 1 ) as usize ] . push ( idx) ,
@@ -83,26 +101,33 @@ pub fn compute_window(
83101 }
84102 }
85103
104+ let get_base = |idx : usize | -> Affine < Aff :: Config > {
105+ if idx >> ( usize:: BITS - 1 ) == 1 {
106+ let base = bases[ !idx] . to_ark_affine ( ) ;
107+ Affine :: < Aff :: Config > {
108+ x : base. x ,
109+ y : -base. y ,
110+ infinity : base. infinity ,
111+ }
112+ } else {
113+ * bases[ idx] . to_ark_affine ( )
114+ }
115+ } ;
116+
86117 let mut buckets = vec ! [ zero; 1 << c] ;
87118
88119 loop {
89- d[ 0 ] = BaseField :: ONE ;
120+ d[ 0 ] = BaseField :: < Aff > :: ONE ;
90121 for ( k, ( bucket, idx) ) in core:: iter:: zip ( & mut buckets, & mut indices) . enumerate ( ) {
91- if let Some ( idx) = idx. last ( ) . copied ( ) {
92- let value = if idx >> ( usize:: BITS - 1 ) == 1 {
93- let mut val = bases[ !idx] ;
94- val. inner . y = -val. inner . y ;
95- val
96- } else {
97- bases[ idx]
98- } ;
122+ if let Some ( & idx) = idx. last ( ) {
123+ let value = get_base ( idx) ;
99124
100- if !bucket. inner . infinity {
101- let a = value. inner . x - bucket. inner . x ;
102- if a != BaseField :: ZERO {
125+ if !bucket. infinity {
126+ let a = value. x - bucket. x ;
127+ if a != BaseField :: < Aff > :: ZERO {
103128 d[ k + 1 ] = d[ k] * a;
104- } else if value. inner . y == bucket. inner . y {
105- d[ k + 1 ] = d[ k] * value. inner . y . double ( ) ;
129+ } else if value. y == bucket. y {
130+ d[ k + 1 ] = d[ k] * value. y . double ( ) ;
106131 } else {
107132 d[ k + 1 ] = d[ k] ;
108133 }
@@ -117,21 +142,15 @@ pub fn compute_window(
117142 . enumerate ( )
118143 . rev ( )
119144 {
120- if let Some ( idx) = idx. last ( ) . copied ( ) {
121- let value = if idx >> ( usize:: BITS - 1 ) == 1 {
122- let mut val = bases[ !idx] ;
123- val. inner . y = -val. inner . y ;
124- val
125- } else {
126- bases[ idx]
127- } ;
145+ if let Some ( & idx) = idx. last ( ) {
146+ let value = get_base ( idx) ;
128147
129- if !bucket. inner . infinity {
130- let a = value. inner . x - bucket. inner . x ;
131- if a != BaseField :: ZERO {
148+ if !bucket. infinity {
149+ let a = value. x - bucket. x ;
150+ if a != BaseField :: < Aff > :: ZERO {
132151 e[ k] = e[ k + 1 ] * a;
133- } else if value. inner . y == bucket. inner . y {
134- e[ k] = e[ k + 1 ] * value. inner . y . double ( ) ;
152+ } else if value. y == bucket. y {
153+ e[ k] = e[ k + 1 ] * value. y . double ( ) ;
135154 } else {
136155 e[ k] = e[ k + 1 ] ;
137156 }
@@ -151,24 +170,18 @@ pub fn compute_window(
151170 ) {
152171 empty &= idx. len ( ) <= 1 ;
153172 if let Some ( idx) = idx. pop ( ) {
154- let value = if idx >> ( usize:: BITS - 1 ) == 1 {
155- let mut val = bases[ !idx] ;
156- val. inner . y = -val. inner . y ;
157- val
158- } else {
159- bases[ idx]
160- } ;
173+ let value = get_base ( idx) ;
161174
162- if !bucket. inner . infinity {
163- let x1: BaseField = bucket. inner . x ;
164- let x2 = value. inner . x ;
165- let y1 = bucket. inner . y ;
166- let y2 = value. inner . y ;
175+ if !bucket. infinity {
176+ let x1 = bucket. x ;
177+ let x2 = value. x ;
178+ let y1 = bucket. y ;
179+ let y2 = value. y ;
167180
168181 let eq_x = x1 == x2;
169182
170183 if eq_x && y1 != y2 {
171- bucket. inner . infinity = true ;
184+ bucket. infinity = true ;
172185 } else {
173186 let r = d * e;
174187 let m = if eq_x {
@@ -181,8 +194,8 @@ pub fn compute_window(
181194
182195 let x3 = m. square ( ) - x1 - x2;
183196 let y3 = m * ( x1 - x3) - y1;
184- bucket. inner . x = x3;
185- bucket. inner . y = y3;
197+ bucket. x = x3;
198+ bucket. y = y3;
186199 }
187200 } else {
188201 * bucket = value;
@@ -195,22 +208,17 @@ pub fn compute_window(
195208 }
196209 }
197210
198- let mut running_sum = G1 :: ZERO ;
199- let mut res = G1 :: ZERO ;
211+ let mut running_sum = Projective :: < Aff :: Config > :: ZERO ;
212+ let mut res = Projective :: < Aff :: Config > :: ZERO ;
200213 buckets. into_iter ( ) . rev ( ) . for_each ( |b| {
201- running_sum. inner += b. inner ;
214+ running_sum += b;
202215 res += running_sum;
203216 } ) ;
204217 res
205218}
206219
207- // Compute msm using windowed non-adjacent form
208- #[ track_caller]
209- pub fn msm_wnaf_g1_446 (
210- bases : & [ super :: bls12_446:: G1Affine ] ,
211- scalars : & [ super :: bls12_446:: Zp ] ,
212- ) -> super :: bls12_446:: G1 {
213- use super :: bls12_446:: * ;
220+ fn msm_wnaf < A : MsmAffine > ( bases : & [ A ] , scalars : & [ super :: bls12_446:: Zp ] ) -> Projective < A :: Config > {
221+ // size of the scalars, the modulus used for FrConfig in curve446::mod.rs is 299 bits
214222 let num_bits = 299usize ;
215223
216224 assert_eq ! ( bases. len( ) , scalars. len( ) ) ;
@@ -224,7 +232,6 @@ pub fn msm_wnaf_g1_446(
224232 let c = if size < 32 {
225233 3
226234 } else {
227- // natural log approx
228235 ( size. ilog2 ( ) as usize * 69 / 100 ) + 2
229236 } ;
230237
@@ -236,26 +243,44 @@ pub fn msm_wnaf_g1_446(
236243
237244 let window_sums: Vec < _ > = ( 0 ..digits_count)
238245 . into_par_iter ( )
239- . map ( |i| compute_window ( i, bases, & scalar_digits, digits_count) )
246+ . map ( |i| compute_window :: < A > ( i, bases, & scalar_digits, digits_count) )
240247 . collect ( ) ;
241248
242- // We store the sum for the lowest window.
243249 let lowest = * window_sums. first ( ) . unwrap ( ) ;
244250
245- // We're traversing windows from high to low.
246251 lowest
247252 + window_sums[ 1 ..]
248253 . iter ( )
249254 . rev ( )
250- . fold ( G1 :: ZERO , |mut total, & sum_i| {
255+ . fold ( Projective :: < A :: Config > :: ZERO , |mut total, & sum_i| {
251256 total += sum_i;
252257 for _ in 0 ..c {
253- total = total . double ( ) ;
258+ total. double_in_place ( ) ;
254259 }
255260 total
256261 } )
257262}
258263
264+ #[ track_caller]
265+ pub fn msm_wnaf_g1_446 (
266+ bases : & [ super :: bls12_446:: G1Affine ] ,
267+ scalars : & [ super :: bls12_446:: Zp ] ,
268+ ) -> super :: bls12_446:: G1 {
269+ super :: bls12_446:: G1 {
270+ inner : msm_wnaf ( bases, scalars) ,
271+ }
272+ }
273+
274+ #[ track_caller]
275+ pub fn msm_wnaf_g2_446 (
276+ bases : & [ super :: bls12_446:: G2Affine ] ,
277+ scalars : & [ super :: bls12_446:: Zp ] ,
278+ ) -> super :: bls12_446:: G2 {
279+ super :: bls12_446:: G2 {
280+ inner : msm_wnaf ( bases, scalars) ,
281+ }
282+ }
283+
259284#[ cfg( target_family = "wasm" ) ]
260285pub mod cross_origin {
261286 use crate :: serialization:: {
@@ -265,10 +290,9 @@ pub mod cross_origin {
265290 use serde:: { Deserialize , Serialize } ;
266291 use wasm_par_mq:: { par_fn, register_fn, IntoParallelIterator , ParallelIterator } ;
267292
293+ use super :: { msm_wnaf_g1_446, msm_wnaf_g2_446} ;
268294 use crate :: curve_api:: bls12_446:: { G1Affine , G2Affine , Zp , G1 , G2 } ;
269295
270- use super :: msm_wnaf_g1_446;
271-
272296 /// Input for parallel MSM: a chunk of bases and scalars
273297 /// Each worker computes a partial MSM on its subset, then results are summed.
274298 #[ derive( Serialize , Deserialize ) ]
@@ -320,7 +344,7 @@ pub mod cross_origin {
320344 . map ( |s| s. try_into ( ) . unwrap ( ) )
321345 . collect ( ) ;
322346
323- let result = G2Affine :: multi_mul_scalar ( & bases, & scalars) ;
347+ let result = msm_wnaf_g2_446 ( & bases, & scalars) ;
324348 result. inner . into ( )
325349 }
326350 register_fn ! (
@@ -363,7 +387,7 @@ pub mod cross_origin {
363387 . expect ( "worker returned invalid projective point" )
364388 }
365389
366- pub fn msm_g2_cross_origin ( bases : & [ G2Affine ] , scalars : & [ Zp ] ) -> G2 {
390+ pub fn msm_wnaf_g2_446_cross_origin ( bases : & [ G2Affine ] , scalars : & [ Zp ] ) -> G2 {
367391 assert_eq ! ( bases. len( ) , scalars. len( ) ) ;
368392
369393 let num_workers = wasm_par_mq:: num_workers ( ) . max ( 1 ) ;
@@ -397,3 +421,74 @@ pub mod cross_origin {
397421 . expect ( "worker returned invalid projective point" )
398422 }
399423}
424+
425+ #[ cfg( test) ]
426+ mod tests {
427+ use ark_ec:: CurveGroup ;
428+ use rand:: rngs:: StdRng ;
429+ use rand:: { thread_rng, Rng , SeedableRng } ;
430+
431+ use super :: * ;
432+ use crate :: curve_api:: bls12_446:: { G1Affine , G2Affine , Zp , G1 , G2 } ;
433+
434+ const MSM_SIZES : [ usize ; 7 ] = [ 1 , 2 , 7 , 32 , 64 , 256 , 1024 ] ;
435+ const TEST_COUNT : usize = 10 ;
436+
437+ fn random_g1_affine_points ( rng : & mut dyn rand:: RngCore , n : usize ) -> Vec < G1Affine > {
438+ ( 0 ..n)
439+ . map ( |_| G1Affine {
440+ inner : G1 :: GENERATOR . mul_scalar ( Zp :: rand ( rng) ) . inner . into_affine ( ) ,
441+ } )
442+ . collect ( )
443+ }
444+
445+ fn random_g2_affine_points ( rng : & mut dyn rand:: RngCore , n : usize ) -> Vec < G2Affine > {
446+ ( 0 ..n)
447+ . map ( |_| G2Affine {
448+ inner : G2 :: GENERATOR . mul_scalar ( Zp :: rand ( rng) ) . inner . into_affine ( ) ,
449+ } )
450+ . collect ( )
451+ }
452+
453+ fn random_scalars ( rng : & mut dyn rand:: RngCore , n : usize ) -> Vec < Zp > {
454+ ( 0 ..n) . map ( |_| Zp :: rand ( rng) ) . collect ( )
455+ }
456+
457+ #[ test]
458+ fn test_wnaf_msm_g1_matches_arkworks ( ) {
459+ let seed = thread_rng ( ) . gen ( ) ;
460+ println ! ( "test_wnaf_msm_g1_matches_arkworks seed: {seed:x}" ) ;
461+ let rng = & mut StdRng :: seed_from_u64 ( seed) ;
462+
463+ for _ in 0 ..TEST_COUNT {
464+ for n in MSM_SIZES {
465+ let bases = random_g1_affine_points ( rng, n) ;
466+ let scalars = random_scalars ( rng, n) ;
467+
468+ let wnaf_result = msm_wnaf_g1_446 ( & bases, & scalars) ;
469+ let ark_result = G1Affine :: multi_mul_scalar ( & bases, & scalars) ;
470+
471+ assert_eq ! ( wnaf_result, ark_result, "G1 MSM mismatch for n={n}" ) ;
472+ }
473+ }
474+ }
475+
476+ #[ test]
477+ fn test_wnaf_msm_g2_matches_arkworks ( ) {
478+ let seed = thread_rng ( ) . gen ( ) ;
479+ println ! ( "test_wnaf_msm_g2_matches_arkworks seed: {seed:x}" ) ;
480+ let rng = & mut StdRng :: seed_from_u64 ( seed) ;
481+
482+ for _ in 0 ..TEST_COUNT {
483+ for n in MSM_SIZES {
484+ let bases = random_g2_affine_points ( rng, n) ;
485+ let scalars = random_scalars ( rng, n) ;
486+
487+ let wnaf_result = msm_wnaf_g2_446 ( & bases, & scalars) ;
488+ let ark_result = G2Affine :: multi_mul_scalar ( & bases, & scalars) ;
489+
490+ assert_eq ! ( wnaf_result, ark_result, "G2 MSM mismatch for n={n}" ) ;
491+ }
492+ }
493+ }
494+ }
0 commit comments