@@ -25,9 +25,9 @@ pub struct MatmulProblemDefinition {
2525 pub k : usize ,
2626 pub lhs_pow2_factor : u8 ,
2727 pub rhs_pow2_factor : u8 ,
28- pub elem_lhs : ElemType ,
29- pub elem_rhs : ElemType ,
30- pub elem_out : ElemType ,
28+ pub elem_lhs : MatmulElemType ,
29+ pub elem_rhs : MatmulElemType ,
30+ pub elem_out : MatmulElemType ,
3131 pub matrix_layout_lhs : MatrixBatchLayout ,
3232 pub matrix_layout_rhs : MatrixBatchLayout ,
3333}
@@ -67,6 +67,12 @@ pub fn should_tune_double_buffering(fused: bool, key: &MatmulAutotuneKey) -> boo
6767 }
6868}
6969
70+ #[ derive( Hash , Eq , PartialEq , Debug , Clone , Serialize , Deserialize , AutotuneKey ) ]
71+ pub struct MatmulElemType {
72+ pub elem : ElemType ,
73+ pub quantized : bool ,
74+ }
75+
7076impl MatmulAutotuneKey {
7177 /// Create the autotune key based on the shape of both lhs and rhs as well as the element type
7278 /// used for the calculation.
@@ -77,9 +83,9 @@ impl MatmulAutotuneKey {
7783 rhs_shape : & [ usize ] ,
7884 lhs_strides : & [ usize ] ,
7985 rhs_strides : & [ usize ] ,
80- elem_lhs : ElemType ,
81- elem_rhs : ElemType ,
82- elem_out : ElemType ,
86+ elem_lhs : MatmulElemType ,
87+ elem_rhs : MatmulElemType ,
88+ elem_out : MatmulElemType ,
8389 ) -> MatmulAutotuneKey {
8490 let ndims = lhs_shape. len ( ) ;
8591 let m = lhs_shape[ ndims - 2 ] ;
0 commit comments