Skip to content

Commit 4d4297e

Browse files
authored
[Bench] Split serve.py:main into async/async versions (#22405)
Signed-off-by: Linkun <[email protected]>
1 parent 2a4c825 commit 4d4297e

File tree

1 file changed

+58
-54
lines changed

1 file changed

+58
-54
lines changed

vllm/benchmarks/serve.py

Lines changed: 58 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
948948
)
949949

950950

951-
def main(args: argparse.Namespace):
951+
def main(args: argparse.Namespace) -> dict[str, Any]:
952+
return asyncio.run(main_async(args))
953+
954+
async def main_async(args: argparse.Namespace) -> dict[str, Any]:
952955
print(args)
953956
random.seed(args.seed)
954957
np.random.seed(args.seed)
@@ -1025,8 +1028,7 @@ def main(args: argparse.Namespace):
10251028
gc.collect()
10261029
gc.freeze()
10271030

1028-
benchmark_result = asyncio.run(
1029-
benchmark(
1031+
benchmark_result = await benchmark(
10301032
endpoint_type=args.endpoint_type,
10311033
api_url=api_url,
10321034
base_url=base_url,
@@ -1052,62 +1054,62 @@ def main(args: argparse.Namespace):
10521054
ramp_up_start_rps=args.ramp_up_start_rps,
10531055
ramp_up_end_rps=args.ramp_up_end_rps,
10541056
ready_check_timeout_sec=args.ready_check_timeout_sec,
1055-
))
1057+
)
10561058

10571059
# Save config and results to json
1058-
if args.save_result or args.append_result:
1059-
result_json: dict[str, Any] = {}
1060-
1061-
# Setup
1062-
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
1063-
result_json["date"] = current_dt
1064-
result_json["endpoint_type"] = args.endpoint_type
1065-
result_json["label"] = label
1066-
result_json["model_id"] = model_id
1067-
result_json["tokenizer_id"] = tokenizer_id
1068-
result_json["num_prompts"] = args.num_prompts
1069-
1070-
# Metadata
1071-
if args.metadata:
1072-
for item in args.metadata:
1073-
if "=" in item:
1074-
kvstring = item.split("=")
1075-
result_json[kvstring[0].strip()] = kvstring[1].strip()
1076-
else:
1077-
raise ValueError(
1078-
"Invalid metadata format. Please use KEY=VALUE format."
1079-
)
1080-
1081-
# Traffic
1082-
result_json["request_rate"] = (args.request_rate if args.request_rate
1083-
< float("inf") else "inf")
1084-
result_json["burstiness"] = args.burstiness
1085-
result_json["max_concurrency"] = args.max_concurrency
1060+
result_json: dict[str, Any] = {}
1061+
1062+
# Setup
1063+
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
1064+
result_json["date"] = current_dt
1065+
result_json["endpoint_type"] = args.endpoint_type
1066+
result_json["label"] = label
1067+
result_json["model_id"] = model_id
1068+
result_json["tokenizer_id"] = tokenizer_id
1069+
result_json["num_prompts"] = args.num_prompts
1070+
1071+
# Metadata
1072+
if args.metadata:
1073+
for item in args.metadata:
1074+
if "=" in item:
1075+
kvstring = item.split("=")
1076+
result_json[kvstring[0].strip()] = kvstring[1].strip()
1077+
else:
1078+
raise ValueError(
1079+
"Invalid metadata format. Please use KEY=VALUE format."
1080+
)
10861081

1087-
if args.ramp_up_strategy is not None:
1088-
result_json["ramp_up_strategy"] = args.ramp_up_strategy
1089-
result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
1090-
result_json["ramp_up_end_rps"] = args.ramp_up_end_rps
1091-
1092-
# Merge with benchmark result
1093-
result_json = {**result_json, **benchmark_result}
1094-
1095-
if not args.save_detailed:
1096-
# Remove fields with too many data points
1097-
for field in [
1098-
"input_lens",
1099-
"output_lens",
1100-
"ttfts",
1101-
"itls",
1102-
"generated_texts",
1103-
"errors",
1104-
]:
1105-
if field in result_json:
1106-
del result_json[field]
1107-
if field in benchmark_result:
1108-
del benchmark_result[field]
1082+
# Traffic
1083+
result_json["request_rate"] = (args.request_rate if args.request_rate
1084+
< float("inf") else "inf")
1085+
result_json["burstiness"] = args.burstiness
1086+
result_json["max_concurrency"] = args.max_concurrency
1087+
1088+
if args.ramp_up_strategy is not None:
1089+
result_json["ramp_up_strategy"] = args.ramp_up_strategy
1090+
result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
1091+
result_json["ramp_up_end_rps"] = args.ramp_up_end_rps
1092+
1093+
# Merge with benchmark result
1094+
result_json = {**result_json, **benchmark_result}
1095+
1096+
if not args.save_detailed:
1097+
# Remove fields with too many data points
1098+
for field in [
1099+
"input_lens",
1100+
"output_lens",
1101+
"ttfts",
1102+
"itls",
1103+
"generated_texts",
1104+
"errors",
1105+
]:
1106+
if field in result_json:
1107+
del result_json[field]
1108+
if field in benchmark_result:
1109+
del benchmark_result[field]
11091110

11101111
# Save to file
1112+
if args.save_result or args.append_result:
11111113
base_model_id = model_id.split("/")[-1]
11121114
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
11131115
if args.max_concurrency is not None else "")
@@ -1129,3 +1131,5 @@ def main(args: argparse.Namespace):
11291131
outfile.write("\n")
11301132
json.dump(result_json, outfile)
11311133
save_to_pytorch_benchmark_format(args, result_json, file_name)
1134+
1135+
return result_json

0 commit comments

Comments
 (0)