Skip to content

Commit ce4a097

Browse files
Revert "Added swizzle searching, disabled fp16 accum, and enabled ping-pong for cutlass (pytorch#144829)"
This reverts commit 5508444. Reverted pytorch#144829 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#144829 (comment)))
1 parent 527101f commit ce4a097

File tree

3 files changed

+8
-12
lines changed

3 files changed

+8
-12
lines changed

torch/_inductor/codegen/cuda/cutlass_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def get_accumulator_dtype(
245245
return torch_dtype
246246
else:
247247
return torch.float
248-
if torch_dtype in (torch.float16, torch.bfloat16, torch.float):
248+
if torch_dtype in (torch.bfloat16, torch.float):
249249
return torch.float
250250
if torch_dtype == torch.int8:
251251
return torch.int32

torch/_inductor/codegen/cuda/gemm_template.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
4747
}
4848
{{instance_type}}::Arguments arguments;
49-
{{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw, swizzle,
49+
{{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw,
5050
X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}}
5151
{{instance_type}} gemm_op;
5252
if (workspace_size) {
@@ -118,7 +118,6 @@
118118
{{epilogue_arguments}},
119119
hw_info
120120
};
121-
arguments.scheduler.max_swizzle_size = {{swizzle}};
122121
"""
123122

124123
# Jinja template for Cutlass 3.x GEMM Kernel arguments if epilogue fusion is applied,
@@ -502,11 +501,11 @@ def _add_cutlass_gemm_choices(
502501

503502
ops = self.gen_ops()
504503
for name, op in ops:
505-
for swizzle in (1, 2, 4, 8):
506-
description = f"{name} swizzle={swizzle}"
507-
self.maybe_append_choice(
508-
choices, description=description, op=op, swizzle=swizzle
509-
)
504+
self.maybe_append_choice(
505+
choices,
506+
description=name,
507+
op=op,
508+
)
510509
if len(ops) == 0:
511510
input_layouts = [node.get_layout() for node in input_nodes]
512511
input_strides = [node.get_stride() for node in input_nodes]
@@ -953,7 +952,6 @@ def render( # type: ignore[override]
953952
Bias=Bias,
954953
epilogue_template=epilogue_template,
955954
argument_template=argument_template,
956-
swizzle=kwargs["swizzle"],
957955
should_swap_xw=should_swap_xw,
958956
template=self,
959957
kernel=kernel,
@@ -1218,7 +1216,6 @@ def render_gemm_arguments(
12181216
argument_template: str,
12191217
epilogue_template: str,
12201218
should_swap_xw: bool,
1221-
swizzle: int,
12221219
X: IRNode,
12231220
W: IRNode,
12241221
Bias: IRNode,
@@ -1264,7 +1261,6 @@ def render_gemm_arguments(
12641261
M="M",
12651262
N="N",
12661263
epilogue_args=epilogue_args,
1267-
swizzle=swizzle,
12681264
)
12691265
assert epilogue_template is not None
12701266

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1253,7 +1253,7 @@ class cuda:
12531253
# Set this to "pingpong" to avoid numerical issues
12541254
# caused by the op ordering of the "pingpong" memory access
12551255
# pattern used by some Cutlass Kernels.
1256-
cutlass_op_denylist_regex: Optional[str] = None
1256+
cutlass_op_denylist_regex: Optional[str] = "pingpong"
12571257

12581258

12591259
class rocm:

0 commit comments

Comments
 (0)