Skip to content

Commit 3904c27

Browse files
authored
Flash Attention: use loader from matmul + fix sync bug (#1067)
1 parent b24b963 commit 3904c27

File tree

15 files changed

+161
-153
lines changed

15 files changed

+161
-153
lines changed

crates/cubecl-attention/src/components/global/simple/attention.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
use cubecl_core as cubecl;
22
use cubecl_core::prelude::*;
33
use cubecl_matmul::components::global::PartitionedStage;
4+
use cubecl_matmul::components::global::read::FullStageGlobalReader;
45
use cubecl_matmul::components::stage::StridedStageMemory;
56
use cubecl_std::tensor::r#virtual::VirtualTensor;
67
use cubecl_std::{CubeOption, CubeOptionExpand};
78
use std::marker::PhantomData;
89

910
use crate::components::attention_types::*;
11+
use crate::components::global::AttentionGlobalLayout;
1012
use crate::components::global::simple::QueryReader;
1113
use crate::components::global::simple::{AttentionWriter, AttentionWriterExpand, MaskReader};
12-
use crate::components::global::{AttentionGlobalLayout, simple::DummyKeyValueReader};
1314
use crate::components::stage::{
14-
AttentionPartitioner, AttentionTilingLayout, StageAttention, StageAttentionConfig as _,
15+
AttentionLoadingStrategy, AttentionPartitioner, AttentionTilingLayout, StageAttention,
16+
StageAttentionConfig as _,
1517
};
1618
use crate::components::{
1719
AttentionPrecision,
@@ -33,8 +35,8 @@ impl<
3335
AP: AttentionPrecision,
3436
> GlobalAttention<AP> for SimpleGlobalAttention<AP, SA>
3537
{
36-
type KeyReader = DummyKeyValueReader<KG<AP>, KS<AP>>;
37-
type ValueReader = DummyKeyValueReader<VG<AP>, VS<AP>>;
38+
type KeyReader = FullStageGlobalReader<KG<AP>, KS<AP>, AttentionLoadingStrategy>;
39+
type ValueReader = FullStageGlobalReader<VG<AP>, VS<AP>, AttentionLoadingStrategy>;
3840
type MaskReader = MaskReader<AP>;
3941

4042
type Writer = <SA::Partitioner as AttentionPartitioner>::Writer<OS<AP>, OG<AP>>;
@@ -51,10 +53,6 @@ impl<
5153
seq_kv: u32,
5254
#[comptime] config: Self::Config,
5355
) {
54-
// Init staging shared memories
55-
let mut key_stage = key_reader.init_stage();
56-
let mut value_stage = value_reader.init_stage();
57-
5856
// Load queries which stay alive in registers for all the kernel
5957
let mut query_registers = SA::init_query(config.stage_config);
6058
SA::read_query(&query_reader, &mut query_registers, config.stage_config);
@@ -73,19 +71,21 @@ impl<
7371
let num_stage_iterations =
7472
seq_kv.div_ceil(config.stage_config.elements_in_partition_seq_kv());
7573

74+
let mut barrier = ();
75+
7676
// Global loop over seq_kv
7777
for _ in 0..num_stage_iterations {
7878
// Put key and value into stage
79-
key_reader.read_global(&mut key_stage);
80-
value_reader.read_global(&mut value_stage);
79+
key_reader.load_stage(&mut barrier, config.key_reader_config);
80+
value_reader.load_stage(&mut barrier, config.value_reader_config);
8181

8282
sync_cube();
8383

8484
// Core of flash attention
8585
SA::execute(
8686
&query_registers,
87-
&key_stage,
88-
&value_stage,
87+
&key_reader.stage(),
88+
&value_reader.stage(),
8989
&mut key_value_registers,
9090
&mask_reader,
9191
&mut mask_registers,
@@ -135,7 +135,7 @@ impl<
135135
let step = config.stage_config.elements_in_partition_seq_kv().runtime();
136136
let layout =
137137
AttentionGlobalLayout::new(&key, batch_index, config.key_reader_config.gmem_config);
138-
DummyKeyValueReader::new(key.view(layout), step, config.key_reader_config)
138+
FullStageGlobalReader::new(key.view(layout), step, config.key_reader_config)
139139
}
140140

141141
fn init_value_reader(
@@ -146,7 +146,7 @@ impl<
146146
let step = config.stage_config.elements_in_partition_seq_kv().runtime();
147147
let layout =
148148
AttentionGlobalLayout::new(&value, batch_index, config.value_reader_config.gmem_config);
149-
DummyKeyValueReader::new(value.view(layout), step, config.value_reader_config)
149+
FullStageGlobalReader::new(value.view(layout), step, config.value_reader_config)
150150
}
151151

152152
fn init_mask_reader(

crates/cubecl-attention/src/components/global/simple/reader/key_value.rs

Lines changed: 0 additions & 98 deletions
This file was deleted.
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
mod key_value;
21
mod mask;
32
mod query;
43

5-
pub use key_value::*;
64
pub use mask::*;
75
pub use query::*;

crates/cubecl-attention/src/components/global/simple/setup.rs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,42 +55,55 @@ impl<
5555
let specialization_tensor_config = SpecializationTensorConfig::MainFlowOnly;
5656
let plane_role_config = PlaneRoleConfig::new_unspecialized(stage_config.num_planes());
5757

58+
let seq_q_check_bounds = !problem
59+
.seq_q
60+
.is_multiple_of(stage_config.elements_in_stage_seq_q() as usize);
61+
let seq_kv_check_bounds = !problem
62+
.seq_kv
63+
.is_multiple_of(stage_config.elements_in_partition_seq_kv() as usize);
64+
let head_dim_check_bounds = !problem
65+
.head_dim
66+
.is_multiple_of(stage_config.elements_in_partition_head_dim() as usize);
67+
let val_dim_check_bounds = !problem
68+
.val_dim
69+
.is_multiple_of(stage_config.elements_in_partition_val_dim() as usize);
70+
5871
let query_gmem_config = GlobalMemoryConfig {
5972
line_size: line_sizes.query as u32,
60-
check_row_bounds: false,
61-
check_col_bounds: false,
73+
check_row_bounds: seq_q_check_bounds,
74+
check_col_bounds: head_dim_check_bounds,
6275
matrix_layout: MatrixLayout::RowMajor,
6376
view_direction: ViewDirection::None,
6477
};
6578

6679
let mask_gmem_config = GlobalMemoryConfig {
6780
line_size: line_sizes.mask as u32,
68-
check_row_bounds: false,
69-
check_col_bounds: false,
81+
check_row_bounds: seq_q_check_bounds,
82+
check_col_bounds: seq_kv_check_bounds,
7083
matrix_layout: MatrixLayout::RowMajor,
7184
view_direction: ViewDirection::Col,
7285
};
7386

7487
let key_gmem_config = GlobalMemoryConfig {
7588
line_size: line_sizes.key as u32,
76-
check_row_bounds: false,
77-
check_col_bounds: false,
89+
check_row_bounds: seq_kv_check_bounds,
90+
check_col_bounds: head_dim_check_bounds,
7891
matrix_layout: MatrixLayout::RowMajor,
7992
view_direction: ViewDirection::Row,
8093
};
8194

8295
let value_gmem_config = GlobalMemoryConfig {
8396
line_size: line_sizes.value as u32,
84-
check_row_bounds: false,
85-
check_col_bounds: false,
97+
check_row_bounds: seq_kv_check_bounds,
98+
check_col_bounds: val_dim_check_bounds,
8699
matrix_layout: MatrixLayout::RowMajor,
87100
view_direction: ViewDirection::Row,
88101
};
89102

90103
let out_gmem_config = GlobalMemoryConfig {
91104
line_size: line_sizes.out as u32,
92-
check_row_bounds: false,
93-
check_col_bounds: false,
105+
check_row_bounds: seq_q_check_bounds,
106+
check_col_bounds: val_dim_check_bounds,
94107
matrix_layout: MatrixLayout::RowMajor,
95108
view_direction: ViewDirection::None,
96109
};

crates/cubecl-attention/src/components/stage/base.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use cubecl_core as cubecl;
22
use cubecl_core::prelude::*;
33
use cubecl_matmul::components::{
4-
global::{WriteEventListener, WriteTiling},
4+
global::{WriteEventListener, WriteTiling, read::sync_full_cyclic::SyncFullCyclicLoading},
55
stage::{ContiguousTilingLayout, RowMajorTilingOrder, StageFamily, StageMemoryConfig},
66
};
77
use std::{fmt::Debug, hash::Hash};
@@ -21,6 +21,7 @@ use cubecl_std::CubeOption;
2121
use cubecl_std::tensor::layout::Coords2d;
2222

2323
pub type AttentionTilingLayout = ContiguousTilingLayout<RowMajorTilingOrder>;
24+
pub type AttentionLoadingStrategy = SyncFullCyclicLoading<RowMajorTilingOrder>;
2425

2526
/// A family of [TileAttention] implementations that operate with any [precision](AttentionPrecision).
2627
pub trait StageAttentionFamily: Send + Sync + 'static {
@@ -132,6 +133,7 @@ pub trait StageAttentionConfig:
132133

133134
fn elements_in_partition_seq_q(&self) -> u32;
134135
fn elements_in_partition_seq_kv(&self) -> u32;
136+
fn elements_in_partition_head_dim(&self) -> u32;
135137
fn elements_in_partition_val_dim(&self) -> u32;
136138

137139
fn elements_in_stage_seq_q(&self) -> u32;
@@ -217,6 +219,11 @@ impl<TC: TileAttentionConfig> StageAttentionConfig for PartitionAttentionConfig<
217219
self.shared().partition_size.seq_kv * self.shared().tile_config.attention_tile_size().seq_kv
218220
}
219221

222+
fn elements_in_partition_head_dim(&self) -> u32 {
223+
self.shared().partition_size.head_dim
224+
* self.shared().tile_config.attention_tile_size().head_dim
225+
}
226+
220227
fn elements_in_partition_val_dim(&self) -> u32 {
221228
self.shared().partition_size.val_dim
222229
* self.shared().tile_config.attention_tile_size().val_dim

crates/cubecl-attention/src/components/stage/partition_attention.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ impl<
117117
// Get the only key-value tile and fill it with hd,kv-th key data
118118
let key_tile = key_value_partition.get_key_mut();
119119
let key_data = SK::tile(key_stage, (kv, hd).runtime());
120-
TA::fill_key_transposed(&key_data, key_tile.key_mut(), config.tile_config());
120+
TA::load_key_transposed(&key_data, key_tile.key_mut(), config.tile_config());
121121

122122
// Perform score matmul on query and key, and accumulate in softmax tile
123123
TA::score_matmul(
@@ -164,7 +164,7 @@ impl<
164164
// Get the only key-value tile and fill it with hd,kv-th key data
165165
let value_data = SV::tile(value_stage, (kv, vd).runtime());
166166
let value_tile = key_value_partition.get_value_mut();
167-
TA::fill_value(&value_data, value_tile.value_mut(), config.tile_config());
167+
TA::load_value(&value_data, value_tile.value_mut(), config.tile_config());
168168

169169
// Get the q,vd-th accumulator and scale it with previously obtained scale
170170
let accumulator = accumulator_partition.get_at_mut(q, vd, config);

crates/cubecl-attention/src/components/stage/tile_ops/mask.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ impl<AP: AttentionPrecision, TA: TileAttention<AP>> MaterializedTileMask<AP, TA>
142142
}
143143

144144
pub fn update_tile(&mut self, tile: StridedTile<MSK<AP>>) {
145-
TA::fill_mask(&tile, &mut self.fragment, self.config);
145+
TA::load_mask(&tile, &mut self.fragment, self.config);
146146
}
147147
}
148148

crates/cubecl-attention/src/components/stage/tile_ops/query.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@ impl<AP: AttentionPrecision, TA: TileAttention<AP>> QueryTile<AP, TA> {
2222

2323
/// Loads the query data into the fragment
2424
pub fn update(&mut self, tile: &StridedTile<QG<AP>>) {
25-
TA::fill_query(tile, &mut self.fragment)
25+
TA::load_query(tile, &mut self.fragment)
2626
}
2727
}

crates/cubecl-attention/src/components/tile/accelerated/attention.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,13 @@ impl<AP: AttentionPrecision> TileAttention<AP> for BlackboxAcceleratedTileAttent
125125
HybridFragment::new(size, config)
126126
}
127127

128-
fn fill_query<E: Numeric>(tile: &StridedTile<E>, fragment: &mut Self::Query) {
128+
fn load_query<E: Numeric>(tile: &StridedTile<E>, fragment: &mut Self::Query) {
129129
let (slice, stride) = tile.as_unlined();
130130

131131
cmma::load(fragment, &slice, stride);
132132
}
133133

134-
fn fill_key_transposed<E: Float>(
134+
fn load_key_transposed<E: Float>(
135135
tile: &StridedTile<E>,
136136
rhs: &mut Self::KeyValue,
137137
#[comptime] _config: Self::Config,
@@ -140,7 +140,7 @@ impl<AP: AttentionPrecision> TileAttention<AP> for BlackboxAcceleratedTileAttent
140140
cmma::load(rhs, &slice, stride);
141141
}
142142

143-
fn fill_value<E: Float>(
143+
fn load_value<E: Float>(
144144
tile: &StridedTile<E>,
145145
rhs: &mut Self::KeyValue,
146146
#[comptime] _config: Self::Config,
@@ -149,12 +149,12 @@ impl<AP: AttentionPrecision> TileAttention<AP> for BlackboxAcceleratedTileAttent
149149
cmma::load(rhs, &slice, stride);
150150
}
151151

152-
fn fill_mask<E: Numeric>(
152+
fn load_mask<E: Numeric>(
153153
tile: &StridedTile<E>,
154154
mask: &mut Self::Mask,
155155
#[comptime] _config: Self::Config,
156156
) {
157-
mask.fill_from_strided_tile(tile)
157+
mask.load_from_strided_tile(tile)
158158
}
159159

160160
fn write_results<E: Float>(

0 commit comments

Comments
 (0)