Skip to content

Commit ae51a6f

Browse files
Cuda device change optimization (#864)
1 parent d37ef7f commit ae51a6f

File tree

41 files changed

+940
-210
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+940
-210
lines changed

crates/cubecl-common/src/device.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
use core::cmp::Ordering;
2+
3+
/// The device id.
4+
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
5+
pub struct DeviceId {
6+
/// The type id identifies the type of the device.
7+
pub type_id: u16,
8+
/// The index id identifies the device number.
9+
pub index_id: u32,
10+
}
11+
12+
/// Device trait for all cubecl devices.
13+
pub trait Device: Default + Clone + core::fmt::Debug + Send + Sync {
14+
/// Create a device from its [id](DeviceId).
15+
fn from_id(device_id: DeviceId) -> Self;
16+
/// Retrieve the [device id](DeviceId) from the device.
17+
fn to_id(&self) -> DeviceId;
18+
/// Returns the number of devices available under the provided type id.
19+
fn device_count(type_id: u16) -> usize;
20+
/// Returns the total number of devices that can be handled by the runtime.
21+
fn device_count_total() -> usize {
22+
Self::device_count(0)
23+
}
24+
}
25+
26+
impl core::fmt::Display for DeviceId {
27+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28+
f.write_fmt(format_args!("{self:?}"))
29+
}
30+
}
31+
32+
impl Ord for DeviceId {
33+
fn cmp(&self, other: &Self) -> Ordering {
34+
match self.type_id.cmp(&other.type_id) {
35+
Ordering::Equal => self.index_id.cmp(&other.index_id),
36+
other => other,
37+
}
38+
}
39+
}
40+
41+
impl PartialOrd for DeviceId {
42+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43+
Some(self.cmp(other))
44+
}
45+
}

crates/cubecl-common/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ extern crate derive_new;
1212
/// std environments.
1313
pub mod rand;
1414

15+
/// Device module.
16+
pub mod device;
17+
1518
/// Stub module contains types for stubs for non-std environments and for std environments.
1619
pub mod stub;
1720

crates/cubecl-core/src/id.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use cubecl_runtime::{client::ComputeClient, id::DeviceId};
1+
use cubecl_common::device::{Device, DeviceId};
2+
use cubecl_runtime::client::ComputeClient;
23

34
/// ID used to identify a Just-in-Time environment.
45
#[derive(Hash, PartialEq, Eq, Debug, Clone)]
@@ -14,7 +15,7 @@ impl CubeTuneId {
1415
device: &R::Device,
1516
) -> Self {
1617
Self {
17-
device: R::device_id(device),
18+
device: device.to_id(),
1819
name: R::name(client),
1920
}
2021
}

crates/cubecl-core/src/runtime.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::codegen::Compiler;
22
use crate::compute::CubeTask;
3+
use cubecl_common::device::Device;
34
use cubecl_ir::{StorageType, TargetProperties};
4-
use cubecl_runtime::id::DeviceId;
55
use cubecl_runtime::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
66

77
pub use cubecl_runtime::channel;
@@ -18,10 +18,7 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
1818
/// The channel used to communicate with the compute server.
1919
type Channel: ComputeChannel<Self::Server>;
2020
/// The device used to retrieve the compute client.
21-
type Device: Default + Clone + core::fmt::Debug + Send + Sync;
22-
23-
/// Fetch the id for the given device.
24-
fn device_id(device: &Self::Device) -> DeviceId;
21+
type Device: Device;
2522

2623
/// Retrieve the compute client from the runtime device.
2724
fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel>;
@@ -50,8 +47,7 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
5047

5148
fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool;
5249

53-
fn device_count() -> usize;
54-
50+
/// Returns the properties of the target hardware architecture.
5551
fn target_properties() -> TargetProperties;
5652
}
5753

crates/cubecl-core/src/runtime_tests/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub mod slice;
2222
pub mod synchronization;
2323
pub mod tensor;
2424
pub mod tensormap;
25+
pub mod to_client;
2526
pub mod topology;
2627
pub mod traits;
2728
pub mod unary;
@@ -135,6 +136,8 @@ macro_rules! testgen_untyped {
135136

136137
cubecl_core::testgen_enums!();
137138
cubecl_core::testgen_comparison!();
139+
140+
cubecl_core::testgen_to_client!();
138141
};
139142
}
140143

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
use cubecl_common::device::{Device, DeviceId};
2+
3+
use crate::Runtime;
4+
use crate::prelude::*;
5+
6+
pub fn test_to_client<R: Runtime>() {
7+
let type_id = 0;
8+
let device_count = R::Device::device_count(type_id);
9+
10+
if device_count < 2 {
11+
return;
12+
}
13+
14+
for (device_0, device_1) in num_combination(type_id, device_count as u32) {
15+
let device_0 = R::Device::from_id(device_0);
16+
let device_1 = R::Device::from_id(device_1);
17+
18+
println!("Moving data from {device_0:?} to {device_1:?} ...");
19+
20+
let client_0 = R::client(&device_0);
21+
let client_1 = R::client(&device_1);
22+
23+
let expected = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
24+
let input = client_0.create(f32::as_bytes(&expected));
25+
26+
let output = client_0.to_client(input, &client_1).handle;
27+
28+
let actual = client_1.read_one(output);
29+
let actual = f32::from_bytes(&actual);
30+
31+
assert_eq!(actual, expected);
32+
}
33+
}
34+
35+
fn num_combination(type_id: u16, n: u32) -> Vec<(DeviceId, DeviceId)> {
36+
let mut results = Vec::new();
37+
38+
for i in 0..n {
39+
for j in i + 1..n {
40+
results.push((DeviceId::new(type_id, i), DeviceId::new(type_id, j)));
41+
}
42+
}
43+
44+
results
45+
}
46+
47+
#[allow(missing_docs)]
48+
#[macro_export]
49+
macro_rules! testgen_to_client {
50+
() => {
51+
use super::*;
52+
53+
#[test]
54+
fn test_to_client() {
55+
cubecl_core::runtime_tests::to_client::test_to_client::<TestRuntime>();
56+
}
57+
};
58+
}

