Skip to content

Commit 8962d1f

Browse files
committed
chore(gpu): refactor full propagation to track noise / degree
1 parent f7655cc commit 8962d1f

File tree

11 files changed

+92
-62
lines changed

11 files changed

+92
-62
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/integer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ void scratch_cuda_full_propagation_64(
113113

114114
void cuda_full_propagation_64_inplace(void *const *streams,
115115
uint32_t const *gpu_indexes,
116-
uint32_t gpu_count, void *input_blocks,
116+
uint32_t gpu_count,
117+
CudaRadixCiphertextFFI *input_blocks,
117118
int8_t *mem_ptr, void *const *ksks,
118119
void *const *bsks, uint32_t num_blocks);
119120

backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -841,8 +841,8 @@ template <typename Torus> struct int_fullprop_buffer {
841841

842842
int_radix_lut<Torus> *lut;
843843

844-
Torus *tmp_small_lwe_vector;
845-
Torus *tmp_big_lwe_vector;
844+
CudaRadixCiphertextFFI *tmp_small_lwe_vector;
845+
CudaRadixCiphertextFFI *tmp_big_lwe_vector;
846846

847847
int_fullprop_buffer(cudaStream_t const *streams, uint32_t const *gpu_indexes,
848848
uint32_t gpu_count, int_radix_params params,
@@ -889,17 +889,14 @@ template <typename Torus> struct int_fullprop_buffer {
889889

890890
lut->broadcast_lut(streams, gpu_indexes, 0);
891891

892-
// Temporary arrays
893-
Torus small_vector_size =
894-
2 * (params.small_lwe_dimension + 1) * sizeof(Torus);
895-
Torus big_vector_size =
896-
2 * (params.glwe_dimension * params.polynomial_size + 1) *
897-
sizeof(Torus);
898-
899-
tmp_small_lwe_vector = (Torus *)cuda_malloc_async(
900-
small_vector_size, streams[0], gpu_indexes[0]);
901-
tmp_big_lwe_vector = (Torus *)cuda_malloc_async(
902-
big_vector_size, streams[0], gpu_indexes[0]);
892+
tmp_small_lwe_vector = new CudaRadixCiphertextFFI;
893+
create_zero_radix_ciphertext_async<Torus>(streams[0], gpu_indexes[0],
894+
tmp_small_lwe_vector, 2,
895+
params.small_lwe_dimension);
896+
tmp_big_lwe_vector = new CudaRadixCiphertextFFI;
897+
create_zero_radix_ciphertext_async<Torus>(streams[0], gpu_indexes[0],
898+
tmp_big_lwe_vector, 2,
899+
params.big_lwe_dimension);
903900
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
904901
free(h_lwe_indexes);
905902
}
@@ -911,8 +908,10 @@ template <typename Torus> struct int_fullprop_buffer {
911908
lut->release(streams, gpu_indexes, 1);
912909
delete lut;
913910

914-
cuda_drop_async(tmp_small_lwe_vector, streams[0], gpu_indexes[0]);
915-
cuda_drop_async(tmp_big_lwe_vector, streams[0], gpu_indexes[0]);
911+
release_radix_ciphertext(streams[0], gpu_indexes[0], tmp_small_lwe_vector);
912+
delete tmp_small_lwe_vector;
913+
release_radix_ciphertext(streams[0], gpu_indexes[0], tmp_big_lwe_vector);
914+
delete tmp_big_lwe_vector;
916915
}
917916
};
918917

backends/tfhe-cuda-backend/cuda/src/integer/integer.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44

55
void cuda_full_propagation_64_inplace(void *const *streams,
66
uint32_t const *gpu_indexes,
7-
uint32_t gpu_count, void *input_blocks,
7+
uint32_t gpu_count,
8+
CudaRadixCiphertextFFI *input_blocks,
89
int8_t *mem_ptr, void *const *ksks,
910
void *const *bsks, uint32_t num_blocks) {
1011

1112
int_fullprop_buffer<uint64_t> *buffer =
1213
(int_fullprop_buffer<uint64_t> *)mem_ptr;
1314

14-
host_full_propagate_inplace<uint64_t>(
15-
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
16-
static_cast<uint64_t *>(input_blocks), buffer, (uint64_t **)(ksks), bsks,
17-
num_blocks);
15+
host_full_propagate_inplace<uint64_t>((cudaStream_t *)(streams), gpu_indexes,
16+
gpu_count, input_blocks, buffer,
17+
(uint64_t **)(ksks), bsks, num_blocks);
1818
}
1919

2020
void scratch_cuda_full_propagation_64(

backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
689689
cuda_memcpy_async_to_cpu(&lut_indexes, lut->get_lut_indexes(0, 0),
690690
lut->num_blocks * sizeof(Torus), streams[0],
691691
gpu_indexes[0]);
692+
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
692693
for (uint i = 0; i < num_radix_blocks; i++) {
693694
lwe_array_out->degrees[i] = lut->degrees[lut_indexes[i]];
694695
lwe_array_out->noise_levels[i] = NoiseLevel::NOMINAL;
@@ -964,6 +965,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table_kb(
964965
cuda_memcpy_async_to_cpu(&lut_indexes, lut->get_lut_indexes(0, 0),
965966
lut->num_blocks * sizeof(Torus), streams[0],
966967
gpu_indexes[0]);
968+
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
967969
for (uint i = 0; i < lwe_array_out->num_radix_blocks; i++) {
968970
lwe_array_out->degrees[i] = lut->degrees[i % lut->num_blocks];
969971
lwe_array_out->noise_levels[i] = NoiseLevel::NOMINAL;
@@ -1173,6 +1175,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
11731175
cuda_memcpy_async_to_cpu(&lut_indexes, lut->get_lut_indexes(0, 0),
11741176
lut->num_blocks * sizeof(Torus), streams[0],
11751177
gpu_indexes[0]);
1178+
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
11761179
for (uint i = 0; i < num_radix_blocks; i++) {
11771180
lwe_array_out->degrees[i] = lut->degrees[lut_indexes[i]];
11781181
lwe_array_out->noise_levels[i] = NoiseLevel::NOMINAL;
@@ -1974,7 +1977,8 @@ void host_compute_shifted_blocks_and_borrow_states(
19741977
template <typename Torus>
19751978
void host_full_propagate_inplace(cudaStream_t const *streams,
19761979
uint32_t const *gpu_indexes,
1977-
uint32_t gpu_count, Torus *input_blocks,
1980+
uint32_t gpu_count,
1981+
CudaRadixCiphertextFFI *input_blocks,
19781982
int_fullprop_buffer<Torus> *mem_ptr,
19791983
Torus *const *ksks, void *const *bsks,
19801984
uint32_t num_blocks) {
@@ -1987,39 +1991,51 @@ void host_full_propagate_inplace(cudaStream_t const *streams,
19871991
uint32_t num_many_lut = 1;
19881992
uint32_t lut_stride = 0;
19891993
for (int i = 0; i < num_blocks; i++) {
1990-
auto cur_input_block = &input_blocks[i * big_lwe_size];
1994+
CudaRadixCiphertextFFI cur_input_block;
1995+
as_radix_ciphertext_slice<Torus>(&cur_input_block, input_blocks, i, i + 1);
19911996

19921997
/// Since the keyswitch is done on one input only, use only 1 GPU
19931998
execute_keyswitch_async<Torus>(
1994-
streams, gpu_indexes, 1, mem_ptr->tmp_small_lwe_vector,
1995-
mem_ptr->lut->lwe_trivial_indexes, cur_input_block,
1999+
streams, gpu_indexes, 1, (Torus *)(mem_ptr->tmp_small_lwe_vector->ptr),
2000+
mem_ptr->lut->lwe_trivial_indexes, (Torus *)cur_input_block.ptr,
19962001
mem_ptr->lut->lwe_trivial_indexes, ksks, params.big_lwe_dimension,
19972002
params.small_lwe_dimension, params.ks_base_log, params.ks_level, 1);
19982003

1999-
cuda_memcpy_async_gpu_to_gpu(&mem_ptr->tmp_small_lwe_vector[small_lwe_size],
2000-
mem_ptr->tmp_small_lwe_vector,
2001-
small_lwe_size * sizeof(Torus), streams[0],
2002-
gpu_indexes[0]);
2004+
copy_radix_ciphertext_slice_async<Torus>(
2005+
streams[0], gpu_indexes[0], mem_ptr->tmp_small_lwe_vector, 1, 2,
2006+
mem_ptr->tmp_small_lwe_vector, 0, 1);
20032007

20042008
execute_pbs_async<Torus>(
2005-
streams, gpu_indexes, 1, mem_ptr->tmp_big_lwe_vector,
2009+
streams, gpu_indexes, 1, (Torus *)mem_ptr->tmp_big_lwe_vector->ptr,
20062010
mem_ptr->lut->lwe_trivial_indexes, mem_ptr->lut->lut_vec,
2007-
mem_ptr->lut->lut_indexes_vec, mem_ptr->tmp_small_lwe_vector,
2011+
mem_ptr->lut->lut_indexes_vec,
2012+
(Torus *)mem_ptr->tmp_small_lwe_vector->ptr,
20082013
mem_ptr->lut->lwe_trivial_indexes, bsks, mem_ptr->lut->buffer,
20092014
params.glwe_dimension, params.small_lwe_dimension,
20102015
params.polynomial_size, params.pbs_base_log, params.pbs_level,
20112016
params.grouping_factor, 2, params.pbs_type, num_many_lut, lut_stride);
20122017

2013-
cuda_memcpy_async_gpu_to_gpu(
2014-
(void *)cur_input_block, mem_ptr->tmp_big_lwe_vector,
2015-
big_lwe_size * sizeof(Torus), streams[0], gpu_indexes[0]);
2018+
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
2019+
&cur_input_block, 0, 1,
2020+
mem_ptr->tmp_big_lwe_vector, 0, 1);
2021+
Torus lut_indexes[mem_ptr->lut->num_blocks];
2022+
cuda_memcpy_async_to_cpu(&lut_indexes, mem_ptr->lut->get_lut_indexes(0, 0),
2023+
mem_ptr->lut->num_blocks * sizeof(Torus),
2024+
streams[0], gpu_indexes[0]);
2025+
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
2026+
input_blocks->degrees[i] = mem_ptr->lut->degrees[lut_indexes[0]];
2027+
input_blocks->noise_levels[i] = NoiseLevel::NOMINAL;
20162028

20172029
if (i < num_blocks - 1) {
2018-
auto next_input_block = &input_blocks[(i + 1) * big_lwe_size];
2019-
legacy_host_addition<Torus>(streams[0], gpu_indexes[0], next_input_block,
2020-
(Torus const *)next_input_block,
2021-
&mem_ptr->tmp_big_lwe_vector[big_lwe_size],
2022-
params.big_lwe_dimension, 1);
2030+
CudaRadixCiphertextFFI next_input_block;
2031+
as_radix_ciphertext_slice<Torus>(&next_input_block, input_blocks, i + 1,
2032+
i + 2);
2033+
CudaRadixCiphertextFFI second_input;
2034+
as_radix_ciphertext_slice<Torus>(&second_input,
2035+
mem_ptr->tmp_big_lwe_vector, 1, 2);
2036+
2037+
host_addition<Torus>(streams[0], gpu_indexes[0], &next_input_block,
2038+
&next_input_block, &second_input, 1);
20232039
}
20242040
}
20252041
}

backends/tfhe-cuda-backend/src/bindings.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ unsafe extern "C" {
319319
streams: *const *mut ffi::c_void,
320320
gpu_indexes: *const u32,
321321
gpu_count: u32,
322-
input_blocks: *mut ffi::c_void,
322+
input_blocks: *mut CudaRadixCiphertextFFI,
323323
mem_ptr: *mut i8,
324324
ksks: *const *mut ffi::c_void,
325325
bsks: *const *mut ffi::c_void,

tfhe/src/integer/gpu/mod.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,7 @@ pub unsafe fn unchecked_scalar_comparison_integer_radix_kb_async<T: UnsignedInte
12491249
/// is required
12501250
pub unsafe fn full_propagate_assign_async<T: UnsignedInteger, B: Numeric>(
12511251
streams: &CudaStreams,
1252-
radix_lwe_input: &mut CudaVec<T>,
1252+
radix_lwe_input: &mut CudaRadixCiphertext,
12531253
bootstrapping_key: &CudaVec<B>,
12541254
keyswitch_key: &CudaVec<T>,
12551255
lwe_dimension: LweDimension,
@@ -1267,7 +1267,7 @@ pub unsafe fn full_propagate_assign_async<T: UnsignedInteger, B: Numeric>(
12671267
) {
12681268
assert_eq!(
12691269
streams.gpu_indexes[0],
1270-
radix_lwe_input.gpu_index(0),
1270+
radix_lwe_input.d_blocks.0.d_vec.gpu_index(0),
12711271
"GPU error: all data should reside on the same GPU."
12721272
);
12731273
assert_eq!(
@@ -1281,6 +1281,23 @@ pub unsafe fn full_propagate_assign_async<T: UnsignedInteger, B: Numeric>(
12811281
"GPU error: all data should reside on the same GPU."
12821282
);
12831283
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
1284+
let mut radix_lwe_input_degrees = radix_lwe_input
1285+
.info
1286+
.blocks
1287+
.iter()
1288+
.map(|b| b.degree.0)
1289+
.collect();
1290+
let mut radix_lwe_input_noise_levels = radix_lwe_input
1291+
.info
1292+
.blocks
1293+
.iter()
1294+
.map(|b| b.noise_level.0)
1295+
.collect();
1296+
let mut cuda_ffi_radix_lwe_input = prepare_cuda_radix_ffi(
1297+
radix_lwe_input,
1298+
&mut radix_lwe_input_degrees,
1299+
&mut radix_lwe_input_noise_levels,
1300+
);
12841301
scratch_cuda_full_propagation_64(
12851302
streams.ptr.as_ptr(),
12861303
streams.gpu_indexes_ptr(),
@@ -1303,7 +1320,7 @@ pub unsafe fn full_propagate_assign_async<T: UnsignedInteger, B: Numeric>(
13031320
streams.ptr.as_ptr(),
13041321
streams.gpu_indexes_ptr(),
13051322
streams.len() as u32,
1306-
radix_lwe_input.as_mut_c_ptr(0),
1323+
&mut cuda_ffi_radix_lwe_input,
13071324
mem_ptr,
13081325
keyswitch_key.ptr.as_ptr(),
13091326
bootstrapping_key.ptr.as_ptr(),
@@ -1315,6 +1332,7 @@ pub unsafe fn full_propagate_assign_async<T: UnsignedInteger, B: Numeric>(
13151332
streams.len() as u32,
13161333
std::ptr::addr_of_mut!(mem_ptr),
13171334
);
1335+
update_noise_degree(radix_lwe_input, &cuda_ffi_radix_lwe_input);
13181336
}
13191337

13201338
#[allow(clippy::too_many_arguments)]

tfhe/src/integer/gpu/server_key/radix/add.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ impl CudaServerKey {
569569
let output_flag = OutputFlag::from_signedness(CudaSignedRadixCiphertext::IS_SIGNED);
570570

571571
let mut ct_res = lhs.duplicate_async(stream);
572-
let mut carry_out: CudaSignedRadixCiphertext = self
572+
let carry_out: CudaSignedRadixCiphertext = self
573573
.add_and_propagate_single_carry_assign_async(
574574
&mut ct_res,
575575
rhs,
@@ -578,14 +578,6 @@ impl CudaServerKey {
578578
output_flag,
579579
);
580580

581-
if lhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO
582-
&& rhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO
583-
{
584-
carry_out.as_mut().info = carry_out.as_ref().info.boolean_info(NoiseLevel::ZERO);
585-
} else {
586-
carry_out.as_mut().info = carry_out.as_ref().info.boolean_info(NoiseLevel::NOMINAL);
587-
}
588-
589581
let ct_overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(carry_out.ciphertext);
590582

591583
(ct_res, ct_overflowed)

tfhe/src/integer/gpu/server_key/radix/mod.rs

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ impl CudaServerKey {
383383
CudaBootstrappingKey::Classic(d_bsk) => {
384384
full_propagate_assign_async(
385385
streams,
386-
&mut ciphertext.d_blocks.0.d_vec,
386+
ciphertext,
387387
&d_bsk.d_vec,
388388
&self.key_switching_key.d_vec,
389389
d_bsk.input_lwe_dimension(),
@@ -403,7 +403,7 @@ impl CudaServerKey {
403403
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
404404
full_propagate_assign_async(
405405
streams,
406-
&mut ciphertext.d_blocks.0.d_vec,
406+
ciphertext,
407407
&d_multibit_bsk.d_vec,
408408
&self.key_switching_key.d_vec,
409409
d_multibit_bsk.input_lwe_dimension(),
@@ -422,14 +422,6 @@ impl CudaServerKey {
422422
}
423423
}
424424
}
425-
ciphertext.info.blocks.iter_mut().for_each(|b| {
426-
b.degree = Degree::new(b.message_modulus.0 - 1);
427-
b.noise_level = if b.noise_level == NoiseLevel::ZERO {
428-
NoiseLevel::ZERO
429-
} else {
430-
NoiseLevel::NOMINAL
431-
};
432-
});
433425
}
434426

435427
/// Prepend trivial zero LSB blocks to an existing [`CudaUnsignedRadixCiphertext`] or

tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,10 @@ where
295295
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
296296
);
297297
assert_eq!(encrypted_overflow.0.degree.get(), 1);
298+
#[cfg(feature = "gpu")]
299+
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::NOMINAL);
300+
301+
#[cfg(not(feature = "gpu"))]
298302
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
299303
}
300304

tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_sub.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,10 @@ where
224224
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
225225
);
226226
assert_eq!(encrypted_overflow.0.degree.get(), 1);
227+
#[cfg(feature = "gpu")]
228+
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::NOMINAL);
229+
230+
#[cfg(not(feature = "gpu"))]
227231
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
228232
}
229233

0 commit comments

Comments
 (0)