@@ -4,8 +4,8 @@ use benchmark::params::{
44 benchmark_compression_parameters, benchmark_parameters, multi_bit_benchmark_parameters,
55} ;
66use benchmark:: utilities:: {
7- get_bench_type, get_param_type, throughput_num_threads , write_to_json, BenchmarkType ,
8- CryptoParametersRecord , OperatorType , ParamType ,
7+ get_bench_type, get_param_type, write_to_json, BenchmarkType , CryptoParametersRecord ,
8+ OperatorType , ParamType ,
99} ;
1010use criterion:: { black_box, Criterion , Throughput } ;
1111use itertools:: Itertools ;
@@ -84,50 +84,61 @@ fn keyswitch<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
8484 }
8585 BenchmarkType :: Throughput => {
8686 bench_id = format ! ( "{bench_name}::throughput::{name}" ) ;
87- let blocks: usize = 1 ;
88- let elements = throughput_num_threads ( blocks, 1 ) ; // FIXME This number of element do not staturate the target machine
89- bench_group. throughput ( Throughput :: Elements ( elements) ) ;
90- bench_group. bench_function ( & bench_id, |b| {
91- let setup_encrypted_values = || {
92- let input_cts = ( 0 ..elements)
93- . map ( |_| {
94- allocate_and_encrypt_new_lwe_ciphertext (
95- & big_lwe_sk,
96- Plaintext ( Scalar :: ONE ) ,
97- params. lwe_noise_distribution . unwrap ( ) ,
98- params. ciphertext_modulus . unwrap ( ) ,
99- & mut encryption_generator,
100- )
101- } )
102- . collect :: < Vec < _ > > ( ) ;
103-
104- let output_cts = ( 0 ..elements)
105- . map ( |_| {
106- LweCiphertext :: new (
107- Scalar :: ZERO ,
108- lwe_sk. lwe_dimension ( ) . to_lwe_size ( ) ,
109- params. ciphertext_modulus . unwrap ( ) ,
110- )
111- } )
112- . collect :: < Vec < _ > > ( ) ;
87+ let mut setup = |batch_size : usize | {
88+ let input_cts = ( 0 ..batch_size)
89+ . map ( |_| {
90+ allocate_and_encrypt_new_lwe_ciphertext (
91+ & big_lwe_sk,
92+ Plaintext ( Scalar :: ONE ) ,
93+ params. lwe_noise_distribution . unwrap ( ) ,
94+ params. ciphertext_modulus . unwrap ( ) ,
95+ & mut encryption_generator,
96+ )
97+ } )
98+ . collect :: < Vec < _ > > ( ) ;
11399
114- ( input_cts, output_cts)
115- } ;
100+ let output_cts = ( 0 ..batch_size)
101+ . map ( |_| {
102+ LweCiphertext :: new (
103+ Scalar :: ZERO ,
104+ lwe_sk. lwe_dimension ( ) . to_lwe_size ( ) ,
105+ params. ciphertext_modulus . unwrap ( ) ,
106+ )
107+ } )
108+ . collect :: < Vec < _ > > ( ) ;
116109
117- b. iter_batched (
118- setup_encrypted_values,
119- |( input_cts, mut output_cts) | {
120- input_cts
121- . par_iter ( )
122- . zip ( output_cts. par_iter_mut ( ) )
123- . for_each ( |( input_ct, output_ct) | {
124- keyswitch_lwe_ciphertext (
125- & ksk_big_to_small,
126- input_ct,
127- output_ct,
128- ) ;
129- } )
110+ ( input_cts, output_cts)
111+ } ;
112+ type Res < Scalar > = (
113+ Vec < LweCiphertext < Vec < Scalar > > > , // input_cts
114+ Vec < LweCiphertext < Vec < Scalar > > > , // output_cts
115+ ) ;
116+ let run = |inputs : & mut Res < Scalar > | {
117+ inputs. 0 . par_iter ( ) . zip ( inputs. 1 . par_iter_mut ( ) ) . for_each (
118+ |( input_ct, output_ct) | {
119+ keyswitch_lwe_ciphertext ( & ksk_big_to_small, input_ct, output_ct) ;
130120 } ,
121+ )
122+ } ;
123+ let elements = {
124+ #[ cfg( any( feature = "gpu" , feature = "hpu" ) ) ]
125+ {
126+ use benchmark:: utilities:: throughput_num_threads;
127+ let blocks: usize = 1 ;
128+ throughput_num_threads ( blocks, 1 ) // FIXME This number of element do not
129+ // staturate the target machine
130+ }
131+ #[ cfg( not( any( feature = "gpu" , feature = "hpu" ) ) ) ]
132+ {
133+ use benchmark:: find_optimal_batch:: find_optimal_batch;
134+ find_optimal_batch ( |inputs, _batch_size| run ( inputs) , & mut setup) as u64
135+ }
136+ } ;
137+ bench_group. throughput ( Throughput :: Elements ( elements) ) ;
138+ bench_group. bench_function ( & bench_id, |b| {
139+ b. iter_batched (
140+ || setup ( elements as usize ) ,
141+ |mut inputs| run ( & mut inputs) ,
131142 criterion:: BatchSize :: SmallInput ,
132143 )
133144 } ) ;
@@ -242,61 +253,76 @@ fn packing_keyswitch<Scalar, F>(
242253 }
243254 BenchmarkType :: Throughput => {
244255 bench_id = format ! ( "{bench_name}::throughput::{name}" ) ;
245- let blocks: usize = 1 ;
246- let elements = throughput_num_threads ( blocks, 1 ) ;
247- bench_group. throughput ( Throughput :: Elements ( elements) ) ;
248- bench_group. bench_function ( & bench_id, |b| {
249- let setup_encrypted_values = || {
250- let input_lwe_lists = ( 0 ..elements)
251- . map ( |_| {
252- let mut input_lwe_list = LweCiphertextList :: new (
253- Scalar :: ZERO ,
254- lwe_sk. lwe_dimension ( ) . to_lwe_size ( ) ,
255- count,
256- ciphertext_modulus,
257- ) ;
256+ let mut setup = |batch_size : usize | {
257+ let input_lwe_lists = ( 0 ..batch_size)
258+ . map ( |_| {
259+ let mut input_lwe_list = LweCiphertextList :: new (
260+ Scalar :: ZERO ,
261+ lwe_sk. lwe_dimension ( ) . to_lwe_size ( ) ,
262+ count,
263+ ciphertext_modulus,
264+ ) ;
258265
259- let plaintext_list = PlaintextList :: new (
260- Scalar :: ZERO ,
261- PlaintextCount ( input_lwe_list. lwe_ciphertext_count ( ) . 0 ) ,
262- ) ;
266+ let plaintext_list = PlaintextList :: new (
267+ Scalar :: ZERO ,
268+ PlaintextCount ( input_lwe_list. lwe_ciphertext_count ( ) . 0 ) ,
269+ ) ;
263270
264- encrypt_lwe_ciphertext_list (
265- & lwe_sk,
266- & mut input_lwe_list,
267- & plaintext_list,
268- params. lwe_noise_distribution . unwrap ( ) ,
269- & mut encryption_generator,
270- ) ;
271+ encrypt_lwe_ciphertext_list (
272+ & lwe_sk,
273+ & mut input_lwe_list,
274+ & plaintext_list,
275+ params. lwe_noise_distribution . unwrap ( ) ,
276+ & mut encryption_generator,
277+ ) ;
271278
272- input_lwe_list
273- } )
274- . collect :: < Vec < _ > > ( ) ;
275-
276- let output_glwes = ( 0 ..elements)
277- . map ( |_| {
278- GlweCiphertext :: new (
279- Scalar :: ZERO ,
280- glwe_sk. glwe_dimension ( ) . to_glwe_size ( ) ,
281- glwe_sk. polynomial_size ( ) ,
282- ciphertext_modulus,
283- )
284- } )
285- . collect :: < Vec < _ > > ( ) ;
279+ input_lwe_list
280+ } )
281+ . collect :: < Vec < _ > > ( ) ;
286282
287- ( input_lwe_lists, output_glwes)
288- } ;
283+ let output_glwes = ( 0 ..batch_size)
284+ . map ( |_| {
285+ GlweCiphertext :: new (
286+ Scalar :: ZERO ,
287+ glwe_sk. glwe_dimension ( ) . to_glwe_size ( ) ,
288+ glwe_sk. polynomial_size ( ) ,
289+ ciphertext_modulus,
290+ )
291+ } )
292+ . collect :: < Vec < _ > > ( ) ;
289293
290- b. iter_batched (
291- setup_encrypted_values,
292- |( input_lwe_lists, mut output_glwes) | {
293- input_lwe_lists
294- . par_iter ( )
295- . zip ( output_glwes. par_iter_mut ( ) )
296- . for_each ( |( input_lwe_list, output_glwe) | {
297- ks_op ( & pksk, input_lwe_list, output_glwe) ;
298- } )
294+ ( input_lwe_lists, output_glwes)
295+ } ;
296+ type Res < Scalar > = (
297+ Vec < LweCiphertextList < Vec < Scalar > > > , // input_lwe_lists
298+ Vec < GlweCiphertext < Vec < Scalar > > > , // output_glwes
299+ ) ;
300+ let run = |inputs : & mut Res < Scalar > | {
301+ inputs. 0 . par_iter ( ) . zip ( inputs. 1 . par_iter_mut ( ) ) . for_each (
302+ |( input_lwe_list, output_glwe) | {
303+ ks_op ( & pksk, input_lwe_list, output_glwe) ;
299304 } ,
305+ )
306+ } ;
307+ let elements = {
308+ #[ cfg( any( feature = "gpu" , feature = "hpu" ) ) ]
309+ {
310+ use benchmark:: utilities:: throughput_num_threads;
311+ let blocks: usize = 1 ;
312+ throughput_num_threads ( blocks, 1 ) // FIXME This number of element do not
313+ // staturate the target machine
314+ }
315+ #[ cfg( not( any( feature = "gpu" , feature = "hpu" ) ) ) ]
316+ {
317+ use benchmark:: find_optimal_batch:: find_optimal_batch;
318+ find_optimal_batch ( |inputs, _batch_size| run ( inputs) , & mut setup) as u64
319+ }
320+ } ;
321+ bench_group. throughput ( Throughput :: Elements ( elements) ) ;
322+ bench_group. bench_function ( & bench_id, |b| {
323+ b. iter_batched (
324+ || setup ( elements as usize ) ,
325+ |mut inputs| run ( & mut inputs) ,
300326 criterion:: BatchSize :: SmallInput ,
301327 )
302328 } ) ;
0 commit comments