Skip to content

Commit 730cc24

Browse files
authored
Adding default values to op serializers and deserializers (#3280)
- This will allow us to have more efficient passing of values over the wire for common or missing values. - This will also allow tokens for focused calibrations (see PR #3269) to not add a token arg for every gate. - This will also add defaults for deserialization only. Note: a second PR will later be needed to add the defaults for serializers. This is done in two steps for backwards compatibility. We need the ability to correct for defaults in deserialization on all servers before we can add support for defaults in serialization.
1 parent fa433c3 commit 730cc24

File tree

5 files changed

+110
-19
lines changed

5 files changed

+110
-19
lines changed

cirq/google/common_serializers.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,12 @@ def _convert_physical_z(op: ops.Operation, proto: v2.program_pb2.Operation):
175175
op_deserializer.DeserializingArg(
176176
serialized_name='axis_half_turns',
177177
constructor_arg_name='phase_exponent',
178+
default=0.0,
178179
),
179180
op_deserializer.DeserializingArg(
180181
serialized_name='half_turns',
181182
constructor_arg_name='exponent',
183+
default=1.0,
182184
),
183185
],
184186
),
@@ -189,6 +191,7 @@ def _convert_physical_z(op: ops.Operation, proto: v2.program_pb2.Operation):
189191
op_deserializer.DeserializingArg(
190192
serialized_name='half_turns',
191193
constructor_arg_name='exponent',
194+
default=1.0,
192195
),
193196
],
194197
op_wrapper=lambda op, proto: _convert_physical_z(op, proto)),
@@ -199,14 +202,17 @@ def _convert_physical_z(op: ops.Operation, proto: v2.program_pb2.Operation):
199202
op_deserializer.DeserializingArg(
200203
serialized_name='x_exponent',
201204
constructor_arg_name='x_exponent',
205+
default=0.0,
202206
),
203207
op_deserializer.DeserializingArg(
204208
serialized_name='z_exponent',
205209
constructor_arg_name='z_exponent',
210+
default=0.0,
206211
),
207212
op_deserializer.DeserializingArg(
208213
serialized_name='axis_phase_exponent',
209214
constructor_arg_name='axis_phase_exponent',
215+
default=0.0,
210216
),
211217
],
212218
),
@@ -322,7 +328,8 @@ def _convert_physical_z(op: ops.Operation, proto: v2.program_pb2.Operation):
322328
args=[
323329
op_deserializer.DeserializingArg(
324330
serialized_name='axis_half_turns',
325-
constructor_arg_name='phase_exponent'),
331+
constructor_arg_name='phase_exponent',
332+
),
326333
op_deserializer.DeserializingArg(serialized_name='axis_half_turns',
327334
constructor_arg_name='exponent',
328335
value_func=lambda _: 1),
@@ -376,8 +383,11 @@ def _convert_physical_z(op: ops.Operation, proto: v2.program_pb2.Operation):
376383
serialized_gate_id='cz',
377384
gate_constructor=ops.CZPowGate,
378385
args=[
379-
op_deserializer.DeserializingArg(serialized_name='half_turns',
380-
constructor_arg_name='exponent')
386+
op_deserializer.DeserializingArg(
387+
serialized_name='half_turns',
388+
constructor_arg_name='exponent',
389+
default=1.0,
390+
)
381391
])
382392

383393
#
@@ -539,10 +549,16 @@ def _can_serialize_limited_iswap(exponent: float):
539549
serialized_gate_id='fsim',
540550
gate_constructor=ops.FSimGate,
541551
args=[
542-
op_deserializer.DeserializingArg(serialized_name='theta',
543-
constructor_arg_name='theta'),
544-
op_deserializer.DeserializingArg(serialized_name='phi',
545-
constructor_arg_name='phi'),
552+
op_deserializer.DeserializingArg(
553+
serialized_name='theta',
554+
constructor_arg_name='theta',
555+
default=0.0,
556+
),
557+
op_deserializer.DeserializingArg(
558+
serialized_name='phi',
559+
constructor_arg_name='phi',
560+
default=0.0,
561+
),
546562
])
547563

548564

