Skip to content

Commit 2b8f4f7

Browse files
committed
Cancel event processing task
1 parent 51a3601 commit 2b8f4f7

File tree

1 file changed

+35
-7
lines changed

1 file changed

+35
-7
lines changed

llmstack/server/consumers.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,19 @@ async def connect(self):
128128
self.scope.get("user", None),
129129
self._preview,
130130
)
131+
self._event_response_task = None
132+
self._connected = True
131133
await self.accept()
132134

133135
async def disconnect(self, close_code):
136+
self._connected = False
134137
if self._app_runner:
135138
await self._app_runner.stop()
136139

137140
async def stop(self):
141+
self._connected = False
142+
if self._event_response_task:
143+
self._event_response_task.cancel()
138144
await self.close()
139145

140146
async def _respond_to_event(self, text_data):
@@ -151,6 +157,10 @@ async def _respond_to_event(self, text_data):
151157
try:
152158
response_iterator = self._app_runner.run(app_runner_request)
153159
async for response in response_iterator:
160+
# Check both cancellation and connection state
161+
if asyncio.current_task().cancelled() or not self._connected:
162+
break
163+
154164
if response.type == AppRunnerStreamingResponseType.OUTPUT_STREAM_CHUNK:
155165
await self.send(text_data=response.model_dump_json())
156166
elif response.type == AppRunnerStreamingResponseType.ERRORS:
@@ -356,7 +366,7 @@ async def _respond_to_event_old(self, text_data):
356366
self._coordinator_ref.stop()
357367

358368
async def receive(self, text_data):
359-
run_coro_in_new_loop(self._respond_to_event(text_data))
369+
self._event_response_task = run_coro_in_new_loop(self._respond_to_event(text_data), name="respond_to_event")
360370

361371

362372
class AssetStreamConsumer(AsyncWebsocketConsumer):
@@ -514,10 +524,17 @@ async def connect(self):
514524
processor_slug="",
515525
provider_slug="",
516526
)
527+
self._event_response_task = None
528+
self._app_runner = None
529+
self._connected = True
517530
await self.accept()
518531

519532
async def disconnect(self, close_code):
520-
pass
533+
self._connected = False
534+
if self._event_response_task:
535+
self._event_response_task.cancel()
536+
if self._app_runner:
537+
await self._app_runner.stop()
521538

522539
async def _respond_to_event(self, text_data):
523540
from llmstack.apps.apis import PlaygroundViewSet
@@ -544,12 +561,15 @@ async def _respond_to_event(self, text_data):
544561
client_request_id=client_request_id, session_id=session_id, input=input_data
545562
)
546563

547-
app_runner = await PlaygroundViewSet().get_app_runner_async(
564+
self._app_runner = await PlaygroundViewSet().get_app_runner_async(
548565
session_id, source, self.scope.get("user", None), input_data, config_data
549566
)
550567
try:
551-
response_iterator = app_runner.run(app_runner_request)
568+
response_iterator = self._app_runner.run(app_runner_request)
552569
async for response in response_iterator:
570+
if not self._connected:
571+
break
572+
553573
if response.type == AppRunnerStreamingResponseType.OUTPUT_STREAM_CHUNK:
554574
await self.send(text_data=response.model_dump_json())
555575
elif response.type == AppRunnerStreamingResponseType.OUTPUT:
@@ -560,10 +580,10 @@ async def _respond_to_event(self, text_data):
560580
)
561581
except Exception as e:
562582
logger.exception(f"Failed to run app: {e}")
563-
await app_runner.stop()
583+
await self._app_runner.stop()
564584

565585
async def receive(self, text_data):
566-
run_coro_in_new_loop(self._respond_to_event(text_data))
586+
self._event_response_task = run_coro_in_new_loop(self._respond_to_event(text_data))
567587

568588

569589
class StoreAppConsumer(AppConsumer):
@@ -592,6 +612,8 @@ async def connect(self):
592612
self._app_runner = await AppStoreAppViewSet().get_app_runner_async(
593613
self._session_id, self._app_slug, self._source, self.scope.get("user", None)
594614
)
615+
self._connected = True
616+
self._event_response_task = None
595617
await self.accept()
596618

597619

@@ -625,15 +647,18 @@ async def connect(self):
625647
},
626648
)
627649

650+
self._connected = True
651+
self._event_response_task = None
628652
await self.accept()
629653

630654
async def disconnect(self, close_code):
655+
self._connected = False
631656
if self._app_runner:
632657
await self._app_runner.stop()
633658
await self.close(code=close_code)
634659

635660
async def receive(self, text_data):
636-
run_coro_in_new_loop(self._respond_to_event(text_data))
661+
self._event_response_task = run_coro_in_new_loop(self._respond_to_event(text_data))
637662

638663
async def _respond_to_event(self, text_data):
639664
from llmstack.assets.stream import AssetStream
@@ -654,6 +679,9 @@ async def _respond_to_event(self, text_data):
654679

655680
# Iterate till we get the objrefs for input and output audio
656681
async for response in response_iterator:
682+
if not self._connected:
683+
break
684+
657685
if response.type == AppRunnerStreamingResponseType.OUTPUT_STREAM_CHUNK:
658686
deltas = response.data.deltas
659687
if "agent_input_audio_stream" in deltas:

0 commit comments

Comments
 (0)