Skip to content

Commit 46f8df7

Browse files
authored
Extending types for symbol resolution fast pass-through (#3366)
Extends ParamResolver's logic to circumvent sympy's slowness to members of numbers.Number. It also generalizes sympy constants instead of only handling pi and NegativeOne. Fixes #3359.
1 parent fa80d4c commit 46f8df7

File tree

2 files changed

+124
-25
lines changed

2 files changed

+124
-25
lines changed

cirq/study/resolver.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Resolves ParameterValues to assigned values."""
16-
16+
import numbers
1717
from typing import Any, Dict, Iterator, Optional, TYPE_CHECKING, Union, cast
1818
import numpy as np
1919
import sympy
@@ -89,18 +89,21 @@ def value_of(self,
8989
Returns:
9090
The value of the parameter as resolved by this resolver.
9191
"""
92-
# Input is a float, no resolution needed: return early
93-
if isinstance(value, float):
94-
return value
92+
93+
# Input is a pass through type, no resolution needed: return early
94+
v = _sympy_pass_through(value)
95+
if v is not None:
96+
return v
9597

9698
# Handles 2 cases:
9799
# Input is a string and maps to a number in the dictionary
98100
# Input is a symbol and maps to a number in the dictionary
99101
# In both cases, return it directly.
100102
if value in self.param_dict:
101103
param_value = self.param_dict[value]
102-
if isinstance(param_value, (float, int)):
103-
return param_value
104+
v = _sympy_pass_through(param_value)
105+
if v is not None:
106+
return v
104107

105108
# Input is a string and is not in the dictionary.
106109
# Treat it as a symbol instead.
@@ -111,10 +114,11 @@ def value_of(self,
111114

112115
# Input is a symbol (sympy.Symbol('a')) and its string maps to a number
113116
# in the dictionary ({'a': 1.0}). Return it.
114-
if (isinstance(value, sympy.Symbol) and value.name in self.param_dict):
117+
if isinstance(value, sympy.Symbol) and value.name in self.param_dict:
115118
param_value = self.param_dict[value.name]
116-
if isinstance(param_value, (float, int)):
117-
return param_value
119+
v = _sympy_pass_through(param_value)
120+
if v is not None:
121+
return v
118122

119123
# The following resolves common sympy expressions
120124
# If sympy did its job and wasn't slower than molasses,
@@ -132,10 +136,6 @@ def value_of(self,
132136
if isinstance(value, sympy.Pow) and len(value.args) == 2:
133137
return np.power(self.value_of(value.args[0]),
134138
self.value_of(value.args[1]))
135-
if value == sympy.pi:
136-
return np.pi
137-
if value == sympy.S.NegativeOne:
138-
return -1
139139

140140
# Input is either a sympy formula or the dictionary maps to a
141141
# formula. Use sympy to resolve the value.
@@ -193,3 +193,15 @@ def _json_dict_(self) -> Dict[str, Any]:
193193
@classmethod
194194
def _from_json_dict_(cls, param_dict, **kwargs):
195195
return cls(dict(param_dict))
196+
197+
198+
def _sympy_pass_through(val: Any) -> Optional[Any]:
199+
if isinstance(val, numbers.Number) and not isinstance(val, sympy.Basic):
200+
return val
201+
if isinstance(val, sympy.core.numbers.IntegerConstant):
202+
return val.p
203+
if isinstance(val, sympy.core.numbers.RationalConstant):
204+
return val.p / val.q
205+
if val == sympy.pi:
206+
return np.pi
207+
return None

cirq/study/resolver_test.py

Lines changed: 99 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,120 @@
1313
# limitations under the License.
1414

1515
"""Tests for parameter resolvers."""
16+
import fractions
1617

1718
import numpy as np
19+
import pytest
1820
import sympy
1921

2022
import cirq
2123

2224

23-
def test_value_of():
25+
@pytest.mark.parametrize('val', [
26+
3.2,
27+
np.float32(3.2),
28+
int(1),
29+
np.int(3),
30+
np.int32(45),
31+
np.float64(6.3),
32+
np.int32(2),
33+
np.complex64(1j),
34+
np.complex128(2j),
35+
np.complex(1j),
36+
fractions.Fraction(3, 2),
37+
])
38+
def test_value_of_pass_through_types(val):
39+
_assert_consistent_resolution(val, val)
40+
41+
42+
@pytest.mark.parametrize('val,resolved', [(sympy.pi, np.pi),
43+
(sympy.S.NegativeOne, -1),
44+
(sympy.S.Half, 0.5),
45+
(sympy.S.One, 1)])
46+
def test_value_of_transformed_types(val, resolved):
47+
_assert_consistent_resolution(val, resolved)
48+
49+
50+
@pytest.mark.parametrize('val,resolved', [(sympy.I, 1j)])
51+
def test_value_of_substituted_types(val, resolved):
52+
_assert_consistent_resolution(val, resolved, True)
53+
54+
55+
def _assert_consistent_resolution(v, resolved, subs_called=False):
56+
"""Asserts that parameter resolution works consistently.
57+
58+
The ParamResolver.value_of method can resolve any Sympy expression -
59+
subclasses of sympy.Basic. In the generic case, it calls `sympy.Basic.subs`
60+
to substitute symbols with values specified in a dict, which is known to be
61+
very slow. Instead value_of defines a pass-through shortcut for known
62+
numeric types. For a given value `v` it is asserted that value_of resolves
63+
it to `resolved`, with the exact type of `resolved`.`subs_called` indicates
64+
whether it is expected to have `subs` called or not during the resolution.
65+
66+
Args:
67+
v: the value to resolve
68+
resolved: the expected resolution result
69+
subs_called: if True, it is expected that the slow subs method is called
70+
Raises:
71+
AssertionError in case resolution assertion fail.
72+
"""
73+
74+
class SubsAwareSymbol(sympy.Symbol):
75+
"""A Symbol that registers a call to its `subs` method."""
76+
77+
def __init__(self, sym: str):
78+
self.called = False
79+
self.symbol = sympy.Symbol(sym)
80+
81+
# note: super().subs() doesn't resolve based on the param_dict properly
82+
# for some reason, that's why a delegate (self.symbol) is used instead
83+
def subs(self, *args, **kwargs):
84+
self.called = True
85+
return self.symbol.subs(*args, **kwargs)
86+
87+
r = cirq.ParamResolver({'a': v})
88+
89+
# symbol based resolution
90+
s = SubsAwareSymbol('a')
91+
assert r.value_of(s) == resolved, (f"expected {resolved}, "
92+
f"got {r.value_of(s)}")
93+
assert subs_called == s.called, (
94+
f"For pass-through type "
95+
f"{type(v)} sympy.subs shouldn't have been called.")
96+
assert isinstance(r.value_of(s),
97+
type(resolved)), (f"expected {type(resolved)} "
98+
f"got {type(r.value_of(s))}")
99+
100+
# string based resolution (which in turn uses symbol based resolution)
101+
assert r.value_of('a') == resolved, (f"expected {resolved}, "
102+
f"got {r.value_of('a')}")
103+
assert isinstance(r.value_of('a'),
104+
type(resolved)), (f"expected {type(resolved)} "
105+
f"got {type(r.value_of('a'))}")
106+
107+
# value based resolution
108+
assert r.value_of(v) == resolved, (f"expected {resolved}, "
109+
f"got {r.value_of(v)}")
110+
assert isinstance(r.value_of(v),
111+
type(resolved)), (f"expected {type(resolved)} "
112+
f"got {type(r.value_of(v))}")
113+
114+
115+
def test_value_of_strings():
116+
assert cirq.ParamResolver().value_of('x') == sympy.Symbol('x')
117+
118+
119+
def test_value_of_calculations():
24120
assert not bool(cirq.ParamResolver())
25121

26122
r = cirq.ParamResolver({'a': 0.5, 'b': 0.1, 'c': 1 + 1j})
27123
assert bool(r)
28124

29-
assert r.value_of('x') == sympy.Symbol('x')
30-
assert r.value_of('a') == 0.5
31-
assert r.value_of(sympy.Symbol('a')) == 0.5
32-
assert r.value_of(0.5) == 0.5
33-
assert r.value_of(sympy.Symbol('b')) == 0.1
34-
assert r.value_of(0.3) == 0.3
35-
assert r.value_of(sympy.Symbol('a') * 3) == 1.5
36-
assert r.value_of(sympy.Symbol('b') / 0.1 - sympy.Symbol('a')) == 0.5
37-
38-
assert r.value_of(sympy.pi) == np.pi
39125
assert r.value_of(2 * sympy.pi) == 2 * np.pi
40126
assert r.value_of(4**sympy.Symbol('a') + sympy.Symbol('b') * 10) == 3
41-
assert r.value_of('c') == 1 + 1j
42127
assert r.value_of(sympy.I * sympy.pi) == np.pi * 1j
128+
assert r.value_of(sympy.Symbol('a') * 3) == 1.5
129+
assert r.value_of(sympy.Symbol('b') / 0.1 - sympy.Symbol('a')) == 0.5
43130

44131

45132
def test_param_dict():

0 commit comments

Comments
 (0)