Skip to content

Commit 3443131

Browse files
authored
Flash Attention: strengthen test suite (#1077)
1 parent bf6126c commit 3443131

File tree

13 files changed

+448
-181
lines changed

13 files changed

+448
-181
lines changed

crates/cubecl-attention/src/base.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,19 @@ pub fn launch_attention<R: Runtime, A: Algorithm>(
127127
val_dim: 8,
128128
};
129129

130+
assert!(problem.head_dim as u32 % tile_size.head_dim == 0);
131+
let partition_head_dim = problem.head_dim as u32 / tile_size.head_dim;
132+
let partition_val_dim = partition_head_dim;
133+
130134
let selection = AttentionSelection {
131135
hypercube_selection: HypercubeSelection {},
132136
tiling_scheme: AttentionTilingScheme {
133137
tile_size,
134138
partition_size: AttentionPartitionSize {
135139
seq_q: 1,
136-
head_dim: 1,
140+
head_dim: partition_head_dim,
137141
seq_kv: 1,
138-
val_dim: 1,
142+
val_dim: partition_val_dim,
139143
},
140144
stage_size: AttentionStageSize { seq_q: 1 },
141145
},

crates/cubecl-attention/src/components/global/simple/reader/query.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,20 @@ impl<AP: AttentionPrecision> QueryReader<AP> {
4141

4242
let line_size = self.gmem_config.line_size;
4343

44+
let tile_head_dim = attention_tile_size.head_dim;
45+
4446
let slice = self
4547
.query
4648
.slice(
47-
(
48-
row * attention_tile_size.seq_q,
49-
col * attention_tile_size.head_dim,
50-
),
51-
(attention_tile_size.seq_q, attention_tile_size.head_dim).runtime(),
49+
(row * attention_tile_size.seq_q, col * tile_head_dim),
50+
(attention_tile_size.seq_q, tile_head_dim).runtime(),
5251
)
5352
.to_linear_slice();
5453

5554
let start = 0;
56-
let length = attention_tile_size.seq_q * attention_tile_size.head_dim / line_size;
55+
let length = attention_tile_size.seq_q * tile_head_dim / line_size;
5756
let end = start + length;
58-
let stride = partition_head_dim * attention_tile_size.head_dim / line_size;
57+
let stride = partition_head_dim * tile_head_dim / line_size;
5958

6059
StridedTile::<QG<AP>>::new_strided(
6160
slice,

crates/cubecl-attention/src/components/stage/base.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,14 +179,29 @@ impl<TC: TileAttentionConfig> PartitionAttentionConfig<TC> {
179179

180180
pub fn validate<TC: TileAttentionConfig>(
181181
config: PartitionAttentionConfig<TC>,
182+
problem: &AttentionProblem,
182183
) -> Result<PartitionAttentionConfig<TC>, AttentionSetupError> {
183184
let tile_size = config.shared().tile_config.attention_tile_size();
184185
let partition_size = config.shared().partition_size;
185186

186-
if config.shared().reuse_key_value
187-
&& (tile_size.head_dim != tile_size.val_dim
188-
|| partition_size.head_dim != partition_size.val_dim)
189-
{
187+
if partition_size.head_dim * tile_size.head_dim != problem.head_dim as u32 {
188+
return Err(AttentionSetupError::InvalidConfig(Box::new(
189+
"Tiling scheme's total head dim must equal problem's head dim".to_string(),
190+
)));
191+
}
192+
193+
let head_val_different = tile_size.head_dim != tile_size.val_dim
194+
|| partition_size.head_dim != partition_size.val_dim;
195+
196+
if head_val_different {
197+
return Err(AttentionSetupError::InvalidConfig(Box::new(
198+
"Differing head dim and val dim is not yet supported".to_string(),
199+
)));
200+
}
201+
202+
// This check is stricter than the previous one, but the other may be removed
203+
// eventually while this one will always remain true.
204+
if config.shared().reuse_key_value && head_val_different {
190205
return Err(AttentionSetupError::InvalidConfig(Box::new(
191206
"When reusing key/value, head_dim must equal val_dim in both tile_size and partition_size."
192207
.to_string(),

crates/cubecl-attention/src/components/stage/plane/setup.rs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,20 @@ impl<
107107
num_stages: 1,
108108
};
109109

110-
validate(PartitionAttentionConfig::Plane(PlanePartitionStageConfig {
111-
shared: SharedPartitionAttentionConfig {
112-
tile_config,
113-
partition_size: selection.tiling_scheme.partition_size,
114-
stage_size: selection.tiling_scheme.stage_size,
115-
reuse_key_value: selection.reuse_key_value,
116-
num_planes,
117-
key_smem_config,
118-
value_smem_config,
119-
out_smem_config,
120-
},
121-
}))
110+
validate(
111+
PartitionAttentionConfig::Plane(PlanePartitionStageConfig {
112+
shared: SharedPartitionAttentionConfig {
113+
tile_config,
114+
partition_size: selection.tiling_scheme.partition_size,
115+
stage_size: selection.tiling_scheme.stage_size,
116+
reuse_key_value: selection.reuse_key_value,
117+
num_planes,
118+
key_smem_config,
119+
value_smem_config,
120+
out_smem_config,
121+
},
122+
}),
123+
problem,
124+
)
122125
}
123126
}

crates/cubecl-attention/src/components/stage/tile_ops/softmax/base.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use crate::components::tile::TileAttention;
1616
#[cube]
1717
/// Applies softmax to a tile with masking and updates the running state.
1818
///
19-
/// Scales by `1 / sqrt(dk)`, applies the mask, computes row-wise max and sum,
19+
/// Scales by `1 / sqrt(head_dim)`, applies the mask, computes row-wise max and sum,
2020
/// exponentiates, and updates the softmax state.
2121
///
2222
/// Returns the exponential difference used for normalization.
@@ -26,12 +26,12 @@ pub fn tile_softmax<AP: AttentionPrecision, TA: TileAttention<AP>, R: Reducer>(
2626
state: &mut RunningState<SM<AP>>,
2727
max_placeholder: &mut RowWise<SM<AP>>,
2828
sum_placeholder: &mut RowWise<SM<AP>>,
29-
#[comptime] dk: u32,
29+
#[comptime] head_dim: u32,
3030
#[comptime] config: TA::Config,
3131
) -> RowWise<SM<AP>> {
3232
TA::SoftmaxRow::scale_and_mask::<MaskTile<AP, TA>>(
3333
rowwise_softmax,
34-
SM::<AP>::new(comptime!(1.0 / (dk as f32).sqrt())),
34+
SM::<AP>::new(comptime!(1.0 / (head_dim as f32).sqrt())),
3535
mask,
3636
);
3737

crates/cubecl-attention/src/components/stage/unit/setup.rs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,20 @@ impl<
115115
num_stages: 1,
116116
};
117117

118-
validate(PartitionAttentionConfig::Unit(UnitPartitionStageConfig {
119-
shared: SharedPartitionAttentionConfig {
120-
tile_config,
121-
partition_size: selection.tiling_scheme.partition_size,
122-
stage_size: selection.tiling_scheme.stage_size,
123-
reuse_key_value: selection.reuse_key_value,
124-
num_planes,
125-
key_smem_config,
126-
value_smem_config,
127-
out_smem_config,
128-
},
129-
}))
118+
validate(
119+
PartitionAttentionConfig::Unit(UnitPartitionStageConfig {
120+
shared: SharedPartitionAttentionConfig {
121+
tile_config,
122+
partition_size: selection.tiling_scheme.partition_size,
123+
stage_size: selection.tiling_scheme.stage_size,
124+
reuse_key_value: selection.reuse_key_value,
125+
num_planes,
126+
key_smem_config,
127+
value_smem_config,
128+
out_smem_config,
129+
},
130+
}),
131+
problem,
132+
)
130133
}
131134
}

0 commit comments

Comments
 (0)