Skip to content

Commit 5c98986

Browse files
authored
EngineJob with methods to cancel and get state and results.
1 parent 817d6c9 commit 5c98986

File tree

2 files changed

+165
-57
lines changed

2 files changed

+165
-57
lines changed

cirq/google/engine/engine.py

Lines changed: 134 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import re
2121
import string
2222
import time
23+
import urllib.parse
2324
from collections import Iterable
2425
from typing import Dict, List, Optional, Union, cast
2526

@@ -40,6 +41,7 @@
4041
from cirq.study.sweeps import Points, Unit, Zip
4142

4243
gcs_prefix_pattern = re.compile('gs://[a-z0-9._/-]+')
44+
TERMINAL_STATES = ['SUCCESS', 'FAILURE', 'CANCELLED']
4345

4446

4547
class EngineTrialResult(TrialResult):
@@ -74,7 +76,6 @@ class EngineOptions:
7476
"""
7577

7678
def __init__(self, project_id: str,
77-
credentials: Optional[oauth2client.client.Credentials] = None,
7879
program_id: Optional[str] = None,
7980
job_id: Optional[str] = None,
8081
gcs_prefix: Optional[str] = None,
@@ -84,7 +85,6 @@ def __init__(self, project_id: str,
8485
project_id and either gcs_prefix or gcs_program and gcs_results.
8586
8687
Args:
87-
credentials: Credentials to use.
8888
project_id: The project id string of the Google Cloud Project to
8989
use.
9090
program_id: Id of the program to create, defaults to a random
@@ -97,38 +97,46 @@ def __init__(self, project_id: str,
9797
gcs_results: Explicit override for the results storage location.
9898
"""
9999
self.project_id = project_id
100-
self.credentials = credentials
101-
self.program_id = program_id or 'prog-%s' % ''.join(
102-
random.choice(string.ascii_uppercase + string.digits) for _ in
103-
range(6))
104-
self.job_id = job_id or 'job-0'
105-
if not gcs_prefix and (not gcs_program or not gcs_results):
106-
raise TypeError('Either gcs_prefix must be provided or both'
107-
' gcs_program and gcs_results are required.')
108-
if gcs_prefix and not gcs_prefix_pattern.match(gcs_prefix):
109-
raise TypeError('gcs_prefix must be of the form "gs://'
110-
'<bucket name and optional object prefix>/"')
111-
self.gcs_prefix = gcs_prefix if not gcs_prefix or gcs_prefix.endswith(
112-
'/') else gcs_prefix + '/'
113-
self.gcs_program = gcs_program or '%sprograms/%s/%s' % (
114-
self.gcs_prefix, self.program_id, self.program_id)
115-
self.gcs_results = gcs_results or '%sprograms/%s/jobs/%s' % (
116-
self.gcs_prefix, self.program_id, self.job_id)
100+
self.program_id = program_id
101+
self.job_id = job_id
102+
self.gcs_prefix = gcs_prefix
103+
self.gcs_program = gcs_program
104+
self.gcs_results = gcs_results
117105

118106

