Skip to content

Commit ce158a7

Browse files
committed
Merge branch 'main' into rhypot
2 parents 6ff5e4b + 9b08383 commit ce158a7

File tree

342 files changed

+5218
-3385
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

342 files changed

+5218
-3385
lines changed

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ edition = "2024"
1111
license = "MIT OR Apache-2.0"
1212
readme = "README.md"
1313
rust-version = "1.88"
14-
version = "0.9.0-pre.2"
14+
version = "0.9.0-pre.3"
1515

1616
[workspace.dependencies]
1717
bitflags = { version = "2.9.1", features = ["serde"] }
@@ -35,7 +35,7 @@ serde_json = { version = "1.0.119", default-features = false }
3535
toml = "0.9.1"
3636
variadics_please = "1"
3737

38-
# no_std compatiblity
38+
# no_std compatibility
3939
dashmap = "6.1.0"
4040
foldhash = { version = "0.1.2", default-features = false }
4141
hashbrown = "0.15.5"
@@ -103,7 +103,7 @@ tracel-llvm = { version = "20.1.4-5", features = ["mlir-helpers"] }
103103
# tracel-llvm = { git = "https://github.com/tracel-ai/tracel-llvm.git", branch = "fix/linux", package = "tracel-llvm", features = ["mlir-helpers"] }
104104
# tracel-llvm = { path = "../tracel-llvm/crates/tracel-llvm", features = ["mlir-helpers"] }
105105

