Skip to content

Commit fda82b5

Browse files
authored
Engine fixes and tests
1 parent e540d6c commit fda82b5

File tree

3 files changed

+261
-14
lines changed

3 files changed

+261
-14
lines changed

cirq/google/engine/engine.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import numpy as np
2727
import oauth2client
28-
from apiclient.discovery import build
28+
from apiclient import discovery
2929
from google.protobuf.json_format import MessageToDict
3030

3131
from cirq.api.google.v1 import program_pb2
@@ -131,11 +131,11 @@ def __init__(self, api_key: str, api: str = 'quantum',
131131
'?version={apiVersion}&key=%s')
132132

133133
def run(self,
134+
options: EngineOptions,
134135
circuit: Circuit,
135136
device: Device,
136137
param_resolver: ParamResolver = ParamResolver({}),
137138
repetitions: int = 1,
138-
options: EngineOptions = None,
139139
priority: int = 50,
140140
target_route: str = '/xmonsim',
141141
) -> EngineTrialResult:
@@ -153,16 +153,16 @@ def run(self,
153153
Returns:
154154
Results for this run.
155155
"""
156-
return self.run_sweep(circuit, device, [param_resolver], repetitions,
157-
options, priority, target_route)[0]
156+
return self.run_sweep(options, circuit, device, [param_resolver],
157+
repetitions, priority, target_route)[0]
158158

159159
def run_sweep(self,
160+
options: EngineOptions,
160161
program: Union[Circuit, Schedule],
161162
device: Device = None,
162163
params: Sweepable = None,
163164
repetitions: int = 1,
164-
options: EngineOptions = None,
165-
priority: int = 50,
165+
priority: int = 500,
166166
target_route: str = '/xmonsim',
167167
) -> List[EngineTrialResult]:
168168
"""Runs the entire supplied Circuit or Schedule via Google Quantum
@@ -181,8 +181,8 @@ def run_sweep(self,
181181
Returns:
182182
Results for this run.
183183
"""
184-
if not 0 <= priority < 100:
185-
raise TypeError('priority must be between 0 and 100')
184+
if not 0 <= priority < 1000:
185+
raise TypeError('priority must be between 0 and 1000')
186186

187187
if isinstance(program, Circuit):
188188
if not device:
@@ -203,10 +203,10 @@ def run_sweep(self,
203203

204204
sweeps = _sweepable_to_sweeps(params or ParamResolver({}))
205205

206-
service = build(self.api, self.version,
207-
discoveryServiceUrl=self.discovery_url % (
208-
self.api_key,),
209-
credentials=options.credentials)
206+
service = discovery.build(self.api, self.version,
207+
discoveryServiceUrl=self.discovery_url % (
208+
self.api_key,),
209+
credentials=options.credentials)
210210

211211
proto_program = program_pb2.Program()
212212
for sweep in sweeps:
@@ -262,15 +262,16 @@ def run_sweep(self,
262262
trial_results = []
263263
for sweep_result in response['result']['sweepResults']:
264264
sweep_repetitions = sweep_result['repetitions']
265-
key_sizes = [(m['key'], m['size'])
265+
key_sizes = [(m['key'], len(m['qubits']))
266266
for m in sweep_result['measurementKeys']]
267267
for result in sweep_result['parameterizedResults']:
268268
data = base64.standard_b64decode(result['measurementResults'])
269269
measurements = unpack_results(data, sweep_repetitions,
270270
key_sizes)
271271

272272
trial_results.append(EngineTrialResult(
273-
params=ParamResolver(result.get('params', {})),
273+
params=ParamResolver(
274+
result.get('params', {}).get('assignments', {})),
274275
repetitions=sweep_repetitions,
275276
measurements=measurements))
276277
return trial_results

cirq/google/engine/engine_test.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright 2018 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for engine."""
16+
17+
import numpy as np
18+
import pytest
19+
20+
from apiclient import discovery
21+
from google.protobuf.json_format import MessageToDict
22+
23+
from cirq.api.google.v1 import operations_pb2, params_pb2, program_pb2
24+
from cirq.circuits import Circuit
25+
from cirq.devices import UnconstrainedDevice
26+
from cirq.google.engine.engine import Engine, EngineOptions
27+
from cirq.schedules.schedulers import moment_by_moment_schedule
28+
from cirq.study import ParamResolver, Points
29+
from cirq.testing.python3_mock import python3_mock_test, mock
30+
31+
_A_RESULT = program_pb2.Result(
32+
sweep_results=[program_pb2.SweepResult(repetitions=1, measurement_keys=[
33+
program_pb2.MeasurementKey(
34+
key='q',
35+
qubits=[operations_pb2.Qubit(row=1, col=1)])],
36+
parameterized_results=[
37+
program_pb2.ParameterizedResult(
38+
params=params_pb2.ParameterDict(assignments={'a': 1}),
39+
measurement_results=b'01')])])
40+
41+
_RESULTS = program_pb2.Result(
42+
sweep_results=[program_pb2.SweepResult(repetitions=1, measurement_keys=[
43+
program_pb2.MeasurementKey(
44+
key='q',
45+
qubits=[operations_pb2.Qubit(row=1, col=1)])],
46+
parameterized_results=[
47+
program_pb2.ParameterizedResult(
48+
params=params_pb2.ParameterDict(assignments={'a': 1}),
49+
measurement_results=b'01'),
50+
program_pb2.ParameterizedResult(
51+
params=params_pb2.ParameterDict(assignments={'a': 2}),
52+
measurement_results=b'01')])])
53+
54+
55+
@python3_mock_test(discovery, 'build')
56+
def test_run_circuit(build):
57+
service = mock.Mock()
58+
build.return_value = service
59+
programs = service.projects().programs()
60+
jobs = programs.jobs()
61+
programs.create().execute.return_value = {
62+
'name': 'projects/project-id/programs/test'}
63+
jobs.create().execute.return_value = {
64+
'name': 'projects/project-id/programs/test/jobs/test',
65+
'executionStatus': {'state': 'READY'}}
66+
jobs.get().execute.return_value = {
67+
'name': 'projects/project-id/programs/test/jobs/test',
68+
'executionStatus': {'state': 'SUCCESS'}}
69+
jobs.getResult().execute.return_value = {
70+
'result': MessageToDict(_A_RESULT)}
71+
72+
result = Engine(api_key="key").run(
73+
EngineOptions('project-id', gcs_prefix='gs://bucket/folder'), Circuit(),
74+
UnconstrainedDevice)
75+
assert result.repetitions == 1
76+
assert result.params.param_dict == {'a': 1}
77+
assert result.measurements == {'q': np.array([[0]], dtype='uint8')}
78+
build.assert_called_with('quantum', 'v1alpha1', credentials=None,
79+
discoveryServiceUrl=('https://{api}.googleapis.com'
80+
'/$discovery/rest?version='
81+
'{apiVersion}&key=key'))
82+
assert programs.create.call_args[1]['parent'] == 'projects/project-id'
83+
assert jobs.create.call_args[1][
84+
'parent'] == 'projects/project-id/programs/test'
85+
assert jobs.get().execute.call_count == 1
86+
assert jobs.getResult().execute.call_count == 1
87+
88+
89+
@python3_mock_test(discovery, 'build')
90+
def test_run_circuit_failed(build):
91+
service = mock.Mock()
92+
build.return_value = service
93+
programs = service.projects().programs()
94+
jobs = programs.jobs()
95+
programs.create().execute.return_value = {
96+
'name': 'projects/project-id/programs/test'}
97+
jobs.create().execute.return_value = {
98+
'name': 'projects/project-id/programs/test/jobs/test',
99+
'executionStatus': {'state': 'READY'}}
100+
jobs.get().execute.return_value = {
101+
'name': 'projects/project-id/programs/test/jobs/test',
102+
'executionStatus': {'state': 'FAILURE'}}
103+
104+
with pytest.raises(RuntimeError, match='It is in state FAILURE'):
105+
Engine(api_key="key").run(
106+
EngineOptions('project-id', gcs_prefix='gs://bucket/folder'),
107+
Circuit(),
108+
UnconstrainedDevice)
109+
110+
111+
@python3_mock_test(discovery, 'build')
112+
def test_run_sweep_params(build):
113+
service = mock.Mock()
114+
build.return_value = service
115+
programs = service.projects().programs()
116+
jobs = programs.jobs()
117+
programs.create().execute.return_value = {
118+
'name': 'projects/project-id/programs/test'}
119+
jobs.create().execute.return_value = {
120+
'name': 'projects/project-id/programs/test/jobs/test',
121+
'executionStatus': {'state': 'READY'}}
122+
jobs.get().execute.return_value = {
123+
'name': 'projects/project-id/programs/test/jobs/test',
124+
'executionStatus': {'state': 'SUCCESS'}}
125+
jobs.getResult().execute.return_value = {
126+
'result': MessageToDict(_RESULTS)}
127+
128+
results = Engine(api_key="key").run_sweep(
129+
EngineOptions('project-id', gcs_prefix='gs://bucket/folder'),
130+
moment_by_moment_schedule(UnconstrainedDevice, Circuit()),
131+
params=[ParamResolver({'a': 1}), ParamResolver({'a': 2})])
132+
assert len(results) == 2
133+
for i, v in enumerate([1, 2]):
134+
assert results[i].repetitions == 1
135+
assert results[i].params.param_dict == {'a': v}
136+
assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')}
137+
build.assert_called_with('quantum', 'v1alpha1', credentials=None,
138+
discoveryServiceUrl=('https://{api}.googleapis.com'
139+
'/$discovery/rest?version='
140+
'{apiVersion}&key=key'))
141+
assert programs.create.call_args[1]['parent'] == 'projects/project-id'
142+
sweeps = programs.create.call_args[1]['body']['code']['parameterSweeps']
143+
assert len(sweeps) == 2
144+
for i, v in enumerate([1, 2]):
145+
assert sweeps[i]['repetitions'] == 1
146+
assert sweeps[i]['sweep']['factors'][0]['sweeps'][0]['sweepPoints'][
147+
'points'] == [v]
148+
assert jobs.create.call_args[1][
149+
'parent'] == 'projects/project-id/programs/test'
150+
assert jobs.get().execute.call_count == 1
151+
assert jobs.getResult().execute.call_count == 1
152+
153+
154+
@python3_mock_test(discovery, 'build')
155+
def test_run_sweep_sweeps(build):
156+
service = mock.Mock()
157+
build.return_value = service
158+
programs = service.projects().programs()
159+
jobs = programs.jobs()
160+
programs.create().execute.return_value = {
161+
'name': 'projects/project-id/programs/test'}
162+
jobs.create().execute.return_value = {
163+
'name': 'projects/project-id/programs/test/jobs/test',
164+
'executionStatus': {'state': 'READY'}}
165+
jobs.get().execute.return_value = {
166+
'name': 'projects/project-id/programs/test/jobs/test',
167+
'executionStatus': {'state': 'SUCCESS'}}
168+
jobs.getResult().execute.return_value = {
169+
'result': MessageToDict(_RESULTS)}
170+
171+
results = Engine(api_key="key").run_sweep(
172+
EngineOptions('project-id', gcs_prefix='gs://bucket/folder'),
173+
moment_by_moment_schedule(UnconstrainedDevice, Circuit()),
174+
params=Points('a', [1, 2]))
175+
assert len(results) == 2
176+
for i, v in enumerate([1, 2]):
177+
assert results[i].repetitions == 1
178+
assert results[i].params.param_dict == {'a': v}
179+
assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')}
180+
build.assert_called_with('quantum', 'v1alpha1', credentials=None,
181+
discoveryServiceUrl=('https://{api}.googleapis.com'
182+
'/$discovery/rest?version='
183+
'{apiVersion}&key=key'))
184+
assert programs.create.call_args[1]['parent'] == 'projects/project-id'
185+
sweeps = programs.create.call_args[1]['body']['code']['parameterSweeps']
186+
assert len(sweeps) == 1
187+
assert sweeps[0]['repetitions'] == 1
188+
assert sweeps[0]['sweep']['factors'][0]['sweeps'][0]['sweepPoints'][
189+
'points'] == [1, 2]
190+
assert jobs.create.call_args[1][
191+
'parent'] == 'projects/project-id/programs/test'
192+
assert jobs.get().execute.call_count == 1
193+
assert jobs.getResult().execute.call_count == 1
194+
195+
196+
def test_bad_priority():
197+
with pytest.raises(TypeError, match='priority must be between 0 and 1000'):
198+
Engine(api_key="key").run(
199+
EngineOptions('project-id', gcs_prefix='gs://bucket/folder'),
200+
Circuit(),
201+
UnconstrainedDevice,
202+
priority=1001)
203+

cirq/testing/python3_mock.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2018 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
17+
18+
class FakeMock:
19+
20+
class Mock:
21+
pass
22+
23+
24+
if sys.version_info < (3,):
25+
mock = FakeMock
26+
else:
27+
from unittest import mock
28+
29+
30+
def python3_mock_test(target, method):
31+
"""A decorator for tests that need to mock.patch.object() which is not
32+
supported in Python 2.7. The test only executes if running Python 3.
33+
34+
Args:
35+
target: Target to patch.
36+
method: The name of the method to mock.
37+
"""
38+
if sys.version_info >= (3,):
39+
return mock.patch.object(target, method)
40+
else:
41+
def nothing(f):
42+
pass
43+
return nothing

0 commit comments

Comments
 (0)