Skip to content

Commit 54707e8

Browse files
authored
Flash Attention: Masking (causal, out of bounds, attention mask) (#962)
1 parent a2a55f8 commit 54707e8

File tree

53 files changed

+2301
-1060
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+2301
-1060
lines changed

crates/cubecl-attention/src/base.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pub fn launch<R: Runtime, AP: AttentionPrecision>(
2828
query: TensorHandle<R, QG<AP>>,
2929
key: TensorHandle<R, KG<AP>>,
3030
value: TensorHandle<R, VG<AP>>,
31+
mask: Option<TensorHandle<R, MSK<AP>>>,
3132
out: TensorHandle<R, OG<AP>>,
3233
) -> Result<(), AttentionSetupError> {
3334
launch_ref::<R, AP>(
@@ -36,6 +37,7 @@ pub fn launch<R: Runtime, AP: AttentionPrecision>(
3637
&query.as_ref(),
3738
&key.as_ref(),
3839
&value.as_ref(),
40+
&mask.as_ref().map(|m| m.as_ref()),
3941
&out.as_ref(),
4042
)
4143
}
@@ -47,10 +49,11 @@ pub fn launch_ref<R: Runtime, AP: AttentionPrecision>(
4749
query: &TensorHandleRef<R>,
4850
key: &TensorHandleRef<R>,
4951
value: &TensorHandleRef<R>,
52+
mask: &Option<TensorHandleRef<R>>,
5053
out: &TensorHandleRef<R>,
5154
) -> Result<(), AttentionSetupError> {
5255
match strategy {
53-
Strategy::Tmp => launch_tmp::<R, AP>(client, query, key, value, out),
56+
Strategy::Tmp => launch_tmp::<R, AP>(client, query, key, value, mask, out),
5457
}
5558
}
5659

@@ -59,6 +62,7 @@ pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
5962
query: &TensorHandleRef<R>,
6063
key: &TensorHandleRef<R>,
6164
value: &TensorHandleRef<R>,
65+
mask: &Option<TensorHandleRef<R>>,
6266
out: &TensorHandleRef<R>,
6367
) -> Result<(), AttentionSetupError> {
6468
let line_sizes = AvailableLineSizes::from_elem_types::<R>(
@@ -81,7 +85,8 @@ pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
8185
num_heads: query.shape[2],
8286
head_dim: query.shape[3],
8387
val_dim: value.shape[3],
84-
masked: false,
88+
masked: mask.is_some(),
89+
causal: false,
8590
};
8691

8792
let tile_size = AttentionTileSize {
@@ -123,6 +128,9 @@ pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
123128
query.as_tensor_arg(line_sizes.query),
124129
key.as_tensor_arg(line_sizes.key),
125130
value.as_tensor_arg(line_sizes.value),
131+
mask.as_ref()
132+
.map(|it| it.as_tensor_arg(line_sizes.out))
133+
.into(),
126134
),
127135
out.as_tensor_arg(line_sizes.out),
128136
cube_count_plan.as_args(),

0 commit comments

Comments
 (0)