cirq/google/op_deserializer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,14 @@ class DeserializingArg:
4444
None.
4545
required: Whether a value must be specified when constructing the
4646
deserialized gate. Defaults to True.
47+
default: default value to set if the value is not present in the
48+
arg. If set, required is ignored.
4749
"""
4850
serialized_name: str
4951
constructor_arg_name: str
5052
value_func: Optional[Callable[[arg_func_langs.ARG_LIKE], Any]] = None
5153
required: bool = True
54+
default: Any = None
5255

5356

5457
class GateOpDeserializer:
@@ -107,10 +110,14 @@ def _args_from_proto(self, proto: v2.program_pb2.Operation, *,
107110
) -> Dict[str, arg_func_langs.ARG_LIKE]:
108111
return_args = {}
109112
for arg in self.args:
110-
if arg.serialized_name not in proto.args and arg.required:
111-
raise ValueError(
112-
'Argument {} not in deserializing args, but is required.'.
113-
format(arg.serialized_name))
113+
if arg.serialized_name not in proto.args:
114+
if arg.default:
115+
return_args[arg.constructor_arg_name] = arg.default
116+
continue
117+
elif arg.required:
118+
raise ValueError(
119+
f'Argument {arg.serialized_name} '
120+
'not in deserializing args, but is required.')
114121

115122
value = arg_func_langs._arg_from_proto(
116123
proto.args[arg.serialized_name],

cirq/google/op_deserializer_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,30 @@ def test_from_proto_required_arg_not_assigned():
362362
})
363363
with pytest.raises(ValueError):
364364
deserializer.from_proto(serialized)
365+
366+
367+
def test_defaults():
368+
deserializer = cg.GateOpDeserializer(
369+
serialized_gate_id='my_gate',
370+
gate_constructor=GateWithAttribute,
371+
args=[
372+
cg.DeserializingArg(serialized_name='my_val',
373+
constructor_arg_name='val',
374+
default=1.0),
375+
cg.DeserializingArg(serialized_name='not_req',
376+
constructor_arg_name='not_req',
377+
default='hello',
378+
required=False)
379+
])
380+
serialized = op_proto({
381+
'gate': {
382+
'id': 'my_gate'
383+
},
384+
'args': {},
385+
'qubits': [{
386+
'id': '1_2'
387+
}]
388+
})
389+
g = GateWithAttribute(1.0)
390+
g.not_req = 'hello'
391+
assert deserializer.from_proto(serialized) == g(cirq.GridQubit(1, 2))

cirq/google/op_serializer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass
16-
from typing import (Callable, List, Optional, Type, TypeVar, Union,
16+
from typing import (Any, Callable, List, Optional, Type, TypeVar, Union,
1717
TYPE_CHECKING)
1818

1919
import numpy as np
@@ -44,11 +44,14 @@ class SerializingArg:
4444
returns this value (i.e. `lambda x: default_value`)
4545
required: Whether this argument is a required argument for the
4646
serialized form.
47+
default: default value. avoid serializing if this is the value.
48+
Note that the DeserializingArg must also have this as default.
4749
"""
4850
serialized_name: str
4951
serialized_type: Type[arg_func_langs.ARG_LIKE]
5052
op_getter: Union[str, Callable[['cirq.Operation'], arg_func_langs.ARG_LIKE]]
5153
required: bool = True
54+
default: Any = None
5255

5356

5457
class GateOpSerializer:
@@ -121,7 +124,7 @@ def to_proto(
121124
msg.qubits.add().id = v2.qubit_to_proto_id(qubit)
122125
for arg in self.args:
123126
value = self._value_from_gate(op, arg)
124-
if value is not None:
127+
if value is not None and (not arg.default or value != arg.default):
125128
_arg_to_proto(value,
126129
out=msg.args[arg.serialized_name],
127130
arg_function_language=arg_function_language)
@@ -156,16 +159,16 @@ def _value_from_gate(self, op: 'cirq.Operation', arg: SerializingArg
156159

157160
def _check_type(self, value: arg_func_langs.ARG_LIKE,
158161
arg: SerializingArg) -> None:
159-
if arg.serialized_type == List[bool]:
160-
if (not isinstance(value, (list, tuple, np.ndarray)) or
161-
not all(isinstance(x, (bool, np.bool_)) for x in value)):
162-
raise ValueError('Expected type List[bool] but was {}'.format(
163-
type(value)))
164-
elif arg.serialized_type == float:
162+
if arg.serialized_type == float:
165163
if not isinstance(value, (float, int)):
166164
raise ValueError(
167165
'Expected type convertible to float but was {}'.format(
168166
type(value)))
167+
elif arg.serialized_type == List[bool]:
168+
if (not isinstance(value, (list, tuple, np.ndarray)) or
169+
not all(isinstance(x, (bool, np.bool_)) for x in value)):
170+
raise ValueError('Expected type List[bool] but was {}'.format(
171+
type(value)))
169172
elif value is not None and not isinstance(value, arg.serialized_type):
170173
raise ValueError(
171174
'Argument {} had type {} but gate returned type {}'.format(

cirq/google/op_serializer_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,41 @@ def test_can_serialize_operation_subclass():
416416
q = cirq.GridQubit(1, 1)
417417
assert serializer.can_serialize_operation(SubclassGate(1)(q))
418418
assert not serializer.can_serialize_operation(SubclassGate(0)(q))
419+
420+
421+
def test_defaults_not_serialized():
422+
serializer = cg.GateOpSerializer(gate_type=GateWithAttribute,
423+
serialized_gate_id='my_gate',
424+
args=[
425+
cg.SerializingArg(
426+
serialized_name='my_val',
427+
serialized_type=float,
428+
default=1.0,
429+
op_getter='val')
430+
])
431+
q = cirq.GridQubit(1, 2)
432+
no_default = op_proto({
433+
'gate': {
434+
'id': 'my_gate'
435+
},
436+
'args': {
437+
'my_val': {
438+
'arg_value': {
439+
'float_value': 0.125
440+
}
441+
}
442+
},
443+
'qubits': [{
444+
'id': '1_2'
445+
}]
446+
})
447+
assert no_default == serializer.to_proto(GateWithAttribute(0.125)(q))
448+
with_default = op_proto({
449+
'gate': {
450+
'id': 'my_gate'
451+
},
452+
'qubits': [{
453+
'id': '1_2'
454+
}]
455+
})
456+
assert with_default == serializer.to_proto(GateWithAttribute(1.0)(q))

0 commit comments

Comments
 (0)