11import json
22import random
33from collections .abc import Iterable , Iterator
4+ from itertools import cycle
45from pathlib import Path
56from typing import Any , Literal , Optional , Union
67
@@ -168,6 +169,7 @@ def __iter__(
168169 )
169170 # ensure diff distribution from output tokens
170171 rand = random .Random (self .random_seed + 2 ) # noqa: S311
172+ unique_prefix_iter = cycle (self .processor .get_vocab ().values ())
171173
172174 prefix_index = rand .randint (0 , len (self .text_creator .words ))
173175 prefix_tokens = self ._create_prompt (self .config .prefix_tokens , prefix_index )
@@ -179,7 +181,10 @@ def __iter__(
179181 ):
180182 start_index = rand .randint (0 , len (self .text_creator .words ))
181183 prompt_text = self .processor .decode (
182- prefix_tokens + self ._create_prompt (prompt_tokens , start_index ),
184+ prefix_tokens
185+ + self ._create_prompt (
186+ prompt_tokens , start_index , next (unique_prefix_iter )
187+ ),
183188 skip_special_tokens = True ,
184189 )
185190 yield {
@@ -188,17 +193,20 @@ def __iter__(
188193 "output_tokens_count" : output_tokens ,
189194 }
190195
191- def _create_prompt (self , prompt_tokens : int , start_index : int ) -> list [int ]:
196+ def _create_prompt (
197+ self , prompt_tokens : int , start_index : int , unique_prefix : Optional [int ] = None
198+ ) -> list [int ]:
192199 if prompt_tokens <= 0 :
193200 return []
194201
195202 left = start_index
196203 right = start_index + 4 * prompt_tokens
204+ start_tokens = [unique_prefix ] if unique_prefix else []
197205
198206 while left < right :
199207 mid = (left + right ) // 2
200208 test_prompt = self .text_creator .create_text (start_index , mid - start_index )
201- test_tokens = self .processor .encode (test_prompt )
209+ test_tokens = start_tokens + self .processor .encode (test_prompt )
202210
203211 if len (test_tokens ) == prompt_tokens :
204212 return test_tokens
@@ -208,7 +216,7 @@ def _create_prompt(self, prompt_tokens: int, start_index: int) -> list[int]:
208216 right = mid
209217
210218 final_text = self .text_creator .create_text (start_index , left - start_index )
211- return self .processor .encode (final_text )
219+ return start_tokens + self .processor .encode (final_text )
212220
213221
214222class SyntheticDatasetCreator (DatasetCreator ):
0 commit comments