@@ -143,6 +143,8 @@ async def cancel_invocation(self) -> None:
143143
144144 await ctx.cancel_invocation(await f.invocation_id())
145145 """
146+ inv = await self .invocation_id ()
147+ await self .context .cancel_invocation (inv )
146148
147149class ServerSendHandle (SendHandle ):
148150 """This class implements the send API"""
@@ -223,45 +225,47 @@ class SyncPoint:
223225 This class implements a synchronization point.
224226 """
225227
226- def __init__ (self ):
227- self ._cond = asyncio .Condition ()
228+ def __init__ (self ) -> None :
229+ self .cond : asyncio .Event | None = None
228230
229- async def wait (self ):
231+ def awaiter (self ):
230232 """Wait for the sync point."""
231- async with self ._cond :
232- await self ._cond .wait ()
233+ if self .cond is None :
234+ self .cond = asyncio .Event ()
235+ return self .cond .wait ()
233236
234237 async def arrive (self ):
235- """Arrive at the sync point."""
236- async with self ._cond :
237- self ._cond . notify_all ()
238+ """arrive at the sync point."""
239+ if self .cond is not None :
240+ self .cond . set ()
238241
239242class Tasks :
240243 """
241244 This class implements a list of tasks.
242245 """
243246
244247 def __init__ (self ) -> None :
245- self .tasks : List [asyncio .Future ] = []
248+ self .tasks : set [asyncio .Future ] = set ()
246249
247250 def add (self , task : asyncio .Future ):
248251 """Add a task to the list."""
249- self .tasks .append (task )
252+ self .tasks .add (task )
250253
251254 def safe_remove (_ ):
252255 """Remove the task from the list."""
253256 try :
254257 self .tasks .remove (task )
255- except ValueError :
258+ except KeyError :
256259 pass
257260
258261 task .add_done_callback (safe_remove )
259262
260263 def cancel (self ):
261264 """Cancel all tasks in the list."""
262- for task in self .tasks :
263- task .cancel ()
265+ to_cancel = list (self .tasks )
264266 self .tasks .clear ()
267+ for task in to_cancel :
268+ task .cancel ()
265269
266270# pylint: disable=R0902
267271class ServerInvocationContext (ObjectContext ):
@@ -358,13 +362,13 @@ def on_attempt_finished(self):
358362 async def receive_and_notify_input (self ):
359363 """Receive input from the state machine."""
360364 chunk = await self .receive ()
361- if chunk .get ('type' ) == 'http.request' :
365+ if chunk .get ('type' ) == 'http.disconnect' :
366+ raise DisconnectedException ()
367+ if chunk .get ('body' , None ) is not None :
362368 assert isinstance (chunk ['body' ], bytes )
363369 self .vm .notify_input (chunk ['body' ])
364370 if not chunk .get ('more_body' , False ):
365371 self .vm .notify_input_closed ()
366- if chunk .get ('type' ) == 'http.disconnect' :
367- raise DisconnectedException ()
368372
369373 async def take_and_send_output (self ):
370374 """Take output from state machine and send it"""
@@ -398,9 +402,6 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
398402 return
399403 if isinstance (do_progress_response , DoProgressCancelSignalReceived ):
400404 raise TerminalError ("cancelled" , 409 )
401- if isinstance (do_progress_response , DoProgressReadFromInput ):
402- await self .receive_and_notify_input ()
403- continue
404405 if isinstance (do_progress_response , DoProgressExecuteRun ):
405406 fn = self .run_coros_to_execute [do_progress_response .handle ]
406407 del self .run_coros_to_execute [do_progress_response .handle ]
@@ -414,17 +415,16 @@ async def wrapper(f):
414415 task = asyncio .create_task (wrapper (fn ))
415416 self .tasks .add (task )
416417 continue
417- if isinstance (do_progress_response , DoWaitPendingRun ):
418- sync_task = asyncio .create_task (self .sync_point .wait ())
419- read_task = asyncio .create_task (self .receive_and_notify_input ())
418+ if isinstance (do_progress_response , (DoWaitPendingRun , DoProgressReadFromInput )):
419+ sync_task = asyncio .create_task (self .sync_point .awaiter ())
420420 self .tasks .add (sync_task )
421+
422+ read_task = asyncio .create_task (self .receive_and_notify_input ())
421423 self .tasks .add (read_task )
424+
422425 done , _ = await asyncio .wait ([sync_task , read_task ], return_when = asyncio .FIRST_COMPLETED )
423426 if read_task in done :
424- _ = read_task .result () # rethrow any exception
425- if sync_task in done :
426- continue
427-
427+ _ = read_task .result () # propagate exception
428428
429429 def _create_fetch_result_coroutine (self , handle : int , serde : Serde [T ] | None = None ):
430430 """Create a coroutine that fetches a result from a notification handle."""
@@ -520,6 +520,8 @@ async def create_run_coroutine(self,
520520 except TerminalError as t :
521521 failure = Failure (code = t .status_code , message = t .message )
522522 self .vm .propose_run_completion_failure (handle , failure )
523+ except asyncio .CancelledError as e :
524+ raise e from None
523525 # pylint: disable=W0718
524526 except Exception as e :
525527 if max_attempts is None and max_retry_duration is None :
0 commit comments