@@ -4,11 +4,8 @@ use cubecl_std::tensor::TensorHandle;
44
55use crate :: {
66 components:: {
7- AttentionElems , AttentionIdent , AttentionPartitionSize , AttentionProblem ,
8- AttentionSelection , AttentionSetupError , AttentionStageSize , AttentionTileSize ,
9- AttentionTilingScheme , AvailableLineSizes ,
7+ AttentionElems , AttentionIdent , AttentionProblem , AttentionSetupError , AvailableLineSizes ,
108 args:: { TensorArgs , TensorInputsLaunch } ,
11- batch:: HypercubeSelection ,
129 } ,
1310 kernels:: { Algorithm , blackbox_accelerated:: BlackboxAcceleratedAlgorithm , unit:: UnitAlgorithm } ,
1411} ;
@@ -25,15 +22,15 @@ pub enum Strategy {
2522#[ allow( clippy:: result_large_err, clippy:: too_many_arguments) ]
2623pub fn launch < R : Runtime > (
2724 strategy : & Strategy ,
28- client : & ComputeClient < R :: Server > ,
25+ client : & ComputeClient < R > ,
2926 query : TensorHandle < R > ,
3027 key : TensorHandle < R > ,
3128 value : TensorHandle < R > ,
3229 mask : Option < TensorHandle < R > > ,
3330 out : TensorHandle < R > ,
3431 attention_elems : AttentionElems ,
3532) -> Result < ( ) , AttentionSetupError > {
36- launch_ref :: < R > (
33+ launch_ref (
3734 strategy,
3835 client,
3936 & query. as_ref ( ) ,
@@ -48,7 +45,7 @@ pub fn launch<R: Runtime>(
4845#[ allow( clippy:: result_large_err, clippy:: too_many_arguments) ]
4946pub fn launch_ref < R : Runtime > (
5047 strategy : & Strategy ,
51- client : & ComputeClient < R :: Server > ,
48+ client : & ComputeClient < R > ,
5249 query : & TensorHandleRef < R > ,
5350 key : & TensorHandleRef < R > ,
5451 value : & TensorHandleRef < R > ,
@@ -79,26 +76,35 @@ pub fn launch_ref<R: Runtime>(
7976}
8077
8178pub fn launch_attention < R : Runtime , A : Algorithm > (
82- client : & ComputeClient < R :: Server > ,
79+ client : & ComputeClient < R > ,
8380 query : & TensorHandleRef < R > ,
8481 key : & TensorHandleRef < R > ,
8582 value : & TensorHandleRef < R > ,
8683 mask : & Option < TensorHandleRef < R > > ,
8784 out : & TensorHandleRef < R > ,
8885 attention_elems : & AttentionElems ,
8986) -> Result < ( ) , AttentionSetupError > {
90- let line_sizes = AvailableLineSizes :: from_elem_types :: < R > (
91- query. elem_size ,
92- attention_elems. mask . size ( ) ,
93- out. elem_size ,
94- ) ;
95- let line_sizes = A :: filter_line_sizes ( line_sizes)
96- . filter_with_tensor ( AttentionIdent :: Query , query. strides , query. shape )
97- . filter_with_tensor ( AttentionIdent :: Key , key. strides , key. shape )
98- . filter_with_tensor ( AttentionIdent :: Value , value. strides , value. shape )
99- . filter_with_tensor ( AttentionIdent :: Out , out. strides , out. shape )
100- . pick_max ( )
101- . unwrap ( ) ;
87+ let line_sizes = {
88+ let ls = AvailableLineSizes :: from_elem_types (
89+ client,
90+ query. elem_size ,
91+ attention_elems. mask . size ( ) ,
92+ out. elem_size ,
93+ ) ;
94+ let ls = A :: filter_line_sizes ( ls)
95+ . filter_with_tensor ( AttentionIdent :: Query , query. strides , query. shape )
96+ . filter_with_tensor ( AttentionIdent :: Key , key. strides , key. shape )
97+ . filter_with_tensor ( AttentionIdent :: Value , value. strides , value. shape )
98+ . filter_with_tensor ( AttentionIdent :: Out , out. strides , out. shape ) ;
99+
100+ if let Some ( mask) = mask. as_ref ( ) {
101+ ls. filter_with_tensor ( AttentionIdent :: Mask , mask. strides , mask. shape )
102+ } else {
103+ ls
104+ }
105+ }
106+ . pick_max ( )
107+ . unwrap ( ) ;
102108
103109 let problem = AttentionProblem {
104110 batch : query. shape [ 0 ] ,
@@ -111,47 +117,22 @@ pub fn launch_attention<R: Runtime, A: Algorithm>(
111117 causal : false ,
112118 } ;
113119
114- let tile_size = AttentionTileSize {
115- seq_q : 8 ,
116- head_dim : 8 ,
117- seq_kv : 8 ,
118- val_dim : 8 ,
119- } ;
120-
121- let selection = AttentionSelection {
122- hypercube_selection : HypercubeSelection { } ,
123- tiling_scheme : AttentionTilingScheme {
124- tile_size,
125- partition_size : AttentionPartitionSize {
126- seq_q : 1 ,
127- head_dim : 1 ,
128- seq_kv : 1 ,
129- val_dim : 1 ,
130- } ,
131- stage_size : AttentionStageSize { seq_q : 1 } ,
132- } ,
133- plane_dim : 32 ,
134- reuse_key_value : false ,
135- two_rows_in_array_tile : false ,
136- } ;
137-
138- let config = BlackboxAcceleratedAlgorithm :: setup :: < R > (
120+ let selection = A :: selection (
139121 client,
140122 & problem,
141- & selection ,
123+ client . properties ( ) . hardware . plane_size_max ,
142124 & line_sizes,
143125 attention_elems,
144126 ) ?;
145127
128+ let config = A :: setup ( client, & problem, & selection, & line_sizes, attention_elems) ?;
129+
146130 let cube_count_plan = config
147131 . hypercube_config ( )
148132 . cube_count_plan ( & problem, & selection) ;
149133
150- unsafe {
151- <BlackboxAcceleratedAlgorithm as Algorithm >:: BatchAttention :: launch_unchecked :: <
152- TensorArgs ,
153- R ,
154- > (
134+ let result = unsafe {
135+ <A as Algorithm >:: BatchAttention :: launch_unchecked :: < TensorArgs , R > (
155136 client,
156137 config. cube_dim ( ) ,
157138 cube_count_plan. resolve ( ) ,
@@ -167,8 +148,11 @@ pub fn launch_attention<R: Runtime, A: Algorithm>(
167148 cube_count_plan. as_args ( ) ,
168149 config,
169150 attention_elems,
170- ) ;
171- }
151+ )
152+ } ;
172153
173- Ok ( ( ) )
154+ match result {
155+ Ok ( _) => Ok ( ( ) ) ,
156+ Err ( err) => Err ( AttentionSetupError :: Execution ( err) ) ,
157+ }
174158}
0 commit comments