@@ -2,8 +2,7 @@ use crate::components::tile::TileAttentionConfig;
22use crate :: components:: { AttentionTileSize , attention_types:: * } ;
33use cubecl_core as cubecl;
44use cubecl_core:: prelude:: * ;
5- use cubecl_matmul:: components:: MatrixLayout ;
6- use cubecl_matmul:: components:: global:: memory:: { GlobalIterator , ViewDirection } ;
5+ use cubecl_matmul:: components:: global:: memory:: { GlobalIterator , GlobalMemoryConfig } ;
76use cubecl_matmul:: components:: tile:: StridedTile ;
87use cubecl_std:: tensor:: { View , layout:: Coords2d } ;
98use cubecl_std:: { Swizzle , tensor:: layout:: Coordinates } ;
@@ -44,6 +43,8 @@ pub struct MaterializedMaskReader<M: Numeric> {
4443 logical_iter : LogicalIterator ,
4544 // TODO not sure if mandatory, but i need for the stride when reading in global memory
4645 seq_kv_shape : u32 ,
46+ #[ cube( comptime) ]
47+ gmem_config : GlobalMemoryConfig ,
4748}
4849
4950#[ derive( CubeType ) ]
@@ -64,15 +65,16 @@ impl<AP: AttentionPrecision> MaskReader<AP> {
6465 mask : View < Line < MSK < AP > > , Coords2d > ,
6566 step : u32 ,
6667 seq_kv_shape : u32 ,
67- #[ comptime] view_direction : ViewDirection ,
68+ #[ comptime] gmem_config : GlobalMemoryConfig ,
6869 ) -> Self {
6970 let mask = mask. slice ( ( stage_q_offset, 0 ) , mask. shape ( ) ) ;
70- let global_iter = GlobalIterator :: new ( mask, step, view_direction, false ) ;
71+ let global_iter = GlobalIterator :: new ( mask, step, gmem_config . view_direction , false ) ;
7172
7273 MaskReader :: < AP > :: new_Materialized ( MaterializedMaskReader :: new (
7374 global_iter,
7475 LogicalIterator :: init ( partition_q_offset, step) ,
7576 seq_kv_shape,
77+ gmem_config,
7678 ) )
7779 }
7880
@@ -117,11 +119,13 @@ impl<M: Numeric> MaterializedMaskReader<M> {
117119 global_iter : GlobalIterator < Line < M > > ,
118120 logical_iter : LogicalIterator ,
119121 seq_kv_shape : u32 ,
122+ #[ comptime] gmem_config : GlobalMemoryConfig ,
120123 ) -> Self {
121124 MaterializedMaskReader :: < M > {
122125 global_iter,
123126 logical_iter,
124127 seq_kv_shape,
128+ gmem_config,
125129 }
126130 }
127131
@@ -135,20 +139,29 @@ impl<M: Numeric> MaterializedMaskReader<M> {
135139
136140 let row = row_offset + P :: seq_q_index ( ) * elements_in_partition_seq_q;
137141
142+ let slice = self
143+ . global_iter
144+ . view ( )
145+ . slice (
146+ ( row, col. runtime ( ) ) ,
147+ ( attention_tile_size. seq_q , attention_tile_size. seq_kv ) . runtime ( ) ,
148+ )
149+ . to_linear_slice ( ) ;
150+
151+ let line_size = self . gmem_config . line_size ;
152+ let start = 0 ;
153+ let length = attention_tile_size. seq_q * attention_tile_size. seq_kv / line_size;
154+ let end = start + length;
155+ let stride = self . seq_kv_shape / line_size;
156+
138157 StridedTile :: < M > :: new_strided (
139- self . global_iter
140- . view ( )
141- . slice (
142- ( row, col. runtime ( ) ) ,
143- ( attention_tile_size. seq_q , attention_tile_size. seq_kv ) . runtime ( ) ,
144- )
145- . to_linear_slice ( ) ,
146- 0 ,
147- attention_tile_size. seq_q * attention_tile_size. seq_kv ,
148- self . seq_kv_shape ,
158+ slice,
159+ start,
160+ end,
161+ stride,
149162 Swizzle :: none ( ) ,
150- MatrixLayout :: RowMajor ,
151- 1u32 ,
163+ self . gmem_config . matrix_layout ,
164+ line_size ,
152165 )
153166 }
154167
0 commit comments