Skip to content

Commit 60d1c71

Browse files
Revert "[inductor] Cooperative reductions (pytorch#137756)"
This reverts commit fed37db. Reverted pytorch#137756 on behalf of https://github.com/jeanschmidt due to ROCM tests are timing out :( ([comment](pytorch#137756 (comment)))
1 parent 2487a83 commit 60d1c71

File tree

16 files changed

+76
-620
lines changed

16 files changed

+76
-620
lines changed

test/inductor/test_cooperative_reductions.py

Lines changed: 0 additions & 116 deletions
This file was deleted.

test/inductor/test_perf.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,6 @@ def f(x, scale, amax_keep_dim):
501501
expected_numel = (
502502
1 + hidden_size * 2 + 4 * 2048 * hidden_size * 2 + 4 * 2048 * 2 + 1
503503
)
504-
if config.triton.cooperative_reductions:
505-
expected_numel = 134225922
506-
507504
self.assertExpectedInline(count_numel(f, *inp, True), str(expected_numel))
508505
self.assertExpectedInline(count_numel(f, *inp, False), str(expected_numel))
509506

test/inductor/test_torchinductor.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11747,34 +11747,30 @@ def fn(a: torch.Tensor) -> torch.Tensor:
1174711747
return torch.sum(a)
1174811748

1174911749
kernels = self.get_kernels(fn, [torch.randn([256, 256], device=GPU_TYPE)])
11750-
expected_divisible = {
11751-
# kernel0 reduces from 256 to (xnumel=8, rnumel=8192), which means it reduces 256 by 256 into an array of
11752-
# size 8 by accumulating 8192 elements at once note that rnumel is equal to 512 * 16, so rnumel which is
11753-
# at slot 3 should be in the divisible by 16 descriptor
11754-
0: (0, 1, 3),
11755-
# kernel1 reduces from 8 elements to a single scalar.
11756-
# Since multi-kernel generate 2 variants for each kernel. The second
11757-
# persistent-reduction has index 2.
11758-
1: (0, 1),
11759-
}
1176011750
if config.triton.multi_kernel:
11761-
self.assertEqual(len(kernels), 4)
11762-
expected_divisible[2] = expected_divisible.pop(1)
11763-
elif config.triton.cooperative_reductions:
11764-
self.assertEqual(len(kernels), 1)
11765-
expected_divisible = {
11766-
# one kernel, with extra workspace/semaphore args
11767-
0: (0, 1, 2, 3, 5),
11768-
}
11751+
self.assertTrue(
11752+
len(kernels) == 4,
11753+
"SUM should result in four kernels when multi-kernel is enabled",
11754+
)
1176911755
else:
11770-
self.assertEqual(len(kernels), 2)
11756+
self.assertTrue(len(kernels) == 2, "SUM should result in two kernels")
1177111757

11772-
for kernel_id, expected in expected_divisible.items():
11773-
divisible_by_16 = (
11774-
kernels[kernel_id].triton_meta["configs"][0].divisible_by_16
11775-
)
11776-
self.assertEqual(divisible_by_16, expected)
11758+
# kernel0 reduces from 256 to (xnumel=8, rnumel=8192), which means it reduces 256 by 256 into an array of
11759+
# size 8 by accumulating 8192 elements at once note that rnumel is equal to 512 * 16, so rnumel which is
11760+
# at slot 3 should be in the divisible by 16 descriptor
11761+
arguments_that_are_divisible_by_16_in_kernel0 = (
11762+
kernels[0].triton_meta["configs"][0].divisible_by_16
11763+
)
11764+
self.assertEqual(arguments_that_are_divisible_by_16_in_kernel0, (0, 1, 3))
1177711765

11766+
# kernel1 reduces from 8 elements to a single scalar.
11767+
# Since multi-kernel generate 2 variants for each kernel. The second
11768+
# persistent-reduction has index 2.
11769+
kernel1_index = 2 if config.triton.multi_kernel else 1
11770+
arguments_that_are_divisible_by_16_in_kernel1 = (
11771+
kernels[kernel1_index].triton_meta["configs"][0].divisible_by_16
11772+
)
11773+
self.assertEqual(arguments_that_are_divisible_by_16_in_kernel1, (0, 1))
1177811774
torch._dynamo.reset()
1177911775

1178011776
@config.patch(assume_aligned_inputs=False)

