@@ -170,41 +170,45 @@ def __iter__(
170
170
rand = random .Random (self .random_seed + 2 ) # noqa: S311
171
171
172
172
prefix_index = rand .randint (0 , len (self .text_creator .words ))
173
- prefix_tokens = self .config .prefix_tokens
174
- prefix = self ._create_prompt (prefix_tokens , prefix_index )
173
+ prefix_tokens = self ._create_prompt (self .config .prefix_tokens , prefix_index )
175
174
176
175
for _ , prompt_tokens , output_tokens in zip (
177
176
range (self .config .samples ),
178
177
prompt_tokens_sampler ,
179
178
output_tokens_sampler ,
180
179
):
181
180
start_index = rand .randint (0 , len (self .text_creator .words ))
181
+ prompt_text = self .processor .decode (
182
+ prefix_tokens + self ._create_prompt (prompt_tokens , start_index ),
183
+ skip_special_tokens = True ,
184
+ )
182
185
yield {
183
- "prompt" : prefix + self . _create_prompt ( prompt_tokens , start_index ) ,
184
- "prompt_tokens_count" : prefix_tokens + prompt_tokens ,
186
+ "prompt" : prompt_text ,
187
+ "prompt_tokens_count" : self . config . prefix_tokens + prompt_tokens ,
185
188
"output_tokens_count" : output_tokens ,
186
189
}
187
190
188
- def _create_prompt (self , prompt_tokens : int , start_index : int ) -> str :
191
+ def _create_prompt (self , prompt_tokens : int , start_index : int ) -> list [ int ] :
189
192
if prompt_tokens <= 0 :
190
- return ""
193
+ return []
191
194
192
195
left = start_index
193
196
right = start_index + 4 * prompt_tokens
194
197
195
198
while left < right :
196
199
mid = (left + right ) // 2
197
200
test_prompt = self .text_creator .create_text (start_index , mid - start_index )
198
- test_tokens = len ( self .processor .tokenize (test_prompt ) )
201
+ test_tokens = self .processor .encode (test_prompt )
199
202
200
- if test_tokens == prompt_tokens :
201
- return test_prompt
202
- elif test_tokens < prompt_tokens :
203
+ if len ( test_tokens ) == prompt_tokens :
204
+ return test_tokens
205
+ elif len ( test_tokens ) < prompt_tokens :
203
206
left = mid + 1
204
207
else :
205
208
right = mid
206
209
207
- return self .text_creator .create_text (start_index , left - start_index )
210
+ final_text = self .text_creator .create_text (start_index , left - start_index )
211
+ return self .processor .encode (final_text )
208
212
209
213
210
214
class SyntheticDatasetCreator (DatasetCreator ):
0 commit comments