Skip to content

Commit f54c9cf

Browse files
committed
chore(bench): use find_optimal_batch in core crypto benchmarks
1 parent 195ca69 commit f54c9cf

File tree

8 files changed

+711
-634
lines changed

8 files changed

+711
-634
lines changed

tfhe-benchmark/benches/core_crypto/ks_bench.rs

Lines changed: 118 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use benchmark::params::{
44
benchmark_compression_parameters, benchmark_parameters, multi_bit_benchmark_parameters,
55
};
66
use benchmark::utilities::{
7-
get_bench_type, get_param_type, throughput_num_threads, write_to_json, BenchmarkType,
8-
CryptoParametersRecord, OperatorType, ParamType,
7+
get_bench_type, get_param_type, write_to_json, BenchmarkType, CryptoParametersRecord,
8+
OperatorType, ParamType,
99
};
1010
use criterion::{black_box, Criterion, Throughput};
1111
use itertools::Itertools;
@@ -84,50 +84,61 @@ fn keyswitch<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
8484
}
8585
BenchmarkType::Throughput => {
8686
bench_id = format!("{bench_name}::throughput::{name}");
87-
let blocks: usize = 1;
88-
let elements = throughput_num_threads(blocks, 1); // FIXME This number of element do not staturate the target machine
89-
bench_group.throughput(Throughput::Elements(elements));
90-
bench_group.bench_function(&bench_id, |b| {
91-
let setup_encrypted_values = || {
92-
let input_cts = (0..elements)
93-
.map(|_| {
94-
allocate_and_encrypt_new_lwe_ciphertext(
95-
&big_lwe_sk,
96-
Plaintext(Scalar::ONE),
97-
params.lwe_noise_distribution.unwrap(),
98-
params.ciphertext_modulus.unwrap(),
99-
&mut encryption_generator,
100-
)
101-
})
102-
.collect::<Vec<_>>();
103-
104-
let output_cts = (0..elements)
105-
.map(|_| {
106-
LweCiphertext::new(
107-
Scalar::ZERO,
108-
lwe_sk.lwe_dimension().to_lwe_size(),
109-
params.ciphertext_modulus.unwrap(),
110-
)
111-
})
112-
.collect::<Vec<_>>();
87+
let mut setup = |batch_size: usize| {
88+
let input_cts = (0..batch_size)
89+
.map(|_| {
90+
allocate_and_encrypt_new_lwe_ciphertext(
91+
&big_lwe_sk,
92+
Plaintext(Scalar::ONE),
93+
params.lwe_noise_distribution.unwrap(),
94+
params.ciphertext_modulus.unwrap(),
95+
&mut encryption_generator,
96+
)
97+
})
98+
.collect::<Vec<_>>();
11399

114-
(input_cts, output_cts)
115-
};
100+
let output_cts = (0..batch_size)
101+
.map(|_| {
102+
LweCiphertext::new(
103+
Scalar::ZERO,
104+
lwe_sk.lwe_dimension().to_lwe_size(),
105+
params.ciphertext_modulus.unwrap(),
106+
)
107+
})
108+
.collect::<Vec<_>>();
116109

117-
b.iter_batched(
118-
setup_encrypted_values,
119-
|(input_cts, mut output_cts)| {
120-
input_cts
121-
.par_iter()
122-
.zip(output_cts.par_iter_mut())
123-
.for_each(|(input_ct, output_ct)| {
124-
keyswitch_lwe_ciphertext(
125-
&ksk_big_to_small,
126-
input_ct,
127-
output_ct,
128-
);
129-
})
110+
(input_cts, output_cts)
111+
};
112+
type Res<Scalar> = (
113+
Vec<LweCiphertext<Vec<Scalar>>>, // input_cts
114+
Vec<LweCiphertext<Vec<Scalar>>>, // output_cts
115+
);
116+
let run = |inputs: &mut Res<Scalar>| {
117+
inputs.0.par_iter().zip(inputs.1.par_iter_mut()).for_each(
118+
|(input_ct, output_ct)| {
119+
keyswitch_lwe_ciphertext(&ksk_big_to_small, input_ct, output_ct);
130120
},
121+
)
122+
};
123+
let elements = {
124+
#[cfg(any(feature = "gpu", feature = "hpu"))]
125+
{
126+
use benchmark::utilities::throughput_num_threads;
127+
let blocks: usize = 1;
128+
throughput_num_threads(blocks, 1) // FIXME This number of element do not
129+
// staturate the target machine
130+
}
131+
#[cfg(not(any(feature = "gpu", feature = "hpu")))]
132+
{
133+
use benchmark::find_optimal_batch::find_optimal_batch;
134+
find_optimal_batch(|inputs, _batch_size| run(inputs), &mut setup) as u64
135+
}
136+
};
137+
bench_group.throughput(Throughput::Elements(elements));
138+
bench_group.bench_function(&bench_id, |b| {
139+
b.iter_batched(
140+
|| setup(elements as usize),
141+
|mut inputs| run(&mut inputs),
131142
criterion::BatchSize::SmallInput,
132143
)
133144
});
@@ -242,61 +253,76 @@ fn packing_keyswitch<Scalar, F>(
242253
}
243254
BenchmarkType::Throughput => {
244255
bench_id = format!("{bench_name}::throughput::{name}");
245-
let blocks: usize = 1;
246-
let elements = throughput_num_threads(blocks, 1);
247-
bench_group.throughput(Throughput::Elements(elements));
248-
bench_group.bench_function(&bench_id, |b| {
249-
let setup_encrypted_values = || {
250-
let input_lwe_lists = (0..elements)
251-
.map(|_| {
252-
let mut input_lwe_list = LweCiphertextList::new(
253-
Scalar::ZERO,
254-
lwe_sk.lwe_dimension().to_lwe_size(),
255-
count,
256-
ciphertext_modulus,
257-
);
256+
let mut setup = |batch_size: usize| {
257+
let input_lwe_lists = (0..batch_size)
258+
.map(|_| {
259+
let mut input_lwe_list = LweCiphertextList::new(
260+
Scalar::ZERO,
261+
lwe_sk.lwe_dimension().to_lwe_size(),
262+
count,
263+
ciphertext_modulus,
264+
);
258265

259-
let plaintext_list = PlaintextList::new(
260-
Scalar::ZERO,
261-
PlaintextCount(input_lwe_list.lwe_ciphertext_count().0),
262-
);
266+
let plaintext_list = PlaintextList::new(
267+
Scalar::ZERO,
268+
PlaintextCount(input_lwe_list.lwe_ciphertext_count().0),
269+
);
263270

264-
encrypt_lwe_ciphertext_list(
265-
&lwe_sk,
266-
&mut input_lwe_list,
267-
&plaintext_list,
268-
params.lwe_noise_distribution.unwrap(),
269-
&mut encryption_generator,
270-
);
271+
encrypt_lwe_ciphertext_list(
272+
&lwe_sk,
273+
&mut input_lwe_list,
274+
&plaintext_list,
275+
params.lwe_noise_distribution.unwrap(),
276+
&mut encryption_generator,
277+
);
271278

272-
input_lwe_list
273-
})
274-
.collect::<Vec<_>>();
275-
276-
let output_glwes = (0..elements)
277-
.map(|_| {
278-
GlweCiphertext::new(
279-
Scalar::ZERO,
280-
glwe_sk.glwe_dimension().to_glwe_size(),
281-
glwe_sk.polynomial_size(),
282-
ciphertext_modulus,
283-
)
284-
})
285-
.collect::<Vec<_>>();
279+
input_lwe_list
280+
})
281+
.collect::<Vec<_>>();
286282

287-
(input_lwe_lists, output_glwes)
288-
};
283+
let output_glwes = (0..batch_size)
284+
.map(|_| {
285+
GlweCiphertext::new(
286+
Scalar::ZERO,
287+
glwe_sk.glwe_dimension().to_glwe_size(),
288+
glwe_sk.polynomial_size(),
289+
ciphertext_modulus,
290+
)
291+
})
292+
.collect::<Vec<_>>();
289293

290-
b.iter_batched(
291-
setup_encrypted_values,
292-
|(input_lwe_lists, mut output_glwes)| {
293-
input_lwe_lists
294-
.par_iter()
295-
.zip(output_glwes.par_iter_mut())
296-
.for_each(|(input_lwe_list, output_glwe)| {
297-
ks_op(&pksk, input_lwe_list, output_glwe);
298-
})
294+
(input_lwe_lists, output_glwes)
295+
};
296+
type Res<Scalar> = (
297+
Vec<LweCiphertextList<Vec<Scalar>>>, // input_lwe_lists
298+
Vec<GlweCiphertext<Vec<Scalar>>>, // output_glwes
299+
);
300+
let run = |inputs: &mut Res<Scalar>| {
301+
inputs.0.par_iter().zip(inputs.1.par_iter_mut()).for_each(
302+
|(input_lwe_list, output_glwe)| {
303+
ks_op(&pksk, input_lwe_list, output_glwe);
299304
},
305+
)
306+
};
307+
let elements = {
308+
#[cfg(any(feature = "gpu", feature = "hpu"))]
309+
{
310+
use benchmark::utilities::throughput_num_threads;
311+
let blocks: usize = 1;
312+
throughput_num_threads(blocks, 1) // FIXME This number of element do not
313+
// staturate the target machine
314+
}
315+
#[cfg(not(any(feature = "gpu", feature = "hpu")))]
316+
{
317+
use benchmark::find_optimal_batch::find_optimal_batch;
318+
find_optimal_batch(|inputs, _batch_size| run(inputs), &mut setup) as u64
319+
}
320+
};
321+
bench_group.throughput(Throughput::Elements(elements));
322+
bench_group.bench_function(&bench_id, |b| {
323+
b.iter_batched(
324+
|| setup(elements as usize),
325+
|mut inputs| run(&mut inputs),
300326
criterion::BatchSize::SmallInput,
301327
)
302328
});

0 commit comments

Comments
 (0)