test/inductor/test_torchinductor_strided_blocks.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,6 @@ def test_reduction(
315315
full = torch.randn(full_size).to(device)
316316
view = torch.as_strided(full, view_size, full.stride())
317317

318-
if num_triton_kernels == 2 and config.triton.cooperative_reductions:
319-
# fewer kernels with cooperative reductions
320-
num_triton_kernels = 1
321-
num_block_pointers -= 2
322-
323318
# Expect at least 1 block pointer for the input.
324319
# Add 2 more if we generate 2 kernels.
325320
result, (code,) = self.run_and_compare(

torch/_inductor/codegen/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1457,7 +1457,7 @@ def semaphores(self, min_size: sympy.Expr):
14571457
arg = WorkspaceArg(
14581458
count=min_size,
14591459
zero_mode=WorkspaceZeroMode.ZERO_PER_GRAPH,
1460-
dtype=torch.uint32,
1460+
dtype=torch.int32,
14611461
inner_name="sem_ptr",
14621462
outer_name=f"semaphores_{current_device.type}_{current_device.index}",
14631463
device=current_device,

torch/_inductor/codegen/halide.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1660,7 +1660,7 @@ class HalideScheduling(SIMDScheduling):
16601660
int32_type = "hl.Int(32)"
16611661
# TODO(jansel): Halide doesn't actually support 64 bit indexing...
16621662
int64_type = "hl.Int(64)"
1663-
kernel_type = HalideKernel # type: ignore[arg-type,assignment]
1663+
kernel_type = HalideKernel # type: ignore[arg-type]
16641664

16651665
@classmethod
16661666
def get_backend_features(cls, device: torch.device):

torch/_inductor/codegen/simd.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,6 @@ def __init__(
330330
pid_cache=None,
331331
reduction_hint=ReductionHint.DEFAULT,
332332
override_persistent_reduction=None,
333-
override_cooperative_reduction=None,
334333
) -> None:
335334
if pid_cache is None:
336335
pid_cache = {}
@@ -349,11 +348,6 @@ def __init__(
349348
self.index_dtype: str = index_dtype
350349
self.last_usage: OrderedSet[str] = OrderedSet()
351350
self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list)
352-
self.cooperative_reduction: bool = (
353-
override_cooperative_reduction
354-
if override_cooperative_reduction is not None
355-
else self.should_use_cooperative_reduction()
356-
)
357351
self.persistent_reduction: bool = (
358352
override_persistent_reduction
359353
if override_persistent_reduction is not None
@@ -427,9 +421,6 @@ def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
427421
finally:
428422
self.inside_reduction = prior
429423

430-
def should_use_cooperative_reduction(self) -> bool:
431-
return False # defined in subclass
432-
433424
def should_use_persistent_reduction(self) -> bool:
434425
return False # defined in subclass
435426

@@ -515,7 +506,7 @@ def set_last_usage(self, nodes):
515506
)
516507

517508
def disable_reduction(self):
518-
should_flush = self.range_trees[-1].is_loop or self.cooperative_reduction
509+
should_flush = self.range_trees[-1].is_loop
519510

520511
@contextlib.contextmanager
521512
def ctx():
@@ -1334,7 +1325,6 @@ def get_kernel_args(self, node_schedule, numel, reduction_numel):
13341325
def codegen_node_schedule(
13351326
self, node_schedule, buf_accesses, numel, reduction_numel
13361327
):
1337-
from torch._inductor.codegen.triton import TritonKernel
13381328
from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel
13391329

13401330
tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel)
@@ -1344,8 +1334,7 @@ def codegen_node_schedule(
13441334
index_dtype,
13451335
) = self.get_kernel_args(node_schedule, numel, reduction_numel)
13461336

1347-
is_scan = schedule_contains_op(node_schedule, "scan")
1348-
is_split_scan = is_scan and any(
1337+
is_split_scan = any(
13491338
isinstance(node, BaseSchedulerNode) and node.is_split_scan()
13501339
for node in node_schedule
13511340
)
@@ -1360,10 +1349,6 @@ def codegen_node_schedule(
13601349
index_dtype=index_dtype,
13611350
)
13621351

1363-
if is_scan and kernel_type == TritonKernel:
1364-
# TODO(jansel): scan does not yet work with cooperative reductions
1365-
kernel_kwargs["override_cooperative_reduction"] = False
1366-
13671352
# ops.sort only works with persistent reduction, and is not bandwidth bound anyway
13681353
# so taking the hit of non-coalesced loads is okay
13691354
if has_sort := schedule_contains_op(node_schedule, "sort"):

0 commit comments

Comments
 (0)