diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs.py index 685b44323343..5c197e0858c6 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs.py @@ -598,7 +598,7 @@ def matmul_ogs(x, w, bias, out_matmul_scale = out_matmul_scale.data.view(torch.uint8) if has_scratchpad and "mx_out_scale" in memory["scratchpad"]: out_matmul_scale = memory["scratchpad"]["mx_out_scale"] - out_matmul_has_mx = out_matmul_scale is not None and out_matmul.element_size() == 1 + out_matmul_has_mx = out_matmul_scale is not None and bitwidth(out_dtype) == 8 # matrix multiplication flex = precision_config.flex_ctx bias_stride = None if bias is None else bias.stride(0)