Skip to content

Commit c1bf43e

Browse files
committed
feat(gpu): add a function to set a CudaLweList to 0
1 parent 95863e1 commit c1bf43e

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,19 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
220220
pub(crate) fn ciphertext_modulus(&self) -> CiphertextModulus<T> {
221221
self.0.ciphertext_modulus
222222
}
223+
224+
/// # Safety
225+
///
226+
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
227+
/// not be dropped until stream is synchronised
228+
pub unsafe fn set_to_zero_async(&mut self, streams: &CudaStreams) {
229+
self.0.d_vec.memset_async(0u64, streams, 0);
230+
}
231+
232+
pub fn set_to_zero(&mut self, streams: &CudaStreams) {
233+
unsafe {
234+
self.set_to_zero_async(streams);
235+
streams.synchronize_one(0);
236+
}
237+
}
223238
}

tfhe/src/core_crypto/gpu/vec.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,13 @@ impl<T: Numeric> CudaVec<T> {
175175
///
176176
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
177177
/// not be dropped until streams is synchronised
178-
pub unsafe fn memset_async(&mut self, value: T, streams: &CudaStreams, stream_index: u32)
179-
where
180-
T: Into<u64>,
181-
{
178+
pub unsafe fn memset_async(&mut self, value: u64, streams: &CudaStreams, stream_index: u32) {
182179
let size = self.len() * std::mem::size_of::<T>();
183180
// We check that self is not empty to avoid invalid pointers
184181
if size > 0 {
185182
cuda_memset_async(
186183
self.as_mut_c_ptr(stream_index),
187-
value.into(),
184+
value,
188185
size as u64,
189186
streams.ptr[stream_index as usize],
190187
streams.gpu_indexes[stream_index as usize].0,

0 commit comments

Comments
 (0)