|
10 | 10 | import yaml
|
11 | 11 | from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
12 | 12 | from loguru import logger
|
13 |
| -from pydantic import BeforeValidator, Field, PositiveFloat, PositiveInt |
| 13 | +from pydantic import BeforeValidator, Field, PositiveFloat, PositiveInt, SkipValidation |
14 | 14 | from transformers.tokenization_utils_base import ( # type: ignore[import]
|
15 | 15 | PreTrainedTokenizerBase,
|
16 | 16 | )
|
@@ -115,16 +115,18 @@ class Config:
|
115 | 115 | # types like PreTrainedTokenizerBase
|
116 | 116 | arbitrary_types_allowed = True
|
117 | 117 |
|
118 |
| - data: ( |
| 118 | + data: Annotated[ |
119 | 119 | Iterable[str]
|
120 | 120 | | Iterable[dict[str, Any]]
|
121 | 121 | | Dataset
|
122 | 122 | | DatasetDict
|
123 | 123 | | IterableDataset
|
124 | 124 | | IterableDatasetDict
|
125 | 125 | | str
|
126 |
| - | Path |
127 |
| - ) |
| 126 | + | Path, |
| 127 | + # BUG: See https://github.com/pydantic/pydantic/issues/9541 |
| 128 | + SkipValidation, |
| 129 | + ] |
128 | 130 | profile: StrategyType | ProfileType | Profile
|
129 | 131 | rate: Annotated[list[PositiveFloat] | None, BeforeValidator(parse_float_list)] = (
|
130 | 132 | None
|
@@ -159,7 +161,7 @@ def enable_scenarios(func: Callable) -> Any:
|
159 | 161 | @wraps(func)
|
160 | 162 | async def decorator(*args, scenario: Scenario | None = None, **kwargs) -> Any:
|
161 | 163 | if scenario is not None:
|
162 |
| - kwargs.update(**vars(scenario)) |
| 164 | + kwargs.update(**scenario.model_dump()) |
163 | 165 | return await func(*args, **kwargs)
|
164 | 166 |
|
165 | 167 | # Modify the signature of the decorator to include the `scenario` argument
|
|
0 commit comments