11use core:: marker:: PhantomData ;
22use cubecl:: prelude:: * ;
3+ use cubecl_matmul:: AsyncReadingStrategy ;
34use cubecl_matmul:: components:: batch:: HypercubeSelection ;
45use cubecl_matmul:: components:: stage:: PartitionBuffering ;
56use cubecl_matmul:: components:: {
@@ -11,17 +12,10 @@ use cubecl_matmul::kernels::layered::double_unit::DoubleUnitSelectionArgs;
1112use cubecl_matmul:: kernels:: layered:: ordered_double_buffering:: OrderedSelectionArgs ;
1213use cubecl_matmul:: kernels:: layered:: simple:: SimpleArgs ;
1314use cubecl_matmul:: kernels:: layered:: simple_unit:: SimpleUnitSelectionArgs ;
14- use cubecl_matmul:: kernels:: layered:: {
15- MatmulSelection , MultiRowStrategy , Selection , TileSizeSelection , closest_factor_pair,
16- } ;
1715use cubecl_matmul:: kernels:: layered:: { Selection , TileSizeSelection } ;
18- use cubecl_matmul:: { self as matmul} ;
1916use cubecl_matmul:: {
2017 self as matmul, MatmulInputHandle , SyncPartialReadingStrategy , SyncReadingStrategy ,
2118} ;
22- use cubecl_matmul:: { self as matmul, SyncPartialReadingStrategy , SyncReadingStrategy } ;
23- use cubecl_matmul:: { AsyncReadingStrategy , components:: MatmulPrecision } ;
24- use cubecl_matmul:: { SyncPartialReadingStrategy , SyncReadingStrategy } ;
2519use std:: collections:: BTreeMap ;
2620use std:: time:: Duration ;
2721
@@ -98,8 +92,8 @@ impl<R: Runtime, MP: MatmulPrecision> Benchmark for MatmulBench<R, MP> {
9892 matmul_elems. rhs_global,
9993 matmul_elems. rhs_stage,
10094 matmul_elems. rhs_register,
101- matmul_elems. acc ,
102- matmul_elems. out ,
95+ matmul_elems. acc_register ,
96+ matmul_elems. acc_global ,
10397 self . strategy
10498 )
10599 . to_lowercase ( )
@@ -145,13 +139,13 @@ fn entry(m: usize, n: usize, k: usize) -> (usize, usize, usize, usize) {
145139#[ allow( dead_code) ]
146140fn run < R : Runtime , MP : MatmulPrecision > ( device : R :: Device , strategy : matmul:: Strategy ) {
147141 for tl in [ false ] {
148- for tr in [ true ] {
142+ for tr in [ false ] {
149143 for ( b, m, n, k) in [
150144 // entry(8192, 8192, 8192),
151- // entry(6144, 6144, 6144),
145+ entry ( 6144 , 6144 , 6144 ) ,
152146 // entry(4096, 4096, 4096),
153147 // entry(2048, 2048, 2048),
154- entry ( 1024 , 1024 , 1024 ) ,
148+ // entry(1024, 1024, 1024),
155149 // entry(512, 512, 512),
156150 // entry(64, 1024, 64),
157151 // entry(32, 1024, 32),
@@ -397,7 +391,7 @@ fn run_algos_wmma<R: Runtime, MP: MatmulPrecision>() {
397391#[ allow( unused) ]
398392fn run_benches < R : Runtime , MP : MatmulPrecision > ( ) {
399393 // run_grid_search::<R, MP>();
400- run_algos_unit :: < R , MP > ( ) ;
394+ // run_algos_unit::<R, MP>();
401395 run_algos_wmma :: < R , MP > ( ) ;
402396 // run_algos_vecmat::<R, MP>();
403397}
0 commit comments