|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | from collections.abc import Iterator
|
| 4 | +from math import gcd |
4 | 5 | from pathlib import Path
|
5 |
| -from typing import Any, Callable |
| 6 | +from random import Random |
| 7 | +from typing import Any, Callable, ClassVar |
6 | 8 |
|
7 | 9 | import yaml
|
8 | 10 | from datasets import Features, IterableDataset, Value
|
|
21 | 23 | "SyntheticTextDatasetConfig",
|
22 | 24 | "SyntheticTextDatasetDeserializer",
|
23 | 25 | "SyntheticTextGenerator",
|
| 26 | + "SyntheticTextPrefixBucketConfig", |
24 | 27 | ]
|
25 | 28 |
|
26 | 29 |
|
27 |
| -class SyntheticTextDatasetConfig(StandardBaseModel): |
| 30 | +class SyntheticTextPrefixBucketConfig(StandardBaseModel): |
| 31 | + bucket_weight: int = Field( |
| 32 | + description="Weight of this bucket in the overall distribution.", |
| 33 | + gt=0, |
| 34 | + default=100, |
| 35 | + ) |
| 36 | + prefix_count: int = Field( |
| 37 | + description="The number of unique prefixs to generate for this bucket.", |
| 38 | + ge=1, |
| 39 | + default=1, |
| 40 | + ) |
28 | 41 | prefix_tokens: int = Field(
|
29 |
| - description="The number of shared prefix tokens to prepend to each prompt.", |
| 42 | + description="The number of prefix tokens per-prompt for this bucket.", |
30 | 43 | ge=0,
|
31 | 44 | default=0,
|
32 | 45 | )
|
| 46 | + |
| 47 | + |
| 48 | +class SyntheticTextDatasetConfig(StandardBaseModel): |
| 49 | + prefix_buckets: list[SyntheticTextPrefixBucketConfig] | None = Field( |
| 50 | + description="Buckets for the prefix tokens distribution.", |
| 51 | + default=None, |
| 52 | + ) |
33 | 53 | prompt_tokens: int = Field(
|
34 | 54 | description="The average number of text tokens generated for prompts.",
|
35 | 55 | gt=0,
|
@@ -75,6 +95,8 @@ class SyntheticTextDatasetConfig(StandardBaseModel):
|
75 | 95 |
|
76 | 96 |
|
77 | 97 | class SyntheticTextGenerator:
|
| 98 | + PREFIX_DISTRIBUTION_PRECISION: ClassVar[int] = 1000 |
| 99 | + |
78 | 100 | def __init__(
|
79 | 101 | self,
|
80 | 102 | config: SyntheticTextDatasetConfig,
|
@@ -110,17 +132,15 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
|
110 | 132 | )
|
111 | 133 |
|
112 | 134 | # Create a shared prefix if specified
|
113 |
| - if self.config.prefix_tokens > 0: |
114 |
| - prefix = self._create_prompt(self.config.prefix_tokens, faker) |
115 |
| - else: |
116 |
| - prefix = "" # Always have a prefix key for consistency |
| 135 | + rand = Random(self.random_seed + 3) |
| 136 | + prefix_iter = self._create_prefix_iter(faker, rand) |
117 | 137 |
|
118 | 138 | while True:
|
119 | 139 | prompt_tokens_count = next(prompt_tokens_sampler)
|
120 | 140 | output_tokens_count = next(output_tokens_sampler)
|
121 | 141 |
|
122 | 142 | yield {
|
123 |
| - "prefix": prefix, |
| 143 | + "prefix": next(prefix_iter), |
124 | 144 | "prompt": self._create_prompt(
|
125 | 145 | prompt_tokens_count, faker, f"{samples_generated} "
|
126 | 146 | ),
|
@@ -149,6 +169,43 @@ def _create_prompt(
|
149 | 169 | prompt_token_ids[:prompt_tokens_count], skip_special_tokens=True
|
150 | 170 | )
|
151 | 171 |
|
| 172 | + def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]: |
| 173 | + if not self.config.prefix_buckets: |
| 174 | + while True: |
| 175 | + yield "" |
| 176 | + |
| 177 | + total_weight = sum( |
| 178 | + bucket.bucket_weight for bucket in self.config.prefix_buckets |
| 179 | + ) |
| 180 | + if total_weight <= 0: |
| 181 | + raise ValueError("Total weight of prefix buckets must be greater than 0.") |
| 182 | + |
| 183 | + # Calculate the divisor needed to achieve the minimum |
| 184 | + # number of prompts given the weight ratios |
| 185 | + percents = [ |
| 186 | + int( |
| 187 | + self.PREFIX_DISTRIBUTION_PRECISION |
| 188 | + * bucket.bucket_weight |
| 189 | + / bucket.prefix_count |
| 190 | + / total_weight |
| 191 | + ) |
| 192 | + for bucket in self.config.prefix_buckets |
| 193 | + ] |
| 194 | + common_divisor = gcd(*percents) |
| 195 | + |
| 196 | + # Create prefix list maintaining the correct distribution |
| 197 | + prefixes = [] |
| 198 | + for bucket, percent in zip(self.config.prefix_buckets, percents): |
| 199 | + bucket_prefixes = [ |
| 200 | + self._create_prompt(bucket.prefix_tokens, faker) |
| 201 | + for _ in range(bucket.prefix_count) |
| 202 | + ] |
| 203 | + sample_count = percent // common_divisor |
| 204 | + prefixes.extend([bucket_prefixes] * sample_count) |
| 205 | + |
| 206 | + while True: |
| 207 | + yield rand.choice(prefixes) |
| 208 | + |
152 | 209 |
|
153 | 210 | @DatasetDeserializerFactory.register("synthetic_text")
|
154 | 211 | class SyntheticTextDatasetDeserializer(DatasetDeserializer):
|
|
0 commit comments