@@ -221,6 +221,34 @@ def generate_output_token_counts(mean, std, num, input_token_count):
221221 return output
222222
223223
224+ def generate_output_token_counts_from_existing (
225+ distribution : List [int ], num : int , input_token_count : int
226+ ):
227+ assert len (distribution ) > 0 , "Can't have a distribution with 0 tokens"
228+ output = []
229+ # Sample without replacement so that we don't have as much variance
230+ for _ in range (num // len (distribution )):
231+ random .shuffle (distribution )
232+ output .extend (distribution )
233+ random .shuffle (distribution )
234+ output .extend (distribution [: num % len (distribution )])
235+ assert len (output ) == num
236+
237+ for i in range (len (output )):
238+ output [i ] = min (output [i ], MAX_CONTEXT_WINDOW - input_token_count )
239+ return output
240+
241+
242+ def read_distribution_from_file (fpath : str ):
243+ # Assumes the distribution is some json-formatted string that represents a list
244+ try :
245+ with open (fpath , "r" ) as fin :
246+ return json .load (fin )
247+ except FileNotFoundError :
248+ print ("File not found. Exiting." )
249+ raise
250+
251+
224252def run_benchmark (
225253 model : str ,
226254 framework : InferenceFramework ,
@@ -231,17 +259,23 @@ def run_benchmark(
231259 concurrency : int ,
232260 verbose : bool ,
233261 local_port : int ,
262+ response_token_count_distribution : Optional [List ] = None ,
234263):
235264 prompt = generate_prompt (config .input_token_count , hf_model )
236265
237266 prompt_num_tokens = config .input_token_count
238267
239- output_token_counts = generate_output_token_counts (
240- config .output_token_count_mean ,
241- config .output_token_count_std ,
242- num_trials ,
243- config .input_token_count ,
244- )
268+ if response_token_count_distribution is not None :
269+ output_token_counts = generate_output_token_counts_from_existing (
270+ response_token_count_distribution , num_trials , config .input_token_count
271+ )
272+ else :
273+ output_token_counts = generate_output_token_counts (
274+ config .output_token_count_mean ,
275+ config .output_token_count_std ,
276+ num_trials ,
277+ config .input_token_count ,
278+ )
245279
246280 start = time .time ()
247281 results = send_requests (
@@ -352,10 +386,18 @@ def run_benchmarks(
352386 verbose : bool = False ,
353387 hf_model : Optional [str ] = None ,
354388 local_port : int = 5005 ,
389+ response_token_count_distribution_file : Optional [str ] = None ,
355390):
356391 """Run benchmarks."""
357392 all_statistics = []
358393 config = BenchmarkConfig (input_token_count , output_token_count_mean )
394+
395+ response_token_count_distribution = None
396+ if response_token_count_distribution_file is not None :
397+ response_token_count_distribution = read_distribution_from_file (
398+ response_token_count_distribution_file
399+ )
400+
359401 try :
360402 if verbose :
361403 print (f"Running benchmark for config { config } " )
@@ -375,6 +417,7 @@ def run_benchmarks(
375417 concurrency ,
376418 verbose ,
377419 local_port ,
420+ response_token_count_distribution ,
378421 )
379422 all_statistics .append (statistics )
380423 except Exception :
@@ -404,6 +447,7 @@ def run_benchmarks_concurrency_range(
404447 verbose : bool = False ,
405448 hf_model : Optional [str ] = None ,
406449 local_port : int = 5005 ,
450+ response_token_count_distribution_file : Optional [str ] = None ,
407451):
408452 if output_file is not None :
409453 # Create empty file
@@ -422,6 +466,7 @@ def run_benchmarks_concurrency_range(
422466 verbose ,
423467 hf_model ,
424468 local_port ,
469+ response_token_count_distribution_file ,
425470 )
426471
427472
0 commit comments