Skip to content

Commit 7a4d97e

Browse files
Fix deadlock when copy (#1041)
1 parent 7e96681 commit 7a4d97e

File tree

4 files changed

+65
-24
lines changed

4 files changed

+65
-24
lines changed

crates/cubecl-common/src/device.rs

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ mod context {
184184
/// Handle for accessing a [DeviceState] associated with a specific device.
185185
pub struct DeviceContext<S: DeviceState> {
186186
lock: DeviceStateLock,
187+
lock_kind: Arc<ReentrantMutex<()>>,
187188
device_id: DeviceId,
188189
_phantom: PhantomData<S>,
189190
}
@@ -195,6 +196,7 @@ mod context {
195196
fn clone(&self) -> Self {
196197
Self {
197198
lock: self.lock.clone(),
199+
lock_kind: self.lock_kind.clone(),
198200
_phantom: self._phantom,
199201
device_id: self.device_id,
200202
}
@@ -295,6 +297,14 @@ mod context {
295297
Ok(lock)
296298
}
297299

300+
/// Locks all devices under the same kind.
301+
///
302+
/// This is useful when you need mutable access to multiple devices at once, which can lead
303+
/// to deadlocks.
304+
pub fn lock_device_kind(&self) -> ReentrantMutexGuard<'_, ()> {
305+
self.lock_kind.lock()
306+
}
307+
298308
/// Locks the current device making sure this device can be used.
299309
pub fn lock_device(&self) -> DeviceGuard<'_> {
300310
let state = self.lock.lock.lock();
@@ -367,8 +377,14 @@ mod context {
367377

368378
static GLOBAL: spin::Mutex<DeviceLocator> = spin::Mutex::new(DeviceLocator { state: None });
369379

380+
#[derive(Default)]
381+
struct DeviceLocatorState {
382+
device: HashMap<Key, DeviceStateLock>,
383+
device_kind: HashMap<TypeId, Arc<ReentrantMutex<()>>>,
384+
}
385+
370386
struct DeviceLocator {
371-
state: Option<HashMap<Key, DeviceStateLock>>,
387+
state: Option<DeviceLocatorState>,
372388
}
373389

374390
#[derive(Clone)]
@@ -383,18 +399,19 @@ mod context {
383399
impl DeviceStateLock {
384400
fn locate<D: Device + 'static, S: DeviceState>(device: &D) -> DeviceContext<S> {
385401
let id = device.to_id();
402+
let kind = TypeId::of::<D>();
386403
let key = (id, TypeId::of::<D>());
387404
let mut global = GLOBAL.lock();
388405

389-
let map = match &mut global.state {
406+
let locator_state = match &mut global.state {
390407
Some(state) => state,
391408
None => {
392-
global.state = Some(HashMap::default());
409+
global.state = Some(Default::default());
393410
global.state.as_mut().expect("Just created Option::Some")
394411
}
395412
};
396413

397-
let lock = match map.get(&key) {
414+
let lock = match locator_state.device.get(&key) {
398415
Some(value) => value.clone(),
399416
None => {
400417
let state = DeviceStateMap::new();
@@ -403,13 +420,31 @@ mod context {
403420
lock: Arc::new(ReentrantMutex::new(state)),
404421
};
405422

406-
map.insert(key, value);
407-
map.get(&key).expect("Just inserted the key/value").clone()
423+
locator_state.device.insert(key, value);
424+
locator_state
425+
.device
426+
.get(&key)
427+
.expect("Just inserted the key/value")
428+
.clone()
429+
}
430+
};
431+
let lock_kind = match locator_state.device_kind.get(&kind) {
432+
Some(value) => value.clone(),
433+
None => {
434+
locator_state
435+
.device_kind
436+
.insert(kind, Arc::new(ReentrantMutex::new(())));
437+
locator_state
438+
.device_kind
439+
.get(&kind)
440+
.expect("Just inserted the key/value")
441+
.clone()
408442
}
409443
};
410444

411445
DeviceContext {
412446
lock,
447+
lock_kind,
413448
device_id: id,
414449
_phantom: PhantomData,
415450
}

crates/cubecl-cuda/src/compute/context.rs

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,19 @@ impl CudaContext {
162162
let repr = kernel_compiled.repr.unwrap();
163163

164164
if let Some(cache) = &mut self.ptx_cache {
165-
cache
166-
.insert(
167-
name.unwrap(),
168-
PtxCacheEntry {
169-
entrypoint_name: kernel_compiled.entrypoint_name.clone(),
170-
cube_dim: (cube_dim.x, cube_dim.y, cube_dim.z),
171-
shared_mem_bytes: repr.shared_memory_size(),
172-
cluster_dim: cluster_dim.map(|cluster| (cluster.x, cluster.y, cluster.z)),
173-
ptx: ptx.clone(),
174-
},
175-
)
176-
.unwrap();
165+
let result = cache.insert(
166+
name.unwrap(),
167+
PtxCacheEntry {
168+
entrypoint_name: kernel_compiled.entrypoint_name.clone(),
169+
cube_dim: (cube_dim.x, cube_dim.y, cube_dim.z),
170+
shared_mem_bytes: repr.shared_memory_size(),
171+
cluster_dim: cluster_dim.map(|cluster| (cluster.x, cluster.y, cluster.z)),
172+
ptx: ptx.clone(),
173+
},
174+
);
175+
if let Err(err) = result {
176+
log::warn!("Unable to save the ptx {err:?}");
177+
}
177178
}
178179

179180
self.load_ptx(

crates/cubecl-runtime/src/client.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,17 +485,22 @@ where
485485
dst_server: &Self,
486486
) -> Allocation {
487487
if Server::SERVER_COMM_ENABLED {
488+
let guard = self.context.lock_device_kind();
488489
let mut server_src = self.context.lock();
489490
let mut server_dst = dst_server.context.lock();
490491

491-
Server::copy(
492+
let copied = Server::copy(
492493
server_src.deref_mut(),
493494
server_dst.deref_mut(),
494495
src_descriptor,
495496
self.stream_id(),
496497
dst_server.stream_id(),
497498
)
498-
.unwrap()
499+
.unwrap();
500+
core::mem::drop(server_src);
501+
core::mem::drop(server_dst);
502+
core::mem::drop(guard);
503+
copied
499504
} else {
500505
let alloc_desc = AllocationDescriptor::new(
501506
AllocationKind::Optimized,

crates/cubecl-std/src/tests/trigonometry.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pub fn test_to_degrees<R: Runtime>(client: ComputeClient<R::Server>) {
1515
let input_data = vec![0.0, PI / 6.0, PI / 4.0, PI / 2.0, PI, TAU];
1616
let expected = vec![0.0, 30.0, 45.0, 90.0, 180.0, 360.0];
1717

18-
let input = client.create(f32::as_bytes(&input_data));
18+
let input = client.create_from_slice(f32::as_bytes(&input_data));
1919
let output = client.empty(input_data.len() * core::mem::size_of::<f32>());
2020

2121
unsafe {
@@ -53,7 +53,7 @@ pub fn test_to_radians<R: Runtime>(client: ComputeClient<R::Server>) {
5353
let input_data = vec![0.0, 30.0, 45.0, 90.0, 180.0, 360.0];
5454
let expected = vec![0.0, PI / 6.0, PI / 4.0, PI / 2.0, PI, TAU];
5555

56-
let input = client.create(f32::as_bytes(&input_data));
56+
let input = client.create_from_slice(f32::as_bytes(&input_data));
5757
let output = client.empty(input_data.len() * core::mem::size_of::<f32>());
5858

5959
unsafe {
@@ -92,8 +92,8 @@ pub fn test_hypot<R: Runtime>(client: ComputeClient<R::Server>) {
9292
let y_data = vec![4.0, 1.0, 1.0, 12.0, 0.0];
9393
let expected = vec![5.0, 1.0, 1.4142135623730951, 13.0, 0.0];
9494

95-
let x = client.create(f32::as_bytes(&x_data));
96-
let y = client.create(f32::as_bytes(&y_data));
95+
let x = client.create_from_slice(f32::as_bytes(&x_data));
96+
let y = client.create_from_slice(f32::as_bytes(&y_data));
9797
let output = client.empty(x_data.len() * core::mem::size_of::<f32>());
9898

9999
unsafe {

0 commit comments

Comments
 (0)