Skip to content

Commit 06fb1d0

Browse files
authored
Flash attention: lines for mask and value (#1072)
1 parent 4f9a7d8 commit 06fb1d0

File tree

7 files changed

+60
-34
lines changed

7 files changed

+60
-34
lines changed

crates/cubecl-attention/src/base.rs

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,27 @@ pub fn launch_attention<R: Runtime, A: Algorithm>(
8787
out: &TensorHandleRef<R>,
8888
attention_elems: &AttentionElems,
8989
) -> Result<(), AttentionSetupError> {
90-
let line_sizes = AvailableLineSizes::from_elem_types(
91-
client,
92-
query.elem_size,
93-
attention_elems.mask.size(),
94-
out.elem_size,
95-
);
96-
let line_sizes = A::filter_line_sizes(line_sizes)
97-
.filter_with_tensor(AttentionIdent::Query, query.strides, query.shape)
98-
.filter_with_tensor(AttentionIdent::Key, key.strides, key.shape)
99-
.filter_with_tensor(AttentionIdent::Value, value.strides, value.shape)
100-
.filter_with_tensor(AttentionIdent::Out, out.strides, out.shape)
101-
.pick_max()
102-
.unwrap();
90+
let line_sizes = {
91+
let ls = AvailableLineSizes::from_elem_types(
92+
client,
93+
query.elem_size,
94+
attention_elems.mask.size(),
95+
out.elem_size,
96+
);
97+
let ls = A::filter_line_sizes(ls)
98+
.filter_with_tensor(AttentionIdent::Query, query.strides, query.shape)
99+
.filter_with_tensor(AttentionIdent::Key, key.strides, key.shape)
100+
.filter_with_tensor(AttentionIdent::Value, value.strides, value.shape)
101+
.filter_with_tensor(AttentionIdent::Out, out.strides, out.shape);
102+
103+
if let Some(mask) = mask.as_ref() {
104+
ls.filter_with_tensor(AttentionIdent::Mask, mask.strides, mask.shape)
105+
} else {
106+
ls
107+
}
108+
}
109+
.pick_max()
110+
.unwrap();
103111

104112
let problem = AttentionProblem {
105113
batch: query.shape[0],

crates/cubecl-attention/src/components/tile/accelerated/local_tile.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ impl<E: Numeric> LocalTile<E> {
7070
}
7171

7272
pub fn load_from_strided_tile<E2: Numeric>(&mut self, strided_tile: &StridedTile<E2>) {
73+
// Assumes line size == 1
7374
for r in 0..self.layout.unit_size.0 {
7475
for c in 0..self.layout.unit_size.1 {
7576
let (row, col) = self.layout.absolute_pos((r, c));

crates/cubecl-attention/src/components/tile/accelerated/setup.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ use cubecl_core::client::ComputeClient;
22
use cubecl_matmul::components::ComputeResources;
33

44
use crate::components::AttentionElems;
5+
use crate::components::AttentionIdent;
56
use crate::components::AttentionTileSize;
7+
use crate::components::AvailableLineSizes;
68
use crate::components::tile::SharedTileAttentionConfig;
79
use crate::components::tile::TileAttentionConfig;
810
use crate::components::tile::accelerated::BlackboxAcceleratedTileAttention;
@@ -86,6 +88,11 @@ impl TileAttentionFamily for BlackboxAcceleratedTileAttention {
8688
selection.reuse_key_value,
8789
)
8890
}
91+
92+
fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
93+
// Vectorized mask not supported
94+
available_line_sizes.filter(|ls| *ls == 1, AttentionIdent::Mask)
95+
}
8996
}
9097

9198
fn validate(

crates/cubecl-attention/src/components/tile/unit_register/attention.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ fn strided_tile_to_unit_tile<E: Numeric, E2: Numeric>(
344344
let line_size = strided_tile.line_size;
345345
assert!(unit_tile.layout.num_cols % line_size == 0);
346346

347-
let col_iterations = comptime!(unit_tile.layout.num_cols / strided_tile.line_size);
347+
let col_iterations = comptime!(unit_tile.layout.num_cols / line_size);
348348

349349
for row in 0..unit_tile.layout.num_rows {
350350
for col in 0..col_iterations {

crates/cubecl-attention/src/kernels/blackbox_accelerated.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use cubecl_matmul::components::{global::PartitionedStageFamily, stage::StridedStageFamily};
22

33
use crate::components::stage::plane::PlanePartitionStageAttentionFamily;
4+
use crate::components::tile::TileAttentionFamily;
45
use crate::components::tile::accelerated::BlackboxAcceleratedTileAttention;
56
use crate::{
67
components::{
@@ -24,12 +25,14 @@ impl Algorithm for BlackboxAcceleratedAlgorithm {
2425
type BatchAttention = SimpleBatchAttentionFamily<Self::GlobalAttention>;
2526

2627
fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
27-
AvailableLineSizes {
28+
let supported = AvailableLineSizes {
2829
query: available_line_sizes.query,
2930
key: vec![1],
30-
value: vec![1],
31-
mask: vec![1],
31+
value: available_line_sizes.value,
32+
mask: available_line_sizes.mask,
3233
out: available_line_sizes.out,
33-
}
34+
};
35+
36+
Self::TileAttention::filter_line_sizes(supported)
3437
}
3538
}

crates/cubecl-attention/src/kernels/unit.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ impl Algorithm for UnitAlgorithm {
2727
AvailableLineSizes {
2828
query: available_line_sizes.query,
2929
key: vec![1],
30-
value: vec![1],
31-
mask: vec![1],
30+
value: available_line_sizes.value,
31+
mask: available_line_sizes.mask,
3232
out: available_line_sizes.out,
3333
}
3434
}

crates/cubecl-attention/src/tests/attention_test_launcher.rs

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,27 @@ pub fn test_attention_algorithm<A, P, R>(
5656
let out = tensor_raw_parts_output::<P, R>(&client, &problem);
5757

5858
let attention_elems = AttentionElems::new::<P::AP>();
59-
let line_sizes = AvailableLineSizes::from_elem_types(
60-
&client,
61-
attention_elems.query_global.size(),
62-
attention_elems.mask.size(),
63-
attention_elems.out_global.size(),
64-
);
65-
let line_sizes = A::filter_line_sizes(line_sizes);
66-
let line_sizes = line_sizes
67-
.filter_with_tensor(AttentionIdent::Query, &query.strides, &query.shape)
68-
.filter_with_tensor(AttentionIdent::Key, &key.strides, &key.shape)
69-
.filter_with_tensor(AttentionIdent::Value, &value.strides, &value.shape)
70-
.filter_with_tensor(AttentionIdent::Out, &out.strides, &out.shape)
71-
.pick_max()
72-
.unwrap();
59+
let line_sizes = {
60+
let ls = AvailableLineSizes::from_elem_types(
61+
&client,
62+
attention_elems.query_global.size(),
63+
attention_elems.mask.size(),
64+
attention_elems.out_global.size(),
65+
);
66+
let ls = A::filter_line_sizes(ls)
67+
.filter_with_tensor(AttentionIdent::Query, &query.strides, &query.shape)
68+
.filter_with_tensor(AttentionIdent::Key, &key.strides, &key.shape)
69+
.filter_with_tensor(AttentionIdent::Value, &value.strides, &value.shape)
70+
.filter_with_tensor(AttentionIdent::Out, &out.strides, &out.shape);
71+
72+
if let Some(mask) = mask.as_ref() {
73+
ls.filter_with_tensor(AttentionIdent::Mask, &mask.strides, &mask.shape)
74+
} else {
75+
ls
76+
}
77+
}
78+
.pick_max()
79+
.unwrap();
7380

7481
let config = match A::setup(&client, &problem, &selection, &line_sizes, &attention_elems) {
7582
Ok(config) => config,

0 commit comments

Comments
 (0)