@@ -170,41 +170,45 @@ def __iter__(
170170 rand = random .Random (self .random_seed + 2 ) # noqa: S311
171171
172172 prefix_index = rand .randint (0 , len (self .text_creator .words ))
173- prefix_tokens = self .config .prefix_tokens
174- prefix = self ._create_prompt (prefix_tokens , prefix_index )
173+ prefix_tokens = self ._create_prompt (self .config .prefix_tokens , prefix_index )
175174
176175 for _ , prompt_tokens , output_tokens in zip (
177176 range (self .config .samples ),
178177 prompt_tokens_sampler ,
179178 output_tokens_sampler ,
180179 ):
181180 start_index = rand .randint (0 , len (self .text_creator .words ))
181+ prompt_text = self .processor .decode (
182+ prefix_tokens + self ._create_prompt (prompt_tokens , start_index ),
183+ skip_special_tokens = True ,
184+ )
182185 yield {
183- "prompt" : prefix + self . _create_prompt ( prompt_tokens , start_index ) ,
184- "prompt_tokens_count" : prefix_tokens + prompt_tokens ,
186+ "prompt" : prompt_text ,
187+ "prompt_tokens_count" : self . config . prefix_tokens + prompt_tokens ,
185188 "output_tokens_count" : output_tokens ,
186189 }
187190
188- def _create_prompt (self , prompt_tokens : int , start_index : int ) -> str :
191+ def _create_prompt (self , prompt_tokens : int , start_index : int ) -> list [ int ] :
189192 if prompt_tokens <= 0 :
190- return ""
193+ return []
191194
192195 left = start_index
193196 right = start_index + 4 * prompt_tokens
194197
195198 while left < right :
196199 mid = (left + right ) // 2
197200 test_prompt = self .text_creator .create_text (start_index , mid - start_index )
198- test_tokens = len ( self .processor .tokenize (test_prompt ) )
201+ test_tokens = self .processor .encode (test_prompt )
199202
200- if test_tokens == prompt_tokens :
201- return test_prompt
202- elif test_tokens < prompt_tokens :
203+ if len ( test_tokens ) == prompt_tokens :
204+ return test_tokens
205+ elif len ( test_tokens ) < prompt_tokens :
203206 left = mid + 1
204207 else :
205208 right = mid
206209
207- return self .text_creator .create_text (start_index , left - start_index )
210+ final_text = self .text_creator .create_text (start_index , left - start_index )
211+ return self .processor .encode (final_text )
208212
209213
210214class SyntheticDatasetCreator (DatasetCreator ):
0 commit comments