Skip to content

Commit db18fca

Browse files
committed
fix: add clear_gpu_thread_loccals
This function is used to clear gpu thread locals. This is mainly useful to counter the 'bug' where a rayon pool does not wait for its threads to exit, which creates sync problems between the cuda driver and the cpu thread thread_local
1 parent 93bee27 commit db18fca

File tree

5 files changed

+94
-11
lines changed

5 files changed

+94
-11
lines changed

tfhe/src/high_level_api/global_state.rs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ where
187187
})
188188
}
189189

190+
#[cfg(feature = "gpu")]
191+
pub use gpu::clear_gpu_thread_locals;
190192
#[cfg(feature = "gpu")]
191193
pub(in crate::high_level_api) use gpu::with_thread_local_cuda_streams_for_gpu_indexes;
192194

@@ -222,6 +224,21 @@ mod gpu {
222224
.collect(),
223225
}
224226
}
227+
228+
fn clear(&mut self) {
229+
self.custom.take();
230+
// "Reset" the lazycells instead of emptying the vec as this allows to reuse the
231+
// the StreamPool, the streams are going to get re-created lazily again
232+
for (index, cell) in self.single.iter_mut().enumerate() {
233+
let ctor =
234+
Box::new(move || CudaStreams::new_single_gpu(GpuIndex::new(index as u32)));
235+
*cell = LazyCell::new(ctor as Box<dyn Fn() -> CudaStreams>);
236+
}
237+
}
238+
}
239+
240+
thread_local! {
241+
static POOL: RefCell<CudaStreamPool> = RefCell::new(CudaStreamPool::new());
225242
}
226243

227244
pub(in crate::high_level_api) fn with_thread_local_cuda_streams_for_gpu_indexes<
@@ -231,10 +248,6 @@ mod gpu {
231248
gpu_indexes: &[GpuIndex],
232249
func: F,
233250
) -> R {
234-
thread_local! {
235-
static POOL: RefCell<CudaStreamPool> = RefCell::new(CudaStreamPool::new());
236-
}
237-
238251
if gpu_indexes.len() == 1 {
239252
POOL.with_borrow(|pool| func(&pool.single[gpu_indexes[0].get() as usize]))
240253
} else {
@@ -296,6 +309,13 @@ mod gpu {
296309
}
297310
}
298311
}
312+
313+
/// Clears all the thread_locals that store Cuda related items
314+
/// this means keys, and other internal data, streams used
315+
pub fn clear_gpu_thread_locals() {
316+
unset_server_key();
317+
POOL.with_borrow_mut(|pool| pool.clear());
318+
}
299319
}
300320

301321
#[cfg(feature = "hpu")]

