|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import math |
3 | 4 | from collections.abc import Iterator |
4 | | -from math import gcd |
5 | 5 | from pathlib import Path |
6 | 6 | from random import Random |
7 | | -from typing import Any, Callable, ClassVar |
| 7 | +from typing import Any, Callable |
8 | 8 |
|
9 | 9 | import yaml |
10 | 10 | from datasets import Features, IterableDataset, Value |
@@ -95,8 +95,6 @@ class SyntheticTextDatasetConfig(StandardBaseModel): |
95 | 95 |
|
96 | 96 |
|
97 | 97 | class SyntheticTextGenerator: |
98 | | - PREFIX_DISTRIBUTION_PRECISION: ClassVar[int] = 1000 |
99 | | - |
100 | 98 | def __init__( |
101 | 99 | self, |
102 | 100 | config: SyntheticTextDatasetConfig, |
@@ -174,34 +172,26 @@ def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]: |
174 | 172 | while True: |
175 | 173 | yield "" |
176 | 174 |
|
177 | | - total_weight = sum( |
178 | | - bucket.bucket_weight for bucket in self.config.prefix_buckets |
| 175 | + # Increase weights to ensure an integer number of samples per per-prefix |
| 176 | + least_common_prefix_count = math.lcm( |
| 177 | + *(bucket.prefix_count for bucket in self.config.prefix_buckets) |
179 | 178 | ) |
180 | | - if total_weight <= 0: |
181 | | - raise ValueError("Total weight of prefix buckets must be greater than 0.") |
182 | | - |
183 | | - # Calculate the divisor needed to achieve the minimum |
184 | | - # number of prompts given the weight ratios |
185 | | - percents = [ |
186 | | - int( |
187 | | - self.PREFIX_DISTRIBUTION_PRECISION |
188 | | - * bucket.bucket_weight |
189 | | - / bucket.prefix_count |
190 | | - / total_weight |
191 | | - ) |
| 179 | + unnorm_weights = [ |
| 180 | + least_common_prefix_count * bucket.bucket_weight // bucket.prefix_count |
192 | 181 | for bucket in self.config.prefix_buckets |
193 | 182 | ] |
194 | | - common_divisor = gcd(*percents) |
| 183 | + # Use GCD to reduce the weights to smallest integer ratio |
| 184 | + common_divisor = math.gcd(*unnorm_weights) |
195 | 185 |
|
196 | 186 | # Create prefix list maintaining the correct distribution |
197 | 187 | prefixes = [] |
198 | | - for bucket, percent in zip(self.config.prefix_buckets, percents): |
| 188 | + for bucket, weight in zip(self.config.prefix_buckets, unnorm_weights): |
199 | 189 | bucket_prefixes = [ |
200 | 190 | self._create_prompt(bucket.prefix_tokens, faker) |
201 | 191 | for _ in range(bucket.prefix_count) |
202 | 192 | ] |
203 | | - sample_count = percent // common_divisor |
204 | | - prefixes.extend([bucket_prefixes] * sample_count) |
| 193 | + sample_count = weight // common_divisor |
| 194 | + prefixes.extend(bucket_prefixes * sample_count) |
205 | 195 |
|
206 | 196 | while True: |
207 | 197 | yield rand.choice(prefixes) |
|
0 commit comments