119107
class Engine:
120-
"""Executor for Google Quantum Engine
108+
"""Executor for Google Quantum Engine.
121109
"""
122110

123111
def __init__(self, api_key: str, api: str = 'quantum',
124112
version: str = 'v1alpha1',
125-
discovery_url: Optional[str] = None) -> None:
113+
discovery_url: Optional[str] = None,
114+
credentials: Optional[oauth2client.client.Credentials] = None,
115+
gcs_prefix: Optional[str] = None
116+
) -> None:
117+
"""Engine service client.
118+
119+
Args:
120+
api_key: API key to use to retrieve discovery doc.
121+
api: API name.
122+
version: API version.
123+
discovery_url: Discovery url to use.
124+
credentials: Credentials to use.
125+
"""
126126
self.api_key = api_key
127127
self.api = api
128128
self.version = version
129129
self.discovery_url = discovery_url or ('https://{api}.googleapis.com/'
130130
'$discovery/rest'
131131
'?version={apiVersion}&key=%s')
132+
self.credentials = credentials
133+
self.gcs_prefix = gcs_prefix
134+
self.service = discovery.build(
135+
self.api,
136+
self.version,
137+
discoveryServiceUrl=self.discovery_url % urllib.parse.quote_plus(
138+
self.api_key),
139+
credentials=credentials)
132140

133141
def run(self,
134142
options: EngineOptions,
@@ -153,8 +161,8 @@ def run(self,
153161
Returns:
154162
Results for this run.
155163
"""
156-
return self.run_sweep(options, circuit, device, [param_resolver],
157-
repetitions, priority, target_route)[0]
164+
return list(self.run_sweep(options, circuit, device, [param_resolver],
165+
repetitions, priority, target_route))[0]
158166

159167
def run_sweep(self,
160168
options: EngineOptions,
@@ -164,7 +172,7 @@ def run_sweep(self,
164172
repetitions: int = 1,
165173
priority: int = 500,
166174
target_route: str = '/xmonsim',
167-
) -> List[EngineTrialResult]:
175+
) -> 'EngineJob':
168176
"""Runs the entire supplied Circuit or Schedule via Google Quantum
169177
Engine.
170178
@@ -181,6 +189,29 @@ def run_sweep(self,
181189
Returns:
182190
Results for this run.
183191
"""
192+
# Check and compute engine options.
193+
gcs_prefix = options.gcs_prefix or self.gcs_prefix
194+
gcs_prefix = gcs_prefix if not gcs_prefix or gcs_prefix.endswith(
195+
'/') else gcs_prefix + '/'
196+
if gcs_prefix and not gcs_prefix_pattern.match(gcs_prefix):
197+
raise TypeError('gcs_prefix must be of the form "gs://'
198+
'<bucket name and optional object prefix>/"')
199+
if not gcs_prefix and (not options.gcs_program or
200+
not options.gcs_results):
201+
raise TypeError('Either gcs_prefix must be provided or both'
202+
' gcs_program and gcs_results are required.')
203+
204+
project_id = options.project_id
205+
program_id = options.program_id or 'prog-%s' % ''.join(
206+
random.choice(string.ascii_uppercase + string.digits) for _ in
207+
range(6))
208+
job_id = options.job_id or 'job-0'
209+
gcs_program = options.gcs_program or '%sprograms/%s/%s' % (
210+
gcs_prefix, program_id, program_id)
211+
gcs_results = options.gcs_results or '%sprograms/%s/jobs/%s' % (
212+
gcs_prefix, program_id, job_id)
213+
214+
# Check program to run and program parameters.
184215
if not 0 <= priority < 1000:
185216
raise TypeError('priority must be between 0 and 1000')
186217

