|
25 | 25 |
|
26 | 26 |
|
27 | 27 | class SyntheticDatasetConfig(BaseModel):
|
| 28 | + prefix_tokens: int = Field( |
| 29 | + description="The number of shared prefix tokens to prepend to each prompt.", |
| 30 | + ge=0, |
| 31 | + default=0, |
| 32 | + ) |
28 | 33 | prompt_tokens: int = Field(
|
29 | 34 | description="The average number of text tokens generated for prompts.",
|
30 | 35 | gt=0,
|
@@ -164,15 +169,19 @@ def __iter__(
|
164 | 169 | # ensure diff distribution from output tokens
|
165 | 170 | rand = random.Random(self.random_seed + 2) # noqa: S311
|
166 | 171 |
|
| 172 | + prefix_index = rand.randint(0, len(self.text_creator.words)) |
| 173 | + prefix_tokens = self.config.prefix_tokens |
| 174 | + prefix = self._create_prompt(prefix_tokens, prefix_index) |
| 175 | + |
167 | 176 | for _, prompt_tokens, output_tokens in zip(
|
168 | 177 | range(self.config.samples),
|
169 | 178 | prompt_tokens_sampler,
|
170 | 179 | output_tokens_sampler,
|
171 | 180 | ):
|
172 | 181 | start_index = rand.randint(0, len(self.text_creator.words))
|
173 | 182 | yield {
|
174 |
| - "prompt": self._create_prompt(prompt_tokens, start_index), |
175 |
| - "prompt_tokens_count": prompt_tokens, |
| 183 | + "prompt": prefix + self._create_prompt(prompt_tokens, start_index), |
| 184 | + "prompt_tokens_count": prefix_tokens + prompt_tokens, |
176 | 185 | "output_tokens_count": output_tokens,
|
177 | 186 | }
|
178 | 187 |
|
|
0 commit comments