Skip to content

Commit 91ded05

Browse files
ezyangpytorchmergebot
authored andcommitted
Add sym_log2 (pytorch#137980)
Internal xref: https://fb.workplace.com/groups/1075192433118967/permalink/1515595595745313/ Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#137980 Approved by: https://github.com/bobrenjc93
1 parent 006130d commit 91ded05

File tree

7 files changed

+94
-46
lines changed

7 files changed

+94
-46
lines changed

test/test_dynamic_shapes.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,14 +499,24 @@ def test_sym_int(self):
499499
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)"""
500500
)
501501

502+
def test_sym_log2(self):
503+
shape_env = ShapeEnv()
504+
a0 = create_symint(shape_env, 4)
505+
r = torch._sym_log2(a0)
506+
self.assertEqual(r, 2.0)
507+
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
508+
self.assertExpectedInline(
509+
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s0)), 2.0)"""
510+
)
511+
502512
def test_sym_sqrt(self):
503513
shape_env = ShapeEnv()
504514
a0 = create_symint(shape_env, 4)
505515
r = torch._sym_sqrt(a0)
506516
self.assertEqual(r, 2)
507517
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
508518
self.assertExpectedInline(
509-
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)"""
519+
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s0)), 2.0)"""
510520
)
511521

512522
def test_sym_floor(self):
@@ -540,7 +550,8 @@ def test_sym_trunc(self):
540550
self.assertEqual(r, 2)
541551
self.assertIsInstance(r, torch.SymInt, msg=type(r))
542552
self.assertExpectedInline(
543-
str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)"""
553+
str(shape_env.guards[1][0]),
554+
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s0))), 2)""",
544555
)
545556

546557
def test_sym_ceil(self):

torch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,8 @@ def _get_sym_math_fn(name):
876876
def fn(a):
877877
if overrides.has_torch_function_unary(a):
878878
return overrides.handle_torch_function(fn, (a,), a)
879+
if isinstance(a, SymInt):
880+
a = torch.sym_float(a)
879881
if hasattr(a, f"__sym_{name}__"):
880882
return getattr(a, f"__sym_{name}__")()
881883
return getattr(math, name)(a)
@@ -895,6 +897,7 @@ def fn(a):
895897
"asin",
896898
"acos",
897899
"atan",
900+
"log2",
898901
):
899902
__sym_name = f"_sym_{__name}"
900903
__fn = _get_sym_math_fn(__name)

torch/fx/experimental/sym_node.py

