Skip to content

Commit 80f3613

Browse files
authored
refactor: Move Runtime to cubecl-runtime (#1068)
1 parent 3904c27 commit 80f3613

File tree

231 files changed

+1495
-1401
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

231 files changed

+1495
-1401
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ tracel-llvm = { version = "20.1.4-5", features = ["mlir-helpers"] }
103103
# tracel-llvm = { git = "https://github.com/tracel-ai/tracel-llvm.git", branch = "fix/linux", package = "tracel-llvm", features = ["mlir-helpers"] }
104104
# tracel-llvm = { path = "../tracel-llvm/crates/tracel-llvm", features = ["mlir-helpers"] }
105105

106-
cudarc = { version = "0.17.7", features = [
106+
cudarc = { version = "0.18.1", features = [
107107
"std",
108108
"driver",
109109
"nvrtc",

crates/cubecl-attention/src/base.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ pub enum Strategy {
2525
#[allow(clippy::result_large_err, clippy::too_many_arguments)]
2626
pub fn launch<R: Runtime>(
2727
strategy: &Strategy,
28-
client: &ComputeClient<R::Server>,
28+
client: &ComputeClient<R>,
2929
query: TensorHandle<R>,
3030
key: TensorHandle<R>,
3131
value: TensorHandle<R>,
3232
mask: Option<TensorHandle<R>>,
3333
out: TensorHandle<R>,
3434
attention_elems: AttentionElems,
3535
) -> Result<(), AttentionSetupError> {
36-
launch_ref::<R>(
36+
launch_ref(
3737
strategy,
3838
client,
3939
&query.as_ref(),
@@ -48,7 +48,7 @@ pub fn launch<R: Runtime>(
4848
#[allow(clippy::result_large_err, clippy::too_many_arguments)]
4949
pub fn launch_ref<R: Runtime>(
5050
strategy: &Strategy,
51-
client: &ComputeClient<R::Server>,
51+
client: &ComputeClient<R>,
5252
query: &TensorHandleRef<R>,
5353
key: &TensorHandleRef<R>,
5454
value: &TensorHandleRef<R>,
@@ -79,15 +79,16 @@ pub fn launch_ref<R: Runtime>(
7979
}
8080

8181
pub fn launch_attention<R: Runtime, A: Algorithm>(
82-
client: &ComputeClient<R::Server>,
82+
client: &ComputeClient<R>,
8383
query: &TensorHandleRef<R>,
8484
key: &TensorHandleRef<R>,
8585
value: &TensorHandleRef<R>,
8686
mask: &Option<TensorHandleRef<R>>,
8787
out: &TensorHandleRef<R>,
8888
attention_elems: &AttentionElems,
8989
) -> Result<(), AttentionSetupError> {
90-
let line_sizes = AvailableLineSizes::from_elem_types::<R>(
90+
let line_sizes = AvailableLineSizes::from_elem_types(
91+
client,
9192
query.elem_size,
9293
attention_elems.mask.size(),
9394
out.elem_size,
@@ -135,7 +136,7 @@ pub fn launch_attention<R: Runtime, A: Algorithm>(
135136
two_rows_in_array_tile: false,
136137
};
137138

138-
let config = BlackboxAcceleratedAlgorithm::setup::<R>(
139+
let config = BlackboxAcceleratedAlgorithm::setup(
139140
client,
140141
&problem,
141142
&selection,

crates/cubecl-attention/src/components/batch/base.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub trait BatchAttentionFamily: Send + Sync + 'static {
2727
/// Out-of-bounds can happen
2828
#[allow(clippy::too_many_arguments)]
2929
unsafe fn launch_unchecked<'a, AA: AttentionArgs, R: Runtime>(
30-
client: &ComputeClient<<R as Runtime>::Server>,
30+
client: &ComputeClient<R>,
3131
cube_dim: CubeDim,
3232
cube_count: CubeCount,
3333
input: InputRuntimeArg<'a, AA, R>,
@@ -41,7 +41,7 @@ pub trait BatchAttentionFamily: Send + Sync + 'static {
4141
///
4242
/// This function may return an error if the configuration cannot be supported on the current runtime.
4343
fn setup<R: Runtime>(
44-
client: &ComputeClient<R::Server>,
44+
client: &ComputeClient<R>,
4545
problem: &AttentionProblem,
4646
selection: &AttentionSelection,
4747
line_sizes: &AttentionLineSizes,

crates/cubecl-attention/src/components/batch/simple/setup.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ impl<GA: GlobalAttentionFamily> BatchAttentionFamily for SimpleBatchAttentionFam
2323
type Config = SimpleBatchConfig<GA::Config>;
2424

2525
fn setup<R: cubecl_core::Runtime>(
26-
client: &ComputeClient<R::Server>,
26+
client: &ComputeClient<R>,
2727
problem: &AttentionProblem,
2828
selection: &AttentionSelection,
2929
line_sizes: &AttentionLineSizes,
3030
dtypes: &AttentionElems,
3131
) -> Result<Self::Config, crate::components::AttentionSetupError> {
32-
let global_config = GA::setup::<R>(client, problem, selection, line_sizes, dtypes)?;
32+
let global_config = GA::setup(client, problem, selection, line_sizes, dtypes)?;
3333

3434
SimpleBatchConfig::new(
3535
global_config,
@@ -41,7 +41,7 @@ impl<GA: GlobalAttentionFamily> BatchAttentionFamily for SimpleBatchAttentionFam
4141
}
4242

4343
unsafe fn launch_unchecked<'a, AA: AttentionArgs, R: cubecl_core::Runtime>(
44-
client: &cubecl_core::prelude::ComputeClient<<R as cubecl_core::Runtime>::Server>,
44+
client: &cubecl_core::prelude::ComputeClient<R>,
4545
cube_dim: cubecl_core::CubeDim,
4646
cube_count: cubecl_core::CubeCount,
4747
input: InputRuntimeArg<'a, AA, R>,

crates/cubecl-attention/src/components/global/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub trait GlobalAttentionFamily: Send + Sync + 'static {
2323
///
2424
/// This function may return an error if the configuration cannot be supported on the current runtime.
2525
fn setup<R: Runtime>(
26-
client: &ComputeClient<R::Server>,
26+
client: &ComputeClient<R>,
2727
problem: &AttentionProblem,
2828
selection: &AttentionSelection,
2929
line_sizes: &AttentionLineSizes,

crates/cubecl-attention/src/components/global/simple/setup.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ impl<
4040
type Config = SimpleGlobalAttentionConfig<SA::Config>;
4141

4242
fn setup<R: cubecl_core::Runtime>(
43-
client: &ComputeClient<R::Server>,
43+
client: &ComputeClient<R>,
4444
problem: &AttentionProblem,
4545
selection: &AttentionSelection,
4646
line_sizes: &AttentionLineSizes,
4747
dtypes: &AttentionElems,
4848
) -> Result<Self::Config, AttentionSetupError> {
49-
let stage_config = SA::setup::<R>(client, problem, selection, line_sizes, dtypes)?;
49+
let stage_config = SA::setup(client, problem, selection, line_sizes, dtypes)?;
5050

5151
let precompute_job = LoadingPrecomputeStrategy::Never.into();
5252
let plane_dim = stage_config.plane_dim();

crates/cubecl-attention/src/components/line_size.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::fmt::Debug;
22

3-
use cubecl_core::{LineSizeError, Runtime, tensor_line_size_parallel};
3+
use cubecl_core::{LineSizeError, Runtime, client::ComputeClient, tensor_line_size_parallel};
44

55
use crate::components::{AttentionIdent, AttentionSetupError};
66

@@ -29,10 +29,17 @@ pub struct AvailableLineSizes {
2929
}
3030

3131
impl AvailableLineSizes {
32-
pub fn from_elem_types<R: Runtime>(elem_in: usize, elem_mask: usize, elem_out: usize) -> Self {
33-
let in_available: Vec<u8> = R::io_optimized_line_sizes_unchecked(elem_in).collect();
34-
let mask_available: Vec<u8> = R::io_optimized_line_sizes_unchecked(elem_mask).collect();
35-
let out_available = R::io_optimized_line_sizes_unchecked(elem_out).collect();
32+
pub fn from_elem_types<R: Runtime>(
33+
client: &ComputeClient<R>,
34+
elem_in: usize,
35+
elem_mask: usize,
36+
elem_out: usize,
37+
) -> Self {
38+
let in_available: Vec<u8> = client.io_optimized_line_sizes_unchecked(elem_in).collect();
39+
let mask_available: Vec<u8> = client
40+
.io_optimized_line_sizes_unchecked(elem_mask)
41+
.collect();
42+
let out_available = client.io_optimized_line_sizes_unchecked(elem_out).collect();
3643

3744
AvailableLineSizes {
3845
query: in_available.clone(),

crates/cubecl-attention/src/components/stage/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ pub trait StageAttentionFamily: Send + Sync + 'static {
4545
///
4646
/// This function may return an error if the configuration cannot be supported on the current runtime.
4747
fn setup<R: Runtime>(
48-
client: &ComputeClient<R::Server>,
48+
client: &ComputeClient<R>,
4949
problem: &AttentionProblem,
5050
selection: &AttentionSelection,
5151
line_sizes: &AttentionLineSizes,

crates/cubecl-attention/src/components/stage/plane/setup.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl<
5353
type Config = PartitionAttentionConfig<TA::Config>;
5454

5555
fn setup<R: cubecl_core::Runtime>(
56-
client: &ComputeClient<R::Server>,
56+
client: &ComputeClient<R>,
5757
problem: &AttentionProblem,
5858
selection: &AttentionSelection,
5959
line_sizes: &AttentionLineSizes,
@@ -62,8 +62,7 @@ impl<
6262
let num_planes = selection.tiling_scheme.stage_size.seq_q
6363
* TA::computation_resources()?.num_planes(selection.plane_dim)?;
6464

65-
let tile_config =
66-
TA::setup::<R>(client, problem, selection, line_sizes, num_planes, dtypes)?;
65+
let tile_config = TA::setup(client, problem, selection, line_sizes, num_planes, dtypes)?;
6766

6867
let key_smem_config = StageMemoryConfig {
6968
num_planes,

crates/cubecl-attention/src/components/stage/unit/setup.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl<
5353
type Config = PartitionAttentionConfig<TA::Config>;
5454

5555
fn setup<R: cubecl_core::Runtime>(
56-
client: &ComputeClient<R::Server>,
56+
client: &ComputeClient<R>,
5757
problem: &AttentionProblem,
5858
selection: &AttentionSelection,
5959
line_sizes: &AttentionLineSizes,
@@ -70,8 +70,7 @@ impl<
7070
};
7171

7272
let num_planes = compute_resources.num_planes(selection.plane_dim)?;
73-
let tile_config =
74-
TA::setup::<R>(client, problem, selection, line_sizes, num_planes, dtypes)?;
73+
let tile_config = TA::setup(client, problem, selection, line_sizes, num_planes, dtypes)?;
7574

7675
let key_smem_config = StageMemoryConfig {
7776
num_planes,

0 commit comments

Comments
 (0)