@@ -73,7 +73,7 @@ class SampleRequest:
73
73
Represents a single inference request for benchmarking.
74
74
"""
75
75
76
- prompt : Union [str , Any ]
76
+ prompt : Union [str , list [ str ] ]
77
77
prompt_len : int
78
78
expected_output_len : int
79
79
multi_modal_data : Optional [
@@ -409,6 +409,7 @@ def sample(
409
409
range_ratio : float = DEFAULT_RANGE_RATIO ,
410
410
input_len : int = DEFAULT_INPUT_LEN ,
411
411
output_len : int = DEFAULT_OUTPUT_LEN ,
412
+ batchsize : int = 1 ,
412
413
** kwargs ,
413
414
) -> list [SampleRequest ]:
414
415
@@ -439,6 +440,21 @@ def sample(
439
440
request_id = request_id_prefix + str (i ),
440
441
)
441
442
)
443
+ # only used for embeddings benchmark.
444
+ if batchsize > 1 :
445
+ batch_requests = []
446
+ # Create batched requests
447
+ for i in range (0 , num_requests , batchsize ):
448
+ batch = requests [i : i + batchsize ]
449
+ batch_requests .append (
450
+ SampleRequest (
451
+ prompt = [req .prompt for req in batch ],
452
+ prompt_len = sum (req .prompt_len for req in batch ),
453
+ expected_output_len = 0 ,
454
+ request_id = request_id_prefix + str (i // batchsize ),
455
+ )
456
+ )
457
+ requests = batch_requests
442
458
return requests
443
459
444
460
def get_prefix (
@@ -475,8 +491,8 @@ def get_sampling_params(
475
491
input_high = math .ceil (real_input_len * (1 + range_ratio ))
476
492
output_low = math .floor (output_len * (1 - range_ratio ))
477
493
output_high = math .ceil (output_len * (1 + range_ratio ))
478
- # Ensure the lower bound for output length is at least 1 to
479
- # prevent sampling 0 tokens.
494
+ # Ensure the lower bound for output length is at least 1 to
495
+ # prevent sampling 0 tokens.
480
496
output_low = max (output_low , 1 )
481
497
482
498
if input_low > input_high :
@@ -506,7 +522,6 @@ def get_sampling_params(
506
522
size = num_requests )
507
523
return input_lens , output_lens , offsets
508
524
509
-
510
525
def generate_token_sequence (
511
526
self ,
512
527
* ,
@@ -1105,6 +1120,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
1105
1120
"context length sampled from [input_len * (1 - range_ratio), "
1106
1121
"input_len * (1 + range_ratio)]." ),
1107
1122
)
1123
+ random_group .add_argument (
1124
+ "--random-batch-size" ,
1125
+ type = int ,
1126
+ default = 1 ,
1127
+ help = ("Batch size for random sampling. "
1128
+ "Only used for embeddings benchmark." ),
1129
+ )
1108
1130
1109
1131
# random multimodal dataset options
1110
1132
random_mm_group = parser .add_argument_group (
@@ -1196,8 +1218,6 @@ def normalize(d: dict) -> dict[tuple[int, int, int], float]:
1196
1218
),
1197
1219
)
1198
1220
1199
-
1200
-
1201
1221
hf_group = parser .add_argument_group ("hf dataset options" )
1202
1222
hf_group .add_argument ("--hf-subset" ,
1203
1223
type = str ,
@@ -1348,29 +1368,32 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
1348
1368
else :
1349
1369
# For datasets that follow a similar structure, use a mapping.
1350
1370
dataset_mapping = {
1351
- "sharegpt" :
1352
- lambda : ShareGPTDataset (random_seed = args .seed ,
1353
- dataset_path = args .dataset_path ).sample (
1354
- tokenizer = tokenizer ,
1355
- num_requests = args .num_prompts ,
1356
- output_len = args .sharegpt_output_len ,
1357
- request_id_prefix = args .request_id_prefix ,
1358
- ),
1359
- "burstgpt" :
1360
- lambda : BurstGPTDataset (random_seed = args .seed ,
1361
- dataset_path = args .dataset_path ).
1362
- sample (tokenizer = tokenizer , num_requests = args .num_prompts ,
1363
- request_id_prefix = args .request_id_prefix ,),
1364
- "random" :
1365
- lambda : RandomDataset (random_seed = args .seed ,
1366
- dataset_path = args .dataset_path ).sample (
1371
+ "sharegpt" : lambda : ShareGPTDataset (
1372
+ random_seed = args .seed , dataset_path = args .dataset_path
1373
+ ).sample (
1374
+ tokenizer = tokenizer ,
1375
+ num_requests = args .num_prompts ,
1376
+ output_len = args .sharegpt_output_len ,
1377
+ request_id_prefix = args .request_id_prefix ,
1378
+ ),
1379
+ "burstgpt" : lambda : BurstGPTDataset (
1380
+ random_seed = args .seed , dataset_path = args .dataset_path
1381
+ ).sample (
1382
+ tokenizer = tokenizer ,
1383
+ num_requests = args .num_prompts ,
1384
+ request_id_prefix = args .request_id_prefix ,
1385
+ ),
1386
+ "random" : lambda : RandomDataset (
1387
+ random_seed = args .seed , dataset_path = args .dataset_path
1388
+ ).sample (
1367
1389
tokenizer = tokenizer ,
1368
1390
num_requests = args .num_prompts ,
1369
1391
prefix_len = args .random_prefix_len ,
1370
1392
input_len = args .random_input_len ,
1371
1393
output_len = args .random_output_len ,
1372
1394
range_ratio = args .random_range_ratio ,
1373
1395
request_id_prefix = args .request_id_prefix ,
1396
+ batchsize = args .random_batch_size ,
1374
1397
),
1375
1398
"random-mm" :
1376
1399
lambda : RandomMultiModalDataset (
0 commit comments