2020import re
2121import string
2222import time
23+ import urllib .parse
2324from collections import Iterable
2425from typing import Dict , List , Optional , Union , cast
2526
4041from cirq .study .sweeps import Points , Unit , Zip
4142
4243gcs_prefix_pattern = re .compile ('gs://[a-z0-9._/-]+' )
44+ TERMINAL_STATES = ['SUCCESS' , 'FAILURE' , 'CANCELLED' ]
4345
4446
4547class EngineTrialResult (TrialResult ):
@@ -74,7 +76,6 @@ class EngineOptions:
7476 """
7577
7678 def __init__ (self , project_id : str ,
77- credentials : Optional [oauth2client .client .Credentials ] = None ,
7879 program_id : Optional [str ] = None ,
7980 job_id : Optional [str ] = None ,
8081 gcs_prefix : Optional [str ] = None ,
@@ -84,7 +85,6 @@ def __init__(self, project_id: str,
8485 project_id and either gcs_prefix or gcs_program and gcs_results.
8586
8687 Args:
87- credentials: Credentials to use.
8888 project_id: The project id string of the Google Cloud Project to
8989 use.
9090 program_id: Id of the program to create, defaults to a random
@@ -97,38 +97,46 @@ def __init__(self, project_id: str,
9797 gcs_results: Explicit override for the results storage location.
9898 """
9999 self .project_id = project_id
100- self .credentials = credentials
101- self .program_id = program_id or 'prog-%s' % '' .join (
102- random .choice (string .ascii_uppercase + string .digits ) for _ in
103- range (6 ))
104- self .job_id = job_id or 'job-0'
105- if not gcs_prefix and (not gcs_program or not gcs_results ):
106- raise TypeError ('Either gcs_prefix must be provided or both'
107- ' gcs_program and gcs_results are required.' )
108- if gcs_prefix and not gcs_prefix_pattern .match (gcs_prefix ):
109- raise TypeError ('gcs_prefix must be of the form "gs://'
110- '<bucket name and optional object prefix>/"' )
111- self .gcs_prefix = gcs_prefix if not gcs_prefix or gcs_prefix .endswith (
112- '/' ) else gcs_prefix + '/'
113- self .gcs_program = gcs_program or '%sprograms/%s/%s' % (
114- self .gcs_prefix , self .program_id , self .program_id )
115- self .gcs_results = gcs_results or '%sprograms/%s/jobs/%s' % (
116- self .gcs_prefix , self .program_id , self .job_id )
100+ self .program_id = program_id
101+ self .job_id = job_id
102+ self .gcs_prefix = gcs_prefix
103+ self .gcs_program = gcs_program
104+ self .gcs_results = gcs_results
117105
118106
119107class Engine :
120- """Executor for Google Quantum Engine
108+ """Executor for Google Quantum Engine.
121109 """
122110
123111 def __init__ (self , api_key : str , api : str = 'quantum' ,
124112 version : str = 'v1alpha1' ,
125- discovery_url : Optional [str ] = None ) -> None :
113+ discovery_url : Optional [str ] = None ,
114+ credentials : Optional [oauth2client .client .Credentials ] = None ,
115+ gcs_prefix : Optional [str ] = None
116+ ) -> None :
117+ """Engine service client.
118+
119+ Args:
120+ api_key: API key to use to retrieve discovery doc.
121+ api: API name.
122+ version: API version.
123+ discovery_url: Discovery url to use.
124+ credentials: Credentials to use.
125+ """
126126 self .api_key = api_key
127127 self .api = api
128128 self .version = version
129129 self .discovery_url = discovery_url or ('https://{api}.googleapis.com/'
130130 '$discovery/rest'
131131 '?version={apiVersion}&key=%s' )
132+ self .credentials = credentials
133+ self .gcs_prefix = gcs_prefix
134+ self .service = discovery .build (
135+ self .api ,
136+ self .version ,
137+ discoveryServiceUrl = self .discovery_url % urllib .parse .quote_plus (
138+ self .api_key ),
139+ credentials = credentials )
132140
133141 def run (self ,
134142 options : EngineOptions ,
@@ -153,8 +161,8 @@ def run(self,
153161 Returns:
154162 Results for this run.
155163 """
156- return self .run_sweep (options , circuit , device , [param_resolver ],
157- repetitions , priority , target_route )[0 ]
164+ return list ( self .run_sweep (options , circuit , device , [param_resolver ],
165+ repetitions , priority , target_route ) )[0 ]
158166
159167 def run_sweep (self ,
160168 options : EngineOptions ,
@@ -164,7 +172,7 @@ def run_sweep(self,
164172 repetitions : int = 1 ,
165173 priority : int = 500 ,
166174 target_route : str = '/xmonsim' ,
167- ) -> List [ EngineTrialResult ] :
175+ ) -> 'EngineJob' :
168176 """Runs the entire supplied Circuit or Schedule via Google Quantum
169177 Engine.
170178
@@ -181,6 +189,29 @@ def run_sweep(self,
181189 Returns:
182190 Results for this run.
183191 """
192+ # Check and compute engine options.
193+ gcs_prefix = options .gcs_prefix or self .gcs_prefix
194+ gcs_prefix = gcs_prefix if not gcs_prefix or gcs_prefix .endswith (
195+ '/' ) else gcs_prefix + '/'
196+ if gcs_prefix and not gcs_prefix_pattern .match (gcs_prefix ):
197+ raise TypeError ('gcs_prefix must be of the form "gs://'
198+ '<bucket name and optional object prefix>/"' )
199+ if not gcs_prefix and (not options .gcs_program or
200+ not options .gcs_results ):
201+ raise TypeError ('Either gcs_prefix must be provided or both'
202+ ' gcs_program and gcs_results are required.' )
203+
204+ project_id = options .project_id
205+ program_id = options .program_id or 'prog-%s' % '' .join (
206+ random .choice (string .ascii_uppercase + string .digits ) for _ in
207+ range (6 ))
208+ job_id = options .job_id or 'job-0'
209+ gcs_program = options .gcs_program or '%sprograms/%s/%s' % (
210+ gcs_prefix , program_id , program_id )
211+ gcs_results = options .gcs_results or '%sprograms/%s/jobs/%s' % (
212+ gcs_prefix , program_id , job_id )
213+
214+ # Check program to run and program parameters.
184215 if not 0 <= priority < 1000 :
185216 raise TypeError ('priority must be between 0 and 1000' )
186217
@@ -201,64 +232,53 @@ def run_sweep(self,
201232 else :
202233 raise TypeError ('Unexpected execution type' )
203234
235+ # Create program.
204236 sweeps = _sweepable_to_sweeps (params or ParamResolver ({}))
205-
206- service = discovery .build (self .api , self .version ,
207- discoveryServiceUrl = self .discovery_url % (
208- self .api_key ,),
209- credentials = options .credentials )
210-
211237 proto_program = program_pb2 .Program ()
212238 for sweep in sweeps :
213239 sweep_proto = proto_program .parameter_sweeps .add ()
214240 sweep_to_proto (sweep , sweep_proto )
215241 sweep_proto .repetitions = repetitions
216242 proto_program .operations .extend (list (schedule_to_proto (schedule )))
217-
218243 code = {
219244 '@type' : 'type.googleapis.com/cirq.api.google.v1.Program' }
220245 code .update (MessageToDict (proto_program ))
221-
222246 request = {
223- 'name' : 'projects/%s/programs/%s' % (options . project_id ,
224- options . program_id ,),
225- 'gcs_code_location' : {'uri' : options . gcs_program , },
247+ 'name' : 'projects/%s/programs/%s' % (project_id ,
248+ program_id ,),
249+ 'gcs_code_location' : {'uri' : gcs_program , },
226250 'code' : code ,
227251 }
252+ response = self .service .projects ().programs ().create (
253+ parent = 'projects/%s' % project_id , body = request ).execute ()
228254
229- response = service .projects ().programs ().create (
230- parent = 'projects/%s' % options .project_id , body = request ).execute ()
231-
255+ # Create job.
232256 request = {
233- 'name' : '%s/jobs/%s' % (response ['name' ], options . job_id ),
257+ 'name' : '%s/jobs/%s' % (response ['name' ], job_id ),
234258 'output_config' : {
235259 'gcs_results_location' : {
236- 'uri' : options . gcs_results
260+ 'uri' : gcs_results
237261 }
238262 },
239263 'scheduling_config' : {
240264 'priority' : priority ,
241265 'target_route' : target_route
242266 },
243267 }
244- response = service .projects ().programs ().jobs ().create (
268+ response = self . service .projects ().programs ().jobs ().create (
245269 parent = response ['name' ], body = request ).execute ()
246270
247- for _ in range (1000 ):
248- if response ['executionStatus' ]['state' ] in ['SUCCESS' , 'FAILURE' ,
249- 'CANCELLED' ]:
250- break
251- time .sleep (0.5 )
252- response = service .projects ().programs ().jobs ().get (
253- name = response ['name' ]).execute ()
254-
255- if response ['executionStatus' ]['state' ] != 'SUCCESS' :
256- raise RuntimeError ('Job %s did not succeed. It is in state %s.' % (
257- response ['name' ], response ['executionStatus' ]['state' ]))
271+ return EngineJob (
272+ EngineOptions (project_id , program_id , job_id , gcs_prefix ,
273+ gcs_program , gcs_results ), response , self )
258274
259- response = service .projects ().programs ().jobs ().getResult (
260- parent = response ['name' ]).execute ()
275+ def get_job (self , job_resource_name ) -> Dict :
276+ return self .service .projects ().programs ().jobs ().get (
277+ name = job_resource_name ).execute ()
261278
279+ def get_job_results (self , job_resource_name ) -> List [EngineTrialResult ]:
280+ response = self .service .projects ().programs ().jobs ().getResult (
281+ parent = job_resource_name ).execute ()
262282 trial_results = []
263283 for sweep_result in response ['result' ]['sweepResults' ]:
264284 sweep_repetitions = sweep_result ['repetitions' ]
@@ -276,6 +296,66 @@ def run_sweep(self,
276296 measurements = measurements ))
277297 return trial_results
278298
299+ def cancel_job (self , job_resource_name ):
300+ self .service .projects ().programs ().jobs ().cancel (
301+ name = job_resource_name , body = {}).execute ()
302+
303+
304+ class EngineJob :
305+ """A job running on the engine.
306+
307+ Attributes:
308+ engine_options: The engine options used for the job.
309+ job_resource_name: The full resource name of the engine job.
310+ """
311+
312+ def __init__ (self ,
313+ engine_options : EngineOptions ,
314+ job : Dict ,
315+ engine : Engine ) -> None :
316+ """A job submitted to the engine.
317+
318+ Args:
319+ engine_options: The EngineOptions used to create the job.
320+ job: A full Job Dict.
321+ engine: Engine connected to the job.
322+ """
323+ self .engine_options = engine_options
324+ self ._job = job
325+ self ._engine = engine
326+ self .job_resource_name = job ['name' ]
327+ self ._results = None # type: List[EngineTrialResult]
328+
329+ def _update_job (self ):
330+ if self ._job ['executionStatus' ]['state' ] not in TERMINAL_STATES :
331+ self ._job = self ._engine .get_job (self .job_resource_name )
332+ return self ._job
333+
334+ def state (self ):
335+ return self ._update_job ()['executionStatus' ]['state' ]
336+
337+ def cancel (self ):
338+ self ._engine .cancel_job (self .job_resource_name )
339+
340+ def results (self ) -> List [EngineTrialResult ]:
341+ if not self ._results :
342+ job = self ._update_job ()
343+ for _ in range (1000 ):
344+ if job ['executionStatus' ]['state' ] in TERMINAL_STATES :
345+ break
346+ time .sleep (0.5 )
347+ job = self ._update_job ()
348+ if job ['executionStatus' ]['state' ] != 'SUCCESS' :
349+ raise RuntimeError (
350+ 'Job %s did not succeed. It is in state %s.' % (
351+ job ['name' ], job ['executionStatus' ]['state' ]))
352+ self ._results = self ._engine .get_job_results (
353+ self .job_resource_name )
354+ return self ._results
355+
356+ def __iter__ (self ):
357+ return self .results ().__iter__ ()
358+
279359
280360def _sweepable_to_sweeps (sweepable : Sweepable ) -> List [Sweep ]:
281361 if isinstance (sweepable , ParamResolver ):
0 commit comments