Skip to content

Commit ed30fa7

Browse files
janselpytorchmergebot
authored andcommitted
[inductor] sympy.Integer([01]) -> sympy.S.(Zero|One) (pytorch#139523)
Pull Request resolved: pytorch#139523 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#139364, pytorch#139365, pytorch#139370, pytorch#139452
1 parent b6fb135 commit ed30fa7

File tree

19 files changed

+77
-85
lines changed

19 files changed

+77
-85
lines changed

torch/_inductor/codegen/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ class TensorArg:
170170
name: str
171171
buffer: str
172172
dtype: torch.dtype
173-
offset: sympy.Expr = sympy.Integer(0) # c++ only
173+
offset: sympy.Expr = sympy.S.Zero # c++ only
174174
alias_of: Optional[str] = None # halide only
175175

176176

torch/_inductor/codegen/cpp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def stride_at(index: sympy.Expr, var: sympy.Symbol):
246246
# see test_torchinductor_dynamic_shapes.py::test_full_boolean_dynamic_shapes_cpu
247247
# which has tmp0 = ops.index_expr(s0 >= 1024, torch.bool) and fails below calculation.
248248
# in this case, there is no dependencies between index and var.
249-
return sympy.Integer(0)
249+
return sympy.S.Zero
250250
replacement = {var: var + 1}
251251
new_index = sympy_subs(index, replacement) # type: ignore[arg-type]
252252
return sympy.simplify(new_index - index)
@@ -4711,8 +4711,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
47114711
class LoopLevel:
47124712
var: Optional[sympy.Expr] = None
47134713
size: Optional[sympy.Expr] = None
4714-
offset: sympy.Expr = sympy.Integer(0)
4715-
steps: sympy.Expr = sympy.Integer(1)
4714+
offset: sympy.Expr = sympy.S.Zero
4715+
steps: sympy.Expr = sympy.S.One
47164716
parallel: int = 0
47174717
simd_omp: bool = False
47184718
simd_vec: bool = False

torch/_inductor/codegen/cpp_template_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def store_pointwise_nodes(
216216
for i, sz in enumerate(var_sizes[0])
217217
}
218218
if not offsets:
219-
offsets = [sympy.Integer(0)] * len(var_sizes[0])
219+
offsets = [sympy.S.Zero] * len(var_sizes[0])
220220
if not reindexers:
221221
reindexers = [None] * len(nodes)
222222
assert len(offsets) == len(var_sizes[0])

torch/_inductor/codegen/halide.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ def visit_floor_div(base, divisor):
788788
if not nodes:
789789
nodes.append(tree.lookup(1, tree.numel))
790790
handled_count = 0
791-
divisor = sympy.Integer(1)
791+
divisor = sympy.S.One
792792
added_sym_size = []
793793
# decide on a minimal set of symbols and put them in self.halide_vars
794794
while handled_count < len(nodes) and not eq(tree.numel, divisor):
@@ -846,7 +846,7 @@ def visit_floor_div(base, divisor):
846846
idx += 1
847847
divisor *= size
848848
length = 1
849-
expr = sympy.Integer(0)
849+
expr = sympy.S.Zero
850850
while not eq(node.length, length):
851851
sym, size = added_sym_size[idx]
852852
idx += 1
@@ -855,8 +855,8 @@ def visit_floor_div(base, divisor):
855855
self.index_replacements[node.symbol()] = expr
856856
except IndexError:
857857
assert had_fallback
858-
full_index = sympy.Integer(0)
859-
stride = sympy.Integer(1)
858+
full_index = sympy.S.Zero
859+
stride = sympy.S.One
860860
for sym, size in added_sym_size:
861861
full_index += stride * sym
862862
stride *= size
@@ -937,8 +937,8 @@ def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool):
937937
), sym
938938

