Skip to content

Commit 8274dad

Browse files
ezyangpytorchmergebot
authored andcommitted
Make OpaqueUnaryFn pickleable (pytorch#138395)
Fixes pytorch#138070 Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#138395 Approved by: https://github.com/XuehaiPan, https://github.com/bobrenjc93
1 parent 4d9b5a8 commit 8274dad

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

test/test_sympy_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
import sys
77
from typing import Callable, List, Tuple, Type
8+
import pickle
89

910
import sympy
1011

@@ -30,6 +31,7 @@
3031
from torch.utils._sympy.singleton_int import SingletonInt
3132
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
3233
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
34+
from torch.utils._sympy.functions import OpaqueUnaryFn_cos
3335

3436

3537
UNARY_OPS = [
@@ -811,6 +813,13 @@ def test_simple_floordiv_gcd(self):
811813
self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1)
812814

813815

816+
class TestSympyFunctions(TestCase):
817+
def test_pickle(self):
818+
x = OpaqueUnaryFn_cos(sympy.Symbol('a'))
819+
r = pickle.loads(pickle.dumps(x))
820+
self.assertEqual(x, r)
821+
822+
814823
class TestSingletonInt(TestCase):
815824
def test_basic(self):
816825
j1 = SingletonInt(1, coeff=1)

torch/utils/_sympy/functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,9 @@ def eval(cls, a):
12021202
return getattr(sympy, name)(a)
12031203
return None
12041204

1205-
OpaqueUnaryFn.__name__ = "OpaqueUnaryFn_" + name
1205+
nm = "OpaqueUnaryFn_" + name
1206+
OpaqueUnaryFn.__name__ = nm
1207+
OpaqueUnaryFn.__qualname__ = nm
12061208

12071209
return OpaqueUnaryFn
12081210

0 commit comments

Comments
 (0)