Skip to content

Commit 29c0312

Browse files
authored
Flash Attention: all lines (#1073)
1 parent 06fb1d0 commit 29c0312

File tree

3 files changed

+18
-25
lines changed

3 files changed

+18
-25
lines changed

crates/cubecl-attention/src/components/tile/unit_register/attention.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,22 @@ fn strided_tile_to_transposed_unit_tile<E: Numeric, E2: Numeric>(
363363
strided_tile: &StridedTile<E>,
364364
unit_tile: &mut UnitTile<E2>,
365365
) {
366-
for row in 0..unit_tile.layout.num_rows {
367-
for col in 0..unit_tile.layout.num_cols {
368-
unit_tile.data[row * unit_tile.layout.num_cols + col] =
369-
E2::cast_from(strided_tile.get_line(col, row))
366+
let line_size = strided_tile.line_size;
367+
assert!(unit_tile.layout.num_cols % line_size == 0);
368+
369+
let input_num_rows = unit_tile.layout.num_cols;
370+
let input_num_cols = unit_tile.layout.num_rows;
371+
let line_iterations = comptime!(input_num_cols / line_size);
372+
373+
for input_row in 0..input_num_rows {
374+
for input_col_line in 0..line_iterations {
375+
let line_read = strided_tile.get_line(input_row, input_col_line);
376+
377+
#[unroll]
378+
for i in 0..line_size {
379+
unit_tile.data[(input_col_line + i) * input_num_rows + input_row] =
380+
E2::cast_from(line_read[i]);
381+
}
370382
}
371383
}
372384
}

crates/cubecl-attention/src/kernels/blackbox_accelerated.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,6 @@ impl Algorithm for BlackboxAcceleratedAlgorithm {
2525
type BatchAttention = SimpleBatchAttentionFamily<Self::GlobalAttention>;
2626

2727
fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
28-
let supported = AvailableLineSizes {
29-
query: available_line_sizes.query,
30-
key: vec![1],
31-
value: available_line_sizes.value,
32-
mask: available_line_sizes.mask,
33-
out: available_line_sizes.out,
34-
};
35-
36-
Self::TileAttention::filter_line_sizes(supported)
28+
Self::TileAttention::filter_line_sizes(available_line_sizes)
3729
}
3830
}

crates/cubecl-attention/src/kernels/unit.rs

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ use crate::components::stage::unit::UnitPartitionStageAttentionFamily;
44
use crate::components::tile::unit_register::UnitRegisterTileAttention;
55
use crate::{
66
components::{
7-
AvailableLineSizes, batch::simple::SimpleBatchAttentionFamily,
8-
global::simple::SimpleGlobalAttentionFamily,
7+
batch::simple::SimpleBatchAttentionFamily, global::simple::SimpleGlobalAttentionFamily,
98
},
109
kernels::Algorithm,
1110
};
@@ -22,14 +21,4 @@ impl Algorithm for UnitAlgorithm {
2221
>;
2322
type GlobalAttention = SimpleGlobalAttentionFamily<Self::StageAttention>;
2423
type BatchAttention = SimpleBatchAttentionFamily<Self::GlobalAttention>;
25-
26-
fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
27-
AvailableLineSizes {
28-
query: available_line_sizes.query,
29-
key: vec![1],
30-
value: available_line_sizes.value,
31-
mask: available_line_sizes.mask,
32-
out: available_line_sizes.out,
33-
}
34-
}
3524
}

0 commit comments

Comments
 (0)