Skip to content

Commit 7f3960b

Browse files
committed
Add basic prefix support back to synthetic data
Signed-off-by: Samuel Monson <[email protected]>
1 parent 0443820 commit 7f3960b

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

src/guidellm/data/deserializers/synthetic.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525

2626

2727
class SyntheticTextDatasetConfig(StandardBaseModel):
28+
prefix_tokens: int = Field(
29+
description="The number of shared prefix tokens to prepend to each prompt.",
30+
ge=0,
31+
default=0,
32+
)
2833
prompt_tokens: int = Field(
2934
description="The average number of text tokens generated for prompts.",
3035
gt=0,
@@ -104,20 +109,29 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
104109
)
105110
)
106111

112+
# Create a shared prefix if specified
113+
if self.config.prefix_tokens > 0:
114+
prefix = self._create_prompt(self.config.prefix_tokens, faker)
115+
else:
116+
prefix = "" # Always have a prefix key for consistency
117+
107118
while True:
108119
prompt_tokens_count = next(prompt_tokens_sampler)
109120
output_tokens_count = next(output_tokens_sampler)
110121

111122
yield {
123+
"prefix": prefix,
112124
"prompt": self._create_prompt(
113-
prompt_tokens_count, samples_generated, faker
125+
prompt_tokens_count, faker, f"{samples_generated} "
114126
),
115127
"prompt_tokens_count": prompt_tokens_count,
116128
"output_tokens_count": output_tokens_count,
117129
}
118130
samples_generated += 1
119131

120-
def _create_prompt(self, prompt_tokens_count: int, index: int, faker: Faker) -> str:
132+
def _create_prompt(
133+
self, prompt_tokens_count: int, faker: Faker, unique: str = ""
134+
) -> str:
121135
prompt_token_ids = []
122136
avg_chars_per_token = 5
123137
margin_of_safety = 1.5
@@ -128,7 +142,7 @@ def _create_prompt(self, prompt_tokens_count: int, index: int, faker: Faker) ->
128142
num_chars = (
129143
prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts
130144
)
131-
text = f"{index} " + faker.text(max_nb_chars=num_chars)
145+
text = unique + faker.text(max_nb_chars=num_chars)
132146
prompt_token_ids = self.processor.encode(text)
133147

134148
return self.processor.decode(
@@ -166,6 +180,7 @@ def __call__(
166180
),
167181
features=Features(
168182
{
183+
"prefix": Value("string"),
169184
"prompt": Value("string"),
170185
"prompt_tokens_count": Value("int32"),
171186
"output_tokens_count": Value("int32"),

0 commit comments

Comments
 (0)