1313 CreateBatchCompletionsRequestContent ,
1414 TokenOutput ,
1515)
16+ from tqdm import tqdm
1617
1718CONFIG_FILE = os .getenv ("CONFIG_FILE" )
1819AWS_REGION = os .getenv ("AWS_REGION" , "us-west-2" )
@@ -123,11 +124,16 @@ async def batch_inference():
123124
124125 results_generators = await generate_with_vllm (request , content , model , job_index )
125126
127+ bar = tqdm (total = len (content .prompts ), desc = "Processed prompts" )
128+
126129 outputs = []
127130 for generator in results_generators :
128131 last_output_text = ""
129132 tokens = []
130133 async for request_output in generator :
134+ if request_output .finished :
135+ bar .update (1 )
136+
131137 token_text = request_output .outputs [- 1 ].text [len (last_output_text ) :]
132138 log_probs = (
133139 request_output .outputs [0 ].logprobs [- 1 ] if content .return_token_log_probs else None
@@ -155,6 +161,8 @@ async def batch_inference():
155161
156162 outputs .append (output .dict ())
157163
164+ bar .close ()
165+
158166 if request .data_parallelism == 1 :
159167 with smart_open .open (request .output_data_path , "w" ) as f :
160168 f .write (json .dumps (outputs ))
@@ -178,6 +186,7 @@ async def generate_with_vllm(request, content, model, job_index):
178186 quantization = request .model_config .quantize ,
179187 tensor_parallel_size = request .model_config .num_shards ,
180188 seed = request .model_config .seed or 0 ,
189+ disable_log_requests = True ,
181190 )
182191
183192 llm = AsyncLLMEngine .from_engine_args (engine_args )
0 commit comments