diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index aeae6cfd..9d8ce129 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -287,7 +287,7 @@ def unbind_connecting_task() -> None: else: logger.debug("unbind_connecting_task failed, id did not match") - if not self._connecting_task: + if not self._connecting_task or self._connecting_task.done(): self._connecting_task = asyncio.create_task( _do_ensure_connected( transport_options=self._transport_options, @@ -308,9 +308,16 @@ def unbind_connecting_task() -> None: ) ) - await self._connecting_task + try: + await self._connecting_task + except asyncio.CancelledError: + pass + if self._terminating_task: - await self._terminating_task + try: + await self._terminating_task + except asyncio.CancelledError: + pass def is_terminal(self) -> bool: """ @@ -403,7 +410,10 @@ async def close( "seconds to close, leaking", ) return - await self._close_internal(reason) + try: + await self._close_internal(reason) + except asyncio.CancelledError: + pass def _close_internal_nowait(self, reason: Exception | None = None) -> None: """ @@ -501,10 +511,9 @@ async def do_close() -> None: # This will get us GC'd, so this should be the last thing. self._close_session_callback(self) - if self._terminating_task: - return self._terminating_task + if not self._terminating_task: + self._terminating_task = asyncio.create_task(do_close()) - self._terminating_task = asyncio.create_task(do_close()) return self._terminating_task def _start_buffered_message_sender(