Skip to content

Commit 84c9613

Browse files
authored
Flash Attention: row-wise reductions (#946)
1 parent da0e292 commit 84c9613

File tree

42 files changed

+1565
-909
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1565
-909
lines changed

crates/cubecl-attention/src/base.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::{
99
AttentionTilingScheme, AvailableLineSizes, args::TensorInputsLaunch, attention_types::*,
1010
batch::HypercubeSelection,
1111
},
12-
kernels::{Algorithm, dummy::DummyAlgorithm},
12+
kernels::{Algorithm, dummy::DummyRegisterAlgorithm},
1313
};
1414

1515
use crate::components::batch::BatchAttentionConfig;
@@ -66,7 +66,7 @@ pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
6666
&MSK::<AP>::as_type_native_unchecked(),
6767
&OG::<AP>::as_type_native_unchecked(),
6868
);
69-
let line_sizes = DummyAlgorithm::filter_line_sizes(line_sizes)
69+
let line_sizes = DummyRegisterAlgorithm::filter_line_sizes(line_sizes)
7070
.filter_with_tensor(AttentionIdent::Query, query.strides, query.shape)
7171
.filter_with_tensor(AttentionIdent::Key, key.strides, key.shape)
7272
.filter_with_tensor(AttentionIdent::Value, value.strides, value.shape)
@@ -105,16 +105,17 @@ pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
105105
},
106106
plane_dim: 32,
107107
reuse_key_value: false,
108+
two_rows_in_array_tile: false,
108109
};
109110

110-
let config = DummyAlgorithm::setup::<AP, R>(client, &problem, &selection, &line_sizes)?;
111+
let config = DummyRegisterAlgorithm::setup::<AP, R>(client, &problem, &selection, &line_sizes)?;
111112

112113
let cube_count_plan = config
113114
.hypercube_config()
114115
.cube_count_plan(&problem, &selection);
115116

