@@ -7,7 +7,7 @@ use crate::{
77 AttentionProblem , AttentionSelection , AttentionTilingScheme , batch:: HypercubeSelection ,
88 } ,
99 kernels:: Algorithm ,
10- tests:: attention_test_launcher:: test_attention_algorithm,
10+ tests:: { attention_test_launcher:: test_attention_algorithm, test_utils :: TestPrecision } ,
1111} ;
1212
1313#[ derive( Default ) ]
@@ -38,7 +38,7 @@ pub mod tiling_scheme_ops {
3838 }
3939}
4040
41- pub fn attention_test_launch < A : Algorithm , R : Runtime > (
41+ pub fn attention_test_launch < A : Algorithm , P : TestPrecision , R : Runtime > (
4242 client : ComputeClient < R > ,
4343 tiling_scheme : AttentionTilingScheme ,
4444 problem : AttentionProblem ,
@@ -52,16 +52,15 @@ pub fn attention_test_launch<A: Algorithm, R: Runtime>(
5252 two_rows_in_array_tile : test_options. two_rows_in_array_tile ,
5353 } ;
5454
55- test_attention_algorithm :: < A , ( f32 , f32 ) , R > ( client, problem, selection) ;
56- // test_attention_algorithm::<A, (half::f16, half::f16), R>(client, problem, selection);
55+ test_attention_algorithm :: < A , P , R > ( client, problem, selection) ;
5756}
5857
5958#[ macro_export]
6059macro_rules! testgen_attention {
6160 ( ) => {
6261 use super :: * ;
6362
64- #[ cfg( feature = "attention_tests " ) ]
63+ #[ cfg( feature = "attention_tests_unit " ) ]
6564 mod attention_unit {
6665 type Algorithm = cubecl_attention:: kernels:: unit:: UnitAlgorithm ;
6766 const TILE_SIZE : cubecl_attention:: components:: AttentionTileSize =
@@ -73,10 +72,10 @@ macro_rules! testgen_attention {
7372 } ;
7473 const STAGE_Q_BASE : u32 = 32 ;
7574
76- $crate:: testgen_attention_suite !( ) ;
75+ $crate:: testgen_attention_precision !( ) ;
7776 }
7877
79- #[ cfg( feature = "attention_tests " ) ]
78+ #[ cfg( feature = "attention_tests_blackbox_accelerated " ) ]
8079 mod attention_blackbox_accelerated {
8180 type Algorithm =
8281 cubecl_attention:: kernels:: blackbox_accelerated:: BlackboxAcceleratedAlgorithm ;
@@ -98,7 +97,28 @@ macro_rules! testgen_attention {
9897 } ;
9998 const STAGE_Q_BASE : u32 = 1 ;
10099
101- $crate:: testgen_attention_suite!( ) ;
100+ $crate:: testgen_attention_precision!( ) ;
101+ }
102+ } ;
103+ }
104+
105+ #[ macro_export]
106+ macro_rules! testgen_attention_precision {
107+ ( ) => {
108+ use super :: * ;
109+
110+ #[ cfg( feature = "attention_tests_f16" ) ]
111+ mod f16_ty {
112+ use super :: * ;
113+
114+ $crate:: testgen_attention_suite!( ( half:: f16, half:: f16) ) ;
115+ }
116+
117+ #[ cfg( feature = "attention_tests_f32" ) ]
118+ mod f32_ty {
119+ use super :: * ;
120+
121+ $crate:: testgen_attention_suite!( ( f32 , f32 ) ) ;
102122 }
103123 } ;
104124}
0 commit comments