Skip to content

Commit 27df0c5

Browse files
Revert "[inductor] use int64 for large index (pytorch#154575)"
This reverts commit 2596e3d. Reverted pytorch#154575 on behalf of https://github.com/clee2000 due to broke inductor/test_op_dtype_prop.py::TestCaseCUDA::test_op_dtype_propagation_add_cuda_int32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/15510656657/job/43673763835) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/2596e3d0617852469241be8777cf46db5c83928c), note for self: bad TD ([comment](pytorch#154575 (comment)))
1 parent 49888e6 commit 27df0c5

File tree

7 files changed

+6
-79
lines changed

7 files changed

+6
-79
lines changed

test/inductor/test_torchinductor.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13477,35 +13477,6 @@ def pad_same(x, k, s, d=(1, 1), value=0):
1347713477
ref = pad_same(x, (5, 5), (2, 2))
1347813478
self.assertEqual(res, ref, atol=0, rtol=0)
1347913479

13480-
@skip_if_halide # only 32-bit indexing
13481-
@largeTensorTest("16GB", inductor=True)
13482-
def test_split_reduction_with_int64_size(self):
13483-
if torch._inductor.config.cpu_backend == "triton":
13484-
raise unittest.SkipTest(
13485-
"Fail for triton cpu backend with error: https://gist.github.com/shunting314/a873fb32b6b7b5a437f44280ae86839f"
13486-
)
13487-
13488-
if self.device == "cpu":
13489-
raise unittest.SkipTest(
13490-
"The test fails some times on CI: "
13491-
"https://github.com/pytorch/pytorch/actions/runs/15333913377/job/43153170162. "
13492-
"Skip for now."
13493-
)
13494-
13495-
size = (30000, 100000)
13496-
13497-
# rand rather than randn since the mean for the latter is close to 0
13498-
# which happens to be close to the value generated by the bug.
13499-
t = torch.rand(size, dtype=torch.float, device=self.device)
13500-
op = torch.mean
13501-
expected = op(t)
13502-
actual = torch.compile(op)(t)
13503-
# self.common takes more GPU memory. Do the check dirctly
13504-
self.assertTrue(
13505-
torch.allclose(expected, actual, atol=1e-2, rtol=1e-2),
13506-
f"{expected=} {actual=}",
13507-
)
13508-
1350913480
def test_remove_noop_view_default(self):
1351013481
def f(x):
1351113482
batch_size = x.shape[0]

test/inductor/test_torchinductor_dynamic_shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
"test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure(
6262
("cpu", "cuda", "xpu")
6363
),
64-
"test_randint_distribution_dynamic_shapes": TestFailure(("xpu",)),
64+
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda", "xpu")),
6565
}
6666
if not torch._inductor.config.cpp_wrapper:
6767
test_failures["test_conv_inference_heuristics_dynamic_shapes"] = TestFailure(

torch/_inductor/codegen/triton.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,7 @@ def has_rmask(self) -> bool:
227227

228228
@property
229229
def mask_str(self) -> str:
230-
# The sorted call is added to make sure the order is still
231-
# deterministic if self.mask_vars contains mix of string
232-
# and TritonCSEVariable
233-
return (
234-
" & ".join(sorted(map(str, self.mask_vars))) if self.mask_vars else "None"
235-
)
230+
return " & ".join(map(str, self.mask_vars)) if self.mask_vars else "None"
236231

237232

238233
@dataclasses.dataclass

torch/_inductor/codegen/triton_utils.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,34 +111,11 @@ def signature_to_meta(
111111
size_dtype: Optional[str],
112112
argdefs: list[ArgName],
113113
indices: Optional[list[int]] = None,
114-
is_template: bool = False,
115114
) -> dict[str, str]:
116115
if indices is None:
117116
indices = list(range(len(signature)))
118-
119-
def _decide_tl_dtype(arg):
120-
# Even if the ks0 symbol itself is within tl.int32 range, it's
121-
# risky to use tl.int32 dtype since we may have ks0*ks1 later
122-
# for kernels like torch.mean when dynamic shape is enabled.
123-
#
124-
# Check config.triton.use_block_ptr, since Triton block pointer
125-
# does not support 64bit indexing:
126-
# https://gist.github.com/shunting314/6a41c776171720ce4561f202dcde0ad6
127-
#
128-
# If the triton metadata is for a template, don't use tl.int64 index.
129-
# Templates like flex attention/decoding uses block pointers which
130-
# does not support 64 bit indexing.
131-
if (
132-
not config.triton.use_block_ptr
133-
and not is_template
134-
and isinstance(arg, SizeArg)
135-
and arg.name.startswith("ks")
136-
):
137-
return "tl.int64"
138-
return size_dtype
139-
140117
return {
141-
argdefs[i].name: signature_of(arg, size_dtype=_decide_tl_dtype(arg))
118+
argdefs[i].name: signature_of(arg, size_dtype=size_dtype)
142119
for i, arg in zip(indices, signature)
143120
}
144121

torch/_inductor/ir.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@
9090
convert_shape_to_symint,
9191
developer_warning,
9292
do_bench_using_profiling,
93-
dtype_from_size,
9493
get_dtype_size,
9594
get_kernel_metadata,
9695
GPU_ALIGN_BYTES,
@@ -1679,10 +1678,9 @@ def body() -> OpsValue:
16791678
return loader(new_index, reindex([indices]))
16801679

16811680
if need_mask:
1682-
index_dtype = dtype_from_size(reduction_numel)
16831681
mask = ops.lt(
1684-
ops.index_expr(indices, index_dtype),
1685-
ops.index_expr(reduction_numel, index_dtype),
1682+
ops.index_expr(indices, torch.int32),
1683+
ops.index_expr(reduction_numel, torch.int32),
16861684
)
16871685
return ops.masked(mask, body, default)
16881686
else:

torch/_inductor/select_algorithm.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -494,10 +494,7 @@ def jit_lines(self):
494494
argdefs, _, signature, _ = self.args.python_argdefs()
495495
triton_meta: dict[str, Any] = {
496496
"signature": signature_to_meta(
497-
signature,
498-
size_dtype=self.index_dtype,
499-
argdefs=argdefs,
500-
is_template=True,
497+
signature, size_dtype=self.index_dtype, argdefs=argdefs
501498
),
502499
"device": DeviceProperties.create(self.output_node.get_device()),
503500
"constants": {},

torch/_inductor/utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3129,14 +3129,3 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
31293129
isinstance(wrapper, SubgraphPythonWrapperCodegen)
31303130
and wrapper.partition_signatures is not None
31313131
)
3132-
3133-
3134-
def dtype_from_size(size: int) -> torch.dtype:
3135-
from .virtualized import V
3136-
3137-
if V.graph.sizevars.statically_known_lt(
3138-
size, 2**31
3139-
) and V.graph.sizevars.statically_known_geq(size, -(2**31)):
3140-
return torch.int32
3141-
else:
3142-
return torch.int64

0 commit comments

Comments
 (0)