|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | """Tests for parameter resolvers.""" |
| 16 | +import fractions |
16 | 17 |
|
17 | 18 | import numpy as np |
| 19 | +import pytest |
18 | 20 | import sympy |
19 | 21 |
|
20 | 22 | import cirq |
21 | 23 |
|
22 | 24 |
|
23 | | -def test_value_of(): |
| 25 | +@pytest.mark.parametrize('val', [ |
| 26 | + 3.2, |
| 27 | + np.float32(3.2), |
| 28 | + int(1), |
| 29 | + np.int(3), |
| 30 | + np.int32(45), |
| 31 | + np.float64(6.3), |
| 32 | + np.int32(2), |
| 33 | + np.complex64(1j), |
| 34 | + np.complex128(2j), |
| 35 | + np.complex(1j), |
| 36 | + fractions.Fraction(3, 2), |
| 37 | +]) |
| 38 | +def test_value_of_pass_through_types(val): |
| 39 | + _assert_consistent_resolution(val, val) |
| 40 | + |
| 41 | + |
| 42 | +@pytest.mark.parametrize('val,resolved', [(sympy.pi, np.pi), |
| 43 | + (sympy.S.NegativeOne, -1), |
| 44 | + (sympy.S.Half, 0.5), |
| 45 | + (sympy.S.One, 1)]) |
| 46 | +def test_value_of_transformed_types(val, resolved): |
| 47 | + _assert_consistent_resolution(val, resolved) |
| 48 | + |
| 49 | + |
| 50 | +@pytest.mark.parametrize('val,resolved', [(sympy.I, 1j)]) |
| 51 | +def test_value_of_substituted_types(val, resolved): |
| 52 | + _assert_consistent_resolution(val, resolved, True) |
| 53 | + |
| 54 | + |
| 55 | +def _assert_consistent_resolution(v, resolved, subs_called=False): |
| 56 | + """Asserts that parameter resolution works consistently. |
| 57 | +
|
| 58 | + The ParamResolver.value_of method can resolve any Sympy expression - |
| 59 | + subclasses of sympy.Basic. In the generic case, it calls `sympy.Basic.subs` |
| 60 | + to substitute symbols with values specified in a dict, which is known to be |
| 61 | + very slow. Instead value_of defines a pass-through shortcut for known |
| 62 | + numeric types. For a given value `v` it is asserted that value_of resolves |
| 63 | + it to `resolved`, with the exact type of `resolved`.`subs_called` indicates |
| 64 | + whether it is expected to have `subs` called or not during the resolution. |
| 65 | +
|
| 66 | + Args: |
| 67 | + v: the value to resolve |
| 68 | + resolved: the expected resolution result |
| 69 | + subs_called: if True, it is expected that the slow subs method is called |
| 70 | + Raises: |
| 71 | + AssertionError in case resolution assertion fail. |
| 72 | + """ |
| 73 | + |
| 74 | + class SubsAwareSymbol(sympy.Symbol): |
| 75 | + """A Symbol that registers a call to its `subs` method.""" |
| 76 | + |
| 77 | + def __init__(self, sym: str): |
| 78 | + self.called = False |
| 79 | + self.symbol = sympy.Symbol(sym) |
| 80 | + |
| 81 | + # note: super().subs() doesn't resolve based on the param_dict properly |
| 82 | + # for some reason, that's why a delegate (self.symbol) is used instead |
| 83 | + def subs(self, *args, **kwargs): |
| 84 | + self.called = True |
| 85 | + return self.symbol.subs(*args, **kwargs) |
| 86 | + |
| 87 | + r = cirq.ParamResolver({'a': v}) |
| 88 | + |
| 89 | + # symbol based resolution |
| 90 | + s = SubsAwareSymbol('a') |
| 91 | + assert r.value_of(s) == resolved, (f"expected {resolved}, " |
| 92 | + f"got {r.value_of(s)}") |
| 93 | + assert subs_called == s.called, ( |
| 94 | + f"For pass-through type " |
| 95 | + f"{type(v)} sympy.subs shouldn't have been called.") |
| 96 | + assert isinstance(r.value_of(s), |
| 97 | + type(resolved)), (f"expected {type(resolved)} " |
| 98 | + f"got {type(r.value_of(s))}") |
| 99 | + |
| 100 | + # string based resolution (which in turn uses symbol based resolution) |
| 101 | + assert r.value_of('a') == resolved, (f"expected {resolved}, " |
| 102 | + f"got {r.value_of('a')}") |
| 103 | + assert isinstance(r.value_of('a'), |
| 104 | + type(resolved)), (f"expected {type(resolved)} " |
| 105 | + f"got {type(r.value_of('a'))}") |
| 106 | + |
| 107 | + # value based resolution |
| 108 | + assert r.value_of(v) == resolved, (f"expected {resolved}, " |
| 109 | + f"got {r.value_of(v)}") |
| 110 | + assert isinstance(r.value_of(v), |
| 111 | + type(resolved)), (f"expected {type(resolved)} " |
| 112 | + f"got {type(r.value_of(v))}") |
| 113 | + |
| 114 | + |
| 115 | +def test_value_of_strings(): |
| 116 | + assert cirq.ParamResolver().value_of('x') == sympy.Symbol('x') |
| 117 | + |
| 118 | + |
| 119 | +def test_value_of_calculations(): |
24 | 120 | assert not bool(cirq.ParamResolver()) |
25 | 121 |
|
26 | 122 | r = cirq.ParamResolver({'a': 0.5, 'b': 0.1, 'c': 1 + 1j}) |
27 | 123 | assert bool(r) |
28 | 124 |
|
29 | | - assert r.value_of('x') == sympy.Symbol('x') |
30 | | - assert r.value_of('a') == 0.5 |
31 | | - assert r.value_of(sympy.Symbol('a')) == 0.5 |
32 | | - assert r.value_of(0.5) == 0.5 |
33 | | - assert r.value_of(sympy.Symbol('b')) == 0.1 |
34 | | - assert r.value_of(0.3) == 0.3 |
35 | | - assert r.value_of(sympy.Symbol('a') * 3) == 1.5 |
36 | | - assert r.value_of(sympy.Symbol('b') / 0.1 - sympy.Symbol('a')) == 0.5 |
37 | | - |
38 | | - assert r.value_of(sympy.pi) == np.pi |
39 | 125 | assert r.value_of(2 * sympy.pi) == 2 * np.pi |
40 | 126 | assert r.value_of(4**sympy.Symbol('a') + sympy.Symbol('b') * 10) == 3 |
41 | | - assert r.value_of('c') == 1 + 1j |
42 | 127 | assert r.value_of(sympy.I * sympy.pi) == np.pi * 1j |
| 128 | + assert r.value_of(sympy.Symbol('a') * 3) == 1.5 |
| 129 | + assert r.value_of(sympy.Symbol('b') / 0.1 - sympy.Symbol('a')) == 0.5 |
43 | 130 |
|
44 | 131 |
|
45 | 132 | def test_param_dict(): |
|
0 commit comments