@@ -201,64 +232,53 @@ def run_sweep(self,
201232
else:
202233
raise TypeError('Unexpected execution type')
203234

235+
# Create program.
204236
sweeps = _sweepable_to_sweeps(params or ParamResolver({}))
205-
206-
service = discovery.build(self.api, self.version,
207-
discoveryServiceUrl=self.discovery_url % (
208-
self.api_key,),
209-
credentials=options.credentials)
210-
211237
proto_program = program_pb2.Program()
212238
for sweep in sweeps:
213239
sweep_proto = proto_program.parameter_sweeps.add()
214240
sweep_to_proto(sweep, sweep_proto)
215241
sweep_proto.repetitions = repetitions
216242
proto_program.operations.extend(list(schedule_to_proto(schedule)))
217-
218243
code = {
219244
'@type': 'type.googleapis.com/cirq.api.google.v1.Program'}
220245
code.update(MessageToDict(proto_program))
221-
222246
request = {
223-
'name': 'projects/%s/programs/%s' % (options.project_id,
224-
options.program_id,),
225-
'gcs_code_location': {'uri': options.gcs_program, },
247+
'name': 'projects/%s/programs/%s' % (project_id,
248+
program_id,),
249+
'gcs_code_location': {'uri': gcs_program, },
226250
'code': code,
227251
}
252+
response = self.service.projects().programs().create(
253+
parent='projects/%s' % project_id, body=request).execute()
228254

229-
response = service.projects().programs().create(
230-
parent='projects/%s' % options.project_id, body=request).execute()
231-
255+
# Create job.
232256
request = {
233-
'name': '%s/jobs/%s' % (response['name'], options.job_id),
257+
'name': '%s/jobs/%s' % (response['name'], job_id),
234258
'output_config': {
235259
'gcs_results_location': {
236-
'uri': options.gcs_results
260+
'uri': gcs_results
237261
}
238262
},
239263
'scheduling_config': {
240264
'priority': priority,
241265
'target_route': target_route
242266
},
243267
}
244-
response = service.projects().programs().jobs().create(
268+
response = self.service.projects().programs().jobs().create(
245269
parent=response['name'], body=request).execute()
246270

247-
for _ in range(1000):
248-
if response['executionStatus']['state'] in ['SUCCESS', 'FAILURE',
249-
'CANCELLED']:
250-
break
251-
time.sleep(0.5)
252-
response = service.projects().programs().jobs().get(
253-
name=response['name']).execute()
254-
255-
if response['executionStatus']['state'] != 'SUCCESS':
256-
raise RuntimeError('Job %s did not succeed. It is in state %s.' % (
257-
response['name'], response['executionStatus']['state']))
271+
return EngineJob(
272+
EngineOptions(project_id, program_id, job_id, gcs_prefix,
273+
gcs_program, gcs_results), response, self)
258274

259-
response = service.projects().programs().jobs().getResult(
260-
parent=response['name']).execute()
275+
def get_job(self, job_resource_name) -> Dict:
276+
return self.service.projects().programs().jobs().get(
277+
name=job_resource_name).execute()
261278

279+
def get_job_results(self, job_resource_name) -> List[EngineTrialResult]:
280+
response = self.service.projects().programs().jobs().getResult(
281+
parent=job_resource_name).execute()
262282
trial_results = []
263283
for sweep_result in response['result']['sweepResults']:
264284
sweep_repetitions = sweep_result['repetitions']
@@ -276,6 +296,66 @@ def run_sweep(self,
276296
measurements=measurements))
277297
return trial_results
278298

299+
def cancel_job(self, job_resource_name):
300+
self.service.projects().programs().jobs().cancel(
301+
name=job_resource_name, body={}).execute()
302+
303+
304+
class EngineJob:
305+
"""A job running on the engine.
306+
307+
Attributes:
308+
engine_options: The engine options used for the job.
309+
job_resource_name: The full resource name of the engine job.
310+
"""
311+
312+
def __init__(self,
313+
engine_options: EngineOptions,
314+
job: Dict,
315+
engine: Engine) -> None:
316+
"""A job submitted to the engine.
317+
318+
Args:
319+
engine_options: The EngineOptions used to create the job.
320+
job: A full Job Dict.
321+
engine: Engine connected to the job.
322+
"""
323+
self.engine_options = engine_options
324+
self._job = job
325+
self._engine = engine
326+
self.job_resource_name = job['name']
327+
self._results = None # type: List[EngineTrialResult]
328+
329+
def _update_job(self):
330+
if self._job['executionStatus']['state'] not in TERMINAL_STATES:
331+
self._job = self._engine.get_job(self.job_resource_name)
332+
return self._job
333+
334+
def state(self):
335+
return self._update_job()['executionStatus']['state']
336+
337+
def cancel(self):
338+
self._engine.cancel_job(self.job_resource_name)
339+
340+
def results(self) -> List[EngineTrialResult]:
341+
if not self._results:
342+
job = self._update_job()
343+
for _ in range(1000):
344+
if job['executionStatus']['state'] in TERMINAL_STATES:
345+
break
346+
time.sleep(0.5)
347+
job = self._update_job()
348+
if job['executionStatus']['state'] != 'SUCCESS':
349+
raise RuntimeError(
350+
'Job %s did not succeed. It is in state %s.' % (
351+
job['name'], job['executionStatus']['state']))
352+
self._results = self._engine.get_job_results(
353+
self.job_resource_name)
354+
return self._results
355+
356+
def __iter__(self):
357+
return self.results().__iter__()
358+
279359

280360
def _sweepable_to_sweeps(sweepable: Sweepable) -> List[Sweep]:
281361
if isinstance(sweepable, ParamResolver):

cirq/google/engine/engine_test.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,11 @@ def test_run_sweep_params(build):
125125
jobs.getResult().execute.return_value = {
126126
'result': MessageToDict(_RESULTS)}
127127

128-
results = Engine(api_key="key").run_sweep(
128+
job = Engine(api_key="key").run_sweep(
129129
EngineOptions('project-id', gcs_prefix='gs://bucket/folder'),
130130
moment_by_moment_schedule(UnconstrainedDevice, Circuit()),
131131
params=[ParamResolver({'a': 1}), ParamResolver({'a': 2})])
132+
results = job.results()
132133
assert len(results) == 2
133134
for i, v in enumerate([1, 2]):
134135
assert results[i].repetitions == 1
@@ -168,10 +169,11 @@ def test_run_sweep_sweeps(build):
168169
jobs.getResult().execute.return_value = {
169170
'result': MessageToDict(_RESULTS)}
170171

171-
results = Engine(api_key="key").run_sweep(
172+
job = Engine(api_key="key").run_sweep(
172173
EngineOptions('project-id', gcs_prefix='gs://bucket/folder'),
173174
moment_by_moment_schedule(UnconstrainedDevice, Circuit()),
174175
params=Points('a', [1, 2]))
176+
results = job.results()
175177
assert len(results) == 2
176178
for i, v in enumerate([1, 2]):
177179
assert results[i].repetitions == 1
@@ -193,11 +195,37 @@ def test_run_sweep_sweeps(build):
193195
assert jobs.getResult().execute.call_count == 1
194196

195197

196-
def test_bad_priority():
198+
@python3_mock_test(discovery, 'build')
199+
def test_bad_priority(build):
197200
with pytest.raises(TypeError, match='priority must be between 0 and 1000'):
198201
Engine(api_key="key").run(
199202
EngineOptions('project-id', gcs_prefix='gs://bucket/folder'),
200203
Circuit(),
201204
UnconstrainedDevice,
202205
priority=1001)
203206

207+
208+
@python3_mock_test(discovery, 'build')
209+
def test_cancel(build):
210+
service = mock.Mock()
211+
build.return_value = service
212+
programs = service.projects().programs()
213+
jobs = programs.jobs()
214+
programs.create().execute.return_value = {
215+
'name': 'projects/project-id/programs/test'}
216+
jobs.create().execute.return_value = {
217+
'name': 'projects/project-id/programs/test/jobs/test',
218+
'executionStatus': {'state': 'READY'}}
219+
jobs.get().execute.return_value = {
220+
'name': 'projects/project-id/programs/test/jobs/test',
221+
'executionStatus': {'state': 'CANCELLED'}}
222+
223+
job = Engine(api_key="key").run_sweep(
224+
EngineOptions('project-id', gcs_prefix='gs://bucket/folder'),
225+
Circuit(), device=UnconstrainedDevice)
226+
job.cancel()
227+
assert job.job_resource_name == ('projects/project-id/programs/test/'
228+
'jobs/test')
229+
assert job.state() == 'CANCELLED'
230+
assert jobs.cancel.call_args[1][
231+
'name'] == 'projects/project-id/programs/test/jobs/test'

0 commit comments

Comments
 (0)