@@ -96,6 +96,8 @@ def __init__(
96
96
self ,
97
97
dataset_path : Optional [str ] = None ,
98
98
random_seed : int = DEFAULT_SEED ,
99
+ disable_shuffle : bool = False ,
100
+ ** kwargs ,
99
101
) -> None :
100
102
"""
101
103
Initialize the BenchmarkDataset with an optional dataset path and random
@@ -111,6 +113,7 @@ def __init__(
111
113
# Set the random seed, ensuring that a None value is replaced with the
112
114
# default seed.
113
115
self .random_seed = random_seed if random_seed is not None else self .DEFAULT_SEED
116
+ self .disable_shuffle = disable_shuffle
114
117
self .data = None
115
118
116
119
def apply_multimodal_chat_transformation (
@@ -1044,7 +1047,8 @@ def load_data(self) -> None:
1044
1047
if "conversations" in entry and len (entry ["conversations" ]) >= 2
1045
1048
]
1046
1049
random .seed (self .random_seed )
1047
- random .shuffle (self .data )
1050
+ if not getattr (self , "disable_shuffle" , False ):
1051
+ random .shuffle (self .data )
1048
1052
1049
1053
def sample (
1050
1054
self ,
@@ -1175,6 +1179,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
1175
1179
action = "store_true" ,
1176
1180
help = "Skip applying chat template to prompt for datasets that support it." ,
1177
1181
)
1182
+ parser .add_argument (
1183
+ "--disable-shuffle" ,
1184
+ action = "store_true" ,
1185
+ help = "Disable shuffling of dataset samples for deterministic ordering." ,
1186
+ )
1178
1187
1179
1188
# group for dataset specific arguments
1180
1189
custom_group = parser .add_argument_group ("custom dataset options" )
@@ -1441,7 +1450,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
1441
1450
args .request_id_prefix = ""
1442
1451
1443
1452
if args .dataset_name == "custom" :
1444
- dataset = CustomDataset (dataset_path = args .dataset_path )
1453
+ dataset = CustomDataset (
1454
+ dataset_path = args .dataset_path , disable_shuffle = args .disable_shuffle
1455
+ )
1445
1456
input_requests = dataset .sample (
1446
1457
num_requests = args .num_prompts ,
1447
1458
tokenizer = tokenizer ,
@@ -1452,7 +1463,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
1452
1463
)
1453
1464
1454
1465
elif args .dataset_name == "sonnet" :
1455
- dataset = SonnetDataset (dataset_path = args .dataset_path )
1466
+ dataset = SonnetDataset (
1467
+ dataset_path = args .dataset_path , disable_shuffle = args .disable_shuffle
1468
+ )
1456
1469
# For the "sonnet" dataset, formatting depends on the backend.
1457
1470
if args .backend == "openai-chat" :
1458
1471
input_requests = dataset .sample (
@@ -1586,6 +1599,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
1586
1599
random_seed = args .seed ,
1587
1600
no_stream = args .no_stream ,
1588
1601
hf_name = args .hf_name ,
1602
+ disable_shuffle = args .disable_shuffle ,
1589
1603
).sample (
1590
1604
num_requests = args .num_prompts ,
1591
1605
tokenizer = tokenizer ,
@@ -1600,7 +1614,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
1600
1614
# For datasets that follow a similar structure, use a mapping.
1601
1615
dataset_mapping = {
1602
1616
"spec_bench" : lambda : SpecBench (
1603
- dataset_path = args .dataset_path , category = args .spec_bench_category
1617
+ dataset_path = args .dataset_path ,
1618
+ category = args .spec_bench_category ,
1619
+ disable_shuffle = args .disable_shuffle ,
1604
1620
).sample (
1605
1621
num_requests = args .num_prompts ,
1606
1622
tokenizer = tokenizer ,
@@ -1609,7 +1625,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
1609
1625
no_oversample = args .no_oversample ,
1610
1626
),
1611
1627
"sharegpt" : lambda : ShareGPTDataset (
1612
- random_seed = args .seed , dataset_path = args .dataset_path
1628
+ random_seed = args .seed ,
1629
+ dataset_path = args .dataset_path ,
1630
+ disable_shuffle = args .disable_shuffle ,
1613
1631
).sample (
1614
1632
tokenizer = tokenizer ,
1615
1633
num_requests = args .num_prompts ,
@@ -1618,15 +1636,19 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
1618
1636
no_oversample = args .no_oversample ,
1619
1637
),
1620
1638
"burstgpt" : lambda : BurstGPTDataset (
1621
- random_seed = args .seed , dataset_path = args .dataset_path
1639
+ random_seed = args .seed ,
1640
+ dataset_path = args .dataset_path ,
1641
+ disable_shuffle = args .disable_shuffle ,
1622
1642
).sample (
1623
1643
tokenizer = tokenizer ,
1624
1644
num_requests = args .num_prompts ,
1625
1645
request_id_prefix = args .request_id_prefix ,
1626
1646
no_oversample = args .no_oversample ,
1627
1647
),
1628
1648
"random" : lambda : RandomDataset (
1629
- random_seed = args .seed , dataset_path = args .dataset_path
1649
+ random_seed = args .seed ,
1650
+ dataset_path = args .dataset_path ,
1651
+ disable_shuffle = args .disable_shuffle ,
1630
1652
).sample (
1631
1653
tokenizer = tokenizer ,
1632
1654
num_requests = args .num_prompts ,
@@ -1639,7 +1661,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
1639
1661
no_oversample = args .no_oversample ,
1640
1662
),
1641
1663
"random-mm" : lambda : RandomMultiModalDataset (
1642
- random_seed = args .seed , dataset_path = args .dataset_path
1664
+ random_seed = args .seed ,
1665
+ dataset_path = args .dataset_path ,
1666
+ disable_shuffle = args .disable_shuffle ,
1643
1667
).sample (
1644
1668
tokenizer = tokenizer ,
1645
1669
num_requests = args .num_prompts ,
@@ -1655,7 +1679,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
1655
1679
no_oversample = args .no_oversample ,
1656
1680
),
1657
1681
"prefix_repetition" : lambda : PrefixRepetitionRandomDataset (
1658
- random_seed = args .seed , dataset_path = args .dataset_path
1682
+ random_seed = args .seed ,
1683
+ dataset_path = args .dataset_path ,
1684
+ disable_shuffle = args .disable_shuffle ,
1659
1685
).sample (
1660
1686
tokenizer = tokenizer ,
1661
1687
num_requests = args .num_prompts ,
@@ -1733,7 +1759,8 @@ def load_data(self) -> None:
1733
1759
)
1734
1760
1735
1761
random .seed (self .random_seed )
1736
- random .shuffle (self .data )
1762
+ if not getattr (self , "disable_shuffle" , False ):
1763
+ random .shuffle (self .data )
1737
1764
1738
1765
def sample (
1739
1766
self ,
@@ -1825,7 +1852,8 @@ def load_data(self) -> None:
1825
1852
self .data .append ({"prompt" : prompt })
1826
1853
1827
1854
random .seed (self .random_seed )
1828
- random .shuffle (self .data )
1855
+ if not getattr (self , "disable_shuffle" , False ):
1856
+ random .shuffle (self .data )
1829
1857
1830
1858
def sample (self , ** kwargs ) -> list :
1831
1859
# leverage CustomDataset sample
@@ -2033,7 +2061,8 @@ def load_data(self) -> None:
2033
2061
split = self .dataset_split ,
2034
2062
streaming = self .load_stream ,
2035
2063
)
2036
- self .data = self .data .shuffle (seed = self .random_seed )
2064
+ if not getattr (self , "disable_shuffle" , False ):
2065
+ self .data = self .data .shuffle (seed = self .random_seed )
2037
2066
2038
2067
2039
2068
# -----------------------------------------------------------------------------
@@ -2849,7 +2878,8 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]:
2849
2878
abs (token_mismatch_total ),
2850
2879
sign ,
2851
2880
)
2852
- random .shuffle (requests )
2881
+ if not getattr (self , "disable_shuffle" , False ):
2882
+ random .shuffle (requests )
2853
2883
return requests
2854
2884
2855
2885
0 commit comments