Skip to content

Commit 6662be6

Browse files
MML-codersjmonson
andcommitted
Add unique single-token prefix to every request
Co-authored-by: Mehul <[email protected]> Co-authored-by: Samuel Monson <[email protected]> Signed-off-by: Samuel Monson <[email protected]>
1 parent 94a4508 commit 6662be6

File tree

3 files changed

+646
-4
lines changed

3 files changed

+646
-4
lines changed

src/guidellm/dataset/synthetic.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import random
33
from collections.abc import Iterable, Iterator
4+
from itertools import cycle
45
from pathlib import Path
56
from typing import Any, Literal, Optional, Union
67

@@ -168,6 +169,7 @@ def __iter__(
168169
)
169170
# ensure diff distribution from output tokens
170171
rand = random.Random(self.random_seed + 2) # noqa: S311
172+
unique_prefix_iter = cycle(self.processor.get_vocab().values())
171173

172174
prefix_index = rand.randint(0, len(self.text_creator.words))
173175
prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index)
@@ -179,7 +181,10 @@ def __iter__(
179181
):
180182
start_index = rand.randint(0, len(self.text_creator.words))
181183
prompt_text = self.processor.decode(
182-
prefix_tokens + self._create_prompt(prompt_tokens, start_index),
184+
prefix_tokens
185+
+ self._create_prompt(
186+
prompt_tokens, start_index, next(unique_prefix_iter)
187+
),
183188
skip_special_tokens=True,
184189
)
185190
yield {
@@ -188,17 +193,20 @@ def __iter__(
188193
"output_tokens_count": output_tokens,
189194
}
190195

191-
def _create_prompt(self, prompt_tokens: int, start_index: int) -> list[int]:
196+
def _create_prompt(
197+
self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None
198+
) -> list[int]:
192199
if prompt_tokens <= 0:
193200
return []
194201

195202
left = start_index
196203
right = start_index + 4 * prompt_tokens
204+
start_tokens = [unique_prefix] if unique_prefix else []
197205

198206
while left < right:
199207
mid = (left + right) // 2
200208
test_prompt = self.text_creator.create_text(start_index, mid - start_index)
201-
test_tokens = self.processor.encode(test_prompt)
209+
test_tokens = start_tokens + self.processor.encode(test_prompt)
202210

203211
if len(test_tokens) == prompt_tokens:
204212
return test_tokens
@@ -208,7 +216,7 @@ def _create_prompt(self, prompt_tokens: int, start_index: int) -> list[int]:
208216
right = mid
209217

210218
final_text = self.text_creator.create_text(start_index, left - start_index)
211-
return self.processor.encode(final_text)
219+
return start_tokens + self.processor.encode(final_text)
212220

213221

214222
class SyntheticDatasetCreator(DatasetCreator):

tests/unit/dataset/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)