939939
# group the expression by variables used
940-
offset = sympy.Integer(0)
941-
split_expr = {s: sympy.Integer(0) for s in symbols}
940+
offset = sympy.S.Zero
941+
split_expr = {s: sympy.S.Zero for s in symbols}
942942
split_failed: List[Tuple[List[sympy.Symbol], sympy.Expr]] = []
943943
index = sympy.expand(self.rename_indexing(index))
944944
for part in index.args if isinstance(index, sympy.Add) else [index]:
@@ -972,7 +972,7 @@ def expr_to_dimension(expr, syms):
972972
length = sympy.simplify(
973973
sympy_subs(expr, {sym: self.sym_size(sym) - 1 for sym in syms}) + 1
974974
)
975-
stride = sympy.Integer(1)
975+
stride = sympy.S.One
976976
if isinstance(expr, sympy.Mul):
977977
for term in expr.args:
978978
if isinstance(term, sympy.Integer):
@@ -994,11 +994,11 @@ def expr_to_dimension(expr, syms):
994994
if not dims: # scalar load/store
995995
if self.has_indirect_indexing:
996996
# workaround https://github.com/halide/Halide/issues/8338
997-
dims.append(DimensionInfo(sympy.Integer(0), 1, 1))
997+
dims.append(DimensionInfo(sympy.S.Zero, 1, 1))
998998
elif not V.graph.sizevars.statically_known_equals(dims[0].stride, 1):
999999
# Halide assumes dimension 0 is stride == 1, so add a dummy dimension
10001000
dims.insert(
1001-
0, DimensionInfo(sympy.Integer(0), 1 if is_store else dims[0].stride, 1)
1001+
0, DimensionInfo(sympy.S.Zero, 1 if is_store else dims[0].stride, 1)
10021002
)
10031003

10041004
if dims and not is_store:

torch/_inductor/codegen/simd.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def __init__(
101101
prefix: str,
102102
*,
103103
kernel: SIMDKernel,
104-
divisor=sympy.Integer(1),
105-
length=sympy.Integer(1),
104+
divisor=sympy.S.One,
105+
length=sympy.S.One,
106106
root: IterationRangesRoot,
107107
) -> None:
108108
super().__init__()
@@ -205,7 +205,7 @@ def lookup(self, divisor, length):
205205
return self.nodes[expr]
206206

207207
def construct_entries(self, lengths: List[sympy.Expr]):
208-
divisor = sympy.Integer(1)
208+
divisor = sympy.S.One
209209
itervars = []
210210
for length in reversed(lengths):
211211
itervars.append(self.lookup(divisor, length))
@@ -224,7 +224,7 @@ def vars_and_sizes(self, index: sympy.Expr):
224224
x.divisor, fallback=config.unbacked_symint_fallback
225225
)
226226
)
227-
divisor = sympy.Integer(1)
227+
divisor = sympy.S.One
228228
index_vars = []
229229
sizes = []
230230

@@ -481,7 +481,7 @@ def combine_modular_indexing_pairs(self, index):
481481
new_index,
482482
{
483483
tree_node.root.index_sym(): tree_node.root.lookup(
484-
sympy.Integer(1), tree_node.root.numel
484+
sympy.S.One, tree_node.root.numel
485485
).symbol()
486486
},
487487
)
@@ -572,7 +572,7 @@ def getter(flat_vars):
572572
return_getters = []
573573
for size in length_group:
574574
if sv.statically_known_equals(size, 1): # type: ignore[arg-type]
575-
return_getters.append(lambda _: sympy.Integer(0))
575+
return_getters.append(lambda _: sympy.S.Zero)
576576
continue
577577

578578
while current_group < len(remaining) and sv.statically_known_equals(
@@ -635,7 +635,7 @@ def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]):
635635
"""
636636
groups = [rt.numel for rt in self.range_trees]
637637
if not self.inside_reduction:
638-
groups[-1] = sympy.Integer(1)
638+
groups[-1] = sympy.S.One
639639

640640
if len(lengths) == len(self.range_trees) and all(
641641
V.graph.sizevars.simplify(sympy_product(x) - g) == 0
@@ -1564,7 +1564,7 @@ def candidate_tilings(node):
15641564
return tilings
15651565

15661566
@classmethod
1567-
def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)):
1567+
def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.S.One):
15681568
"""
15691569
Heuristics to decide how to tile kernels.
15701570
Currently, we tile based on stride-1 dimensions.

torch/_inductor/codegen/simd_kernel_features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
self,
7171
node_schedule: List[NodeScheduleEntry],
7272
numel: sympy.Expr,
73-
reduction_numel: sympy.Expr = sympy.Integer(1),
73+
reduction_numel: sympy.Expr = sympy.S.One,
7474
):
7575
self.node_schedule = node_schedule
7676
self.numel = V.graph.sizevars.simplify(numel) # numel excludes reduction_numel