106-
cudarc = { version = "0.17.7", features = [
106+
cudarc = { version = "0.18.1", features = [
107107
"std",
108108
"driver",
109109
"nvrtc",

crates/cubecl-attention/Cargo.toml

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,25 @@ default = ["std", "cubecl-runtime/default", "cubecl-core/default"]
1515
export_tests = ["pretty_assertions"]
1616
std = ["cubecl-runtime/std", "cubecl-core/std"]
1717

18-
attention_tests = []
18+
attention_tests_f16 = []
19+
attention_tests_f32 = []
20+
attention_tests_unit = []
21+
attention_tests_blackbox_accelerated = []
22+
attention_tests_all = [
23+
"attention_tests_f16",
24+
"attention_tests_f32",
25+
"attention_tests_unit",
26+
"attention_tests_blackbox_accelerated",
27+
]
1928

2029
[dependencies]
2130
bytemuck = { workspace = true }
22-
cubecl-common = { path = "../cubecl-common", version = "0.9.0-pre.2", default-features = false }
23-
cubecl-core = { path = "../cubecl-core", version = "0.9.0-pre.2", default-features = false }
24-
cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0-pre.2", default-features = false }
25-
cubecl-std = { path = "../cubecl-std", version = "0.9.0-pre.2", default-features = false }
26-
cubecl-matmul = { path = "../cubecl-matmul", version = "0.9.0-pre.2", default-features = false }
27-
cubecl-random = { path = "../cubecl-random", version = "0.9.0-pre.2", default-features = false }
31+
cubecl-common = { path = "../cubecl-common", version = "=0.9.0-pre.3", default-features = false }
32+
cubecl-core = { path = "../cubecl-core", version = "=0.9.0-pre.3", default-features = false }
33+
cubecl-runtime = { path = "../cubecl-runtime", version = "=0.9.0-pre.3", default-features = false }
34+
cubecl-std = { path = "../cubecl-std", version = "=0.9.0-pre.3", default-features = false }
35+
cubecl-matmul = { path = "../cubecl-matmul", version = "=0.9.0-pre.3", default-features = false }
36+
cubecl-random = { path = "../cubecl-random", version = "=0.9.0-pre.3", default-features = false }
2837
half = { workspace = true, features = ["bytemuck"] }
2938
pretty_assertions = { workspace = true, optional = true }
3039
serde = { workspace = true }

crates/cubecl-attention/src/base.rs

Lines changed: 38 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@ use cubecl_std::tensor::TensorHandle;
44

55
use crate::{
66
components::{
7-
AttentionElems, AttentionIdent, AttentionPartitionSize, AttentionProblem,
8-
AttentionSelection, AttentionSetupError, AttentionStageSize, AttentionTileSize,
9-
AttentionTilingScheme, AvailableLineSizes,
7+
AttentionElems, AttentionIdent, AttentionProblem, AttentionSetupError, AvailableLineSizes,
108
args::{TensorArgs, TensorInputsLaunch},
11-
batch::HypercubeSelection,
129
},
1310
kernels::{Algorithm, blackbox_accelerated::BlackboxAcceleratedAlgorithm, unit::UnitAlgorithm},
1411
};
@@ -25,15 +22,15 @@ pub enum Strategy {
2522
#[allow(clippy::result_large_err, clippy::too_many_arguments)]
2623
pub fn launch<R: Runtime>(
2724
strategy: &Strategy,
28-
client: &ComputeClient<R::Server>,
25+
client: &ComputeClient<R>,
2926
query: TensorHandle<R>,
3027
key: TensorHandle<R>,
3128
value: TensorHandle<R>,
3229
mask: Option<TensorHandle<R>>,
3330
out: TensorHandle<R>,
3431
attention_elems: AttentionElems,
3532
) -> Result<(), AttentionSetupError> {
36-
launch_ref::<R>(
33+
launch_ref(
3734
strategy,
3835
client,
3936
&query.as_ref(),
@@ -48,7 +45,7 @@ pub fn launch<R: Runtime>(
4845
#[allow(clippy::result_large_err, clippy::too_many_arguments)]
4946
pub fn launch_ref<R: Runtime>(
5047
strategy: &Strategy,
51-
client: &ComputeClient<R::Server>,
48+
client: &ComputeClient<R>,
5249
query: &TensorHandleRef<R>,
5350
key: &TensorHandleRef<R>,
5451
value: &TensorHandleRef<R>,
@@ -79,26 +76,35 @@ pub fn launch_ref<R: Runtime>(
7976
}
8077

8178
pub fn launch_attention<R: Runtime, A: Algorithm>(
82-
client: &ComputeClient<R::Server>,
79+
client: &ComputeClient<R>,
8380
query: &TensorHandleRef<R>,
8481
key: &TensorHandleRef<R>,
8582
value: &TensorHandleRef<R>,
8683
mask: &Option<TensorHandleRef<R>>,
8784
out: &TensorHandleRef<R>,
8885
attention_elems: &AttentionElems,
8986
) -> Result<(), AttentionSetupError> {
90-
let line_sizes = AvailableLineSizes::from_elem_types::<R>(
91-
query.elem_size,
92-
attention_elems.mask.size(),
93-
out.elem_size,
94-
);
95-
let line_sizes = A::filter_line_sizes(line_sizes)
96-
.filter_with_tensor(AttentionIdent::Query, query.strides, query.shape)
97-
.filter_with_tensor(AttentionIdent::Key, key.strides, key.shape)
98-
.filter_with_tensor(AttentionIdent::Value, value.strides, value.shape)
99-
.filter_with_tensor(AttentionIdent::Out, out.strides, out.shape)
100-
.pick_max()
101-
.unwrap();
87+
let line_sizes = {
88+
let ls = AvailableLineSizes::from_elem_types(
89+
client,
90+
query.elem_size,
91+
attention_elems.mask.size(),
92+
out.elem_size,
93+
);
94+
let ls = A::filter_line_sizes(ls)
95+
.filter_with_tensor(AttentionIdent::Query, query.strides, query.shape)
96+
.filter_with_tensor(AttentionIdent::Key, key.strides, key.shape)
97+
.filter_with_tensor(AttentionIdent::Value, value.strides, value.shape)
98+
.filter_with_tensor(AttentionIdent::Out, out.strides, out.shape);
99+
100+
if let Some(mask) = mask.as_ref() {
101+
ls.filter_with_tensor(AttentionIdent::Mask, mask.strides, mask.shape)
102+
} else {
103+
ls
104+
}
105+
}
106+
.pick_max()
107+
.unwrap();
102108

103109
let problem = AttentionProblem {
104110
batch: query.shape[0],
@@ -111,47 +117,22 @@ pub fn launch_attention<R: Runtime, A: Algorithm>(
111117
causal: false,
112118
};
113119

114-
let tile_size = AttentionTileSize {
115-
seq_q: 8,
116-
head_dim: 8,
117-
seq_kv: 8,
118-
val_dim: 8,
119-
};
120-
121-
let selection = AttentionSelection {
122-
hypercube_selection: HypercubeSelection {},
123-
tiling_scheme: AttentionTilingScheme {
124-
tile_size,
125-
partition_size: AttentionPartitionSize {
126-
seq_q: 1,
127-
head_dim: 1,
128-
seq_kv: 1,
129-
val_dim: 1,
130-
},
131-
stage_size: AttentionStageSize { seq_q: 1 },
132-
},
133-
plane_dim: 32,
134-
reuse_key_value: false,
135-
two_rows_in_array_tile: false,
136-
};
137-
138-
let config = BlackboxAcceleratedAlgorithm::setup::<R>(
120+
let selection = A::selection(
139121
client,
140122
&problem,
141-
&selection,
123+
client.properties().hardware.plane_size_max,
142124
&line_sizes,
143125
attention_elems,
144126
)?;
145127

128+
let config = A::setup(client, &problem, &selection, &line_sizes, attention_elems)?;
129+
146130
let cube_count_plan = config
147131
.hypercube_config()
148132
.cube_count_plan(&problem, &selection);
149133

150-
unsafe {
151-
<BlackboxAcceleratedAlgorithm as Algorithm>::BatchAttention::launch_unchecked::<
152-
TensorArgs,
153-
R,
154-
>(
134+
let result = unsafe {
135+
<A as Algorithm>::BatchAttention::launch_unchecked::<TensorArgs, R>(
155136
client,
156137
config.cube_dim(),
157138
cube_count_plan.resolve(),
@@ -167,8 +148,11 @@ pub fn launch_attention<R: Runtime, A: Algorithm>(
167148
cube_count_plan.as_args(),
168149
config,
169150
attention_elems,
170-
);
171-
}
151+
)
152+
};
172153

173-
Ok(())
154+
match result {
155+
Ok(_) => Ok(()),
156+
Err(err) => Err(AttentionSetupError::Execution(err)),
157+
}
174158
}

crates/cubecl-attention/src/components/batch/base.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,21 @@ pub trait BatchAttentionFamily: Send + Sync + 'static {
2727
/// Out-of-bounds can happen
2828
#[allow(clippy::too_many_arguments)]
2929
unsafe fn launch_unchecked<'a, AA: AttentionArgs, R: Runtime>(
30-
client: &ComputeClient<<R as Runtime>::Server>,
30+
client: &ComputeClient<R>,
3131
cube_dim: CubeDim,
3232
cube_count: CubeCount,
3333
input: InputRuntimeArg<'a, AA, R>,
3434
output: OutputRuntimeArg<'a, AA, R>,
3535
cube_count_input: CubeCountInputArgs<'a, R>,
3636
config: Self::Config,
3737
dtypes: &AttentionElems,
38-
);
38+
) -> Result<(), LaunchError>;
3939

4040
/// Constructs the configuration based on the Attention problem, selection, and line sizes.
4141
///
4242
/// This function may return an error if the configuration cannot be supported on the current runtime.
4343
fn setup<R: Runtime>(
44-
client: &ComputeClient<R::Server>,
44+
client: &ComputeClient<R>,
4545
problem: &AttentionProblem,
4646
selection: &AttentionSelection,
4747
line_sizes: &AttentionLineSizes,

crates/cubecl-attention/src/components/batch/simple/setup.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::marker::PhantomData;
22

3-
use cubecl_core::client::ComputeClient;
3+
use cubecl_core::{client::ComputeClient, server::LaunchError};
44

55
use crate::components::{
66
AttentionElems, AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection,
@@ -23,13 +23,13 @@ impl<GA: GlobalAttentionFamily> BatchAttentionFamily for SimpleBatchAttentionFam
2323
type Config = SimpleBatchConfig<GA::Config>;
2424

2525
fn setup<R: cubecl_core::Runtime>(
26-
client: &ComputeClient<R::Server>,
26+
client: &ComputeClient<R>,
2727
problem: &AttentionProblem,
2828
selection: &AttentionSelection,
2929
line_sizes: &AttentionLineSizes,
3030
dtypes: &AttentionElems,
3131
) -> Result<Self::Config, crate::components::AttentionSetupError> {
32-
let global_config = GA::setup::<R>(client, problem, selection, line_sizes, dtypes)?;
32+
let global_config = GA::setup(client, problem, selection, line_sizes, dtypes)?;
3333

3434
SimpleBatchConfig::new(
3535
global_config,
@@ -41,15 +41,15 @@ impl<GA: GlobalAttentionFamily> BatchAttentionFamily for SimpleBatchAttentionFam
4141
}
4242

4343
unsafe fn launch_unchecked<'a, AA: AttentionArgs, R: cubecl_core::Runtime>(
44-
client: &cubecl_core::prelude::ComputeClient<<R as cubecl_core::Runtime>::Server>,
44+
client: &cubecl_core::prelude::ComputeClient<R>,
4545
cube_dim: cubecl_core::CubeDim,
4646
cube_count: cubecl_core::CubeCount,
4747
input: InputRuntimeArg<'a, AA, R>,
4848
output: OutputRuntimeArg<'a, AA, R>,
4949
cube_count_input: crate::components::batch::CubeCountInputArgs<'a, R>,
5050
config: Self::Config,
5151
dtypes: &AttentionElems,
52-
) {
52+
) -> Result<(), LaunchError> {
5353
unsafe {
5454
attention::launch_unchecked::<AA, Self, R>(
5555
client,
@@ -60,7 +60,7 @@ impl<GA: GlobalAttentionFamily> BatchAttentionFamily for SimpleBatchAttentionFam
6060
cube_count_input,
6161
config,
6262
dtypes.into(),
63-
);
63+
)
6464
}
6565
}
6666
}

crates/cubecl-attention/src/components/error.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use cubecl_core::{CubeCount, CubeDim, LineSizeError};
1+
use cubecl_core::{CubeCount, CubeDim, LineSizeError, server::LaunchError};
22
use cubecl_matmul::components::MatmulSetupError;
33
use std::fmt::{Debug, Display};
44

@@ -15,6 +15,9 @@ pub enum AttentionSetupError {
1515

1616
/// Error in underlying matmul
1717
MatmulSetup(MatmulSetupError),
18+
19+
/// An error that happened during execution.
20+
Execution(LaunchError),
1821
}
1922

2023
/// A specific feature required for attention is not available in the current runtime or hardware.
@@ -75,6 +78,9 @@ impl Debug for AttentionSetupError {
7578
AttentionSetupError::MatmulSetup(matmul_setup_error) => {
7679
writeln!(f, "{matmul_setup_error:?}")
7780
}
81+
AttentionSetupError::Execution(error) => {
82+
writeln!(f, "{error:?}")
83+
}
7884
}
7985
}
8086
}

crates/cubecl-attention/src/components/global/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub trait GlobalAttentionFamily: Send + Sync + 'static {
2323
///
2424
/// This function may return an error if the configuration cannot be supported on the current runtime.
2525
fn setup<R: Runtime>(
26-
client: &ComputeClient<R::Server>,
26+
client: &ComputeClient<R>,
2727
problem: &AttentionProblem,
2828
selection: &AttentionSelection,
2929
line_sizes: &AttentionLineSizes,

0 commit comments

Comments
 (0)