Skip to content

Commit a3eed17

Browse files
committed
Add prefix before decode
Signed-off-by: Samuel Monson <[email protected]>
1 parent a5d5772 commit a3eed17

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

src/guidellm/dataset/synthetic.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -170,41 +170,45 @@ def __iter__(
170170
rand = random.Random(self.random_seed + 2) # noqa: S311
171171

172172
prefix_index = rand.randint(0, len(self.text_creator.words))
173-
prefix_tokens = self.config.prefix_tokens
174-
prefix = self._create_prompt(prefix_tokens, prefix_index)
173+
prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index)
175174

176175
for _, prompt_tokens, output_tokens in zip(
177176
range(self.config.samples),
178177
prompt_tokens_sampler,
179178
output_tokens_sampler,
180179
):
181180
start_index = rand.randint(0, len(self.text_creator.words))
181+
prompt_text = self.processor.decode(
182+
prefix_tokens + self._create_prompt(prompt_tokens, start_index),
183+
skip_special_tokens=True,
184+
)
182185
yield {
183-
"prompt": prefix + self._create_prompt(prompt_tokens, start_index),
184-
"prompt_tokens_count": prefix_tokens + prompt_tokens,
186+
"prompt": prompt_text,
187+
"prompt_tokens_count": self.config.prefix_tokens + prompt_tokens,
185188
"output_tokens_count": output_tokens,
186189
}
187190

188-
def _create_prompt(self, prompt_tokens: int, start_index: int) -> str:
191+
def _create_prompt(self, prompt_tokens: int, start_index: int) -> list[int]:
189192
if prompt_tokens <= 0:
190-
return ""
193+
return []
191194

192195
left = start_index
193196
right = start_index + 4 * prompt_tokens
194197

195198
while left < right:
196199
mid = (left + right) // 2
197200
test_prompt = self.text_creator.create_text(start_index, mid - start_index)
198-
test_tokens = len(self.processor.tokenize(test_prompt))
201+
test_tokens = self.processor.encode(test_prompt)
199202

200-
if test_tokens == prompt_tokens:
201-
return test_prompt
202-
elif test_tokens < prompt_tokens:
203+
if len(test_tokens) == prompt_tokens:
204+
return test_tokens
205+
elif len(test_tokens) < prompt_tokens:
203206
left = mid + 1
204207
else:
205208
right = mid
206209

207-
return self.text_creator.create_text(start_index, left - start_index)
210+
final_text = self.text_creator.create_text(start_index, left - start_index)
211+
return self.processor.encode(final_text)
208212

209213

210214
class SyntheticDatasetCreator(DatasetCreator):

0 commit comments

Comments
 (0)