|
| 1 | +import json |
1 | 2 | from collections.abc import Iterable |
2 | 3 | from pathlib import Path |
3 | | -from typing import Any, Literal, Optional, Self, Union |
| 4 | +from typing import Any, Literal, Optional, TypeVar, Union |
4 | 5 |
|
| 6 | +import yaml |
5 | 7 | from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict |
| 8 | +from loguru import logger |
6 | 9 | from transformers.tokenization_utils_base import ( # type: ignore[import] |
7 | 10 | PreTrainedTokenizerBase, |
8 | 11 | ) |
|
14 | 17 |
|
15 | 18 | __ALL__ = ["Scenario", "GenerativeTextScenario"] |
16 | 19 |
|
| 20 | +T = TypeVar("T", bound="Scenario") |
| 21 | + |
17 | 22 |
|
18 | 23 | class Scenario(StandardBaseModel): |
19 | 24 | target: str |
20 | 25 |
|
21 | | - def _update(self, **fields: Any) -> Self: |
22 | | - for k, v in fields.items(): |
23 | | - if not hasattr(self, k): |
24 | | - raise ValueError(f"Invalid field {k}") |
25 | | - setattr(self, k, v) |
26 | | - |
27 | | - return self |
| 26 | + @classmethod |
| 27 | + def from_file( |
| 28 | + cls: type[T], filename: Union[str, Path], overrides: Optional[dict] = None |
| 29 | + ) -> T: |
| 30 | + try: |
| 31 | + with open(filename) as f: |
| 32 | + if str(filename).endswith(".yaml") or str(filename).endswith(".yml"): |
| 33 | + data = yaml.safe_load(f) |
| 34 | + else: # Assume everything else is json |
| 35 | + data = json.load(f) |
| 36 | + except (json.JSONDecodeError, yaml.YAMLError) as e: |
| 37 | + logger.error("Failed to parse scenario") |
| 38 | + raise e |
28 | 39 |
|
29 | | - def update(self, **fields: Any) -> Self: |
30 | | - return self._update(**{k: v for k, v in fields.items() if v is not None}) |
| 40 | + data.update(overrides) |
| 41 | + return cls.model_validate(data) |
31 | 42 |
|
32 | 43 |
|
33 | 44 | class GenerativeTextScenario(Scenario): |
|
0 commit comments