torch/_inductor/codegen/triton.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def codegen_broadcast_and_reshape(
241241

242242
# Reshape to add singletons.
243243
pre_broadcast_shape = [
244-
sympy.Integer(1) if is_broadcasting else dim
244+
sympy.S.One if is_broadcasting else dim
245245
for dim, is_broadcasting in zip(
246246
self.broadcast_shape, self.broadcasting_dims
247247
)
@@ -342,7 +342,7 @@ def remove_dims(it):
342342
and V.kernel.numels[-1] != 1
343343
):
344344
# Need to expand rank by 1 to match rank when self.inside_reduction=True
345-
final_shape.append(sympy.Integer(1))
345+
final_shape.append(sympy.S.One)
346346

347347
return BlockPtrOptions(
348348
params=params,
@@ -375,9 +375,7 @@ def format(self, name: str, roffset=True) -> str:
375375
f = V.kernel.index_to_str
376376
offsets = [*self.offsets]
377377
if not roffset:
378-
offsets = [
379-
self.replace_roffset(offset, sympy.Integer(0)) for offset in offsets
380-
]
378+
offsets = [self.replace_roffset(offset, sympy.S.Zero) for offset in offsets]
381379
args = [
382380
(
383381
f"{name} + ({f(self.constant_offset)})"
@@ -408,9 +406,7 @@ def boundary_check(self) -> List[int]:
408406
idx
409407
for idx in range(len(self.shape))
410408
if (
411-
not sizevars.statically_known_equals(
412-
self.strides[idx], sympy.Integer(0)
413-
)
409+
not sizevars.statically_known_equals(self.strides[idx], sympy.S.Zero)
414410
and not sizevars.statically_known_multiple_of(
415411
self.shape[idx], self.block_shape[idx]
416412
)
@@ -437,7 +433,7 @@ def advance_roffset(self):
437433
advance = [
438434
(
439435
self.replace_roffset(offset, rblock)
440-
- self.replace_roffset(offset, sympy.Integer(0))
436+
- self.replace_roffset(offset, sympy.S.Zero)
441437
)
442438
for offset in self.offsets
443439
]
@@ -1655,7 +1651,7 @@ def get_slice_numels(dims: List[Any]) -> List[Any]:
16551651
Compute the cumulative size of each dimension's slice.
16561652
This proceeds from the last dim up to the second.
16571653
"""
1658-
numels = [sympy.Integer(1)]
1654+
numels = [sympy.S.One]
16591655
for dim in dims[:0:-1]:
16601656
numel = dim * numels[0]
16611657
numels.insert(0, numel)
@@ -1680,10 +1676,10 @@ def get_slice_numels(dims: List[Any]) -> List[Any]:
16801676
# Provide default values for unmatched dims and strides.
16811677
for dim in dims[1:]:
16821678
if dim not in match:
1683-
match[dim] = sympy.Integer(1)
1679+
match[dim] = sympy.S.One
16841680
for stride in strides[1:]:
16851681
if stride not in match:
1686-
match[stride] = sympy.Integer(0)
1682+
match[stride] = sympy.S.Zero
16871683

16881684
sizevars = V.graph.sizevars
16891685

@@ -1786,7 +1782,7 @@ def match_block_pointer() -> Optional[BlockPtrOptions]:
17861782
# For example xindex * 5 + rindex * 3 is partitioned to
17871783
# (xindex * 5, rindex * 3).
17881784
symbol = tree.symbol()
1789-
subexpr = sympy.Integer(0) + sum(
1785+
subexpr = sympy.S.Zero + sum(
17901786
expr for expr in index_terms if symbol in expr.free_symbols
17911787
)
17921788

torch/_inductor/dependencies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def get_numel(self) -> sympy.Expr:
204204
numel = V.graph.get_numel(self.name)
205205
else:
206206
vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols)
207-
numel = sympy.Integer(1)
207+
numel = sympy.S.One
208208
for var, size in zip(self.var_names, self.size):
209209
if var in vars:
210210
numel = numel * size
@@ -328,7 +328,7 @@ def index(self):
328328
raise NotImplementedError("WeakDep does not have an index")
329329

330330
def get_numel(self) -> sympy.Expr:
331-
return sympy.Integer(1)
331+
return sympy.S.One
332332

333333
def rename(self, renames: Dict[str, str]) -> "WeakDep":
334334
if self.name in renames:

torch/_inductor/index_propagation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
135135
if not is_integer_dtype(result_type):
136136
return NotImplemented
137137

138-
result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
138+
result_expr = ModularIndexing(x.expr, sympy.S.One, y.expr)
139139
return TypedExpr(result_expr, result_type)
140140

141141
@staticmethod
@@ -152,7 +152,7 @@ def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
152152
x_expr.is_nonnegative is not None
153153
and x_expr.is_nonnegative == y_expr.is_positive
154154
):
155-
result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
155+
result_expr = ModularIndexing(x.expr, sympy.S.One, y.expr)
156156
return TypedExpr(result_expr, result_type)
157157
return NotImplemented
158158

0 commit comments

Comments
 (0)