Skip to content

Commit 3ecbb14

Browse files
authored
[Benchmarks] add benchmark for embedding models (#23000)
Signed-off-by: zjy0516 <[email protected]>
1 parent 7d67a9d commit 3ecbb14

File tree

3 files changed

+274
-107
lines changed

3 files changed

+274
-107
lines changed

vllm/benchmarks/datasets.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class SampleRequest:
7373
Represents a single inference request for benchmarking.
7474
"""
7575

76-
prompt: Union[str, Any]
76+
prompt: Union[str, list[str]]
7777
prompt_len: int
7878
expected_output_len: int
7979
multi_modal_data: Optional[
@@ -409,6 +409,7 @@ def sample(
409409
range_ratio: float = DEFAULT_RANGE_RATIO,
410410
input_len: int = DEFAULT_INPUT_LEN,
411411
output_len: int = DEFAULT_OUTPUT_LEN,
412+
batchsize: int = 1,
412413
**kwargs,
413414
) -> list[SampleRequest]:
414415

@@ -439,6 +440,21 @@ def sample(
439440
request_id=request_id_prefix + str(i),
440441
)
441442
)
443+
# only used for embeddings benchmark.
444+
if batchsize > 1:
445+
batch_requests = []
446+
# Create batched requests
447+
for i in range(0, num_requests, batchsize):
448+
batch = requests[i : i + batchsize]
449+
batch_requests.append(
450+
SampleRequest(
451+
prompt=[req.prompt for req in batch],
452+
prompt_len=sum(req.prompt_len for req in batch),
453+
expected_output_len=0,
454+
request_id=request_id_prefix + str(i // batchsize),
455+
)
456+
)
457+
requests = batch_requests
442458
return requests
443459

444460
def get_prefix(
@@ -475,8 +491,8 @@ def get_sampling_params(
475491
input_high = math.ceil(real_input_len * (1 + range_ratio))
476492
output_low = math.floor(output_len * (1 - range_ratio))
477493
output_high = math.ceil(output_len * (1 + range_ratio))
478-
# Ensure the lower bound for output length is at least 1 to
479-
# prevent sampling 0 tokens.
494+
# Ensure the lower bound for output length is at least 1 to
495+
# prevent sampling 0 tokens.
480496
output_low = max(output_low, 1)
481497

482498
if input_low > input_high:
@@ -506,7 +522,6 @@ def get_sampling_params(
506522
size=num_requests)
507523
return input_lens, output_lens, offsets
508524

509-
510525
def generate_token_sequence(
511526
self,
512527
*,
@@ -1105,6 +1120,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
11051120
"context length sampled from [input_len * (1 - range_ratio), "
11061121
"input_len * (1 + range_ratio)]."),
11071122
)
1123+
random_group.add_argument(
1124+
"--random-batch-size",
1125+
type=int,
1126+
default=1,
1127+
help=("Batch size for random sampling. "
1128+
"Only used for embeddings benchmark."),
1129+
)
11081130

11091131
# random multimodal dataset options
11101132
random_mm_group = parser.add_argument_group(
@@ -1196,8 +1218,6 @@ def normalize(d: dict) -> dict[tuple[int, int, int], float]:
11961218
),
11971219
)
11981220

1199-
1200-
12011221
hf_group = parser.add_argument_group("hf dataset options")
12021222
hf_group.add_argument("--hf-subset",
12031223
type=str,
@@ -1348,29 +1368,32 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
13481368
else:
13491369
# For datasets that follow a similar structure, use a mapping.
13501370
dataset_mapping = {
1351-
"sharegpt":
1352-
lambda: ShareGPTDataset(random_seed=args.seed,
1353-
dataset_path=args.dataset_path).sample(
1354-
tokenizer=tokenizer,
1355-
num_requests=args.num_prompts,
1356-
output_len=args.sharegpt_output_len,
1357-
request_id_prefix=args.request_id_prefix,
1358-
),
1359-
"burstgpt":
1360-
lambda: BurstGPTDataset(random_seed=args.seed,
1361-
dataset_path=args.dataset_path).
1362-
sample(tokenizer=tokenizer, num_requests=args.num_prompts,
1363-
request_id_prefix=args.request_id_prefix,),
1364-
"random":
1365-
lambda: RandomDataset(random_seed=args.seed,
1366-
dataset_path=args.dataset_path).sample(
1371+
"sharegpt": lambda: ShareGPTDataset(
1372+
random_seed=args.seed, dataset_path=args.dataset_path
1373+
).sample(
1374+
tokenizer=tokenizer,
1375+
num_requests=args.num_prompts,
1376+
output_len=args.sharegpt_output_len,
1377+
request_id_prefix=args.request_id_prefix,
1378+
),
1379+
"burstgpt": lambda: BurstGPTDataset(
1380+
random_seed=args.seed, dataset_path=args.dataset_path
1381+
).sample(
1382+
tokenizer=tokenizer,
1383+
num_requests=args.num_prompts,
1384+
request_id_prefix=args.request_id_prefix,
1385+
),
1386+
"random": lambda: RandomDataset(
1387+
random_seed=args.seed, dataset_path=args.dataset_path
1388+
).sample(
13671389
tokenizer=tokenizer,
13681390
num_requests=args.num_prompts,
13691391
prefix_len=args.random_prefix_len,
13701392
input_len=args.random_input_len,
13711393
output_len=args.random_output_len,
13721394
range_ratio=args.random_range_ratio,
13731395
request_id_prefix=args.request_id_prefix,
1396+
batchsize=args.random_batch_size,
13741397
),
13751398
"random-mm":
13761399
lambda: RandomMultiModalDataset(

vllm/benchmarks/lib/endpoint_request_func.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ async def async_request_openai_completions(
6969
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
7070

7171
payload = {
72-
"model": request_func_input.model_name \
73-
if request_func_input.model_name else request_func_input.model,
72+
"model": request_func_input.model_name
73+
if request_func_input.model_name else request_func_input.model,
7474
"prompt": request_func_input.prompt,
7575
"temperature": 0.0,
7676
"repetition_penalty": 1.0,
@@ -135,7 +135,7 @@ async def async_request_openai_completions(
135135
# Decoding phase
136136
else:
137137
output.itl.append(timestamp -
138-
most_recent_timestamp)
138+
most_recent_timestamp)
139139

140140
most_recent_timestamp = timestamp
141141
generated_text += text or ""
@@ -254,7 +254,7 @@ async def async_request_openai_chat_completions(
254254
# Decoding phase
255255
else:
256256
output.itl.append(timestamp -
257-
most_recent_timestamp)
257+
most_recent_timestamp)
258258

259259
generated_text += content or ""
260260
elif usage := data.get("usage"):
@@ -394,12 +394,61 @@ def to_bytes(y, sr):
394394
return output
395395

396396

397+
async def async_request_openai_embeddings(
398+
request_func_input: RequestFuncInput,
399+
session: aiohttp.ClientSession,
400+
pbar: Optional[tqdm] = None,
401+
):
402+
api_url = request_func_input.api_url
403+
assert api_url.endswith(
404+
"embeddings"
405+
), "OpenAI Embeddings API URL must end with 'embeddings'."
406+
407+
headers = {
408+
"Content-Type": "application/json",
409+
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
410+
}
411+
412+
payload = {
413+
"model": request_func_input.model,
414+
"input": request_func_input.prompt,
415+
}
416+
417+
output = RequestFuncOutput()
418+
st = time.perf_counter()
419+
try:
420+
async with session.post(
421+
url=api_url,
422+
headers=headers,
423+
json=payload
424+
) as response:
425+
if response.status == 200:
426+
output.latency = time.perf_counter() - st
427+
data = await response.json()
428+
output.success = True
429+
output.generated_text = ""
430+
output.prompt_len = data.get(
431+
"usage", {}).get(
432+
"prompt_tokens", 0)
433+
else:
434+
output.success = False
435+
output.error = response.reason or ""
436+
except Exception as e:
437+
output.success = False
438+
output.error = str(e)
439+
440+
if pbar:
441+
pbar.update(1)
442+
return output
443+
444+
397445
# TODO: Add more request functions for different API protocols.
398446
ASYNC_REQUEST_FUNCS = {
399447
"vllm": async_request_openai_completions,
400448
"openai": async_request_openai_completions,
401449
"openai-chat": async_request_openai_chat_completions,
402450
"openai-audio": async_request_openai_audio,
451+
"openai-embeddings": async_request_openai_embeddings,
403452
}
404453

405454
OPENAI_COMPATIBLE_BACKENDS = [

0 commit comments

Comments
 (0)