2525# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
2727import asyncio
28+ import gc
2829import json
2930import os
31+ import queue
3032import threading
3133from typing import Dict , List
3234
@@ -115,13 +117,19 @@ def initialize(self, args):
115117 # Counter to keep track of ongoing request counts
116118 self .ongoing_request_count = 0
117119
120+ # Starting the response thread. It allows vLLM to keep making progress while
121+ # response sender(s) are sending responses to server frontend.
122+ self ._response_queue = queue .Queue ()
123+ self ._response_thread = threading .Thread (target = self .response_loop )
124+ self ._response_thread .start ()
125+
118126 # Starting asyncio event loop to process the received requests asynchronously.
119127 self ._loop = asyncio .get_event_loop ()
120- self ._loop_thread = threading .Thread (
128+ self ._event_thread = threading .Thread (
121129 target = self .engine_loop , args = (self ._loop ,)
122130 )
123131 self ._shutdown_event = asyncio .Event ()
124- self ._loop_thread .start ()
132+ self ._event_thread .start ()
125133
126134 def init_engine (self ):
127135 # Currently, Triton needs to use decoupled policy for asynchronously
@@ -290,6 +298,27 @@ def get_sampling_params_dict(self, params_json):
290298
291299 return params_dict
292300
301+ def response_loop (self ):
302+ while True :
303+ item = self ._response_queue .get ()
304+ # To signal shutdown a None item will be added to the queue.
305+ if item is None :
306+ break
307+ response_state , response , response_flag = item
308+ response_sender = response_state ["response_sender" ]
309+ try :
310+ response_sender .send (response , response_flag )
311+ # Stop checking for cancellation if the last response is generated.
312+ if not response_state ["last_response_generated" ]:
313+ response_state ["is_cancelled" ] = response_sender .is_cancelled ()
314+ except Exception as e :
315+ self .logger .log_error (
316+ f"An error occurred while sending a response: { e } "
317+ )
318+ finally :
319+ if response_flag == pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL :
320+ self .ongoing_request_count -= 1
321+
293322 def create_response (self , vllm_output , prepend_input ):
294323 """
295324 Parses the output from the vLLM engine into Triton
@@ -330,7 +359,13 @@ async def generate(self, request):
330359 Forwards single request to LLM engine and returns responses.
331360 """
332361 response_sender = request .get_response_sender ()
362+ response_state = {
363+ "response_sender" : response_sender ,
364+ "is_cancelled" : False ,
365+ "last_response_generated" : False , # last response ready but not yet sent
366+ }
333367 self .ongoing_request_count += 1
368+ decrement_ongoing_request_count = True
334369 try :
335370 request_id = random_uuid ()
336371 prompt = pb_utils .get_input_tensor_by_name (
@@ -385,13 +420,31 @@ async def generate(self, request):
385420 lora_local_path = self .lora_repository [lora_name ]
386421 lora_request = LoRARequest (lora_id , lora_int_id , lora_local_path )
387422
388- async for output in self .llm_engine .generate (
389- prompt , sampling_params , request_id , lora_request = lora_request
390- ):
391- if response_sender .is_cancelled ():
423+ response_iterator = await self .llm_engine .add_request (
424+ request_id , prompt , sampling_params , lora_request = lora_request
425+ )
426+
427+ async for output in response_iterator :
428+ is_cancelled = response_state ["is_cancelled" ]
429+ if not stream :
430+ is_cancelled = response_sender .is_cancelled ()
431+ if is_cancelled :
392432 self .logger .log_info ("[vllm] Cancelling the request" )
393433 await self .llm_engine .abort (request_id )
394434 self .logger .log_info ("[vllm] Successfully cancelled the request" )
435+ if stream :
436+ response_state ["last_response_generated" ] = True
437+ response = pb_utils .InferenceResponse (
438+ error = pb_utils .TritonError (
439+ message = "Request was cancelled" ,
440+ code = pb_utils .TritonError .CANCELLED ,
441+ )
442+ )
443+ flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
444+ decrement_ongoing_request_count = False
445+ self ._response_queue .put_nowait (
446+ (response_state , response , flags )
447+ )
395448 break
396449 if stream :
397450 prev_outputs_lengths = None
@@ -400,15 +453,13 @@ async def generate(self, request):
400453 len (prev_output .text )
401454 for prev_output in prev_outputs .outputs
402455 ]
456+ response = self .create_stream_response (output , prev_outputs_lengths )
457+ flags = 0
403458 if output .finished :
404- response_sender .send (
405- self .create_stream_response (output , prev_outputs_lengths ),
406- flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL ,
407- )
408- else :
409- response_sender .send (
410- self .create_stream_response (output , prev_outputs_lengths )
411- )
459+ response_state ["last_response_generated" ] = True
460+ flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
461+ decrement_ongoing_request_count = False
462+ self ._response_queue .put_nowait ((response_state , response , flags ))
412463 prev_outputs = output
413464
414465 last_output = output
@@ -420,7 +471,7 @@ async def generate(self, request):
420471 )
421472
422473 except Exception as e :
423- self .logger .log_info (f"[vllm] Error generating stream: { e } " )
474+ self .logger .log_error (f"[vllm] Error generating stream: { e } " )
424475 error = pb_utils .TritonError (f"Error generating stream: { e } " )
425476 triton_output_tensor = pb_utils .Tensor (
426477 "text_output" , np .asarray (["N/A" ], dtype = self .output_dtype )
@@ -433,7 +484,8 @@ async def generate(self, request):
433484 )
434485 raise e
435486 finally :
436- self .ongoing_request_count -= 1
487+ if decrement_ongoing_request_count :
488+ self .ongoing_request_count -= 1
437489
438490 def verify_loras (self , request ):
439491 # We will check if the requested lora exists here, if not we will send a
@@ -500,6 +552,20 @@ def finalize(self):
500552 """
501553 self .logger .log_info ("[vllm] Issuing finalize to vllm backend" )
502554 self ._shutdown_event .set ()
503- if self ._loop_thread is not None :
504- self ._loop_thread .join ()
505- self ._loop_thread = None
555+
556+ # Shutdown the event thread.
557+ if self ._event_thread is not None :
558+ self ._event_thread .join ()
559+ self ._event_thread = None
560+
561+ # Shutdown the response thread.
562+ self ._response_queue .put (None )
563+ if self ._response_thread is not None :
564+ self ._response_thread .join ()
565+ self ._response_thread = None
566+
567+ # When using parallel tensors, the stub process may not shutdown due to
568+ # unreleased references, so manually run the garbage collector once.
569+ self .logger .log_info ("[vllm] Running Garbage Collector on finalize..." )
570+ gc .collect ()
571+ self .logger .log_info ("[vllm] Garbage Collector on finalize... done" )
0 commit comments