Skip to content

Commit a508f4c

Browse files
committed
fix(gpu): enforce tighter bounds on compression output
1 parent dad278c commit a508f4c

File tree

5 files changed

+71
-78
lines changed

5 files changed

+71
-78
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/compression/compression_utilities.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ template <typename Torus> struct int_decompression {
102102
// Example: in the 2_2 case we are mapping a 2 bits message onto a 4 bits
103103
// space, we want to keep the original 2 bits value in the 4 bits space,
104104
// so we apply the identity and the encoding will rescale it for us.
105-
auto decompression_rescale_f = [encryption_params](Torus x) -> Torus {
106-
return x;
107-
};
105+
auto decompression_rescale_f = [](Torus x) -> Torus { return x; };
108106

109107
auto effective_compression_message_modulus =
110108
encryption_params.carry_modulus;

backends/tfhe-cuda-backend/cuda/src/crypto/ciphertext.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ host_sample_extract(cudaStream_t stream, uint32_t gpu_index,
5757
uint32_t const *nth_array, uint32_t num_nths,
5858
uint32_t lwe_per_glwe, uint32_t glwe_dimension) {
5959
cuda_set_device(gpu_index);
60-
6160
dim3 grid(num_nths);
6261
dim3 thds(params::degree / params::opt);
6362
sample_extract<Torus, params><<<grid, thds, 0, stream>>>(

backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh

Lines changed: 61 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ __global__ void pack(Torus *array_out, Torus *array_in, uint32_t log_modulus,
4343
}
4444
}
4545

46+
/// Packs `num_lwes` LWE-ciphertext contained in `num_glwes` GLWE-ciphertext in
47+
/// a compressed array This function follows the naming used in the CPU
48+
/// implementation
4649
template <typename Torus>
4750
__host__ void host_pack(cudaStream_t stream, uint32_t gpu_index,
4851
Torus *array_out, Torus *array_in, uint32_t num_glwes,
@@ -55,26 +58,23 @@ __host__ void host_pack(cudaStream_t stream, uint32_t gpu_index,
5558

5659
auto log_modulus = mem_ptr->storage_log_modulus;
5760
// [0..num_glwes-1) GLWEs
58-
auto in_len = (compression_params.glwe_dimension + 1) *
59-
compression_params.polynomial_size;
61+
auto in_len = num_glwes * compression_params.glwe_dimension *
62+
compression_params.polynomial_size +
63+
num_lwes;
64+
6065
auto number_bits_to_pack = in_len * log_modulus;
61-
auto nbits = sizeof(Torus) * 8;
66+
6267
// number_bits_to_pack.div_ceil(Scalar::BITS)
68+
auto nbits = sizeof(Torus) * 8;
6369
auto out_len = (number_bits_to_pack + nbits - 1) / nbits;
6470

65-
// Last GLWE
66-
number_bits_to_pack = in_len * log_modulus;
67-
auto last_out_len = (number_bits_to_pack + nbits - 1) / nbits;
68-
69-
auto num_coeffs = (num_glwes - 1) * out_len + last_out_len;
70-
7171
int num_blocks = 0, num_threads = 0;
72-
getNumBlocksAndThreads(num_coeffs, 1024, num_blocks, num_threads);
72+
getNumBlocksAndThreads(out_len, 1024, num_blocks, num_threads);
7373

7474
dim3 grid(num_blocks);
7575
dim3 threads(num_threads);
7676
pack<Torus><<<grid, threads, 0, stream>>>(array_out, array_in, log_modulus,
77-
num_coeffs, in_len, out_len);
77+
out_len, in_len, out_len);
7878
check_cuda_error(cudaGetLastError());
7979
}
8080

@@ -99,14 +99,13 @@ host_integer_compress(cudaStream_t const *streams, uint32_t const *gpu_indexes,
9999
uint32_t lwe_in_size = input_lwe_dimension + 1;
100100
uint32_t glwe_out_size = (compression_params.glwe_dimension + 1) *
101101
compression_params.polynomial_size;
102-
uint32_t num_glwes_for_compression =
103-
num_radix_blocks / mem_ptr->lwe_per_glwe + 1;
102+
uint32_t num_glwes =
103+
(num_radix_blocks + mem_ptr->lwe_per_glwe - 1) / mem_ptr->lwe_per_glwe;
104104

105105
// Keyswitch LWEs to GLWE
106106
auto tmp_glwe_array_out = mem_ptr->tmp_glwe_array_out;
107107
cuda_memset_async(tmp_glwe_array_out, 0,
108-
num_glwes_for_compression *
109-
(compression_params.glwe_dimension + 1) *
108+
num_glwes * (compression_params.glwe_dimension + 1) *
110109
compression_params.polynomial_size * sizeof(Torus),
111110
streams[0], gpu_indexes[0]);
112111
auto fp_ks_buffer = mem_ptr->fp_ks_buffer;
@@ -131,23 +130,21 @@ host_integer_compress(cudaStream_t const *streams, uint32_t const *gpu_indexes,
131130
// Modulus switch
132131
host_modulus_switch_inplace<Torus>(
133132
streams[0], gpu_indexes[0], tmp_glwe_array_out,
134-
num_glwes_for_compression * (compression_params.glwe_dimension + 1) *
135-
compression_params.polynomial_size,
133+
num_glwes * compression_params.glwe_dimension *
134+
compression_params.polynomial_size +
135+
num_radix_blocks,
136136
mem_ptr->storage_log_modulus);
137137

138138
host_pack<Torus>(streams[0], gpu_indexes[0], glwe_array_out,
139-
tmp_glwe_array_out, num_glwes_for_compression,
140-
num_radix_blocks, mem_ptr);
139+
tmp_glwe_array_out, num_glwes, num_radix_blocks, mem_ptr);
141140
}
142141

143142
template <typename Torus>
144143
__global__ void extract(Torus *glwe_array_out, Torus const *array_in,
145-
uint32_t index, uint32_t log_modulus,
146-
uint32_t input_len, uint32_t initial_out_len) {
144+
uint32_t log_modulus, uint32_t initial_out_len) {
147145
auto nbits = sizeof(Torus) * 8;
148146

149147
auto i = threadIdx.x + blockIdx.x * blockDim.x;
150-
auto chunk_array_in = array_in + index * input_len;
151148
if (i < initial_out_len) {
152149
// Unpack
153150
Torus mask = ((Torus)1 << log_modulus) - 1;
@@ -161,12 +158,11 @@ __global__ void extract(Torus *glwe_array_out, Torus const *array_in,
161158

162159
Torus unpacked_i;
163160
if (start_block == end_block_inclusive) {
164-
auto single_part = chunk_array_in[start_block] >> start_remainder;
161+
auto single_part = array_in[start_block] >> start_remainder;
165162
unpacked_i = single_part & mask;
166163
} else {
167-
auto first_part = chunk_array_in[start_block] >> start_remainder;
168-
auto second_part = chunk_array_in[start_block + 1]
169-
<< (nbits - start_remainder);
164+
auto first_part = array_in[start_block] >> start_remainder;
165+
auto second_part = array_in[start_block + 1] << (nbits - start_remainder);
170166

171167
unpacked_i = (first_part | second_part) & mask;
172168
}
@@ -177,6 +173,7 @@ __global__ void extract(Torus *glwe_array_out, Torus const *array_in,
177173
}
178174

179175
/// Extracts the glwe_index-nth GLWE ciphertext
176+
/// This function follows the naming used in the CPU implementation
180177
template <typename Torus>
181178
__host__ void host_extract(cudaStream_t stream, uint32_t gpu_index,
182179
Torus *glwe_array_out, Torus const *array_in,
@@ -188,36 +185,51 @@ __host__ void host_extract(cudaStream_t stream, uint32_t gpu_index,
188185
cuda_set_device(gpu_index);
189186

190187
auto compression_params = mem_ptr->compression_params;
191-
192188
auto log_modulus = mem_ptr->storage_log_modulus;
189+
auto glwe_ciphertext_size = (compression_params.glwe_dimension + 1) *
190+
compression_params.polynomial_size;
191+
192+
uint32_t body_count = mem_ptr->body_count;
193+
auto num_glwes = (body_count + compression_params.polynomial_size - 1) /
194+
compression_params.polynomial_size;
195+
196+
// Compressed length of the compressed GLWE we want to extract
197+
if (mem_ptr->body_count % compression_params.polynomial_size == 0)
198+
body_count = compression_params.polynomial_size;
199+
else if (glwe_index == num_glwes - 1)
200+
body_count = mem_ptr->body_count % compression_params.polynomial_size;
201+
else
202+
body_count = compression_params.polynomial_size;
193203

194-
uint32_t body_count =
195-
std::min(mem_ptr->body_count, compression_params.polynomial_size);
196204
auto initial_out_len =
197205
compression_params.glwe_dimension * compression_params.polynomial_size +
198206
body_count;
199207

200-
auto compressed_glwe_accumulator_size =
201-
(compression_params.glwe_dimension + 1) *
202-
compression_params.polynomial_size;
203-
auto number_bits_to_unpack = compressed_glwe_accumulator_size * log_modulus;
208+
// Calculates how many bits this particular GLWE shall use
209+
auto number_bits_to_unpack = initial_out_len * log_modulus;
204210
auto nbits = sizeof(Torus) * 8;
205-
// number_bits_to_unpack.div_ceil(Scalar::BITS)
206211
auto input_len = (number_bits_to_unpack + nbits - 1) / nbits;
207212

208-
// We assure the tail of the glwe is zeroed
209-
auto zeroed_slice = glwe_array_out + initial_out_len;
210-
cuda_memset_async(zeroed_slice, 0,
211-
(compression_params.polynomial_size - body_count) *
212-
sizeof(Torus),
213-
stream, gpu_index);
213+
// Calculates how many bits a full-packed GLWE shall use
214+
number_bits_to_unpack = glwe_ciphertext_size * log_modulus;
215+
auto len = (number_bits_to_unpack + nbits - 1) / nbits;
216+
// Uses that length to set the input pointer
217+
auto chunk_array_in = array_in + glwe_index * len;
218+
219+
// Ensure the tail of the GLWE is zeroed
220+
if (initial_out_len < glwe_ciphertext_size) {
221+
auto zeroed_slice = glwe_array_out + initial_out_len;
222+
cuda_memset_async(glwe_array_out, 0,
223+
(glwe_ciphertext_size - initial_out_len) * sizeof(Torus),
224+
stream, gpu_index);
225+
}
226+
214227
int num_blocks = 0, num_threads = 0;
215228
getNumBlocksAndThreads(initial_out_len, 128, num_blocks, num_threads);
216229
dim3 grid(num_blocks);
217230
dim3 threads(num_threads);
218-
extract<Torus><<<grid, threads, 0, stream>>>(glwe_array_out, array_in,
219-
glwe_index, log_modulus,
220-
input_len, initial_out_len);
231+
extract<Torus><<<grid, threads, 0, stream>>>(glwe_array_out, chunk_array_in,
232+
log_modulus, initial_out_len);
221233
check_cuda_error(cudaGetLastError());
222234
}
223235

@@ -241,8 +253,7 @@ __host__ void host_integer_decompress(
241253
PANIC("Cuda error: wrong number of LWEs in decompress: the number of LWEs "
242254
"should be the same as indexes_array_size.")
243255

244-
// the first element is the last index in h_indexes_array that lies in the
245-
// related GLWE
256+
// the first element is the number of LWEs that lies in the related GLWE
246257
std::vector<std::pair<int, Torus *>> glwe_vec;
247258

248259
// Extract all GLWEs
@@ -253,7 +264,7 @@ __host__ void host_integer_decompress(
253264
auto extracted_glwe = h_mem_ptr->tmp_extracted_glwe;
254265
host_extract<Torus>(streams[0], gpu_indexes[0], extracted_glwe,
255266
d_packed_glwe_in, current_glwe_index, h_mem_ptr);
256-
glwe_vec.push_back(std::make_pair(0, extracted_glwe));
267+
glwe_vec.push_back(std::make_pair(1, extracted_glwe));
257268
for (int i = 1; i < indexes_array_size; i++) {
258269
auto glwe_index = h_indexes_array[i] / lwe_per_glwe;
259270
if (glwe_index != current_glwe_index) {
@@ -262,9 +273,9 @@ __host__ void host_integer_decompress(
262273
// Extracts a new GLWE
263274
host_extract<Torus>(streams[0], gpu_indexes[0], extracted_glwe,
264275
d_packed_glwe_in, glwe_index, h_mem_ptr);
265-
glwe_vec.push_back(std::make_pair(i, extracted_glwe));
276+
glwe_vec.push_back(std::make_pair(1, extracted_glwe));
266277
} else {
267-
// Updates the index
278+
// Updates the quantity
268279
++glwe_vec.back().first;
269280
}
270281
}
@@ -275,17 +286,16 @@ __host__ void host_integer_decompress(
275286
uint32_t current_idx = 0;
276287
auto d_indexes_array_chunk = d_indexes_array;
277288
for (const auto &max_idx_and_glwe : glwe_vec) {
278-
const uint32_t last_idx = max_idx_and_glwe.first;
289+
const auto num_lwes = max_idx_and_glwe.first;
279290
extracted_glwe = max_idx_and_glwe.second;
280291

281-
auto num_lwes = last_idx + 1 - current_idx;
282292
cuda_glwe_sample_extract_64(
283293
streams[0], gpu_indexes[0], extracted_lwe, extracted_glwe,
284294
d_indexes_array_chunk, num_lwes, compression_params.polynomial_size,
285295
compression_params.glwe_dimension, compression_params.polynomial_size);
286296
d_indexes_array_chunk += num_lwes;
287297
extracted_lwe += num_lwes * lwe_accumulator_size;
288-
current_idx = last_idx;
298+
current_idx += num_lwes;
289299
}
290300

291301
// Reset

tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::core_crypto::entities::packed_integers::PackedIntegers;
22
use crate::core_crypto::gpu::vec::{CudaVec, GpuIndex};
33
use crate::core_crypto::gpu::CudaStreams;
44
use crate::core_crypto::prelude::compressed_modulus_switched_glwe_ciphertext::CompressedModulusSwitchedGlweCiphertext;
5-
use crate::core_crypto::prelude::{glwe_ciphertext_size, CiphertextCount, LweCiphertextCount};
5+
use crate::core_crypto::prelude::{CiphertextCount, LweCiphertextCount};
66
use crate::integer::ciphertext::{CompressedCiphertextList, DataKind};
77
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
88
use crate::integer::gpu::ciphertext::{
@@ -357,25 +357,11 @@ impl CompressedCiphertextList {
357357
let message_modulus = self.packed_list.message_modulus;
358358
let carry_modulus = self.packed_list.carry_modulus;
359359

360-
let mut flat_cpu_data = modulus_switched_glwe_ciphertext_list
360+
let flat_cpu_data = modulus_switched_glwe_ciphertext_list
361361
.iter()
362362
.flat_map(|ct| ct.packed_integers.packed_coeffs.clone())
363363
.collect_vec();
364364

365-
let glwe_ciphertext_count = self.packed_list.modulus_switched_glwe_ciphertext_list.len();
366-
let glwe_size = self.packed_list.modulus_switched_glwe_ciphertext_list[0]
367-
.glwe_dimension()
368-
.to_glwe_size();
369-
let polynomial_size =
370-
self.packed_list.modulus_switched_glwe_ciphertext_list[0].polynomial_size();
371-
372-
// FIXME: have a more precise memory handling, this is too long and should be "just" the
373-
// original flat_cpu_data.len()
374-
let unpacked_glwe_ciphertext_flat_len =
375-
glwe_ciphertext_count * glwe_ciphertext_size(glwe_size, polynomial_size);
376-
377-
flat_cpu_data.resize(unpacked_glwe_ciphertext_flat_len, 0u64);
378-
379365
let flat_gpu_data = unsafe {
380366
let v = CudaVec::from_cpu_async(flat_cpu_data.as_slice(), streams, 0);
381367
streams.synchronize();

tfhe/src/integer/gpu/list_compression/server_keys.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
33
use crate::core_crypto::gpu::vec::CudaVec;
44
use crate::core_crypto::gpu::CudaStreams;
55
use crate::core_crypto::prelude::{
6-
glwe_ciphertext_size, CiphertextModulus, CiphertextModulusLog, GlweCiphertextCount,
7-
LweCiphertextCount, PolynomialSize,
6+
glwe_ciphertext_size, glwe_mask_size, CiphertextModulus, CiphertextModulusLog,
7+
GlweCiphertextCount, LweCiphertextCount, PolynomialSize,
88
};
99
use crate::integer::ciphertext::DataKind;
1010
use crate::integer::compression_keys::CompressionKey;
@@ -173,12 +173,12 @@ impl CudaCompressionKey {
173173
.sum();
174174

175175
let num_glwes = num_lwes.div_ceil(self.lwe_per_glwe.0);
176-
let glwe_ciphertext_size =
177-
glwe_ciphertext_size(compressed_glwe_size, compressed_polynomial_size);
176+
let glwe_mask_size = glwe_mask_size(
177+
compressed_glwe_size.to_glwe_dimension(),
178+
compressed_polynomial_size,
179+
);
178180
// The number of u64 (both mask and bodies)
179-
// FIXME: have a more precise memory handling, this is too long and should be
180-
// num_glwes * glwe_mask_size + num_lwes
181-
let uncompressed_len = num_glwes * glwe_ciphertext_size;
181+
let uncompressed_len = num_glwes * glwe_mask_size + num_lwes;
182182
let number_bits_to_pack = uncompressed_len * self.storage_log_modulus.0;
183183
let compressed_len = number_bits_to_pack.div_ceil(u64::BITS as usize);
184184
let mut packed_glwe_list = CudaVec::new(compressed_len, streams, 0);

0 commit comments

Comments
 (0)