Lines changed: 49 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# mypy: allow-untyped-defs
2+
3+
from __future__ import annotations
4+
5+
26
"""
37
This file does three things:
48
- Contains the definition of SymNode
@@ -145,12 +149,12 @@ def compute_hint():
145149
)
146150
self.fx_node = tx_validation_en and fx_node
147151

148-
def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode":
152+
def with_shape_env(self, shape_env: ShapeEnv) -> SymNode:
149153
return SymNode(
150154
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
151155
)
152156

153-
def _value_eq(self, other: "SymNode") -> bool:
157+
def _value_eq(self, other: SymNode) -> bool:
154158
# Purposely don't include the shape_env in the eq.
155159
return (
156160
self._expr == other._expr
@@ -281,121 +285,121 @@ def _graph_repr(self) -> builtins.str:
281285

282286
# These methods call the metaprogrammed methods, they're hand written
283287
# here so we get good stack traces
284-
def abs(self) -> "SymNode":
288+
def abs(self) -> SymNode:
285289
return self._abs() # type: ignore[attr-defined]
286290

287-
def pos(self) -> "SymNode":
291+
def pos(self) -> SymNode:
288292
return self._pos() # type: ignore[attr-defined]
289293

290-
def round(self, ndigits=None) -> "SymNode":
294+
def round(self, ndigits=None) -> SymNode:
291295
return self._round(ndigits) # type: ignore[attr-defined]
292296

293-
def trunc(self) -> "SymNode":
297+
def trunc(self) -> SymNode:
294298
return self._trunc() # type: ignore[attr-defined]
295299

296-
def add(self, other) -> "SymNode":
300+
def add(self, other) -> SymNode:
297301
return self._add(other) # type: ignore[attr-defined]
298302

299-
def sub(self, other) -> "SymNode":
303+
def sub(self, other) -> SymNode:
300304
return self._sub(other) # type: ignore[attr-defined]
301305

302-
def mul(self, other) -> "SymNode":
306+
def mul(self, other) -> SymNode:
303307
return self._mul(other) # type: ignore[attr-defined]
304308

305-
def mod(self, other) -> "SymNode":
309+
def mod(self, other) -> SymNode:
306310
return self._mod(other) # type: ignore[attr-defined]
307311

308-
def float_pow(self, other) -> "SymNode":
312+
def float_pow(self, other) -> SymNode:
309313
return self._float_pow(other) # type: ignore[attr-defined]
310314

311-
def pow_by_natural(self, other) -> "SymNode":
315+
def pow_by_natural(self, other) -> SymNode:
312316
return self._pow_by_natural(other) # type: ignore[attr-defined]
313317

314-
def and_(self, other) -> "SymNode":
318+
def and_(self, other) -> SymNode:
315319
return self._and_(other) # type: ignore[attr-defined]
316320

317-
def or_(self, other) -> "SymNode":
321+
def or_(self, other) -> SymNode:
318322
return self._or_(other) # type: ignore[attr-defined]
319323

320-
def float_truediv(self, other) -> "SymNode":
324+
def float_truediv(self, other) -> SymNode:
321325
return self._float_truediv(other) # type: ignore[attr-defined]
322326

323-
def int_truediv(self, other) -> "SymNode":
327+
def int_truediv(self, other) -> SymNode:
324328
return self._int_truediv(other) # type: ignore[attr-defined]
325329

326-
def int_floordiv(self, other) -> "SymNode":
330+
def int_floordiv(self, other) -> SymNode:
327331
return self._int_floordiv(other) # type: ignore[attr-defined]
328332

329-
def lshift(self, other) -> "SymNode":
333+
def lshift(self, other) -> SymNode:
330334
return self._lshift(other) # type: ignore[attr-defined]
331335

332-
def rshift(self, other) -> "SymNode":
336+
def rshift(self, other) -> SymNode:
333337
return self._rshift(other) # type: ignore[attr-defined]
334338

335-
def sym_not(self) -> "SymNode": # noqa: F811
339+
def sym_not(self) -> SymNode: # noqa: F811
336340
return self._sym_not() # type: ignore[attr-defined]
337341

338-
def eq(self, other) -> "SymNode":
342+
def eq(self, other) -> SymNode:
339343
return self._eq(other) # type: ignore[attr-defined]
340344

341-
def ne(self, other) -> "SymNode":
345+
def ne(self, other) -> SymNode:
342346
return self._ne(other) # type: ignore[attr-defined]
343347

344-
def gt(self, other) -> "SymNode":
348+
def gt(self, other) -> SymNode:
345349
return self._gt(other) # type: ignore[attr-defined]
346350

347-
def lt(self, other) -> "SymNode":
351+
def lt(self, other) -> SymNode:
348352
return self._lt(other) # type: ignore[attr-defined]
349353

350-
def le(self, other) -> "SymNode":
354+
def le(self, other) -> SymNode:
351355
return self._le(other) # type: ignore[attr-defined]
352356

353-
def ge(self, other) -> "SymNode":
357+
def ge(self, other) -> SymNode:
354358
return self._ge(other) # type: ignore[attr-defined]
355359

356-
def floor(self) -> "SymNode":
360+
def floor(self) -> SymNode:
357361
return self._floor() # type: ignore[attr-defined]
358362

359-
def is_integer(self) -> "SymNode":
363+
def is_integer(self) -> SymNode:
360364
return self._is_integer() # type: ignore[attr-defined]
361365

362-
def sym_float(self) -> "SymNode": # noqa: F811
366+
def sym_float(self) -> SymNode: # noqa: F811
363367
return self._sym_float() # type: ignore[attr-defined]
364368

365-
def sym_int(self) -> "SymNode":
369+
def sym_int(self) -> SymNode:
366370
return self._sym_int() # type: ignore[attr-defined]
367371

368-
def ceil(self) -> "SymNode":
372+
def ceil(self) -> SymNode:
369373
return self._ceil() # type: ignore[attr-defined]
370374

371-
def neg(self) -> "SymNode":
375+
def neg(self) -> SymNode:
372376
return self._neg() # type: ignore[attr-defined]
373377

374-
def sym_min(self, other) -> "SymNode": # noqa: F811
378+
def sym_min(self, other) -> SymNode: # noqa: F811
375379
return self._sym_min(other) # type: ignore[attr-defined]
376380

377-
def sym_max(self, other) -> "SymNode": # noqa: F811
381+
def sym_max(self, other) -> SymNode: # noqa: F811
378382
return self._sym_max(other) # type: ignore[attr-defined]
379383

380-
def sym_ite(self, then_val, else_val) -> "SymNode":
384+
def sym_ite(self, then_val, else_val) -> SymNode:
381385
return self._sym_ite(then_val, else_val) # type: ignore[attr-defined]
382386

383-
def is_contiguous(self, sizes, strides) -> "SymNode":
387+
def is_contiguous(self, sizes, strides) -> SymNode:
384388
return self._is_contiguous(sizes, strides) # type: ignore[attr-defined]
385389

386-
def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode":
390+
def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode:
387391
return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined]
388392

389-
def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode":
393+
def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode:
390394
return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined]
391395

392-
def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode":
396+
def is_channels_last_strides_2d(self, sizes, strides) -> SymNode:
393397
return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined]
394398

395-
def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode":
399+
def is_channels_last_strides_3d(self, sizes, strides) -> SymNode:
396400
return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined]
397401

398-
def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode":
402+
def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode:
399403
return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined]
400404

401405
# Make C++ happy
@@ -409,7 +413,7 @@ def sym_and(self, other):
409413
def truediv(self, other):
410414
return self.float_truediv(other)
411415

412-
def floordiv(self, other) -> "SymNode":
416+
def floordiv(self, other) -> SymNode:
413417
return self.int_floordiv(other)
414418

415419
# We didn't bind integer pow in C++
@@ -426,7 +430,7 @@ def int_(self):
426430
# functions consider factoring it out to be metaprogrammed too. Note that
427431
# some load bearing logic is directly in torch.sym_sum
428432

429-
def sym_sum(self, args) -> "SymNode":
433+
def sym_sum(self, args) -> SymNode:
430434
import sympy
431435

432436
# Inner impl
@@ -629,6 +633,7 @@ def fn(self):
629633
"asin",
630634
"acos",
631635
"atan",
636+
"log2",
632637
)
633638
for name in math_op_names:
634639
sym_name = f"sym_{name}"
@@ -656,7 +661,7 @@ def fn(self):
656661
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
657662

658663
# Methods that are only for float
659-
only_float_magic_methods = {"is_integer", "round", "sym_int"}
664+
only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"}
660665

661666

662667
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}

torch/utils/_sympy/functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,6 +1197,8 @@ def eval(cls, a):
11971197
a = sympy.oo
11981198
if a is -int_oo:
11991199
a = -sympy.oo
1200+
if name == "log2":
1201+
return sympy.log(a, 2)
12001202
return getattr(sympy, name)(a)
12011203
return None
12021204

@@ -1221,3 +1223,4 @@ def eval(cls, a):
12211223
OpaqueUnaryFn_exp = make_opaque_unary_fn("exp")
12221224
OpaqueUnaryFn_log = make_opaque_unary_fn("log")
12231225
OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh")
1226+
OpaqueUnaryFn_log2 = make_opaque_unary_fn("log2")

torch/utils/_sympy/interp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Min,
3232
Mod,
3333
ModularIndexing,
34+
OpaqueUnaryFn_log2,
3435
PowByNatural,
3536
PythonMod,
3637
RoundDecimal,
@@ -101,7 +102,11 @@ def handlers():
101102
Identity: "identity",
102103
IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
103104
RoundDecimal: "round_decimal",
105+
# TODO: do the rest of the opaque unary functions...
106+
OpaqueUnaryFn_log2: "log2",
104107
}
108+
# TODO: This is kind of pointless, we shouldn't be generating sympy.sin
109+
# for these functions, they should be Opaque instead
105110
for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
106111
HANDLERS[getattr(sympy, name)] = name
107112

torch/utils/_sympy/reference.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Mod,
1818
OpaqueUnaryFn_exp,
1919
OpaqueUnaryFn_log,
20+
OpaqueUnaryFn_log2,
2021
OpaqueUnaryFn_sqrt,
2122
PowByNatural,
2223
RoundDecimal,
@@ -162,6 +163,10 @@ def exp(x):
162163
def log(x):
163164
return OpaqueUnaryFn_log(x)
164165

166+
@staticmethod
167+
def log2(x):
168+
return OpaqueUnaryFn_log2(x)
169+
165170
@staticmethod
166171
def sqrt(x):
167172
return OpaqueUnaryFn_sqrt(x)
@@ -247,6 +252,10 @@ def exp(x):
247252
def log(x):
248253
raise AssertionError("log is not valid shape sympy expr")
249254

255+
@staticmethod
256+
def log2(x):
257+
return torch._sym_log2(x) # type: ignore[attr-defined]
258+
250259
@staticmethod
251260
def sqrt(x):
252261
return torch._sym_sqrt(x) # type: ignore[attr-defined]
@@ -472,6 +481,10 @@ def exp(x):
472481
def log(x):
473482
return torch.ops.aten.log.default(x)
474483

484+
@staticmethod
485+
def log2(x):
486+
return torch.ops.aten.log2.default(x)
487+
475488
@staticmethod
476489
def sqrt(x):
477490
return torch.ops.aten.sqrt.default(x)

torch/utils/_sympy/value_ranges.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
IntTrueDiv,
3535
OpaqueUnaryFn_exp,
3636
OpaqueUnaryFn_log,
37+
OpaqueUnaryFn_log2,
3738
OpaqueUnaryFn_sqrt,
3839
PowByNatural,
3940
RoundDecimal,
@@ -760,6 +761,13 @@ def log(x):
760761
return ValueRanges.unknown()
761762
return ValueRanges.increasing_map(x, OpaqueUnaryFn_log)
762763

764+
@staticmethod
765+
def log2(x):
766+
x = ValueRanges.wrap(x)
767+
if x.lower <= 0:
768+
return ValueRanges.unknown()
769+
return ValueRanges.increasing_map(x, OpaqueUnaryFn_log2)
770+
763771
@classmethod
764772
def minimum(cls, a, b):
765773
return cls.min_or_max(a, b, sympy.Min)

0 commit comments

Comments
 (0)