tfhe/src/high_level_api/integers/oprf.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -827,8 +827,8 @@ mod test {
827827
use crate::prelude::check_valid_cuda_malloc_assert_oom;
828828
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
829829
use crate::{
830-
unset_server_key, ClientKey, CompressedServerKey, FheInt128, FheUint32, FheUint64,
831-
GpuIndex,
830+
clear_gpu_thread_locals, ClientKey, CompressedServerKey, FheInt128, FheUint32,
831+
FheUint64, GpuIndex,
832832
};
833833
use rayon::iter::IndexedParallelIterator;
834834
use rayon::prelude::{IntoParallelRefIterator, ParallelSlice};
@@ -902,7 +902,7 @@ mod test {
902902
let idx: Vec<usize> = (0..sample_count).collect();
903903
let pool = ThreadPoolBuilder::new()
904904
.num_threads(8 * num_gpus)
905-
.exit_handler(|_| unset_server_key())
905+
.exit_handler(|_| clear_gpu_thread_locals())
906906
.build()
907907
.unwrap();
908908
let real_values: Vec<u64> = pool.install(|| {

tfhe/src/high_level_api/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ use crate::{error, Error, Versionize};
5454
use backward_compatibility::compressed_ciphertext_list::SquashedNoiseCiphertextStateVersions;
5555
pub use config::{Config, ConfigBuilder};
5656
#[cfg(feature = "gpu")]
57+
pub use global_state::clear_gpu_thread_locals;
58+
#[cfg(feature = "gpu")]
5759
pub use global_state::CudaGpuChoice;
5860
#[cfg(feature = "gpu")]
5961
pub use global_state::CustomMultiGpuIndexes;

tfhe/src/high_level_api/tests/gpu_selection.rs renamed to tfhe/src/high_level_api/tests/gpu.rs

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,67 @@ use crate::core_crypto::gpu::get_number_of_gpus;
44
use crate::high_level_api::global_state::CustomMultiGpuIndexes;
55
use crate::prelude::*;
66
use crate::{
7-
set_server_key, unset_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device,
8-
FheUint32, GpuIndex,
7+
clear_gpu_thread_locals, set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device,
8+
FheUint32, FheUint8, GpuIndex,
99
};
10+
use rayon::iter::{IntoParallelIterator, ParallelIterator};
11+
use rayon::ThreadPoolBuilder;
12+
13+
/// Regression test: dropping a rayon pool whose threads holds some thread-local GPU
14+
/// data (keys, streams, etc) can cause issue if not properly cleaned up.
15+
///
16+
/// This is because rayon does not seem to wait for the thread destruction
17+
/// which then creates ordering problems with the CUDA driver
18+
///
19+
/// The scenario is:
20+
/// 1. Create a custom rayon thread pool.
21+
/// 2. On each thread, set a GPU server key (which stores CUDA resources in thread-locals).
22+
/// 3. decrypt as decrypt init and uses a different set of cuda stream thread locals
23+
/// 4. Drop the pool
24+
#[test]
25+
fn test_drop_rayon_pool_with_gpu_server_key_thread_locals() {
26+
let config = ConfigBuilder::default().build();
27+
let cks = ClientKey::generate(config);
28+
29+
let num_gpus = get_number_of_gpus() as usize;
30+
31+
let compressed_sks = CompressedServerKey::new(&cks);
32+
let sks_vec: Vec<_> = (0..num_gpus)
33+
.map(|i| compressed_sks.decompress_to_specific_gpu(GpuIndex::new(i as u32)))
34+
.collect();
35+
36+
let pool = ThreadPoolBuilder::new()
37+
.num_threads(4 * num_gpus)
38+
.exit_handler(|_| clear_gpu_thread_locals())
39+
.build()
40+
.unwrap();
41+
42+
let results: Vec<u8> = pool.install(|| {
43+
(0..4 * num_gpus)
44+
.into_par_iter()
45+
.map_init(
46+
|| {
47+
let gpu_index = rayon::current_thread_index().unwrap_or(0) % num_gpus;
48+
set_server_key(sks_vec[gpu_index].clone());
49+
},
50+
|(), _| {
51+
let ct = FheUint8::encrypt_trivial(42u8);
52+
let result: u8 = ct.decrypt(&cks);
53+
result
54+
},
55+
)
56+
.collect()
57+
});
58+
59+
for val in &results {
60+
assert_eq!(*val, 42u8);
61+
}
62+
63+
// Explicitly drop the pool — this is where the bug manifests:
64+
// rayon threads are joined, their thread-locals (holding GPU server keys
65+
// referencing CUDA resources) are dropped.
66+
drop(pool);
67+
}
1068

1169
#[test]
1270
fn test_gpu_selection() {
@@ -187,6 +245,9 @@ fn test_specific_gpu_selection() {
187245
assert_eq!(c.current_device(), Device::CudaGpu);
188246
assert_eq!(c.gpu_indexes(), &[first_gpu]);
189247
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
190-
unset_server_key();
248+
// unset_server_key is sufficient but we use clear_gpu_thread_locals
249+
// in order to test that after calling it, the thread is still usable
250+
// (the needed thread locals will lazily recreate themselves, nothing prevents them)
251+
clear_gpu_thread_locals();
191252
}
192253
}

tfhe/src/high_level_api/tests/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
mod cpk_re_randomization;
22
#[cfg(feature = "gpu")]
3-
mod gpu_selection;
3+
mod gpu;
44
mod noise_distribution;
55
mod noise_squashing;
66
mod tags_on_entities;

0 commit comments

Comments
 (0)