6
6
from typing import List , Optional , Tuple
7
7
8
8
import torch
9
- from transformers import AutoModelForCausalLM , PreTrainedTokenizerBase
9
+ from transformers import (AutoModelForCausalLM , AutoTokenizer ,
10
+ PreTrainedTokenizerBase )
10
11
from tqdm import tqdm
11
12
12
- from vllm import LLM , SamplingParams
13
- from vllm .transformers_utils .tokenizer import get_tokenizer
14
-
15
13
16
14
def sample_requests (
17
15
dataset_path : str ,
18
16
num_requests : int ,
19
17
tokenizer : PreTrainedTokenizerBase ,
18
+ fixed_output_len : Optional [int ],
20
19
) -> List [Tuple [str , int , int ]]:
20
+ if fixed_output_len is not None :
21
+ if fixed_output_len < 4 :
22
+ raise ValueError ("output_len too small" )
23
+
21
24
# Load the dataset.
22
25
with open (dataset_path ) as f :
23
26
dataset = json .load (f )
@@ -35,6 +38,8 @@ def sample_requests(
35
38
tokenized_dataset = []
36
39
for i in range (len (dataset )):
37
40
output_len = len (completion_token_ids [i ])
41
+ if fixed_output_len is not None :
42
+ output_len = fixed_output_len
38
43
tokenized_dataset .append ((prompts [i ], prompt_token_ids [i ], output_len ))
39
44
40
45
# Filter out too long sequences.
@@ -66,6 +71,7 @@ def run_vllm(
66
71
trust_remote_code : bool ,
67
72
dtype : str ,
68
73
) -> float :
74
+ from vllm import LLM , SamplingParams
69
75
llm = LLM (
70
76
model = model ,
71
77
tokenizer = tokenizer ,
@@ -160,14 +166,37 @@ def run_hf(
160
166
return end - start
161
167
162
168
169
+ def run_mii (
170
+ requests : List [Tuple [str , int , int ]],
171
+ model : str ,
172
+ tensor_parallel_size : int ,
173
+ output_len : int ,
174
+ ) -> float :
175
+ from mii import pipeline
176
+ llm = pipeline (model , tensor_parallel = tensor_parallel_size )
177
+ prompts = [prompt for prompt , _ , _ in requests ]
178
+
179
+ start = time .perf_counter ()
180
+ llm (prompts , max_new_tokens = output_len )
181
+ end = time .perf_counter ()
182
+ return end - start
183
+
184
+
163
185
def main (args : argparse .Namespace ):
164
186
print (args )
165
187
random .seed (args .seed )
166
188
167
189
# Sample the requests.
168
- tokenizer = get_tokenizer (args .tokenizer ,
169
- trust_remote_code = args .trust_remote_code )
170
- requests = sample_requests (args .dataset , args .num_prompts , tokenizer )
190
+ tokenizer = AutoTokenizer .from_pretrained (
191
+ args .tokenizer , trust_remote_code = args .trust_remote_code )
192
+ if args .dataset is None :
193
+ # Synthesize a prompt with the given input length.
194
+ prompt = "hi" * (args .input_len - 1 )
195
+ requests = [(prompt , args .input_len , args .output_len )
196
+ for _ in range (args .num_prompts )]
197
+ else :
198
+ requests = sample_requests (args .dataset , args .num_prompts , tokenizer ,
199
+ args .output_len )
171
200
172
201
if args .backend == "vllm" :
173
202
elapsed_time = run_vllm (requests , args .model , args .tokenizer ,
@@ -179,6 +208,9 @@ def main(args: argparse.Namespace):
179
208
elapsed_time = run_hf (requests , args .model , tokenizer , args .n ,
180
209
args .use_beam_search , args .hf_max_batch_size ,
181
210
args .trust_remote_code )
211
+ elif args .backend == "mii" :
212
+ elapsed_time = run_mii (requests , args .model , args .tensor_parallel_size ,
213
+ args .output_len )
182
214
else :
183
215
raise ValueError (f"Unknown backend: { args .backend } " )
184
216
total_num_tokens = sum (prompt_len + output_len
@@ -191,12 +223,21 @@ def main(args: argparse.Namespace):
191
223
parser = argparse .ArgumentParser (description = "Benchmark the throughput." )
192
224
parser .add_argument ("--backend" ,
193
225
type = str ,
194
- choices = ["vllm" , "hf" ],
226
+ choices = ["vllm" , "hf" , "mii" ],
195
227
default = "vllm" )
196
228
parser .add_argument ("--dataset" ,
197
229
type = str ,
198
- required = True ,
230
+ default = None ,
199
231
help = "Path to the dataset." )
232
+ parser .add_argument ("--input-len" ,
233
+ type = int ,
234
+ default = None ,
235
+ help = "Input prompt length for each request" )
236
+ parser .add_argument ("--output-len" ,
237
+ type = int ,
238
+ default = None ,
239
+ help = "Output length for each request. Overrides the "
240
+ "output length from the dataset." )
200
241
parser .add_argument ("--model" , type = str , default = "facebook/opt-125m" )
201
242
parser .add_argument ("--tokenizer" , type = str , default = None )
202
243
parser .add_argument ('--quantization' ,
@@ -231,6 +272,13 @@ def main(args: argparse.Namespace):
231
272
'for FP32 and FP16 models, and BF16 precision '
232
273
'for BF16 models.' )
233
274
args = parser .parse_args ()
275
+ if args .tokenizer is None :
276
+ args .tokenizer = args .model
277
+ if args .dataset is None :
278
+ assert args .input_len is not None
279
+ assert args .output_len is not None
280
+ else :
281
+ assert args .input_len is None
234
282
235
283
if args .backend == "vllm" :
236
284
if args .hf_max_batch_size is not None :
@@ -240,7 +288,18 @@ def main(args: argparse.Namespace):
240
288
raise ValueError ("HF max batch size is required for HF backend." )
241
289
if args .quantization is not None :
242
290
raise ValueError ("Quantization is only for vLLM backend." )
243
- if args .tokenizer is None :
244
- args .tokenizer = args .model
245
-
291
+ elif args .backend == "mii" :
292
+ if args .dtype != "auto" :
293
+ raise ValueError ("dtype must be auto for MII backend." )
294
+ if args .n != 1 :
295
+ raise ValueError ("n must be 1 for MII backend." )
296
+ if args .use_beam_search :
297
+ raise ValueError ("Beam search is not supported for MII backend." )
298
+ if args .quantization is not None :
299
+ raise ValueError ("Quantization is only for vLLM backend." )
300
+ if args .hf_max_batch_size is not None :
301
+ raise ValueError ("HF max batch size is only for HF backend." )
302
+ if args .tokenizer != args .model :
303
+ raise ValueError ("Tokenizer must be the same as the model for MII "
304
+ "backend." )
246
305
main (args )
0 commit comments