1
1
import json
2
2
import random
3
3
from collections .abc import Iterable , Iterator
4
+ from itertools import cycle
4
5
from pathlib import Path
5
6
from typing import Any , Literal , Optional , Union
6
7
@@ -168,6 +169,7 @@ def __iter__(
168
169
)
169
170
# ensure diff distribution from output tokens
170
171
rand = random .Random (self .random_seed + 2 ) # noqa: S311
172
+ unique_prefix_iter = cycle (self .processor .get_vocab ().values ())
171
173
172
174
prefix_index = rand .randint (0 , len (self .text_creator .words ))
173
175
prefix_tokens = self ._create_prompt (self .config .prefix_tokens , prefix_index )
@@ -179,7 +181,10 @@ def __iter__(
179
181
):
180
182
start_index = rand .randint (0 , len (self .text_creator .words ))
181
183
prompt_text = self .processor .decode (
182
- prefix_tokens + self ._create_prompt (prompt_tokens , start_index ),
184
+ prefix_tokens
185
+ + self ._create_prompt (
186
+ prompt_tokens , start_index , next (unique_prefix_iter )
187
+ ),
183
188
skip_special_tokens = True ,
184
189
)
185
190
yield {
@@ -188,17 +193,20 @@ def __iter__(
188
193
"output_tokens_count" : output_tokens ,
189
194
}
190
195
191
- def _create_prompt (self , prompt_tokens : int , start_index : int ) -> list [int ]:
196
+ def _create_prompt (
197
+ self , prompt_tokens : int , start_index : int , unique_prefix : Optional [int ] = None
198
+ ) -> list [int ]:
192
199
if prompt_tokens <= 0 :
193
200
return []
194
201
195
202
left = start_index
196
203
right = start_index + 4 * prompt_tokens
204
+ start_tokens = [unique_prefix ] if unique_prefix else []
197
205
198
206
while left < right :
199
207
mid = (left + right ) // 2
200
208
test_prompt = self .text_creator .create_text (start_index , mid - start_index )
201
- test_tokens = self .processor .encode (test_prompt )
209
+ test_tokens = start_tokens + self .processor .encode (test_prompt )
202
210
203
211
if len (test_tokens ) == prompt_tokens :
204
212
return test_tokens
@@ -208,7 +216,7 @@ def _create_prompt(self, prompt_tokens: int, start_index: int) -> list[int]:
208
216
right = mid
209
217
210
218
final_text = self .text_creator .create_text (start_index , left - start_index )
211
- return self .processor .encode (final_text )
219
+ return start_tokens + self .processor .encode (final_text )
212
220
213
221
214
222
class SyntheticDatasetCreator (DatasetCreator ):
0 commit comments