Skip to content

Commit 2487a83

Browse files
Revert "Add sym_log2 (pytorch#137980)"
This reverts commit 5d450d7. Reverted pytorch#137980 on behalf of https://github.com/jeanschmidt due to lint broke from this onwards on main ([comment](pytorch#137980 (comment)))
1 parent 8274dad commit 2487a83

File tree

7 files changed

+45
-93
lines changed

7 files changed

+45
-93
lines changed

test/test_dynamic_shapes.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -499,24 +499,14 @@ def test_sym_int(self):
499499
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)"""
500500
)
501501

502-
def test_sym_log2(self):
503-
shape_env = ShapeEnv()
504-
a0 = create_symint(shape_env, 4)
505-
r = torch._sym_log2(a0)
506-
self.assertEqual(r, 2.0)
507-
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
508-
self.assertExpectedInline(
509-
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s0)), 2.0)"""
510-
)
511-
512502
def test_sym_sqrt(self):
513503
shape_env = ShapeEnv()
514504
a0 = create_symint(shape_env, 4)
515505
r = torch._sym_sqrt(a0)
516506
self.assertEqual(r, 2)
517507
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
518508
self.assertExpectedInline(
519-
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s0)), 2.0)"""
509+
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)"""
520510
)
521511

522512
def test_sym_floor(self):
@@ -550,8 +540,7 @@ def test_sym_trunc(self):
550540
self.assertEqual(r, 2)
551541
self.assertIsInstance(r, torch.SymInt, msg=type(r))
552542
self.assertExpectedInline(
553-
str(shape_env.guards[1][0]),
554-
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s0))), 2)""",
543+
str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)"""
555544
)
556545

557546
def test_sym_ceil(self):

torch/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -876,8 +876,6 @@ def _get_sym_math_fn(name):
876876
def fn(a):
877877
if overrides.has_torch_function_unary(a):
878878
return overrides.handle_torch_function(fn, (a,), a)
879-
if isinstance(a, SymInt):
880-
a = torch.sym_float(a)
881879
if hasattr(a, f"__sym_{name}__"):
882880
return getattr(a, f"__sym_{name}__")()
883881
return getattr(math, name)(a)
@@ -897,7 +895,6 @@ def fn(a):
897895
"asin",
898896
"acos",
899897
"atan",
900-
"log2",
901898
):
902899
__sym_name = f"_sym_{__name}"
903900
__fn = _get_sym_math_fn(__name)

torch/fx/experimental/sym_node.py

