@@ -241,7 +241,7 @@ def codegen_broadcast_and_reshape(
241241
242242 # Reshape to add singletons.
243243 pre_broadcast_shape = [
244- sympy .Integer ( 1 ) if is_broadcasting else dim
244+ sympy .S . One if is_broadcasting else dim
245245 for dim , is_broadcasting in zip (
246246 self .broadcast_shape , self .broadcasting_dims
247247 )
@@ -342,7 +342,7 @@ def remove_dims(it):
342342 and V .kernel .numels [- 1 ] != 1
343343 ):
344344 # Need to expand rank by 1 to match rank when self.inside_reduction=True
345- final_shape .append (sympy .Integer ( 1 ) )
345+ final_shape .append (sympy .S . One )
346346
347347 return BlockPtrOptions (
348348 params = params ,
@@ -375,9 +375,7 @@ def format(self, name: str, roffset=True) -> str:
375375 f = V .kernel .index_to_str
376376 offsets = [* self .offsets ]
377377 if not roffset :
378- offsets = [
379- self .replace_roffset (offset , sympy .Integer (0 )) for offset in offsets
380- ]
378+ offsets = [self .replace_roffset (offset , sympy .S .Zero ) for offset in offsets ]
381379 args = [
382380 (
383381 f"{ name } + ({ f (self .constant_offset )} )"
@@ -408,9 +406,7 @@ def boundary_check(self) -> List[int]:
408406 idx
409407 for idx in range (len (self .shape ))
410408 if (
411- not sizevars .statically_known_equals (
412- self .strides [idx ], sympy .Integer (0 )
413- )
409+ not sizevars .statically_known_equals (self .strides [idx ], sympy .S .Zero )
414410 and not sizevars .statically_known_multiple_of (
415411 self .shape [idx ], self .block_shape [idx ]
416412 )
@@ -437,7 +433,7 @@ def advance_roffset(self):
437433 advance = [
438434 (
439435 self .replace_roffset (offset , rblock )
440- - self .replace_roffset (offset , sympy .Integer ( 0 ) )
436+ - self .replace_roffset (offset , sympy .S . Zero )
441437 )
442438 for offset in self .offsets
443439 ]
@@ -1655,7 +1651,7 @@ def get_slice_numels(dims: List[Any]) -> List[Any]:
16551651 Compute the cumulative size of each dimension's slice.
16561652 This proceeds from the last dim up to the second.
16571653 """
1658- numels = [sympy .Integer ( 1 ) ]
1654+ numels = [sympy .S . One ]
16591655 for dim in dims [:0 :- 1 ]:
16601656 numel = dim * numels [0 ]
16611657 numels .insert (0 , numel )
@@ -1680,10 +1676,10 @@ def get_slice_numels(dims: List[Any]) -> List[Any]:
16801676 # Provide default values for unmatched dims and strides.
16811677 for dim in dims [1 :]:
16821678 if dim not in match :
1683- match [dim ] = sympy .Integer ( 1 )
1679+ match [dim ] = sympy .S . One
16841680 for stride in strides [1 :]:
16851681 if stride not in match :
1686- match [stride ] = sympy .Integer ( 0 )
1682+ match [stride ] = sympy .S . Zero
16871683
16881684 sizevars = V .graph .sizevars
16891685
@@ -1786,7 +1782,7 @@ def match_block_pointer() -> Optional[BlockPtrOptions]:
17861782 # For example xindex * 5 + rindex * 3 is partitioned to
17871783 # (xindex * 5, rindex * 3).
17881784 symbol = tree .symbol ()
1789- subexpr = sympy .Integer ( 0 ) + sum (
1785+ subexpr = sympy .S . Zero + sum (
17901786 expr for expr in index_terms if symbol in expr .free_symbols
17911787 )
17921788
0 commit comments