Skip to content

Commit 1dd4861

Browse files
Fix/device transfer (#941)
1 parent dcea7e7 commit 1dd4861

File tree

32 files changed

+374
-769
lines changed

32 files changed

+374
-769
lines changed

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,11 @@ tracy-client = { version = "0.18.0" }
8686
strum = { version = "0.27.1", features = ["derive"] }
8787
tracel-xtask = { version = "=2.1.8" }
8888

89-
portable-atomic = { version = "1.10", default-features = false, features = [
89+
90+
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
91+
portable-atomic = { version = "1.11", default-features = false, features = [
9092
"serde",
9193
] }
92-
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
9394
pretty_assertions = "1.4"
9495

9596
# Async

crates/cubecl-common/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ fp8 = ["float8"]
2121
serde = ["serde_bytes"]
2222
std = ["rand/std", "futures-lite", "rand/thread_rng", "serde_json?/std"]
2323

24+
2425
[dependencies]
2526
# ** Please make sure all dependencies support no_std when std is disabled **
2627
bytemuck = { workspace = true, features = ["derive"] }
@@ -61,8 +62,8 @@ futures-lite = { workspace = true, features = [
6162
spin = { workspace = true, features = ["mutex", "spin_mutex"] }
6263

6364
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
64-
portable-atomic = { workspace = true }
6565
portable-atomic-util = { workspace = true }
66+
portable-atomic = { workspace = true }
6667
spin = { workspace = true, features = [
6768
"mutex",
6869
"spin_mutex",

crates/cubecl-common/src/map.rs

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
1-
use crate::stub::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
2-
3-
#[cfg(target_has_atomic = "ptr")]
4-
use alloc::sync::Arc;
5-
6-
#[cfg(not(target_has_atomic = "ptr"))]
7-
use portable_atomic_util::Arc;
8-
1+
use crate::stub::{Arc, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
92
use hashbrown::HashMap;
103

114
/// A thread-safe map that allows concurrent access to values using read-write locks.
125
pub struct SharedStateMap<K, V> {
13-
state: Mutex<Option<HashMap<K, Arc<RwLock<V>>>>>,
6+
state: Mutex<Option<State<K, V>>>,
147
}
158

9+
type State<K, V> = HashMap<K, Arc<RwLock<V>>>;
10+
1611
/// A value in the [SharedStateMap] that provides read and write access.
1712
pub struct SharedState<V> {
1813
val: Arc<RwLock<V>>,
@@ -53,7 +48,7 @@ where
5348
/// Retrieves a value associated with the given key, if it exists.
5449
pub fn get(&self, k: &K) -> Option<SharedState<V>> {
5550
let mut state = self.state.lock().unwrap();
56-
let map = get_or_init(&mut state);
51+
let map = get_or_init::<K, V>(&mut state);
5752

5853
match map.get(k) {
5954
Some(val) => Some(SharedState { val: val.clone() }),
@@ -68,7 +63,7 @@ where
6863
K: Clone,
6964
{
7065
let mut state = self.state.lock().unwrap();
71-
let map = get_or_init(&mut state);
66+
let map = get_or_init::<K, V>(&mut state);
7267

7368
match map.get(k) {
7469
Some(val) => SharedState { val: val.clone() },
@@ -84,24 +79,24 @@ where
8479
/// Inserts a key-value pair into the map.
8580
pub fn insert(&self, k: K, v: V) {
8681
let mut state = self.state.lock().unwrap();
87-
let map = get_or_init(&mut state);
82+
let map = get_or_init::<K, V>(&mut state);
8883

8984
map.insert(k, Arc::new(RwLock::new(v)));
9085
}
9186

9287
/// Clears the map, removing all key-value pairs.
9388
pub fn clear(&self) {
9489
let mut state = self.state.lock().unwrap();
95-
let map = get_or_init(&mut state);
90+
let map = get_or_init::<K, V>(&mut state);
9691
map.clear();
9792
}
9893
}
9994

100-
fn get_or_init<T: Default>(state: &mut Option<T>) -> &mut T {
95+
fn get_or_init<K, V>(state: &mut Option<State<K, V>>) -> &mut State<K, V> {
10196
match state {
10297
Some(state) => state,
10398
None => {
104-
*state = Some(T::default());
99+
*state = Some(State::<K, V>::default());
105100
state.as_mut().unwrap()
106101
}
107102
}

crates/cubecl-common/src/stub.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ pub use spin::{RwLockReadGuard, RwLockWriteGuard};
1010
#[cfg(feature = "std")]
1111
pub use std::sync::{RwLockReadGuard, RwLockWriteGuard};
1212

13+
#[cfg(target_has_atomic = "ptr")]
14+
pub use alloc::sync::Arc;
15+
16+
#[cfg(not(target_has_atomic = "ptr"))]
17+
pub use portable_atomic_util::Arc;
18+
1319
/// A mutual exclusion primitive useful for protecting shared data
1420
///
1521
/// This mutex will block threads waiting for the lock to become available. The

crates/cubecl-cpu/src/compute/server.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use cubecl_core::{
66
compute::CubeTask,
77
future::DynFut,
88
server::{
9-
Allocation, AllocationDescriptor, Binding, Bindings, ComputeServer, CopyDescriptor,
10-
DataTransferService, Handle, IoError, ProfileError, ProfilingToken,
9+
Allocation, AllocationDescriptor, Binding, Bindings, ComputeServer, CopyDescriptor, Handle,
10+
IoError, ProfileError, ProfilingToken, ServerCommunication,
1111
},
1212
};
1313
use cubecl_runtime::{
@@ -28,8 +28,6 @@ pub struct CpuServer {
2828
logger: Arc<ServerLogger>,
2929
}
3030

31-
impl DataTransferService for CpuServer {}
32-
3331
impl CpuServer {
3432
pub fn new(ctx: CpuContext) -> Self {
3533
Self {
@@ -227,6 +225,10 @@ impl ComputeServer for CpuServer {
227225
}
228226
}
229227

228+
impl ServerCommunication for CpuServer {
229+
const SERVER_COMM_ENABLED: bool = false;
230+
}
231+
230232
impl CpuServer {
231233
fn copy_to_binding(&mut self, binding: Binding, data: &[u8]) {
232234
let mut resource = self

crates/cubecl-cpu/src/runtime.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ fn create_client(options: RuntimeOptions) -> ComputeClient<Server, Channel> {
6464
let mem_properties = MemoryDeviceProperties {
6565
max_page_size: max_shared_memory_size as u64,
6666
alignment: ALIGNMENT,
67-
data_transfer_async: false,
6867
};
6968

7069
let memory_management =

crates/cubecl-cuda/src/compute/command.rs

Lines changed: 4 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
use crate::{
22
CudaCompiler,
33
compute::{
4-
DataTransferItem, DataTransferRuntime, MB, context::CudaContext,
5-
io::controller::PinnedMemoryManagedAllocController, storage::gpu::GpuResource,
6-
stream::CudaStreamBackend, sync::Fence, valid_strides,
4+
MB, context::CudaContext, io::controller::PinnedMemoryManagedAllocController,
5+
storage::gpu::GpuResource, stream::CudaStreamBackend, sync::Fence, valid_strides,
76
},
87
};
98
use cubecl_common::{bytes::Bytes, stream_id::StreamId};
@@ -14,7 +13,6 @@ use cubecl_core::{
1413
server::{Binding, CopyDescriptor, Handle, IoError, ProfileError},
1514
};
1615
use cubecl_runtime::{
17-
data_service::DataTransferId,
1816
id::KernelId,
1917
logging::ServerLogger,
2018
memory_management::{MemoryAllocationMode, MemoryHandle},
@@ -25,21 +23,13 @@ use cudarc::driver::sys::{
2523
};
2624
use std::{ffi::c_void, ops::DerefMut, sync::Arc};
2725

28-
const DEVICE_TO_DEVICE: DeviceTransferStrategy = DeviceTransferStrategy::Serialized;
29-
30-
#[allow(unused)]
31-
enum DeviceTransferStrategy {
32-
Peer,
33-
Serialized,
34-
}
35-
3626
#[derive(new)]
3727
/// The `Command` struct encapsulates a CUDA context and a set of resolved CUDA streams, providing an
3828
/// interface for executing GPU-related operations such as memory allocation, data transfers, kernel
3929
/// registration, and task execution.
4030
pub struct Command<'a> {
4131
ctx: &'a mut CudaContext,
42-
streams: ResolvedStreams<'a, CudaStreamBackend>,
32+
pub(crate) streams: ResolvedStreams<'a, CudaStreamBackend>,
4333
}
4434

4535
impl<'a> Command<'a> {
@@ -241,7 +231,7 @@ impl<'a> Command<'a> {
241231
Ok((data, fences))
242232
}
243233

244-
fn copy_to_bytes(
234+
pub fn copy_to_bytes(
245235
&mut self,
246236
descriptor: CopyDescriptor<'_>,
247237
pinned: bool,
@@ -344,72 +334,6 @@ impl<'a> Command<'a> {
344334
Ok(handle)
345335
}
346336

347-
/// Registers a source for an asynchronous data transfer operation.
348-
///
349-
/// # Parameters
350-
///
351-
/// * `id` - The unique identifier for the data transfer.
352-
/// * `src` - The descriptor for the source GPU memory.
353-
pub fn data_transfer_src(&mut self, id: DataTransferId, src: CopyDescriptor<'_>) {
354-
let src_resource = self.resource(src.binding.clone()).unwrap();
355-
let client = DataTransferRuntime::client();
356-
let current = self.streams.current();
357-
358-
let handle = DataTransferItem {
359-
stream: current.sys,
360-
context: self.ctx.context,
361-
resource: src_resource,
362-
};
363-
364-
match DEVICE_TO_DEVICE {
365-
DeviceTransferStrategy::Peer => {
366-
let fence = Fence::new(current.sys);
367-
368-
client.register_src_peer(id, handle, fence);
369-
}
370-
DeviceTransferStrategy::Serialized => {
371-
client.register_src_serialized(id, handle, src.binding);
372-
}
373-
}
374-
}
375-
376-
/// Registers a destination for an asynchronous data transfer operation.
377-
///
378-
/// # Parameters
379-
///
380-
/// * `id` - The unique identifier for the data transfer.
381-
/// * `dest` - The descriptor for the destination GPU memory.
382-
pub fn data_transfer_dest(&mut self, id: DataTransferId, dest: CopyDescriptor<'_>) {
383-
let dst_resource = self.resource(dest.binding).unwrap();
384-
let current = self.streams.current();
385-
let client = DataTransferRuntime::client();
386-
387-
let item = DataTransferItem {
388-
context: self.ctx.context,
389-
stream: current.sys,
390-
resource: dst_resource,
391-
};
392-
393-
match DEVICE_TO_DEVICE {
394-
DeviceTransferStrategy::Peer => {
395-
client.register_dest_peer(id, item);
396-
}
397-
DeviceTransferStrategy::Serialized => {
398-
let num_bytes = dest.shape.iter().product::<usize>() * dest.elem_size;
399-
let bytes = self.reserve_cpu(num_bytes, true, None);
400-
401-
client.register_dest_serialized(
402-
id,
403-
item,
404-
bytes,
405-
dest.shape.to_vec(),
406-
dest.strides.to_vec(),
407-
dest.elem_size,
408-
);
409-
}
410-
}
411-
}
412-
413337
/// Synchronizes the current stream, ensuring all pending operations are complete.
414338
///
415339
/// # Returns

0 commit comments

Comments
 (0)