Skip to content

Commit 6ff5e4b

Browse files
committed
Merge branch 'main' into rhypot
2 parents bd31af7 + 7dcb116 commit 6ff5e4b

File tree

232 files changed

+5915
-6102
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

232 files changed

+5915
-6102
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ edition = "2024"
1111
license = "MIT OR Apache-2.0"
1212
readme = "README.md"
1313
rust-version = "1.88"
14-
version = "0.9.0-pre.1"
14+
version = "0.9.0-pre.2"
1515

1616
[workspace.dependencies]
1717
bitflags = { version = "2.9.1", features = ["serde"] }

crates/cubecl-attention/Cargo.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ attention_tests = []
1919

2020
[dependencies]
2121
bytemuck = { workspace = true }
22-
cubecl-common = { path = "../cubecl-common", version = "0.9.0-pre.1", default-features = false }
23-
cubecl-core = { path = "../cubecl-core", version = "0.9.0-pre.1", default-features = false }
24-
cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0-pre.1", default-features = false }
25-
cubecl-std = { path = "../cubecl-std", version = "0.9.0-pre.1", default-features = false }
26-
cubecl-matmul = { path = "../cubecl-matmul", version = "0.9.0-pre.1", default-features = false }
27-
cubecl-random = { path = "../cubecl-random", version = "0.9.0-pre.1", default-features = false }
22+
cubecl-common = { path = "../cubecl-common", version = "0.9.0-pre.2", default-features = false }
23+
cubecl-core = { path = "../cubecl-core", version = "0.9.0-pre.2", default-features = false }
24+
cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0-pre.2", default-features = false }
25+
cubecl-std = { path = "../cubecl-std", version = "0.9.0-pre.2", default-features = false }
26+
cubecl-matmul = { path = "../cubecl-matmul", version = "0.9.0-pre.2", default-features = false }
27+
cubecl-random = { path = "../cubecl-random", version = "0.9.0-pre.2", default-features = false }
2828
half = { workspace = true, features = ["bytemuck"] }
2929
pretty_assertions = { workspace = true, optional = true }
3030
serde = { workspace = true }

crates/cubecl-attention/src/components/batch/hypercube/base.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ impl HypercubeConfig {
2626
selection: &AttentionSelection,
2727
) -> CubeCountPlan {
2828
CubeCountPlan {
29-
inner: (problem.seq_q as u32)
30-
.div_ceil(selection.tiling_scheme.elements_in_stage_seq_q()),
29+
inner: (problem.seq_q as u32).div_ceil(
30+
selection.tiling_scheme.tile_size.seq_q
31+
* selection.tiling_scheme.partition_size.seq_q
32+
* selection.tiling_scheme.stage_size.seq_q,
33+
),
3134
outer: (problem.batch * problem.num_heads) as u32,
3235
}
3336
}

crates/cubecl-attention/src/components/batch/simple/attention.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::components::{
1010
BatchAttention, BatchAttentionConfig, CubeCountInput, simple::config::SimpleBatchConfig,
1111
},
1212
global::{GlobalAttention, GlobalAttentionConfig as _},
13+
stage::StageAttentionConfig as _,
1314
};
1415

