Skip to content

Commit 477dc3b

Browse files
authored
Flash Attention: fix all-masked rows (#1070)
1 parent 1a885d6 commit 477dc3b

File tree

12 files changed

+146
-48
lines changed

12 files changed

+146
-48
lines changed

crates/cubecl-attention/src/components/global/simple/attention.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ impl<
171171
mask.view(layout),
172172
step,
173173
seq_kv_shape,
174-
config.mask_gmem_config.view_direction,
174+
config.mask_gmem_config,
175175
)
176176
}
177177
CubeOption::None => MaskReader::new_logical(stage_q_offset + partition_q_offset, step),

crates/cubecl-attention/src/components/global/simple/reader/mask.rs

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ use crate::components::tile::TileAttentionConfig;
22
use crate::components::{AttentionTileSize, attention_types::*};
33
use cubecl_core as cubecl;
44
use cubecl_core::prelude::*;
5-
use cubecl_matmul::components::MatrixLayout;
6-
use cubecl_matmul::components::global::memory::{GlobalIterator, ViewDirection};
5+
use cubecl_matmul::components::global::memory::{GlobalIterator, GlobalMemoryConfig};
76
use cubecl_matmul::components::tile::StridedTile;
87
use cubecl_std::tensor::{View, layout::Coords2d};
98
use cubecl_std::{Swizzle, tensor::layout::Coordinates};
@@ -44,6 +43,8 @@ pub struct MaterializedMaskReader<M: Numeric> {
4443
logical_iter: LogicalIterator,
4544
// TODO not sure if mandatory, but i need for the stride when reading in global memory
4645
seq_kv_shape: u32,
46+
#[cube(comptime)]
47+
gmem_config: GlobalMemoryConfig,
4748
}
4849

4950
#[derive(CubeType)]
@@ -64,15 +65,16 @@ impl<AP: AttentionPrecision> MaskReader<AP> {
6465
mask: View<Line<MSK<AP>>, Coords2d>,
6566
step: u32,
6667
seq_kv_shape: u32,
67-
#[comptime] view_direction: ViewDirection,
68+
#[comptime] gmem_config: GlobalMemoryConfig,
6869
) -> Self {
6970
let mask = mask.slice((stage_q_offset, 0), mask.shape());
70-
let global_iter = GlobalIterator::new(mask, step, view_direction, false);
71+
let global_iter = GlobalIterator::new(mask, step, gmem_config.view_direction, false);
7172

7273
MaskReader::<AP>::new_Materialized(MaterializedMaskReader::new(
7374
global_iter,
7475
LogicalIterator::init(partition_q_offset, step),
7576
seq_kv_shape,
77+
gmem_config,
7678
))
7779
}
7880

@@ -117,11 +119,13 @@ impl<M: Numeric> MaterializedMaskReader<M> {
117119
global_iter: GlobalIterator<Line<M>>,
118120
logical_iter: LogicalIterator,
119121
seq_kv_shape: u32,
122+
#[comptime] gmem_config: GlobalMemoryConfig,
120123
) -> Self {
121124
MaterializedMaskReader::<M> {
122125
global_iter,
123126
logical_iter,
124127
seq_kv_shape,
128+
gmem_config,
125129
}
126130
}
127131

@@ -135,20 +139,29 @@ impl<M: Numeric> MaterializedMaskReader<M> {
135139

136140
let row = row_offset + P::seq_q_index() * elements_in_partition_seq_q;
137141

142+
let slice = self
143+
.global_iter
144+
.view()
145+
.slice(
146+
(row, col.runtime()),
147+
(attention_tile_size.seq_q, attention_tile_size.seq_kv).runtime(),
148+
)
149+
.to_linear_slice();
150+
151+
let line_size = self.gmem_config.line_size;
152+
let start = 0;
153+
let length = attention_tile_size.seq_q * attention_tile_size.seq_kv / line_size;
154+
let end = start + length;
155+
let stride = self.seq_kv_shape / line_size;
156+
138157
StridedTile::<M>::new_strided(
139-
self.global_iter
140-
.view()
141-
.slice(
142-
(row, col.runtime()),
143-
(attention_tile_size.seq_q, attention_tile_size.seq_kv).runtime(),
144-
)
145-
.to_linear_slice(),
146-
0,
147-
attention_tile_size.seq_q * attention_tile_size.seq_kv,
148-
self.seq_kv_shape,
158+
slice,
159+
start,
160+
end,
161+
stride,
149162
Swizzle::none(),
150-
MatrixLayout::RowMajor,
151-
1u32,
163+
self.gmem_config.matrix_layout,
164+
line_size,
152165
)
153166
}
154167

