Skip to content

Commit bbf2cee

Browse files
authored
Add support for extensions that are only able to wrap some values (#274)
1 parent 8dd8d58 commit bbf2cee

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

cirq/extension/extensions.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def __init__(
3232
self,
3333
desired_to_actual_to_wrapper: Optional[Dict[
3434
Type[T_DESIRED],
35-
Dict[Type[T_ACTUAL], Callable[[T_ACTUAL], T_DESIRED]]]]=None
35+
Dict[Type[T_ACTUAL],
36+
Callable[[T_ACTUAL],
37+
Optional[T_DESIRED]]]]]=None
3638
) -> None:
3739
"""Specifies extensions.
3840
@@ -48,12 +50,12 @@ def __init__(
4850
{}
4951
if desired_to_actual_to_wrapper is None
5052
else desired_to_actual_to_wrapper
51-
) # type: Dict[Type[Any], Dict[Any, Callable[[Any], Any]]]
53+
) # type: Dict[Type[Any], Dict[Any, Callable[[Any], Optional[Any]]]]
5254

5355
def add_cast(self,
5456
desired_type: Type[T_DESIRED],
5557
actual_type: Type[T_ACTUAL],
56-
conversion: Callable[[T_ACTUAL], T_DESIRED],
58+
conversion: Callable[[T_ACTUAL], Optional[T_DESIRED]],
5759
also_add_inherited_conversions: bool = True,
5860
overwrite_existing: bool = False) -> None:
5961
"""Adds a way to turn one type of thing into another.
@@ -139,7 +141,9 @@ def try_cast(self,
139141
for actual_type in inspect.getmro(type(actual_value)):
140142
wrapper = actual_to_wrapper.get(actual_type)
141143
if wrapper:
142-
return wrapper(actual_value)
144+
wrapped = wrapper(actual_value)
145+
if wrapped is not None:
146+
return wrapped
143147

144148
if isinstance(actual_value, desired_type):
145149
return actual_value

cirq/extension/extensions_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,17 @@ def test_add_cast_redundant_including_subtypes():
263263
overwrite_existing=True)
264264
assert e.try_cast(Child(), Aunt) is o3
265265
assert e.try_cast(Child(), Cousin) is o3
266+
267+
268+
def test_add_potential_cast():
269+
a = Aunt()
270+
c1 = Child()
271+
c2 = Child()
272+
273+
e = extension.Extensions()
274+
e.add_cast(desired_type=Aunt,
275+
actual_type=Child,
276+
conversion=lambda e: a if e is c1 else None)
277+
278+
assert e.try_cast(c1, Aunt) is a
279+
assert e.try_cast(c2, Aunt) is None

0 commit comments

Comments
 (0)