116117
unsafe {
117-
<DummyAlgorithm as Algorithm>::BatchAttention::launch_unchecked::<AP, R>(
118+
<DummyRegisterAlgorithm as Algorithm>::BatchAttention::launch_unchecked::<AP, R>(
118119
client,
119120
config.cube_dim(),
120121
cube_count_plan.resolve(),

crates/cubecl-attention/src/components/mask.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use cubecl_core as cubecl;
22
use cubecl_core::prelude::*;
3+
use cubecl_std::tensor::layout::Coords2d;
34

45
use crate::components::AttentionTilingScheme;
56

@@ -87,8 +88,8 @@ impl PartitionMask {
8788

8889
#[cube]
8990
impl TileMask {
90-
pub fn apply<E: Numeric>(&self, row: u32, col: u32) -> Line<E> {
91-
let should_mask = Line::<E>::cast_from(row >= self.q_bound || col >= self.kv_bound);
92-
should_mask * Line::cast_from(-999999)
91+
pub fn apply<E: Numeric>(&self, pos: Coords2d) -> E {
92+
let should_mask = E::cast_from(pos.0 >= self.q_bound || pos.1 >= self.kv_bound);
93+
should_mask * E::min_value()
9394
}
9495
}

crates/cubecl-attention/src/components/selection.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ pub struct AttentionSelection {
88
pub plane_dim: u32,
99

1010
pub reuse_key_value: bool,
11+
pub two_rows_in_array_tile: bool,
1112
}

crates/cubecl-attention/src/components/stage/base.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ use std::{fmt::Debug, hash::Hash};
88

99
use crate::components::attention_types::*;
1010
use crate::components::stage::dummy::AttentionStageMemoryConfig;
11-
use crate::components::{AttentionIdent, StageMask};
1211
use crate::components::{
1312
AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection,
1413
AttentionSetupError, AvailableLineSizes,
1514
global::GlobalAttentionConfig,
1615
tile::{AttentionTilingLayout, dummy::AttentionMatmulConfig},
1716
};
1817
use crate::components::{AttentionTilingScheme, global::dummy::QueryReader};
18+
use crate::components::{StageMask, tile::RunningState};
1919

2020
/// A family of [TileAttention] implementations that operate with any [precision](AttentionPrecision).
2121
pub trait StageAttentionFamily: Send + Sync + 'static {
@@ -62,14 +62,12 @@ pub trait StageAttention<AP: AttentionPrecision>: 'static + Send + Sync {
6262
/// The configuration type associated with this Attention.
6363
type Config: StageAttentionConfig;
6464

65-
type State: CubeType;
66-
6765
type QueryPartition: CubeType;
6866
type KeyValuePartition: CubeType;
6967
type SoftmaxPartition: CubeType;
7068
type AccumulatorPartition: CubeType;
7169

72-
fn init_state(#[comptime] config: Self::Config) -> Self::State;
70+
fn init_state(#[comptime] config: Self::Config) -> Sequence<RunningState<SM<AP>>>;
7371

7472
fn execute(
7573
key_reader: &Self::KeyStage,
@@ -79,13 +77,13 @@ pub trait StageAttention<AP: AttentionPrecision>: 'static + Send + Sync {
7977
score: &mut Self::SoftmaxPartition,
8078
mask: StageMask,
8179
accumulator: &mut Self::AccumulatorPartition,
82-
prev_state: &mut Self::State,
80+
prev_state: &mut Sequence<RunningState<SM<AP>>>,
8381
#[comptime] config: Self::Config,
8482
);
8583

8684
fn rescale(
8785
acc: &mut Self::AccumulatorPartition,
88-
state: Self::State,
86+
state: Sequence<RunningState<SM<AP>>>,
8987
#[comptime] config: Self::Config,
9088
);
9189

@@ -123,5 +121,5 @@ pub trait StageAttentionConfig:
123121
fn tiling_scheme(&self) -> AttentionTilingScheme;
124122
fn reuse_key_value(&self) -> bool;
125123

126-
fn num_rows_per_unit(&self, ident: AttentionIdent) -> u32;
124+
fn num_rows_per_unit(&self) -> u32;
127125
}

crates/cubecl-attention/src/components/stage/dummy/attention.rs

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,15 @@ use cubecl_matmul::components::{
77
};
88
use std::marker::PhantomData;
99

10-
use crate::components::StageMask;
1110
use crate::components::attention_types::*;
1211
use crate::components::global::dummy::QueryReader;
1312
use crate::components::stage::dummy::SoftmaxPartition;
14-
use crate::components::stage::dummy::StageState;
1513
use crate::components::stage::dummy::{Accumulators, DummyStageConfig, KeyValues, Queries};
1614
use crate::components::stage::{StageAttention, StageAttentionConfig};
1715
use crate::components::tile::RowWise;
1816
use crate::components::tile::TileAttention;
1917
use crate::components::{AttentionPrecision, global::GlobalAttentionConfig};
18+
use crate::components::{StageMask, tile::RunningState};
2019

2120
pub struct DummyStageAttention<AP: AttentionPrecision, SK, SV, SO, TA: TileAttention<AP>> {
2221
_phantom: PhantomData<(AP, SK, SV, SO, TA)>,
@@ -37,7 +36,6 @@ impl<
3736
type ValueStage = SV;
3837
type OutStage = SO;
3938

40-
type State = StageState<AP>;
4139
type QueryPartition = Queries<AP, TA, Self::Config>;
4240
type KeyValuePartition = KeyValues<AP, TA, Self::Config>;
4341
type SoftmaxPartition = SoftmaxPartition<AP, TA, Self::Config>;
@@ -51,7 +49,7 @@ impl<
5149
softmax_partition: &mut Self::SoftmaxPartition,
5250
mask: StageMask,
5351
accumulator_partition: &mut Self::AccumulatorPartition,
54-
state: &mut Self::State,
52+
state: &mut Sequence<RunningState<SM<AP>>>,
5553
#[comptime] config: Self::Config,
5654
) {
5755
let partition_mask = mask.to_partition(UNIT_POS_Y);
@@ -60,6 +58,9 @@ impl<
6058

6159
let mut kv = comptime![0u32];
6260

61+
let mut max_placeholder = TA::init_max_placeholder(config.num_rows_per_unit());
62+
let mut sum_placeholder = TA::init_sum_placeholder(config.num_rows_per_unit());
63+
6364
#[unroll]
6465
#[allow(clippy::explicit_counter_loop)]
6566
for _ in 0..p.seq_kv {
@@ -80,7 +81,7 @@ impl<
8081
}
8182

8283
let mut q = comptime![0u32];
83-
let mut scales = Sequence::<RowWise<ACC<AP>>>::new();
84+
let mut scales = Sequence::<RowWise<SM<AP>>>::new();
8485

8586
#[unroll]
8687
#[allow(clippy::explicit_counter_loop)]
@@ -101,16 +102,17 @@ impl<
101102
comptime![hd += 1];
102103
}
103104

104-
let state_q = state.get_at_mut(q);
105+
let state_q = state.index_mut(q);
105106

106-
let accumulator_scale = TA::softmax(
107+
scales.push(TA::softmax(
107108
softmax_tile,
108109
partition_mask.to_tile(q, kv),
109110
state_q,
111+
&mut max_placeholder,
112+
&mut sum_placeholder,
110113
config.tiling_scheme().elements_in_partition_head_dim(),
111-
);
112-
113-
scales.push(accumulator_scale);
114+
config.tile_config(),
115+
));
114116

115117
comptime![q += 1];
116118
}
@@ -162,7 +164,7 @@ impl<
162164

163165
fn rescale(
164166
acc: &mut Self::AccumulatorPartition,
165-
state: Self::State,
167+
state: Sequence<RunningState<SM<AP>>>,
166168
#[comptime] config: Self::Config,
167169
) {
168170
let p = config.tiling_scheme().partition_size;
@@ -179,7 +181,7 @@ impl<
179181
for _ in 0..p.val_dim {
180182
TA::rescale(
181183
Self::AccumulatorPartition::get_at_mut(acc, q, vd, config),
182-
state.get_at(q),
184+
state.index(q),
183185
config.tile_config(),
184186
);
185187

@@ -190,8 +192,16 @@ impl<
190192
}
191193
}
192194

193-
fn init_state(#[comptime] config: Self::Config) -> Self::State {
194-
StageState::<AP>::init::<Self::Config>(config)
195+
fn init_state(#[comptime] config: Self::Config) -> Sequence<RunningState<SM<AP>>> {
196+
let p = config.tiling_scheme().partition_size;
197+
let mut sequence = Sequence::new();
198+
199+
#[unroll]
200+
for _ in 0..comptime!(p.seq_q) {
201+
sequence.push(TA::init_state(config.tile_config()));
202+
}
203+
204+
sequence
195205
}
196206

197207
fn write<W: WriteEventListener, G: GlobalAttentionConfig>(

crates/cubecl-attention/src/components/stage/dummy/config.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use cubecl_matmul::components::{MatrixLayout, StageIdent, TilingScheme, stage::StageMemoryConfig};
22

33
use crate::components::{
4-
AttentionIdent, AttentionSetupError, AttentionTilingScheme, stage::StageAttentionConfig,
4+
AttentionSetupError, AttentionTilingScheme, stage::StageAttentionConfig,
55
tile::dummy::AttentionMatmulConfig,
66
};
77

@@ -46,8 +46,8 @@ impl<FC: AttentionMatmulConfig> StageAttentionConfig for DummyStageConfig<FC> {
4646
self.reuse_key_value
4747
}
4848

49-
fn num_rows_per_unit(&self, ident: AttentionIdent) -> u32 {
50-
self.tile_config.num_rows_per_unit(ident)
49+
fn num_rows_per_unit(&self) -> u32 {
50+
self.tile_config.num_rows_per_unit()
5151
}
5252
}
5353

crates/cubecl-attention/src/components/stage/dummy/setup.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ impl<
5252
selection: &AttentionSelection,
5353
line_sizes: &AttentionLineSizes,
5454
) -> Result<Self::Config, AttentionSetupError> {
55-
let tile_config = TA::setup::<AP, R>(client, problem, selection, line_sizes)?;
56-
5755
let num_planes = selection.tiling_scheme.stage_size.seq_q
5856
* TA::computation_resources()?.num_planes(selection.plane_dim)?;
5957

58+
let tile_config = TA::setup::<AP, R>(client, problem, selection, line_sizes, num_planes)?;
59+
6060
DummyStageConfig::new(
6161
tile_config,
6262
score_attention_stage_memory_config(selection),

crates/cubecl-attention/src/components/stage/dummy/tile_partitions.rs

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@ use std::marker::PhantomData;
44
use cubecl::prelude::*;
55
use cubecl_core as cubecl;
66

7-
use crate::components::AttentionIdent;
8-
use crate::components::attention_types::*;
97
use crate::components::global::dummy::QueryReader;
10-
use crate::components::tile::RunningState;
118
use crate::components::{AttentionPrecision, stage::StageAttentionConfig, tile::TileAttention};
129

1310
#[derive(CubeType)]
@@ -307,33 +304,3 @@ impl<
307304
self.sequence.index_mut(index)
308305
}
309306
}
310-
311-
#[derive(CubeType)]
312-
pub struct StageState<AP: AttentionPrecision> {
313-
sequence: Sequence<RunningState<SM<AP>>>,
314-
}
315-
316-
#[cube]
317-
impl<AP: AttentionPrecision> StageState<AP> {
318-
pub fn init<S: StageAttentionConfig>(#[comptime] config: S) -> StageState<AP> {
319-
let p = config.tiling_scheme().partition_size;
320-
let mut sequence = Sequence::new();
321-
322-
#[unroll]
323-
for _ in 0..comptime!(p.seq_q) {
324-
sequence.push(RunningState::<SM<AP>>::init(
325-
config.num_rows_per_unit(AttentionIdent::Softmax),
326-
));
327-
}
328-
329-
StageState::<AP> { sequence }
330-
}
331-
332-
pub fn get_at(&self, #[comptime] q: u32) -> &RunningState<SM<AP>> {
333-
self.sequence.index(q)
334-
}
335-
336-
pub fn get_at_mut(&mut self, #[comptime] q: u32) -> &mut RunningState<SM<AP>> {
337-
self.sequence.index_mut(q)
338-
}
339-
}

crates/cubecl-attention/src/components/tile/base.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::components::{
1010
AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection,
1111
AttentionSetupError, AvailableLineSizes,
1212
attention_types::*,
13-
tile::{RowWise, RunningState, dummy::AttentionMatmulConfig},
13+
tile::{KeyValueTile, QueryTile, RowWise, RunningState, dummy::AttentionMatmulConfig},
1414
};
1515
use crate::components::{InvalidConfigError, tile::AccumulatorTile};
1616
use crate::components::{TileMask, tile::SoftmaxTile};
@@ -33,6 +33,7 @@ pub trait TileAttentionFamily: Send + Sync + 'static {
3333
problem: &AttentionProblem,
3434
selection: &AttentionSelection,
3535
line_sizes: &AttentionLineSizes,
36+
num_planes: u32,
3637
) -> Result<Self::Config, AttentionSetupError>;
3738

3839
/// Filters out line sizes that are incompatible with this Attention family.
@@ -50,10 +51,10 @@ pub trait TileAttention<AP: AttentionPrecision>: 'static + Send + Sync {
5051
/// The configuration type associated with this Attention.
5152
type Config: AttentionMatmulConfig;
5253

53-
type QueryTile: CubeType;
54-
type KeyValueTile: CubeType;
54+
type QueryTile: QueryTile<QT<AP>>;
55+
type KeyValueTile: KeyValueTile<KVT<AP>>;
5556
type SoftmaxTile: SoftmaxTile<AP>;
56-
type AccumulatorTile: AccumulatorTile<ACC<AP>>;
57+
type AccumulatorTile: AccumulatorTile<AP>;
5758

5859
fn rescale(
5960
acc: &mut Self::AccumulatorTile,
@@ -77,13 +78,15 @@ pub trait TileAttention<AP: AttentionPrecision>: 'static + Send + Sync {
7778

7879
fn init_softmax(#[comptime] config: Self::Config) -> Self::SoftmaxTile;
7980

80-
fn fill_key<E: Numeric>(
81+
fn init_state(#[comptime] config: Self::Config) -> RunningState<SM<AP>>;
82+
83+
fn fill_key<E: Float>(
8184
tile: &StridedTile<E>,
8285
rhs: &mut Self::KeyValueTile,
8386
#[comptime] config: Self::Config,
8487
);
8588

86-
fn fill_value<E: Numeric>(
89+
fn fill_value<E: Float>(
8790
tile: &StridedTile<E>,
8891
rhs: &mut Self::KeyValueTile,
8992
#[comptime] config: Self::Config,
@@ -102,14 +105,20 @@ pub trait TileAttention<AP: AttentionPrecision>: 'static + Send + Sync {
102105
softmax: &mut Self::SoftmaxTile,
103106
mask: TileMask,
104107
state: &mut RunningState<SM<AP>>,
108+
max_placeholder: &mut RowWise<SM<AP>>,
109+
sum_placeholder: &mut RowWise<SM<AP>>,
105110
#[comptime] dk: u32,
106-
) -> RowWise<ACC<AP>>;
111+
#[comptime] config: Self::Config,
112+
) -> RowWise<SM<AP>>;
107113

108114
fn accumulate_value(
109115
softmax: &Self::SoftmaxTile,
110116
key_value: &Self::KeyValueTile,
111117
accumulator: &mut Self::AccumulatorTile,
112-
scale: &RowWise<ACC<AP>>,
118+
scale: &RowWise<SM<AP>>,
113119
#[comptime] config: Self::Config,
114120
);
121+
122+
fn init_max_placeholder(#[comptime] num_rows: u32) -> RowWise<SM<AP>>;
123+
fn init_sum_placeholder(#[comptime] num_rows: u32) -> RowWise<SM<AP>>;
115124
}

0 commit comments

Comments
 (0)