1516
pub struct SimpleBatchAttention<AP: AttentionPrecision, GA: GlobalAttention<AP>> {
@@ -35,7 +36,7 @@ impl<GA: GlobalAttention<AP>, AP: AttentionPrecision> BatchAttention<AP>
3536
let q_index = CUBE_POS_X;
3637
let batch_index = CUBE_POS_Y;
3738

38-
let stage_q_offset = q_index * global_config.tiling_scheme().elements_in_stage_seq_q();
39+
let stage_q_offset = q_index * global_config.stage_config().elements_in_stage_seq_q();
3940

4041
// Assume [batch, num_heads, seq_*, head_dim] layout
4142
let seq_q = query.shape(2);

crates/cubecl-attention/src/components/batch/simple/config.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ use crate::components::{
1010
pub struct SimpleBatchConfig<G: GlobalAttentionConfig> {
1111
global_config: G,
1212
hypercube_config: HypercubeConfig,
13-
seq_kv: u32,
1413
}
1514

1615
impl<G: GlobalAttentionConfig> BatchAttentionConfig for SimpleBatchConfig<G> {
@@ -30,11 +29,10 @@ impl<G: GlobalAttentionConfig> BatchAttentionConfig for SimpleBatchConfig<G> {
3029
}
3130

3231
impl<G: GlobalAttentionConfig> SimpleBatchConfig<G> {
33-
pub fn new(global_config: G, hypercube_config: HypercubeConfig, seq_kv: u32) -> Self {
32+
pub fn new(global_config: G, hypercube_config: HypercubeConfig) -> Self {
3433
Self {
3534
global_config,
3635
hypercube_config,
37-
seq_kv,
3836
}
3937
}
4038

crates/cubecl-attention/src/components/batch/simple/setup.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ impl<GA: GlobalAttentionFamily> BatchAttentionFamily for SimpleBatchAttentionFam
3636
selection
3737
.hypercube_selection
3838
.to_hypercube_config(problem, client.properties().hardware.max_cube_count.clone()),
39-
problem.seq_kv as u32,
4039
)
4140
.validate(problem)
4241
}

crates/cubecl-attention/src/components/global/base.rs

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@ use cubecl_core as cubecl;
22
use cubecl_core::prelude::*;
33

44
use crate::components::{AttentionElems, global::simple::AttentionWriter};
5-
use cubecl_matmul::components::{global::memory::GlobalMemoryConfig, stage::StageMemoryConfig};
65
use cubecl_std::{CubeOption, tensor::r#virtual::VirtualTensor};
76

87
use crate::components::{
9-
AttentionIdent, AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection,
10-
AttentionSetupError, AttentionTilingScheme, AvailableLineSizes, attention_types::*,
11-
global::simple::QueryReader, stage::StageAttentionConfig,
8+
AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection,
9+
AttentionSetupError, AvailableLineSizes, attention_types::*, global::simple::QueryReader,
10+
stage::StageAttentionConfig,
1211
};
1312
use std::{fmt::Debug, hash::Hash};
1413

@@ -107,14 +106,5 @@ pub trait GlobalAttentionConfig:
107106
type StageConfig: StageAttentionConfig;
108107

109108
fn stage_config(&self) -> Self::StageConfig;
110-
fn key_stage_memory_config(&self) -> StageMemoryConfig;
111-
fn value_stage_memory_config(&self) -> StageMemoryConfig;
112-
113109
fn cube_dim(&self) -> CubeDim;
114-
fn plane_dim(&self) -> u32;
115-
fn global_memory_config(&self, ident: AttentionIdent) -> GlobalMemoryConfig;
116-
117-
fn tiling_scheme(&self) -> AttentionTilingScheme;
118-
119-
fn causal_mask(&self) -> bool;
120110
}

crates/cubecl-attention/src/components/global/layout.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ impl Layout for AttentionGlobalLayout {
4444
type SourceCoordinates = Coords1d;
4545

4646
fn to_source_pos(&self, coords: Self::Coordinates) -> u32 {
47-
let line_size = comptime![self.config.line_size()];
47+
let line_size = comptime![self.config.line_size];
4848
let (row, col) = coords;
4949
let idx = self.batch_offset + row * self.stride_row + col * self.stride_col;
5050

@@ -62,10 +62,7 @@ impl Layout for AttentionGlobalLayout {
6262
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
6363
let (row, col) = pos;
6464

65-
match comptime!((
66-
self.config.check_row_bounds(),
67-
self.config.check_col_bounds()
68-
)) {
65+
match comptime!((self.config.check_row_bounds, self.config.check_col_bounds)) {
6966
(true, true) => row < self.rows && col < self.columns,
7067
(true, false) => row < self.rows,
7168
(false, true) => col < self.columns,

crates/cubecl-attention/src/components/global/simple/attention.rs

Lines changed: 41 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ use cubecl_std::{CubeOption, CubeOptionExpand};
77
use std::marker::PhantomData;
88

99
use crate::components::attention_types::*;
10-
use crate::components::global::base::GlobalAttentionConfig;
10+
use crate::components::global::simple::QueryReader;
1111
use crate::components::global::simple::{AttentionWriter, AttentionWriterExpand, MaskReader};
1212
use crate::components::global::{AttentionGlobalLayout, simple::DummyKeyValueReader};
13-
use crate::components::stage::{AttentionPartitioner, AttentionTilingLayout, StageAttention};
14-
use crate::components::{AttentionIdent, global::simple::QueryReader};
13+
use crate::components::stage::{
14+
AttentionPartitioner, AttentionTilingLayout, StageAttention, StageAttentionConfig as _,
15+
};
1516
use crate::components::{
1617
AttentionPrecision,
17-
global::{GlobalAttention, simple::config::SimpleGlobalConfig},
18+
global::{GlobalAttention, simple::config::SimpleGlobalAttentionConfig},
1819
};
1920

2021
pub struct SimpleGlobalAttention<AP: AttentionPrecision, SA: StageAttention<AP>> {
@@ -32,13 +33,13 @@ impl<
3233
AP: AttentionPrecision,
3334
> GlobalAttention<AP> for SimpleGlobalAttention<AP, SA>
3435
{
35-
type KeyReader = DummyKeyValueReader<KG<AP>, KS<AP>, Self::Config>;
36-
type ValueReader = DummyKeyValueReader<VG<AP>, VS<AP>, Self::Config>;
36+
type KeyReader = DummyKeyValueReader<KG<AP>, KS<AP>>;
37+
type ValueReader = DummyKeyValueReader<VG<AP>, VS<AP>>;
3738
type MaskReader = MaskReader<AP>;
3839

3940
type Writer = <SA::Partitioner as AttentionPartitioner>::Writer<OS<AP>, OG<AP>>;
4041

41-
type Config = SimpleGlobalConfig<SA::Config>;
42+
type Config = SimpleGlobalAttentionConfig<SA::Config>;
4243

4344
fn execute(
4445
query_reader: QueryReader<AP>,
@@ -51,32 +52,32 @@ impl<
5152
#[comptime] config: Self::Config,
5253
) {
5354
// Init staging shared memories
54-
let mut key_stage = key_reader.init_stage(config.key_stage_memory_config());
55-
let mut value_stage = value_reader.init_stage(config.value_stage_memory_config());
55+
let mut key_stage = key_reader.init_stage();
56+
let mut value_stage = value_reader.init_stage();
5657

5758
// Load queries which stay alive in registers for all the kernel
58-
let mut query_registers = SA::init_query(config.stage_config());
59-
SA::read_query(&query_reader, &mut query_registers, config.stage_config());
59+
let mut query_registers = SA::init_query(config.stage_config);
60+
SA::read_query(&query_reader, &mut query_registers, config.stage_config);
6061

6162
// Init registers that will change inside global loop
62-
let mut key_value_registers = SA::init_key_value(config.stage_config());
63+
let mut key_value_registers = SA::init_key_value(config.stage_config);
6364
let mut mask_registers =
64-
SA::init_mask(CubeOption::new_Some((seq_q, seq_kv)), config.stage_config());
65-
let mut softmax_registers = SA::init_softmax(config.stage_config());
66-
let mut accumulator_registers = SA::init_accumulator(config.stage_config());
65+
SA::init_mask(CubeOption::new_Some((seq_q, seq_kv)), config.stage_config);
66+
let mut softmax_registers = SA::init_softmax(config.stage_config);
67+
let mut accumulator_registers = SA::init_accumulator(config.stage_config);
6768

6869
// Init running state
69-
let mut stage_state = SA::init_state(config.stage_config());
70+
let mut stage_state = SA::init_state(config.stage_config);
7071

7172
// Define number of global iterations
7273
let num_stage_iterations =
73-
seq_kv.div_ceil(config.tiling_scheme().elements_in_partition_seq_kv());
74+
seq_kv.div_ceil(config.stage_config.elements_in_partition_seq_kv());
7475

7576
// Global loop over seq_kv
7677
for _ in 0..num_stage_iterations {
7778
// Put key and value into stage
78-
key_reader.read_global(&mut key_stage, config);
79-
value_reader.read_global(&mut value_stage, config);
79+
key_reader.read_global(&mut key_stage);
80+
value_reader.read_global(&mut value_stage);
8081

8182
sync_cube();
8283

@@ -91,7 +92,7 @@ impl<
9192
&mut softmax_registers,
9293
&mut accumulator_registers,
9394
&mut stage_state,
94-
config.stage_config(),
95+
config.stage_config,
9596
);
9697

9798
sync_cube();
@@ -103,19 +104,15 @@ impl<
103104
}
104105

105106
// Accumulators must be rescaled using running state
106-
SA::rescale(
107-
&mut accumulator_registers,
108-
stage_state,
109-
config.stage_config(),
110-
);
107+
SA::rescale(&mut accumulator_registers, stage_state, config.stage_config);
111108

112109
// Write accumulators to output
113110
let mut out_stage = writer.stage();
114111
SA::write::<Self::Writer, Self::Config>(
115112
&accumulator_registers,
116113
&mut out_stage,
117114
&mut writer,
118-
config.stage_config(),
115+
config.stage_config,
119116
)
120117
}
121118

@@ -125,11 +122,7 @@ impl<
125122
query: VirtualTensor<QG<AP>>,
126123
#[comptime] config: Self::Config,
127124
) -> QueryReader<AP> {
128-
let layout = AttentionGlobalLayout::new(
129-
&query,
130-
batch_index,
131-
config.global_memory_config(AttentionIdent::Query),
132-
);
125+
let layout = AttentionGlobalLayout::new(&query, batch_index, config.query_gmem_config);
133126

134127
QueryReader::<AP>::new(stage_q_offset, query.view(layout))
135128
}
@@ -139,27 +132,21 @@ impl<
139132
key: VirtualTensor<KG<AP>>,
140133
#[comptime] config: Self::Config,
141134
) -> Self::KeyReader {
142-
let step = reduction_step::<Self::Config>(config);
143-
let layout = AttentionGlobalLayout::new(
144-
&key,
145-
batch_index,
146-
config.global_memory_config(AttentionIdent::Key),
147-
);
148-
DummyKeyValueReader::new(key.view(layout), step, AttentionIdent::Key)
135+
let step = config.stage_config.elements_in_partition_seq_kv().runtime();
136+
let layout =
137+
AttentionGlobalLayout::new(&key, batch_index, config.key_reader_config.gmem_config);
138+
DummyKeyValueReader::new(key.view(layout), step, config.key_reader_config)
149139
}
150140

151141
fn init_value_reader(
152142
batch_index: u32,
153143
value: VirtualTensor<VG<AP>>,
154144
#[comptime] config: Self::Config,
155145
) -> Self::ValueReader {
156-
let step = reduction_step::<Self::Config>(config);
157-
let layout = AttentionGlobalLayout::new(
158-
&value,
159-
batch_index,
160-
config.global_memory_config(AttentionIdent::Value),
161-
);
162-
DummyKeyValueReader::new(value.view(layout), step, AttentionIdent::Value)
146+
let step = config.stage_config.elements_in_partition_seq_kv().runtime();
147+
let layout =
148+
AttentionGlobalLayout::new(&value, batch_index, config.value_reader_config.gmem_config);
149+
DummyKeyValueReader::new(value.view(layout), step, config.value_reader_config)
163150
}
164151

165152
fn init_mask_reader(
@@ -169,24 +156,22 @@ impl<
169156
seq_kv_shape: u32,
170157
#[comptime] config: Self::Config,
171158
) -> Self::MaskReader {
172-
let step = reduction_step::<Self::Config>(config);
159+
let step = config.stage_config.elements_in_partition_seq_kv().runtime();
173160
let partition_q_offset = <SA::Partitioner as AttentionPartitioner>::seq_q_index()
174-
* config.tiling_scheme().elements_in_partition_seq_q();
161+
* config.stage_config.elements_in_partition_seq_q();
175162

176163
match mask {
177164
CubeOption::Some(mask) => {
178-
let layout = AttentionGlobalLayout::new(
179-
&mask,
180-
batch_index,
181-
config.global_memory_config(AttentionIdent::Mask),
182-
);
165+
let layout =
166+
AttentionGlobalLayout::new(&mask, batch_index, config.mask_gmem_config);
183167

184168
MaskReader::new_materialized(
185169
stage_q_offset,
186170
partition_q_offset,
187171
mask.view(layout),
188172
step,
189173
seq_kv_shape,
174+
config.mask_gmem_config.view_direction,
190175
)
191176
}
192177
CubeOption::None => MaskReader::new_logical(stage_q_offset + partition_q_offset, step),
@@ -199,22 +184,13 @@ impl<
199184
out: VirtualTensor<OG<AP>, ReadWrite>,
200185
#[comptime] config: Self::Config,
201186
) -> Self::Writer {
202-
let conf = config.global_memory_config(AttentionIdent::Out);
203-
let layout = AttentionGlobalLayout::new(&out, batch_index, conf);
187+
let layout =
188+
AttentionGlobalLayout::new(&out, batch_index, config.writer_config.gmem_config);
204189
let out = out.view_mut(layout);
205190

206-
Self::Writer::new::<SA::Config>(
191+
Self::Writer::init::<SA::Config>(
207192
out.slice_mut_unchecked((stage_q_offset, 0), out.shape()),
208-
conf,
209-
config.stage_config(),
193+
config.writer_config,
210194
)
211195
}
212196
}
213-
214-
#[cube]
215-
fn reduction_step<C: GlobalAttentionConfig>(#[comptime] config: C) -> u32 {
216-
config
217-
.tiling_scheme()
218-
.elements_in_partition_seq_kv()
219-
.runtime()
220-
}

0 commit comments

Comments
 (0)