Skip to content

Commit 9dad5cc

Browse files
authored
[Kernel] Turn off CUTLASS scaled_mm for Ada Lovelace (#6384)
1 parent 6ef3bf9 commit 9dad5cc

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 250 -f 5 -t 1
1+
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 1000 -f 5 -t 1
22
model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
33
tasks:
44
- name: "gsm8k"
55
metrics:
66
- name: "exact_match,strict-match"
7-
value: 0.752
7+
value: 0.755
88
- name: "exact_match,flexible-extract"
9-
value: 0.752
10-
limit: 250
9+
value: 0.755
10+
limit: 1000
1111
num_fewshot: 5

.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ tasks:
44
- name: "gsm8k"
55
metrics:
66
- name: "exact_match,strict-match"
7-
value: 0.756
7+
value: 0.753
88
- name: "exact_match,flexible-extract"
9-
value: 0.752
10-
limit: 250
9+
value: 0.753
10+
limit: 1000
1111
num_fewshot: 5

csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
3838
if (cuda_device_capability >= 90) {
3939
return CUDA_VERSION >= 12000;
4040
} else if (cuda_device_capability >= 89) {
41-
return CUDA_VERSION >= 12040;
41+
// CUTLASS Kernels have not been tuned for Ada Lovelace systems
42+
// and are slower than torch.mm. Return false unconditionally in this case.
43+
return false;
44+
45+
// Once the CUTLASS kernels have been optimized for Lovelace systems,
46+
// use the following check:
47+
// return CUDA_VERSION >= 12040;
4248
}
4349
#endif
4450

@@ -98,4 +104,4 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
98104
TORCH_CHECK(version_num >= 75);
99105
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
100106
}
101-
}
107+
}

0 commit comments

Comments
 (0)