|
23 | 23 | from typing import Any, ClassVar, Literal, TypeVar, cast |
24 | 24 |
|
25 | 25 | import yaml |
26 | | -from pydantic import ConfigDict, Field, computed_field, model_serializer |
| 26 | +from pydantic import ( |
| 27 | + ConfigDict, |
| 28 | + Field, |
| 29 | + ValidationError, |
| 30 | + ValidatorFunctionWrapHandler, |
| 31 | + computed_field, |
| 32 | + field_validator, |
| 33 | + model_serializer, |
| 34 | +) |
27 | 35 | from torch.utils.data import Sampler |
28 | 36 | from transformers import PreTrainedTokenizerBase |
29 | 37 |
|
@@ -1142,7 +1150,8 @@ def update_estimate( |
1142 | 1150 | ) |
1143 | 1151 | request_duration = ( |
1144 | 1152 | (request_end_time - request_start_time) |
1145 | | - if request_end_time and request_start_time else None |
| 1153 | + if request_end_time and request_start_time |
| 1154 | + else None |
1146 | 1155 | ) |
1147 | 1156 |
|
1148 | 1157 | # Always track concurrency |
@@ -1818,8 +1827,6 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any: |
1818 | 1827 | else: |
1819 | 1828 | return factory({}) # type: ignore[call-arg] # Confirmed correct at runtime by code above |
1820 | 1829 |
|
1821 | | - |
1822 | | - |
1823 | 1830 | model_config = ConfigDict( |
1824 | 1831 | extra="ignore", |
1825 | 1832 | use_enum_values=True, |
@@ -1931,6 +1938,26 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any: |
1931 | 1938 | default=None, description="Maximum global error rate (0-1) before stopping" |
1932 | 1939 | ) |
1933 | 1940 |
|
| 1941 | + @field_validator("data", mode="wrap") |
| 1942 | + @classmethod |
| 1943 | + def single_to_list( |
| 1944 | + cls, value: Any, handler: ValidatorFunctionWrapHandler |
| 1945 | + ) -> list[Any]: |
| 1946 | + """ |
| 1947 | + Ensures 'data' field is always a list. |
| 1948 | +
|
| 1949 | + :param value: Input value for the 'data' field |
| 1950 | + :return: List of data sources |
| 1951 | + """ |
| 1952 | + try: |
| 1953 | + return handler(value) |
| 1954 | + except ValidationError as err: |
| 1955 | + # If validation fails, try wrapping the value in a list |
| 1956 | + if err.errors()[0]["type"] == "list_type": |
| 1957 | + return handler([value]) |
| 1958 | + else: |
| 1959 | + raise |
| 1960 | + |
1934 | 1961 | @model_serializer |
1935 | 1962 | def serialize_model(self): |
1936 | 1963 | """ |
|
0 commit comments