Skip to content

Commit 660a7fc

Browse files
authored
Add DeepSpeed MII backend to benchmark script (#1649)
1 parent 054072b commit 660a7fc

File tree

1 file changed

+71
-12
lines changed

1 file changed

+71
-12
lines changed

benchmarks/benchmark_throughput.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,21 @@
66
from typing import List, Optional, Tuple
77

88
import torch
9-
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
9+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
10+
PreTrainedTokenizerBase)
1011
from tqdm import tqdm
1112

12-
from vllm import LLM, SamplingParams
13-
from vllm.transformers_utils.tokenizer import get_tokenizer
14-
1513

1614
def sample_requests(
1715
dataset_path: str,
1816
num_requests: int,
1917
tokenizer: PreTrainedTokenizerBase,
18+
fixed_output_len: Optional[int],
2019
) -> List[Tuple[str, int, int]]:
20+
if fixed_output_len is not None:
21+
if fixed_output_len < 4:
22+
raise ValueError("output_len too small")
23+
2124
# Load the dataset.
2225
with open(dataset_path) as f:
2326
dataset = json.load(f)
@@ -35,6 +38,8 @@ def sample_requests(
3538
tokenized_dataset = []
3639
for i in range(len(dataset)):
3740
output_len = len(completion_token_ids[i])
41+
if fixed_output_len is not None:
42+
output_len = fixed_output_len
3843
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
3944

4045
# Filter out too long sequences.
@@ -66,6 +71,7 @@ def run_vllm(
6671
trust_remote_code: bool,
6772
dtype: str,
6873
) -> float:
74+
from vllm import LLM, SamplingParams
6975
llm = LLM(
7076
model=model,
7177
tokenizer=tokenizer,
@@ -160,14 +166,37 @@ def run_hf(
160166
return end - start
161167

162168

169+
def run_mii(
170+
requests: List[Tuple[str, int, int]],
171+
model: str,
172+
tensor_parallel_size: int,
173+
output_len: int,
174+
) -> float:
175+
from mii import pipeline
176+
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
177+
prompts = [prompt for prompt, _, _ in requests]
178+
179+
start = time.perf_counter()
180+
llm(prompts, max_new_tokens=output_len)
181+
end = time.perf_counter()
182+
return end - start
183+
184+
163185
def main(args: argparse.Namespace):
164186
print(args)
165187
random.seed(args.seed)
166188

167189
# Sample the requests.
168-
tokenizer = get_tokenizer(args.tokenizer,
169-
trust_remote_code=args.trust_remote_code)
170-
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
190+
tokenizer = AutoTokenizer.from_pretrained(
191+
args.tokenizer, trust_remote_code=args.trust_remote_code)
192+
if args.dataset is None:
193+
# Synthesize a prompt with the given input length.
194+
prompt = "hi" * (args.input_len - 1)
195+
requests = [(prompt, args.input_len, args.output_len)
196+
for _ in range(args.num_prompts)]
197+
else:
198+
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
199+
args.output_len)
171200

172201
if args.backend == "vllm":
173202
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
@@ -179,6 +208,9 @@ def main(args: argparse.Namespace):
179208
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
180209
args.use_beam_search, args.hf_max_batch_size,
181210
args.trust_remote_code)
211+
elif args.backend == "mii":
212+
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
213+
args.output_len)
182214
else:
183215
raise ValueError(f"Unknown backend: {args.backend}")
184216
total_num_tokens = sum(prompt_len + output_len
@@ -191,12 +223,21 @@ def main(args: argparse.Namespace):
191223
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
192224
parser.add_argument("--backend",
193225
type=str,
194-
choices=["vllm", "hf"],
226+
choices=["vllm", "hf", "mii"],
195227
default="vllm")
196228
parser.add_argument("--dataset",
197229
type=str,
198-
required=True,
230+
default=None,
199231
help="Path to the dataset.")
232+
parser.add_argument("--input-len",
233+
type=int,
234+
default=None,
235+
help="Input prompt length for each request")
236+
parser.add_argument("--output-len",
237+
type=int,
238+
default=None,
239+
help="Output length for each request. Overrides the "
240+
"output length from the dataset.")
200241
parser.add_argument("--model", type=str, default="facebook/opt-125m")
201242
parser.add_argument("--tokenizer", type=str, default=None)
202243
parser.add_argument('--quantization',
@@ -231,6 +272,13 @@ def main(args: argparse.Namespace):
231272
'for FP32 and FP16 models, and BF16 precision '
232273
'for BF16 models.')
233274
args = parser.parse_args()
275+
if args.tokenizer is None:
276+
args.tokenizer = args.model
277+
if args.dataset is None:
278+
assert args.input_len is not None
279+
assert args.output_len is not None
280+
else:
281+
assert args.input_len is None
234282

235283
if args.backend == "vllm":
236284
if args.hf_max_batch_size is not None:
@@ -240,7 +288,18 @@ def main(args: argparse.Namespace):
240288
raise ValueError("HF max batch size is required for HF backend.")
241289
if args.quantization is not None:
242290
raise ValueError("Quantization is only for vLLM backend.")
243-
if args.tokenizer is None:
244-
args.tokenizer = args.model
245-
291+
elif args.backend == "mii":
292+
if args.dtype != "auto":
293+
raise ValueError("dtype must be auto for MII backend.")
294+
if args.n != 1:
295+
raise ValueError("n must be 1 for MII backend.")
296+
if args.use_beam_search:
297+
raise ValueError("Beam search is not supported for MII backend.")
298+
if args.quantization is not None:
299+
raise ValueError("Quantization is only for vLLM backend.")
300+
if args.hf_max_batch_size is not None:
301+
raise ValueError("HF max batch size is only for HF backend.")
302+
if args.tokenizer != args.model:
303+
raise ValueError("Tokenizer must be the same as the model for MII "
304+
"backend.")
246305
main(args)

0 commit comments

Comments
 (0)