|
1 | 1 | import json |
2 | 2 | from collections.abc import Iterable |
3 | 3 | from pathlib import Path |
4 | | -from typing import Any, Literal, Optional, TypeVar, Union |
| 4 | +from typing import Annotated, Any, Literal, Optional, TypeVar, Union |
5 | 5 |
|
6 | 6 | import yaml |
7 | 7 | from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict |
8 | 8 | from loguru import logger |
| 9 | +from pydantic import BeforeValidator |
9 | 10 | from transformers.tokenization_utils_base import ( # type: ignore[import] |
10 | 11 | PreTrainedTokenizerBase, |
11 | 12 | ) |
|
17 | 18 |
|
18 | 19 | __ALL__ = ["Scenario", "GenerativeTextScenario"] |
19 | 20 |
|
| 21 | + |
| 22 | +def parse_float_list(value: Union[str, float, list[float]]) -> list[float]: |
| 23 | + if isinstance(value, (int, float)): |
| 24 | + return [value] |
| 25 | + elif isinstance(value, list): |
| 26 | + return value |
| 27 | + |
| 28 | + values = value.split(",") if "," in value else [value] |
| 29 | + |
| 30 | + try: |
| 31 | + return [float(val) for val in values] |
| 32 | + except ValueError as err: |
| 33 | + raise ValueError( |
| 34 | + "must be a number or comma-separated list of numbers." |
| 35 | + ) from err |
| 36 | + |
| 37 | + |
20 | 38 | T = TypeVar("T", bound="Scenario") |
21 | 39 |
|
22 | 40 |
|
@@ -63,7 +81,7 @@ class Config: |
63 | 81 | data_args: Optional[dict[str, Any]] = None |
64 | 82 | data_sampler: Optional[Literal["random"]] = None |
65 | 83 | rate_type: Union[StrategyType, ProfileType] |
66 | | - rate: Optional[Union[float, list[float]]] = None |
| 84 | + rate: Annotated[Optional[list[float]], BeforeValidator(parse_float_list)] = None |
67 | 85 | max_seconds: Optional[float] = None |
68 | 86 | max_requests: Optional[int] = None |
69 | 87 | warmup_percent: Optional[float] = None |
|
0 commit comments