@@ -948,7 +948,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
948
948
)
949
949
950
950
951
- def main (args : argparse .Namespace ):
951
+ def main (args : argparse .Namespace ) -> dict [str , Any ]:
952
+ return asyncio .run (main_async (args ))
953
+
954
+ async def main_async (args : argparse .Namespace ) -> dict [str , Any ]:
952
955
print (args )
953
956
random .seed (args .seed )
954
957
np .random .seed (args .seed )
@@ -1025,8 +1028,7 @@ def main(args: argparse.Namespace):
1025
1028
gc .collect ()
1026
1029
gc .freeze ()
1027
1030
1028
- benchmark_result = asyncio .run (
1029
- benchmark (
1031
+ benchmark_result = await benchmark (
1030
1032
endpoint_type = args .endpoint_type ,
1031
1033
api_url = api_url ,
1032
1034
base_url = base_url ,
@@ -1052,62 +1054,62 @@ def main(args: argparse.Namespace):
1052
1054
ramp_up_start_rps = args .ramp_up_start_rps ,
1053
1055
ramp_up_end_rps = args .ramp_up_end_rps ,
1054
1056
ready_check_timeout_sec = args .ready_check_timeout_sec ,
1055
- ))
1057
+ )
1056
1058
1057
1059
# Save config and results to json
1058
- if args .save_result or args .append_result :
1059
- result_json : dict [str , Any ] = {}
1060
-
1061
- # Setup
1062
- current_dt = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
1063
- result_json ["date" ] = current_dt
1064
- result_json ["endpoint_type" ] = args .endpoint_type
1065
- result_json ["label" ] = label
1066
- result_json ["model_id" ] = model_id
1067
- result_json ["tokenizer_id" ] = tokenizer_id
1068
- result_json ["num_prompts" ] = args .num_prompts
1069
-
1070
- # Metadata
1071
- if args .metadata :
1072
- for item in args .metadata :
1073
- if "=" in item :
1074
- kvstring = item .split ("=" )
1075
- result_json [kvstring [0 ].strip ()] = kvstring [1 ].strip ()
1076
- else :
1077
- raise ValueError (
1078
- "Invalid metadata format. Please use KEY=VALUE format."
1079
- )
1080
-
1081
- # Traffic
1082
- result_json ["request_rate" ] = (args .request_rate if args .request_rate
1083
- < float ("inf" ) else "inf" )
1084
- result_json ["burstiness" ] = args .burstiness
1085
- result_json ["max_concurrency" ] = args .max_concurrency
1060
+ result_json : dict [str , Any ] = {}
1061
+
1062
+ # Setup
1063
+ current_dt = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
1064
+ result_json ["date" ] = current_dt
1065
+ result_json ["endpoint_type" ] = args .endpoint_type
1066
+ result_json ["label" ] = label
1067
+ result_json ["model_id" ] = model_id
1068
+ result_json ["tokenizer_id" ] = tokenizer_id
1069
+ result_json ["num_prompts" ] = args .num_prompts
1070
+
1071
+ # Metadata
1072
+ if args .metadata :
1073
+ for item in args .metadata :
1074
+ if "=" in item :
1075
+ kvstring = item .split ("=" )
1076
+ result_json [kvstring [0 ].strip ()] = kvstring [1 ].strip ()
1077
+ else :
1078
+ raise ValueError (
1079
+ "Invalid metadata format. Please use KEY=VALUE format."
1080
+ )
1086
1081
1087
- if args .ramp_up_strategy is not None :
1088
- result_json ["ramp_up_strategy" ] = args .ramp_up_strategy
1089
- result_json ["ramp_up_start_rps" ] = args .ramp_up_start_rps
1090
- result_json ["ramp_up_end_rps" ] = args .ramp_up_end_rps
1091
-
1092
- # Merge with benchmark result
1093
- result_json = {** result_json , ** benchmark_result }
1094
-
1095
- if not args .save_detailed :
1096
- # Remove fields with too many data points
1097
- for field in [
1098
- "input_lens" ,
1099
- "output_lens" ,
1100
- "ttfts" ,
1101
- "itls" ,
1102
- "generated_texts" ,
1103
- "errors" ,
1104
- ]:
1105
- if field in result_json :
1106
- del result_json [field ]
1107
- if field in benchmark_result :
1108
- del benchmark_result [field ]
1082
+ # Traffic
1083
+ result_json ["request_rate" ] = (args .request_rate if args .request_rate
1084
+ < float ("inf" ) else "inf" )
1085
+ result_json ["burstiness" ] = args .burstiness
1086
+ result_json ["max_concurrency" ] = args .max_concurrency
1087
+
1088
+ if args .ramp_up_strategy is not None :
1089
+ result_json ["ramp_up_strategy" ] = args .ramp_up_strategy
1090
+ result_json ["ramp_up_start_rps" ] = args .ramp_up_start_rps
1091
+ result_json ["ramp_up_end_rps" ] = args .ramp_up_end_rps
1092
+
1093
+ # Merge with benchmark result
1094
+ result_json = {** result_json , ** benchmark_result }
1095
+
1096
+ if not args .save_detailed :
1097
+ # Remove fields with too many data points
1098
+ for field in [
1099
+ "input_lens" ,
1100
+ "output_lens" ,
1101
+ "ttfts" ,
1102
+ "itls" ,
1103
+ "generated_texts" ,
1104
+ "errors" ,
1105
+ ]:
1106
+ if field in result_json :
1107
+ del result_json [field ]
1108
+ if field in benchmark_result :
1109
+ del benchmark_result [field ]
1109
1110
1110
1111
# Save to file
1112
+ if args .save_result or args .append_result :
1111
1113
base_model_id = model_id .split ("/" )[- 1 ]
1112
1114
max_concurrency_str = (f"-concurrency{ args .max_concurrency } "
1113
1115
if args .max_concurrency is not None else "" )
@@ -1129,3 +1131,5 @@ def main(args: argparse.Namespace):
1129
1131
outfile .write ("\n " )
1130
1132
json .dump (result_json , outfile )
1131
1133
save_to_pytorch_benchmark_format (args , result_json , file_name )
1134
+
1135
+ return result_json
0 commit comments