crates/cubecl-attention/src/components/tile/accelerated/local_tile.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use cubecl_matmul::components::tile::StridedTile;
44
use cubecl_std::tensor::layout::Coords2d;
55

66
use crate::components::tile::{
7-
FragmentAccumulator, FragmentAccumulatorExpand, FragmentMask, FragmentMaskExpand, RowVal,
8-
RowWise, RowwiseFormat, RowwiseFormatExpand,
7+
FragmentAccumulator, FragmentAccumulatorExpand, FragmentMask, FragmentMaskExpand, LOGIT_MASKED,
8+
RowVal, RowWise, RowwiseFormat, RowwiseFormatExpand,
99
};
1010

1111
use crate::components::tile::{FragmentLayout, FragmentLayoutExpand};
@@ -226,13 +226,21 @@ impl<E: Float> RowwiseFormat<E> for LocalTile<E> {
226226
}
227227

228228
fn exp_diff(&mut self, val: &RowWise<E>) {
229+
let threshold = E::new(LOGIT_MASKED);
230+
229231
#[unroll]
230232
for r in 0..self.layout.unit_size.0 {
231233
let row_offset = r * self.layout.unit_size.1;
234+
235+
let val = val.index(r);
236+
232237
#[unroll]
233238
for c in 0..self.layout.unit_size.1 {
234239
let index = row_offset + c;
235-
self.array[index] = Exp::exp(self.array[index] - val.index(r));
240+
241+
let safe_val = Max::max(val, threshold);
242+
let not_masked = E::cast_from(val >= threshold);
243+
self.array[index] = not_masked * Exp::exp(self.array[index] - safe_val);
236244
}
237245
}
238246
}

crates/cubecl-attention/src/components/tile/base.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ use crate::components::{
1414
use std::fmt::Debug;
1515
use std::hash::Hash;
1616

17+
/// Logits below this are considered masked (effectively -inf)
18+
pub(crate) const LOGIT_MASKED: f32 = -1e5;
19+
20+
/// Any value smaller than this is considered numerically zero
21+
/// (used for fully-masked rows or tiny contributions)
22+
pub(crate) const FULLY_MASKED_ROW_THRESHOLD: f32 = 1e-7;
23+
1724
#[cube]
1825
pub trait TileAttention<AP: AttentionPrecision>: Send + Sync + 'static {
1926
type Config: TileAttentionConfig;

crates/cubecl-attention/src/components/tile/rowwise.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use cubecl_core as cubecl;
22
use cubecl_core::prelude::*;
33

4+
use crate::components::tile::FULLY_MASKED_ROW_THRESHOLD;
5+
46
#[derive(CubeType)]
57
/// Contains one value per row of a fragment for which the unit contributes
68
///
@@ -188,13 +190,22 @@ impl<E: Float> RowWise<E> {
188190
}
189191
}
190192

191-
/// Changes the value v at each row for 1/v
193+
/// Replaces each value `v` (v >= 0) in a row with `1/v`.
194+
///
195+
/// If `v = 0`, the result is set to `0` instead of `1/0`.
196+
/// This occurs when the entire row is masked, meaning it should
197+
/// contribute no information, and ensures numerical stability.
192198
pub fn recip_inplace(&mut self) {
193199
let mut i = comptime![0u32];
194200
#[unroll]
195201
for _ in 0..self.num_rows {
196202
let row_val = self.vals.index_mut(i);
197-
row_val.val = Recip::recip(row_val.val);
203+
204+
let epsilon = E::new(FULLY_MASKED_ROW_THRESHOLD);
205+
let not_masked = E::cast_from(row_val.val >= epsilon);
206+
let safe_val = Max::max(row_val.val, epsilon);
207+
let recip = Recip::recip(safe_val);
208+
row_val.val = not_masked * recip;
198209

199210
comptime![i += 1];
200211
}

crates/cubecl-attention/src/components/tile/unit_register/attention.rs

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use cubecl_std::tensor::layout::Coords2d;
77

88
use crate::components::AttentionPrecision;
99
use crate::components::attention_types::*;
10+
use crate::components::tile::LOGIT_MASKED;
1011
use crate::components::tile::RowVal;
1112
use crate::components::tile::RowWise;
1213
use crate::components::tile::unit_register::setup::UnitTileAttentionConfig;
@@ -138,13 +139,21 @@ impl<E: Float> RowwiseFormat<E> for UnitTile<E> {
138139
}
139140

140141
fn exp_diff(&mut self, val: &RowWise<E>) {
142+
let threshold = E::new(LOGIT_MASKED);
143+
141144
#[unroll]
142145
for r in 0..self.layout.num_rows {
143146
let row_offset = r * self.layout.num_cols;
147+
148+
let val = val.index(r);
149+
144150
#[unroll]
145151
for c in 0..self.layout.num_cols {
146152
let index = row_offset + c;
147-
self.data[index] = Exp::exp(self.data[index] - val.index(r));
153+
154+
let safe_val = Max::max(val, threshold);
155+
let not_masked = E::cast_from(val >= threshold);
156+
self.data[index] = not_masked * Exp::exp(self.data[index] - safe_val);
148157
}
149158
}
150159
}
@@ -323,7 +332,7 @@ impl<AP: AttentionPrecision> TileAttention<AP> for UnitRegisterTileAttention {
323332
slice: &mut SliceMut<Line<E>>,
324333
#[comptime] _config: Self::Config,
325334
) {
326-
array_tile_to_slice(out, slice)
335+
unit_tile_to_slice(out, slice)
327336
}
328337
}
329338

