Skip to content

Commit a093f96

Browse files
authored
Flash Attention: test and fix f16 (#1074)
1 parent 29c0312 commit a093f96

File tree

7 files changed

+110
-53
lines changed

7 files changed

+110
-53
lines changed

crates/cubecl-attention/Cargo.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,16 @@ default = ["std", "cubecl-runtime/default", "cubecl-core/default"]
1515
export_tests = ["pretty_assertions"]
1616
std = ["cubecl-runtime/std", "cubecl-core/std"]
1717

18-
attention_tests = []
18+
attention_tests_f16 = []
19+
attention_tests_f32 = []
20+
attention_tests_unit = []
21+
attention_tests_blackbox_accelerated = []
22+
attention_tests_all = [
23+
"attention_tests_f16",
24+
"attention_tests_f32",
25+
"attention_tests_unit",
26+
"attention_tests_blackbox_accelerated",
27+
]
1928

2029
[dependencies]
2130
bytemuck = { workspace = true }

crates/cubecl-attention/src/components/tile/base.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ use std::fmt::Debug;
1515
use std::hash::Hash;
1616

1717
/// Logits below this are considered masked (effectively -inf)
18-
pub(crate) const LOGIT_MASKED: f32 = -1e5;
18+
/// Value chosen to fit within f16 range (~-65,504 max)
19+
pub(crate) const LOGIT_MASKED: f32 = -6e4;
1920

2021
/// Any value smaller than this is considered numerically zero
2122
/// (used for fully-masked rows or tiny contributions)
22-
pub(crate) const FULLY_MASKED_ROW_THRESHOLD: f32 = 1e-7;
23+
/// Value chosen to be above f16 smallest normal (~6.1e-5)
24+
pub(crate) const FULLY_MASKED_ROW_THRESHOLD: f32 = 1e-4;
2325

2426
#[cube]
2527
pub trait TileAttention<AP: AttentionPrecision>: Send + Sync + 'static {

crates/cubecl-attention/src/tests/macros/mod.rs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
AttentionProblem, AttentionSelection, AttentionTilingScheme, batch::HypercubeSelection,
88
},
99
kernels::Algorithm,
10-
tests::attention_test_launcher::test_attention_algorithm,
10+
tests::{attention_test_launcher::test_attention_algorithm, test_utils::TestPrecision},
1111
};
1212

1313
#[derive(Default)]
@@ -38,7 +38,7 @@ pub mod tiling_scheme_ops {
3838
}
3939
}
4040

41-
pub fn attention_test_launch<A: Algorithm, R: Runtime>(
41+
pub fn attention_test_launch<A: Algorithm, P: TestPrecision, R: Runtime>(
4242
client: ComputeClient<R>,
4343
tiling_scheme: AttentionTilingScheme,
4444
problem: AttentionProblem,
@@ -52,16 +52,15 @@ pub fn attention_test_launch<A: Algorithm, R: Runtime>(
5252
two_rows_in_array_tile: test_options.two_rows_in_array_tile,
5353
};
5454

55-
test_attention_algorithm::<A, (f32, f32), R>(client, problem, selection);
56-
// test_attention_algorithm::<A, (half::f16, half::f16), R>(client, problem, selection);
55+
test_attention_algorithm::<A, P, R>(client, problem, selection);
5756
}
5857

5958
#[macro_export]
6059
macro_rules! testgen_attention {
6160
() => {
6261
use super::*;
6362

64-
#[cfg(feature = "attention_tests")]
63+
#[cfg(feature = "attention_tests_unit")]
6564
mod attention_unit {
6665
type Algorithm = cubecl_attention::kernels::unit::UnitAlgorithm;
6766
const TILE_SIZE: cubecl_attention::components::AttentionTileSize =
@@ -73,10 +72,10 @@ macro_rules! testgen_attention {
7372
};
7473
const STAGE_Q_BASE: u32 = 32;
7574

76-
$crate::testgen_attention_suite!();
75+
$crate::testgen_attention_precision!();
7776
}
7877

79-
#[cfg(feature = "attention_tests")]
78+
#[cfg(feature = "attention_tests_blackbox_accelerated")]
8079
mod attention_blackbox_accelerated {
8180
type Algorithm =
8281
cubecl_attention::kernels::blackbox_accelerated::BlackboxAcceleratedAlgorithm;
@@ -98,7 +97,28 @@ macro_rules! testgen_attention {
9897
};
9998
const STAGE_Q_BASE: u32 = 1;
10099

101-
$crate::testgen_attention_suite!();
100+
$crate::testgen_attention_precision!();
101+
}
102+
};
103+
}
104+
105+
#[macro_export]
106+
macro_rules! testgen_attention_precision {
107+
() => {
108+
use super::*;
109+
110+
#[cfg(feature = "attention_tests_f16")]
111+
mod f16_ty {
112+
use super::*;
113+
114+
$crate::testgen_attention_suite!((half::f16, half::f16));
115+
}
116+
117+
#[cfg(feature = "attention_tests_f32")]
118+
mod f32_ty {
119+
use super::*;
120+
121+
$crate::testgen_attention_suite!((f32, f32));
102122
}
103123
};
104124
}

0 commit comments

Comments
 (0)