@@ -287,9 +287,13 @@ def response_loop(self):
287287 # To signal shutdown a None item will be added to the queue.
288288 if item is None :
289289 break
290- response_sender , response , response_flag = item
290+ response_state , response , response_flag = item
291+ response_sender = response_state ["response_sender" ]
291292 try :
292293 response_sender .send (response , response_flag )
294+ # Stop checking for cancellation if the last response is generated.
295+ if not response_state ["last_response_generated" ]:
296+ response_state ["is_cancelled" ] = response_sender .is_cancelled ()
293297 except Exception as e :
294298 self .logger .log_error (
295299 f"An error occurred while sending a response: { e } "
@@ -338,6 +342,11 @@ async def generate(self, request):
338342 Forwards single request to LLM engine and returns responses.
339343 """
340344 response_sender = request .get_response_sender ()
345+ response_state = {
346+ "response_sender" : response_sender ,
347+ "is_cancelled" : False ,
348+ "last_response_generated" : False , # last response ready but not yet sent
349+ }
341350 self .ongoing_request_count += 1
342351 decrement_ongoing_request_count = True
343352 try :
@@ -399,10 +408,26 @@ async def generate(self, request):
399408 )
400409
401410 async for output in response_iterator :
402- if response_sender .is_cancelled ():
411+ is_cancelled = response_state ["is_cancelled" ]
412+ if not stream :
413+ is_cancelled = response_sender .is_cancelled ()
414+ if is_cancelled :
403415 self .logger .log_info ("[vllm] Cancelling the request" )
404416 await self .llm_engine .abort (request_id )
405417 self .logger .log_info ("[vllm] Successfully cancelled the request" )
418+ if stream :
419+ response_state ["last_response_generated" ] = True
420+ response = pb_utils .InferenceResponse (
421+ error = pb_utils .TritonError (
422+ message = "Request was cancelled" ,
423+ code = pb_utils .TritonError .CANCELLED ,
424+ )
425+ )
426+ flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
427+ decrement_ongoing_request_count = False
428+ self ._response_queue .put_nowait (
429+ (response_state , response , flags )
430+ )
406431 break
407432 if stream :
408433 prev_outputs_lengths = None
@@ -414,9 +439,10 @@ async def generate(self, request):
414439 response = self .create_stream_response (output , prev_outputs_lengths )
415440 flags = 0
416441 if output .finished :
442+ response_state ["last_response_generated" ] = True
417443 flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
418444 decrement_ongoing_request_count = False
419- self ._response_queue .put_nowait ((response_sender , response , flags ))
445+ self ._response_queue .put_nowait ((response_state , response , flags ))
420446 prev_outputs = output
421447
422448 last_output = output
0 commit comments