crates/cubecl-cpu/src/compute/server.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use cubecl_core::{
66
compute::CubeTask,
77
future::DynFut,
88
server::{
9-
Allocation, AllocationDescriptor, Binding, Bindings, ComputeServer, CopyDescriptor, Handle,
10-
IoError, ProfileError, ProfilingToken,
9+
Allocation, AllocationDescriptor, Binding, Bindings, ComputeServer, CopyDescriptor,
10+
DataTransferService, Handle, IoError, ProfileError, ProfilingToken,
1111
},
1212
};
1313
use cubecl_runtime::{
@@ -28,6 +28,8 @@ pub struct CpuServer {
2828
logger: ServerLogger,
2929
}
3030

31+
impl DataTransferService for CpuServer {}
32+
3133
impl CpuServer {
3234
pub fn new(ctx: CpuContext) -> Self {
3335
Self {

crates/cubecl-cpu/src/device.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,21 @@
1+
use cubecl_common::device::{Device, DeviceId};
2+
13
#[derive(new, Clone, PartialEq, Eq, Default, Hash, Debug)]
24
pub struct CpuDevice;
5+
6+
impl Device for CpuDevice {
7+
fn from_id(_device_id: DeviceId) -> Self {
8+
Self
9+
}
10+
11+
fn to_id(&self) -> DeviceId {
12+
DeviceId {
13+
type_id: 0,
14+
index_id: 0,
15+
}
16+
}
17+
18+
fn device_count(_type_id: u16) -> usize {
19+
1
20+
}
21+
}

crates/cubecl-cpu/src/runtime.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use cubecl_core::{
77
};
88
use cubecl_runtime::{
99
ComputeRuntime, DeviceProperties,
10-
id::DeviceId,
1110
memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement},
1211
storage::BytesStorage,
1312
};
@@ -63,6 +62,7 @@ fn create_client(options: RuntimeOptions) -> ComputeClient<Server, Channel> {
6362
let mem_properties = MemoryDeviceProperties {
6463
max_page_size: max_shared_memory_size as u64,
6564
alignment: ALIGNMENT,
65+
data_transfer_async: false,
6666
};
6767

6868
let memory_management =
@@ -106,18 +106,10 @@ impl Runtime for CpuRuntime {
106106
(u32::MAX, u32::MAX, u32::MAX)
107107
}
108108

109-
fn device_id(_device: &Self::Device) -> DeviceId {
110-
DeviceId::new(0, 0)
111-
}
112-
113109
fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
114110
is_contiguous(shape, strides)
115111
}
116112

117-
fn device_count() -> usize {
118-
1
119-
}
120-
121113
fn target_properties() -> TargetProperties {
122114
TargetProperties {
123115
// Values are irrelevant, since no wgsl backends currently support manual mma

crates/cubecl-cuda/Cargo.toml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@ default = [
1818
"cubecl-common/default",
1919
"cubecl-core/default",
2020
"cudarc/dynamic-loading",
21-
"cudarc/cuda-12050",
21+
"cuda-12050",
2222
]
2323
ptx-wmma = []
2424
std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"]
2525

26-
cuda-12080 = []
26+
cuda-12080 = ["cudarc/cuda-12080"]
27+
cuda-12050 = ["cudarc/cuda-12050"]
2728

2829
conv_tests = ["cubecl-convolution/conv_tests"]
2930
matmul_tests_all = [
@@ -96,9 +97,6 @@ half = { workspace = true }
9697
log = { workspace = true }
9798
serde = { workspace = true }
9899

99-
[build-dependencies]
100-
cudarc = { workspace = true }
101-
102100
[dev-dependencies]
103101
cubecl-convolution = { path = "../cubecl-convolution", version = "0.7.0", features = [
104102
"export_tests",

0 commit comments

Comments
 (0)