3838
3939class LLMClient :
4040 def __init__ (self , flags : argparse .Namespace ):
41- self ._client = grpcclient .InferenceServerClient (
42- url = flags .url , verbose = flags .verbose
43- )
4441 self ._flags = flags
45- self ._loop = asyncio .get_event_loop ()
4642 self ._results_dict = {}
4743
44+ def get_triton_client (self ):
45+ try :
46+ triton_client = grpcclient .InferenceServerClient (
47+ url = self ._flags .url ,
48+ verbose = self ._flags .verbose ,
49+ )
50+ except Exception as e :
51+ print ("channel creation failed: " + str (e ))
52+ sys .exit ()
53+
54+ return triton_client
55+
4856 async def async_request_iterator (
4957 self , prompts , sampling_parameters , exclude_input_in_output
5058 ):
@@ -65,8 +73,9 @@ async def async_request_iterator(
6573
6674 async def stream_infer (self , prompts , sampling_parameters , exclude_input_in_output ):
6775 try :
76+ triton_client = self .get_triton_client ()
6877 # Start streaming
69- response_iterator = self . _client .stream_infer (
78+ response_iterator = triton_client .stream_infer (
7079 inputs_iterator = self .async_request_iterator (
7180 prompts , sampling_parameters , exclude_input_in_output
7281 ),
@@ -138,7 +147,7 @@ async def run(self):
138147 print ("FAIL: vLLM example" )
139148
140149 def run_async (self ):
141- self . _loop . run_until_complete (self .run ())
150+ asyncio . run (self .run ())
142151
143152 def create_request (
144153 self ,
0 commit comments