Skip to content

Commit 1a885d6

Browse files
authored
Flash Attention: vectorized query + fix metal wmma load from global memory + fix main compilation (#1069)
1 parent 80f3613 commit 1a885d6

File tree

10 files changed

+72
-41
lines changed

10 files changed

+72
-41
lines changed

crates/cubecl-attention/src/components/global/simple/attention.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ impl<
124124
) -> QueryReader<AP> {
125125
let layout = AttentionGlobalLayout::new(&query, batch_index, config.query_gmem_config);
126126

127-
QueryReader::<AP>::new(stage_q_offset, query.view(layout))
127+
QueryReader::<AP>::new(stage_q_offset, query.view(layout), config.query_gmem_config)
128128
}
129129

130130
fn init_key_reader(
Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,34 @@
11
use crate::components::{AttentionTileSize, attention_types::*};
22
use cubecl_core as cubecl;
33
use cubecl_core::prelude::*;
4-
use cubecl_matmul::components::MatrixLayout;
4+
use cubecl_matmul::components::global::memory::GlobalMemoryConfig;
55
use cubecl_matmul::components::tile::StridedTile;
6-
use cubecl_std::{
7-
Swizzle,
8-
tensor::{View, layout::Coords2d},
9-
};
6+
use cubecl_std::Swizzle;
7+
use cubecl_std::tensor::{View, layout::Coords2d};
108

119
use crate::components::AttentionPrecision;
12-
use crate::components::stage::{AttentionPartitioner, StageAttentionConfig};
10+
use crate::components::stage::AttentionPartitioner;
1311

1412
#[derive(CubeType)]
1513
pub struct QueryReader<AP: AttentionPrecision> {
1614
query: View<Line<QG<AP>>, Coords2d>,
15+
#[cube(comptime)]
16+
gmem_config: GlobalMemoryConfig,
1717
}
1818

1919
#[cube]
2020
impl<AP: AttentionPrecision> QueryReader<AP> {
21-
pub fn new(stage_q_offset: u32, query: View<Line<QG<AP>>, Coords2d>) -> Self {
21+
pub fn new(
22+
stage_q_offset: u32,
23+
query: View<Line<QG<AP>>, Coords2d>,
24+
#[comptime] gmem_config: GlobalMemoryConfig,
25+
) -> Self {
2226
let query = query.slice((stage_q_offset, 0), query.shape());
2327

24-
QueryReader::<AP> { query }
28+
QueryReader::<AP> { query, gmem_config }
2529
}
2630

27-
pub fn get_tile<P: AttentionPartitioner, S: StageAttentionConfig>(
31+
pub fn get_tile<P: AttentionPartitioner>(
2832
&self,
2933
tile: Coords2d,
3034
#[comptime] attention_tile_size: AttentionTileSize,
@@ -35,22 +39,32 @@ impl<AP: AttentionPrecision> QueryReader<AP> {
3539

3640
let row = row_in_partition + P::seq_q_index() * partition_seq_q;
3741

42+
let line_size = self.gmem_config.line_size;
43+
44+
let slice = self
45+
.query
46+
.slice(
47+
(
48+
row * attention_tile_size.seq_q,
49+
col * attention_tile_size.head_dim,
50+
),
51+
(attention_tile_size.seq_q, attention_tile_size.head_dim).runtime(),
52+
)
53+
.to_linear_slice();
54+
55+
let start = 0;
56+
let length = attention_tile_size.seq_q * attention_tile_size.head_dim / line_size;
57+
let end = start + length;
58+
let stride = partition_head_dim * attention_tile_size.head_dim / line_size;
59+
3860
StridedTile::<QG<AP>>::new_strided(
39-
self.query
40-
.slice(
41-
(
42-
row * attention_tile_size.seq_q,
43-
col * attention_tile_size.head_dim,
44-
),
45-
(attention_tile_size.seq_q, attention_tile_size.head_dim).runtime(),
46-
)
47-
.to_linear_slice(),
48-
0,
49-
attention_tile_size.seq_q * attention_tile_size.head_dim,
50-
partition_head_dim * attention_tile_size.head_dim,
61+
slice,
62+
start,
63+
end,
64+
stride,
5165
Swizzle::none(),
52-
MatrixLayout::RowMajor,
53-
1u32,
66+
self.gmem_config.matrix_layout,
67+
line_size,
5468
)
5569
}
5670
}

crates/cubecl-attention/src/components/stage/partition_attention.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ impl<
281281
#[unroll]
282282
for hd in 0..partition_head_dim {
283283
let tile_to_write = registers.get_at_mut(q, hd, config);
284-
let tile_read = reader.get_tile::<P, Self::Config>(
284+
let tile_read = reader.get_tile::<P>(
285285
(q, hd).runtime(),
286286
attention_tile_size,
287287
partition_seq_q,

crates/cubecl-attention/src/components/tile/accelerated/attention.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ impl<AP: AttentionPrecision> TileAttention<AP> for BlackboxAcceleratedTileAttent
127127

128128
fn load_query<E: Numeric>(tile: &StridedTile<E>, fragment: &mut Self::Query) {
129129
let (slice, stride) = tile.as_unlined();
130-
131130
cmma::load(fragment, &slice, stride);
132131
}
133132

crates/cubecl-attention/src/components/tile/unit_register/attention.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,19 @@ fn strided_tile_to_unit_tile<E: Numeric, E2: Numeric>(
332332
strided_tile: &StridedTile<E>,
333333
unit_tile: &mut UnitTile<E2>,
334334
) {
335+
let line_size = strided_tile.line_size;
336+
assert!(unit_tile.layout.num_cols % line_size == 0);
337+
338+
let col_iterations = comptime!(unit_tile.layout.num_cols / strided_tile.line_size);
339+
335340
for row in 0..unit_tile.layout.num_rows {
336-
for col in 0..unit_tile.layout.num_cols {
337-
unit_tile.data[row * unit_tile.layout.num_cols + col] =
338-
E2::cast_from(strided_tile.get_line(row, col))
341+
for col in 0..col_iterations {
342+
let line_read = strided_tile.get_line(row, col);
343+
#[unroll]
344+
for i in 0..line_size {
345+
unit_tile.data[row * unit_tile.layout.num_cols + col * line_size + i] =
346+
E2::cast_from(line_read[i]);
347+
}
339348
}
340349
}
341350
}

crates/cubecl-attention/src/kernels/blackbox_accelerated.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ impl Algorithm for BlackboxAcceleratedAlgorithm {
2323
type GlobalAttention = SimpleGlobalAttentionFamily<Self::StageAttention>;
2424
type BatchAttention = SimpleBatchAttentionFamily<Self::GlobalAttention>;
2525

26-
fn filter_line_sizes(_available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
26+
fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
2727
AvailableLineSizes {
28-
query: vec![1],
28+
query: available_line_sizes.query,
2929
key: vec![1],
3030
value: vec![1],
3131
mask: vec![1],

crates/cubecl-attention/src/kernels/unit.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ impl Algorithm for UnitAlgorithm {
2323
type GlobalAttention = SimpleGlobalAttentionFamily<Self::StageAttention>;
2424
type BatchAttention = SimpleBatchAttentionFamily<Self::GlobalAttention>;
2525

26-
fn filter_line_sizes(_available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
26+
fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
2727
AvailableLineSizes {
28-
query: vec![1],
28+
query: available_line_sizes.query,
2929
key: vec![1],
3030
value: vec![1],
3131
mask: vec![1],

crates/cubecl-attention/src/tests/macros/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ pub fn attention_test_launch<A: Algorithm, R: Runtime>(
5252
two_rows_in_array_tile: test_options.two_rows_in_array_tile,
5353
};
5454

55-
test_attention_algorithm::<A, (f32, f32), R>(client, problem, selection);
55+
test_attention_algorithm::<A, (half::f16, half::f16), R>(client, problem, selection);
5656
}
5757

5858
#[macro_export]

crates/cubecl-cpp/src/metal/dialect.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,10 +1072,19 @@ impl DialectWmmaCompiler<Self> for MslDialect {
10721072
let item = value.item();
10731073
if item.vectorization > 1 {
10741074
let elem = item.elem;
1075-
writeln!(
1076-
f,
1077-
"simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
1078-
)
1075+
match value {
1076+
Variable::GlobalInputArray(..) => writeln!(
1077+
f,
1078+
"simdgroup_load({frag}, (device {elem}*)({value} + {offset}), {stride}, 0, {transpose});"
1079+
),
1080+
Variable::SharedMemory(..) => writeln!(
1081+
f,
1082+
"simdgroup_load({frag}, reinterpret_cast<threadgroup {elem} *>({value} + {offset}), {stride}, 0, {transpose});"
1083+
),
1084+
_ => panic!(
1085+
"Vectorized wmma load is only supported from global or shared memory."
1086+
),
1087+
}
10791088
} else {
10801089
writeln!(
10811090
f,

crates/cubecl-std/src/tests/event.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ fn launch_test_3(output: &mut Array<f32>) {
150150
test_3(output.to_slice_mut());
151151
}
152152

153-
pub fn event_test_1<R: Runtime>(client: ComputeClient<R::Server>) {
153+
pub fn event_test_1<R: Runtime>(client: ComputeClient<R>) {
154154
let output = client.empty(8);
155155

156156
unsafe {
@@ -168,7 +168,7 @@ pub fn event_test_1<R: Runtime>(client: ComputeClient<R::Server>) {
168168
assert_eq!(actual, &[20.0, 50.0]);
169169
}
170170

171-
pub fn event_test_2<R: Runtime>(client: ComputeClient<R::Server>) {
171+
pub fn event_test_2<R: Runtime>(client: ComputeClient<R>) {
172172
let output = client.empty(8);
173173

174174
unsafe {
@@ -186,7 +186,7 @@ pub fn event_test_2<R: Runtime>(client: ComputeClient<R::Server>) {
186186
assert_eq!(actual, &[15.0, 30.0]);
187187
}
188188

189-
pub fn event_test_3<R: Runtime>(client: ComputeClient<R::Server>) {
189+
pub fn event_test_3<R: Runtime>(client: ComputeClient<R>) {
190190
let output = client.empty(12);
191191

192192
unsafe {

0 commit comments

Comments
 (0)