Lines changed: 43 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
11
# mypy: allow-untyped-defs
2-
3-
from __future__ import annotations
4-
5-
62
"""
73
This file does three things:
84
- Contains the definition of SymNode
@@ -149,12 +145,12 @@ def compute_hint():
149145
)
150146
self.fx_node = tx_validation_en and fx_node
151147

152-
def with_shape_env(self, shape_env: ShapeEnv) -> SymNode:
148+
def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode":
153149
return SymNode(
154150
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
155151
)
156152

157-
def _value_eq(self, other: SymNode) -> bool:
153+
def _value_eq(self, other: "SymNode") -> bool:
158154
# Purposely don't include the shape_env in the eq.
159155
return (
160156
self._expr == other._expr
@@ -285,121 +281,121 @@ def _graph_repr(self) -> builtins.str:
285281

286282
# These methods call the metaprogrammed methods, they're hand written
287283
# here so we get good stack traces
288-
def abs(self) -> SymNode:
284+
def abs(self) -> "SymNode":
289285
return self._abs() # type: ignore[attr-defined]
290286

291-
def pos(self) -> SymNode:
287+
def pos(self) -> "SymNode":
292288
return self._pos() # type: ignore[attr-defined]
293289

294-
def round(self, ndigits=None) -> SymNode:
290+
def round(self, ndigits=None) -> "SymNode":
295291
return self._round(ndigits) # type: ignore[attr-defined]
296292

297-
def trunc(self) -> SymNode:
293+
def trunc(self) -> "SymNode":
298294
return self._trunc() # type: ignore[attr-defined]
299295

300-
def add(self, other) -> SymNode:
296+
def add(self, other) -> "SymNode":
301297
return self._add(other) # type: ignore[attr-defined]
302298

303-
def sub(self, other) -> SymNode:
299+
def sub(self, other) -> "SymNode":
304300
return self._sub(other) # type: ignore[attr-defined]
305301

306-
def mul(self, other) -> SymNode:
302+
def mul(self, other) -> "SymNode":
307303
return self._mul(other) # type: ignore[attr-defined]
308304

309-
def mod(self, other) -> SymNode:
305+
def mod(self, other) -> "SymNode":
310306
return self._mod(other) # type: ignore[attr-defined]
311307

312-
def float_pow(self, other) -> SymNode:
308+
def float_pow(self, other) -> "SymNode":
313309
return self._float_pow(other) # type: ignore[attr-defined]
314310

315-
def pow_by_natural(self, other) -> SymNode:
311+
def pow_by_natural(self, other) -> "SymNode":
316312
return self._pow_by_natural(other) # type: ignore[attr-defined]
317313

318-
def and_(self, other) -> SymNode:
314+
def and_(self, other) -> "SymNode":
319315
return self._and_(other) # type: ignore[attr-defined]
320316

321-
def or_(self, other) -> SymNode:
317+
def or_(self, other) -> "SymNode":
322318
return self._or_(other) # type: ignore[attr-defined]
323319

324-
def float_truediv(self, other) -> SymNode:
320+
def float_truediv(self, other) -> "SymNode":
325321
return self._float_truediv(other) # type: ignore[attr-defined]
326322

327-
def int_truediv(self, other) -> SymNode:
323+
def int_truediv(self, other) -> "SymNode":
328324
return self._int_truediv(other) # type: ignore[attr-defined]
329325

330-
def int_floordiv(self, other) -> SymNode:
326+
def int_floordiv(self, other) -> "SymNode":
331327
return self._int_floordiv(other) # type: ignore[attr-defined]
332328

333-
def lshift(self, other) -> SymNode:
329+
def lshift(self, other) -> "SymNode":
334330
return self._lshift(other) # type: ignore[attr-defined]
335331

336-
def rshift(self, other) -> SymNode:
332+
def rshift(self, other) -> "SymNode":
337333
return self._rshift(other) # type: ignore[attr-defined]
338334

339-
def sym_not(self) -> SymNode: # noqa: F811
335+
def sym_not(self) -> "SymNode": # noqa: F811
340336
return self._sym_not() # type: ignore[attr-defined]
341337

342-
def eq(self, other) -> SymNode:
338+
def eq(self, other) -> "SymNode":
343339
return self._eq(other) # type: ignore[attr-defined]
344340

345-
def ne(self, other) -> SymNode:
341+
def ne(self, other) -> "SymNode":
346342
return self._ne(other) # type: ignore[attr-defined]
347343

348-
def gt(self, other) -> SymNode:
344+
def gt(self, other) -> "SymNode":
349345
return self._gt(other) # type: ignore[attr-defined]
350346

351-
def lt(self, other) -> SymNode:
347+
def lt(self, other) -> "SymNode":
352348
return self._lt(other) # type: ignore[attr-defined]
353349

354-
def le(self, other) -> SymNode:
350+
def le(self, other) -> "SymNode":
355351
return self._le(other) # type: ignore[attr-defined]
356352

357-
def ge(self, other) -> SymNode:
353+
def ge(self, other) -> "SymNode":
358354
return self._ge(other) # type: ignore[attr-defined]
359355

360-
def floor(self) -> SymNode:
356+
def floor(self) -> "SymNode":
361357
return self._floor() # type: ignore[attr-defined]
362358

363-
def is_integer(self) -> SymNode:
359+
def is_integer(self) -> "SymNode":
364360
return self._is_integer() # type: ignore[attr-defined]
365361

366-
def sym_float(self) -> SymNode: # noqa: F811
362+
def sym_float(self) -> "SymNode": # noqa: F811
367363
return self._sym_float() # type: ignore[attr-defined]
368364

369-
def sym_int(self) -> SymNode:
365+
def sym_int(self) -> "SymNode":
370366
return self._sym_int() # type: ignore[attr-defined]
371367

372-
def ceil(self) -> SymNode:
368+
def ceil(self) -> "SymNode":
373369
return self._ceil() # type: ignore[attr-defined]
374370

375-
def neg(self) -> SymNode:
371+
def neg(self) -> "SymNode":
376372
return self._neg() # type: ignore[attr-defined]
377373

378-
def sym_min(self, other) -> SymNode: # noqa: F811
374+
def sym_min(self, other) -> "SymNode": # noqa: F811
379375
return self._sym_min(other) # type: ignore[attr-defined]
380376

381-
def sym_max(self, other) -> SymNode: # noqa: F811
377+
def sym_max(self, other) -> "SymNode": # noqa: F811
382378
return self._sym_max(other) # type: ignore[attr-defined]
383379

384-
def sym_ite(self, then_val, else_val) -> SymNode:
380+
def sym_ite(self, then_val, else_val) -> "SymNode":
385381
return self._sym_ite(then_val, else_val) # type: ignore[attr-defined]
386382

387-
def is_contiguous(self, sizes, strides) -> SymNode:
383+
def is_contiguous(self, sizes, strides) -> "SymNode":
388384
return self._is_contiguous(sizes, strides) # type: ignore[attr-defined]
389385

390-
def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode:
386+
def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode":
391387
return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined]
392388

393-
def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode:
389+
def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode":
394390
return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined]
395391

396-
def is_channels_last_strides_2d(self, sizes, strides) -> SymNode:
392+
def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode":
397393
return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined]
398394

399-
def is_channels_last_strides_3d(self, sizes, strides) -> SymNode:
395+
def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode":
400396
return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined]
401397

402-
def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode:
398+
def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode":
403399
return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined]
404400

405401
# Make C++ happy
@@ -413,7 +409,7 @@ def sym_and(self, other):
413409
def truediv(self, other):
414410
return self.float_truediv(other)
415411

416-
def floordiv(self, other) -> SymNode:
412+
def floordiv(self, other) -> "SymNode":
417413
return self.int_floordiv(other)
418414

419415
# We didn't bind integer pow in C++
@@ -633,7 +629,6 @@ def fn(self):
633629
"asin",
634630
"acos",
635631
"atan",
636-
"log2",
637632
)
638633
for name in math_op_names:
639634
sym_name = f"sym_{name}"
@@ -661,7 +656,7 @@ def fn(self):
661656
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
662657

663658
# Methods that are only for float
664-
only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"}
659+
only_float_magic_methods = {"is_integer", "round", "sym_int"}
665660

666661

667662
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}

torch/utils/_sympy/functions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,8 +1197,6 @@ def eval(cls, a):
11971197
a = sympy.oo
11981198
if a is -int_oo:
11991199
a = -sympy.oo
1200-
if name == "log2":
1201-
return sympy.log(a, 2)
12021200
return getattr(sympy, name)(a)
12031201
return None
12041202

@@ -1223,4 +1221,3 @@ def eval(cls, a):
12231221
OpaqueUnaryFn_exp = make_opaque_unary_fn("exp")
12241222
OpaqueUnaryFn_log = make_opaque_unary_fn("log")
12251223
OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh")
1226-
OpaqueUnaryFn_log2 = make_opaque_unary_fn("log2")

torch/utils/_sympy/interp.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
Min,
3232
Mod,
3333
ModularIndexing,
34-
OpaqueUnaryFn_log2,
3534
PowByNatural,
3635
PythonMod,
3736
RoundDecimal,
@@ -102,11 +101,7 @@ def handlers():
102101
Identity: "identity",
103102
IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
104103
RoundDecimal: "round_decimal",
105-
# TODO: do the rest of the opaque unary functions...
106-
OpaqueUnaryFn_log2: "log2",
107104
}
108-
# TODO: This is kind of pointless, we shouldn't be generating sympy.sin
109-
# for these functions, they should be Opaque instead
110105
for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
111106
HANDLERS[getattr(sympy, name)] = name
112107

torch/utils/_sympy/reference.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
Mod,
1818
OpaqueUnaryFn_exp,
1919
OpaqueUnaryFn_log,
20-
OpaqueUnaryFn_log2,
2120
OpaqueUnaryFn_sqrt,
2221
PowByNatural,
2322
RoundDecimal,
@@ -163,10 +162,6 @@ def exp(x):
163162
def log(x):
164163
return OpaqueUnaryFn_log(x)
165164

166-
@staticmethod
167-
def log2(x):
168-
return OpaqueUnaryFn_log2(x)
169-
170165
@staticmethod
171166
def sqrt(x):
172167
return OpaqueUnaryFn_sqrt(x)
@@ -252,10 +247,6 @@ def exp(x):
252247
def log(x):
253248
raise AssertionError("log is not valid shape sympy expr")
254249

255-
@staticmethod
256-
def log2(x):
257-
return torch._sym_log2(x) # type: ignore[attr-defined]
258-
259250
@staticmethod
260251
def sqrt(x):
261252
return torch._sym_sqrt(x) # type: ignore[attr-defined]
@@ -481,10 +472,6 @@ def exp(x):
481472
def log(x):
482473
return torch.ops.aten.log.default(x)
483474

484-
@staticmethod
485-
def log2(x):
486-
return torch.ops.aten.log2.default(x)
487-
488475
@staticmethod
489476
def sqrt(x):
490477
return torch.ops.aten.sqrt.default(x)

torch/utils/_sympy/value_ranges.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
IntTrueDiv,
3535
OpaqueUnaryFn_exp,
3636
OpaqueUnaryFn_log,
37-
OpaqueUnaryFn_log2,
3837
OpaqueUnaryFn_sqrt,
3938
PowByNatural,
4039
RoundDecimal,
@@ -761,13 +760,6 @@ def log(x):
761760
return ValueRanges.unknown()
762761
return ValueRanges.increasing_map(x, OpaqueUnaryFn_log)
763762

764-
@staticmethod
765-
def log2(x):
766-
x = ValueRanges.wrap(x)
767-
if x.lower <= 0:
768-
return ValueRanges.unknown()
769-
return ValueRanges.increasing_map(x, OpaqueUnaryFn_log2)
770-
771763
@classmethod
772764
def minimum(cls, a, b):
773765
return cls.min_or_max(a, b, sympy.Min)

0 commit comments

Comments
 (0)