Skip to content

Commit 7c2ec0f

Browse files
authored
[Benchmarking] Add disable_shuffle option for dataset loading (#26258)
Signed-off-by: Yasmin Moslem <[email protected]>
1 parent 039b6ba commit 7c2ec0f

File tree

1 file changed

+43
-13
lines changed

1 file changed

+43
-13
lines changed

vllm/benchmarks/datasets.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def __init__(
9696
self,
9797
dataset_path: Optional[str] = None,
9898
random_seed: int = DEFAULT_SEED,
99+
disable_shuffle: bool = False,
100+
**kwargs,
99101
) -> None:
100102
"""
101103
Initialize the BenchmarkDataset with an optional dataset path and random
@@ -111,6 +113,7 @@ def __init__(
111113
# Set the random seed, ensuring that a None value is replaced with the
112114
# default seed.
113115
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
116+
self.disable_shuffle = disable_shuffle
114117
self.data = None
115118

116119
def apply_multimodal_chat_transformation(
@@ -1044,7 +1047,8 @@ def load_data(self) -> None:
10441047
if "conversations" in entry and len(entry["conversations"]) >= 2
10451048
]
10461049
random.seed(self.random_seed)
1047-
random.shuffle(self.data)
1050+
if not getattr(self, "disable_shuffle", False):
1051+
random.shuffle(self.data)
10481052

10491053
def sample(
10501054
self,
@@ -1175,6 +1179,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
11751179
action="store_true",
11761180
help="Skip applying chat template to prompt for datasets that support it.",
11771181
)
1182+
parser.add_argument(
1183+
"--disable-shuffle",
1184+
action="store_true",
1185+
help="Disable shuffling of dataset samples for deterministic ordering.",
1186+
)
11781187

11791188
# group for dataset specific arguments
11801189
custom_group = parser.add_argument_group("custom dataset options")
@@ -1441,7 +1450,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14411450
args.request_id_prefix = ""
14421451

14431452
if args.dataset_name == "custom":
1444-
dataset = CustomDataset(dataset_path=args.dataset_path)
1453+
dataset = CustomDataset(
1454+
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
1455+
)
14451456
input_requests = dataset.sample(
14461457
num_requests=args.num_prompts,
14471458
tokenizer=tokenizer,
@@ -1452,7 +1463,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14521463
)
14531464

14541465
elif args.dataset_name == "sonnet":
1455-
dataset = SonnetDataset(dataset_path=args.dataset_path)
1466+
dataset = SonnetDataset(
1467+
dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle
1468+
)
14561469
# For the "sonnet" dataset, formatting depends on the backend.
14571470
if args.backend == "openai-chat":
14581471
input_requests = dataset.sample(
@@ -1586,6 +1599,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
15861599
random_seed=args.seed,
15871600
no_stream=args.no_stream,
15881601
hf_name=args.hf_name,
1602+
disable_shuffle=args.disable_shuffle,
15891603
).sample(
15901604
num_requests=args.num_prompts,
15911605
tokenizer=tokenizer,
@@ -1600,7 +1614,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
16001614
# For datasets that follow a similar structure, use a mapping.
16011615
dataset_mapping = {
16021616
"spec_bench": lambda: SpecBench(
1603-
dataset_path=args.dataset_path, category=args.spec_bench_category
1617+
dataset_path=args.dataset_path,
1618+
category=args.spec_bench_category,
1619+
disable_shuffle=args.disable_shuffle,
16041620
).sample(
16051621
num_requests=args.num_prompts,
16061622
tokenizer=tokenizer,
@@ -1609,7 +1625,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
16091625
no_oversample=args.no_oversample,
16101626
),
16111627
"sharegpt": lambda: ShareGPTDataset(
1612-
random_seed=args.seed, dataset_path=args.dataset_path
1628+
random_seed=args.seed,
1629+
dataset_path=args.dataset_path,
1630+
disable_shuffle=args.disable_shuffle,
16131631
).sample(
16141632
tokenizer=tokenizer,
16151633
num_requests=args.num_prompts,
@@ -1618,15 +1636,19 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
16181636
no_oversample=args.no_oversample,
16191637
),
16201638
"burstgpt": lambda: BurstGPTDataset(
1621-
random_seed=args.seed, dataset_path=args.dataset_path
1639+
random_seed=args.seed,
1640+
dataset_path=args.dataset_path,
1641+
disable_shuffle=args.disable_shuffle,
16221642
).sample(
16231643
tokenizer=tokenizer,
16241644
num_requests=args.num_prompts,
16251645
request_id_prefix=args.request_id_prefix,
16261646
no_oversample=args.no_oversample,
16271647
),
16281648
"random": lambda: RandomDataset(
1629-
random_seed=args.seed, dataset_path=args.dataset_path
1649+
random_seed=args.seed,
1650+
dataset_path=args.dataset_path,
1651+
disable_shuffle=args.disable_shuffle,
16301652
).sample(
16311653
tokenizer=tokenizer,
16321654
num_requests=args.num_prompts,
@@ -1639,7 +1661,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
16391661
no_oversample=args.no_oversample,
16401662
),
16411663
"random-mm": lambda: RandomMultiModalDataset(
1642-
random_seed=args.seed, dataset_path=args.dataset_path
1664+
random_seed=args.seed,
1665+
dataset_path=args.dataset_path,
1666+
disable_shuffle=args.disable_shuffle,
16431667
).sample(
16441668
tokenizer=tokenizer,
16451669
num_requests=args.num_prompts,
@@ -1655,7 +1679,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
16551679
no_oversample=args.no_oversample,
16561680
),
16571681
"prefix_repetition": lambda: PrefixRepetitionRandomDataset(
1658-
random_seed=args.seed, dataset_path=args.dataset_path
1682+
random_seed=args.seed,
1683+
dataset_path=args.dataset_path,
1684+
disable_shuffle=args.disable_shuffle,
16591685
).sample(
16601686
tokenizer=tokenizer,
16611687
num_requests=args.num_prompts,
@@ -1733,7 +1759,8 @@ def load_data(self) -> None:
17331759
)
17341760

17351761
random.seed(self.random_seed)
1736-
random.shuffle(self.data)
1762+
if not getattr(self, "disable_shuffle", False):
1763+
random.shuffle(self.data)
17371764

17381765
def sample(
17391766
self,
@@ -1825,7 +1852,8 @@ def load_data(self) -> None:
18251852
self.data.append({"prompt": prompt})
18261853

18271854
random.seed(self.random_seed)
1828-
random.shuffle(self.data)
1855+
if not getattr(self, "disable_shuffle", False):
1856+
random.shuffle(self.data)
18291857

18301858
def sample(self, **kwargs) -> list:
18311859
# leverage CustomDataset sample
@@ -2033,7 +2061,8 @@ def load_data(self) -> None:
20332061
split=self.dataset_split,
20342062
streaming=self.load_stream,
20352063
)
2036-
self.data = self.data.shuffle(seed=self.random_seed)
2064+
if not getattr(self, "disable_shuffle", False):
2065+
self.data = self.data.shuffle(seed=self.random_seed)
20372066

20382067

20392068
# -----------------------------------------------------------------------------
@@ -2849,7 +2878,8 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]:
28492878
abs(token_mismatch_total),
28502879
sign,
28512880
)
2852-
random.shuffle(requests)
2881+
if not getattr(self, "disable_shuffle", False):
2882+
random.shuffle(requests)
28532883
return requests
28542884

28552885

0 commit comments

Comments
 (0)