Skip to content

Commit eb655a2

Browse files
alexsamardzicpytorchmergebot
authored andcommitted
Fix CUTLASS 2.x kernels for auto-tuning (pytorch#146755)
Pull Request resolved: pytorch#146755 Approved by: https://github.com/henrylhtsang
1 parent 683bb12 commit eb655a2

File tree

3 files changed

+10
-34
lines changed

3 files changed

+10
-34
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
HAS_CUDA = HAS_CUDA and not torch.version.hip
4646
SM80OrLater = SM80OrLater and not torch.version.hip
4747
SM90OrLater = SM90OrLater and not torch.version.hip
48-
SM80 = SM80OrLater and torch.cuda.get_device_capability() == (8, 0)
4948

5049

5150
def _get_path_without_sccache() -> str:
@@ -737,22 +736,14 @@ def forward(self, x, w):
737736
torch.testing.assert_close(expected, actual, atol=0.01, rtol=0.01)
738737

739738
# TODO: Enable dynamic test cases when dynamic support is added.
740-
@unittest.skipIf(True, "disabled due to broken on A100")
741-
# error: TypeError: can't multiply sequence by non-int of type 'str'
742-
@unittest.skipIf(not SM80, "need sm_80 exactly")
739+
@unittest.skipIf(not SM80OrLater or SM90OrLater, "need sm_8x exactly")
743740
@parametrize("dynamic", (False,))
744-
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen"))
745741
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
746-
def test_max_autotune_cutlass_backend_mixed_mm(
747-
self, dynamic: bool, max_autotune_gemm_backends: str
748-
):
742+
def test_max_autotune_cutlass_backend_mixed_mm(self, dynamic: bool):
749743
"""
750744
Make sure autotuning mm in sub processes work without crashes.
751745
"""
752746

753-
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
754-
return
755-
756747
def mm(a, b):
757748
return torch.mm(a, b.to(torch.half))
758749

@@ -768,7 +759,7 @@ def mm(a, b):
768759
{
769760
"max_autotune": True,
770761
"autotune_in_subproc": True,
771-
"max_autotune_gemm_backends": max_autotune_gemm_backends,
762+
"max_autotune_gemm_backends": "CUTLASS",
772763
"cuda.cutlass_dir": _CUTLASS_DIR,
773764
"cuda.cutlass_max_profiling_configs": 2,
774765
"use_mixed_mm": True,
@@ -792,22 +783,16 @@ def mm(a, b):
792783
assert cutlass_kernels_count > 0
793784

794785
# TODO: Enable dynamic test cases when dynamic support is added.
795-
@unittest.skipIf(True, "disabled due to broken on A100")
796-
# error: TypeError: can't multiply sequence by non-int of type 'str'
797-
@unittest.skipIf(not SM80, "need sm_80 exactly")
786+
@unittest.skipIf(not SM80OrLater or SM90OrLater, "need sm_8x exactly")
798787
@parametrize("dynamic", (False,))
799-
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen"))
800788
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
801789
def test_max_autotune_cutlass_backend_sparse_semi_structured_mm(
802-
self, dynamic: bool, max_autotune_gemm_backends: str
790+
self, dynamic: bool
803791
):
804792
"""
805793
Make sure autotuning mm in sub processes work without crashes.
806794
"""
807795

808-
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
809-
return
810-
811796
SparseSemiStructuredTensor._FORCE_CUTLASS = True
812797

813798
def mm(a, b):
@@ -823,7 +808,7 @@ def mm(a, b):
823808
{
824809
"max_autotune": True,
825810
"autotune_in_subproc": True,
826-
"max_autotune_gemm_backends": max_autotune_gemm_backends,
811+
"max_autotune_gemm_backends": "CUTLASS",
827812
"cuda.cutlass_dir": _CUTLASS_DIR,
828813
"cuda.cutlass_max_profiling_configs": 2,
829814
"autotune_local_cache": True,

torch/_inductor/codegen/cuda/cuda_kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
55

6-
from sympy import Expr
6+
from sympy import Expr, symbols
77

88
from torch import dtype as torch_dtype
99
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
@@ -404,6 +404,7 @@ def size(
404404
if len(sizes) == 0:
405405
return str(default_value)
406406

407+
sizes = [symbols(v) if isinstance(v, str) else v for v in sizes]
407408
val = sympy_product(sizes)
408409
return val
409410

torch/_inductor/codegen/cuda/gemm_template.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,6 @@
155155
PT_EXPORT {{kernel_call_signature}} {
156156
try {
157157
int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}};
158-
int64_t M = {{kernel.size(X, -2)}};
159-
int64_t K = {{kernel.size(W, -2)}};
160-
int64_t N = {{kernel.size(W, -1)}};
161158
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
162159
using coord_t = cutlass::gemm::GemmCoord::Index;
163160
static cutlass::KernelHardwareInfo hw_info;
@@ -176,13 +173,6 @@
176173
177174
// check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
178175
#ifndef CUTLASS_BACKEND_DISABLE_CHECKS
179-
{{kernel.check_not_null(X)}}
180-
{{kernel.check_not_null(W)}}
181-
{{kernel.check_not_null(Bias)}}
182-
{{kernel.check_not_null(Meta)}}
183-
{{kernel.check_not_null(Y)}}
184-
185-
186176
{
187177
auto status = gemm_op.can_implement(arguments);
188178
CUTLASS_CHECK(status);
@@ -278,7 +268,7 @@
278268
{
279269
static_cast<coord_t>({{M}}),
280270
static_cast<coord_t>({{N}}),
281-
static_cast<coord_t>(K),
271+
static_cast<coord_t>(2 * K),
282272
}, // GemmCoord problem_size
283273
X_ref, // TensorRef<ElementA const, LayoutA> ref_A
284274
W_ref, // TensorRef<ElementB const, LayoutB> ref_B
@@ -1382,7 +1372,7 @@ def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool:
13821372
A_size = [int(i) for i in A_layout.size]
13831373
B_size = [int(i) for i in B_layout.size]
13841374
K = max(A_size[1], B_size[0])
1385-
return (K == A_size[1] or K == 2 * A_size[0]) and K == B_size[0]
1375+
return (K == A_size[1] or K == 2 * A_size[1]) and K == B_size[0]
13861376

13871377
def _shape_match(
13881378
self,

0 commit comments

Comments
 (0)