@@ -363,14 +372,27 @@ fn strided_tile_to_transposed_unit_tile<E: Numeric, E2: Numeric>(
363372
}
364373

365374
#[cube]
366-
fn array_tile_to_slice<E: Numeric, E2: Numeric>(
375+
fn unit_tile_to_slice<E: Numeric, E2: Numeric>(
367376
unit_tile: &UnitTile<E>,
368377
slice: &mut SliceMut<Line<E2>>,
369378
) {
379+
let line_size = slice.line_size();
380+
assert!(unit_tile.layout.num_cols % line_size == 0);
381+
382+
let col_iterations = comptime!(unit_tile.layout.num_cols / line_size);
383+
370384
for row in 0..unit_tile.layout.num_rows {
371-
for col in 0..unit_tile.layout.num_cols {
372-
let index = row * unit_tile.layout.num_cols + col;
373-
slice[index] = Line::cast_from(unit_tile.data[index]);
385+
for col in 0..col_iterations {
386+
let mut out_line = Line::empty(line_size);
387+
388+
#[unroll]
389+
for i in 0..line_size {
390+
let index = row * unit_tile.layout.num_cols + col * line_size + i;
391+
out_line[i] = E2::cast_from(unit_tile.data[index]);
392+
}
393+
394+
let line_index = row * col_iterations + col;
395+
slice[line_index] = out_line;
374396
}
375397
}
376398
}

crates/cubecl-attention/src/kernels/blackbox_accelerated.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ impl Algorithm for BlackboxAcceleratedAlgorithm {
2929
key: vec![1],
3030
value: vec![1],
3131
mask: vec![1],
32-
out: vec![1],
32+
out: available_line_sizes.out,
3333
}
3434
}
3535
}

crates/cubecl-attention/src/kernels/unit.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ impl Algorithm for UnitAlgorithm {
2929
key: vec![1],
3030
value: vec![1],
3131
mask: vec![1],
32-
out: vec![1],
32+
out: available_line_sizes.out,
3333
}
3434
}
3535
}

crates/cubecl-attention/src/tests/attention_test_launcher.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ pub fn test_attention_algorithm<A, P, R>(
3636
let panic_on_launch_err = match env {
3737
Ok(val) => match val.as_str() {
3838
"panic" => true,
39-
"skip" => false,
4039
_ => false,
4140
},
4241
Err(_) => false,
@@ -162,6 +161,7 @@ where
162161
let handle = T::sample(client, &tensor_shape, sample_seed);
163162
let data = client.read_one(handle.handle);
164163
let data = T::from_bytes(&data);
164+
165165
let original_data = data.to_owned();
166166
let data_bytes = T::as_bytes(&original_data);
167167
let shape = tensor_shape.as_slice();

crates/cubecl-attention/src/tests/macros/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ pub fn attention_test_launch<A: Algorithm, R: Runtime>(
5252
two_rows_in_array_tile: test_options.two_rows_in_array_tile,
5353
};
5454

55-
test_attention_algorithm::<A, (half::f16, half::f16), R>(client, problem, selection);
55+
test_attention_algorithm::<A, (f32, f32), R>(client, problem, selection);
56+
// test_attention_algorithm::<A, (half::f16, half::f16), R>(client, problem, selection);
5657
}
5758

5859
#[macro_export]

0 commit comments

Comments
 (0)