Skip to content

Commit ae303d0

Browse files
Add quantized in matmul key (#982)
1 parent 6ce6ed7 commit ae303d0

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

crates/cubecl-matmul/src/tune_key.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ pub struct MatmulProblemDefinition {
2525
pub k: usize,
2626
pub lhs_pow2_factor: u8,
2727
pub rhs_pow2_factor: u8,
28-
pub elem_lhs: ElemType,
29-
pub elem_rhs: ElemType,
30-
pub elem_out: ElemType,
28+
pub elem_lhs: MatmulElemType,
29+
pub elem_rhs: MatmulElemType,
30+
pub elem_out: MatmulElemType,
3131
pub matrix_layout_lhs: MatrixBatchLayout,
3232
pub matrix_layout_rhs: MatrixBatchLayout,
3333
}
@@ -67,6 +67,12 @@ pub fn should_tune_double_buffering(fused: bool, key: &MatmulAutotuneKey) -> boo
6767
}
6868
}
6969

70+
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
71+
pub struct MatmulElemType {
72+
pub elem: ElemType,
73+
pub quantized: bool,
74+
}
75+
7076
impl MatmulAutotuneKey {
7177
/// Create the autotune key based on the shape of both lhs and rhs as well as the element type
7278
/// used for the calculation.
@@ -77,9 +83,9 @@ impl MatmulAutotuneKey {
7783
rhs_shape: &[usize],
7884
lhs_strides: &[usize],
7985
rhs_strides: &[usize],
80-
elem_lhs: ElemType,
81-
elem_rhs: ElemType,
82-
elem_out: ElemType,
86+
elem_lhs: MatmulElemType,
87+
elem_rhs: MatmulElemType,
88+
elem_out: MatmulElemType,
8389
) -> MatmulAutotuneKey {
8490
let ndims = lhs_shape.len();
8591
let m = lhs_shape[ndims - 2];

0 commit comments

Comments
 (0)