2323from typing import Any , ClassVar , Literal , TypeVar , cast
2424
2525import yaml
26- from pydantic import ConfigDict , Field , computed_field , model_serializer
26+ from pydantic import (
27+ AliasChoices ,
28+ AliasGenerator ,
29+ ConfigDict ,
30+ Field ,
31+ ValidationError ,
32+ ValidatorFunctionWrapHandler ,
33+ computed_field ,
34+ field_validator ,
35+ model_serializer ,
36+ )
2737from torch .utils .data import Sampler
2838from transformers import PreTrainedTokenizerBase
2939
@@ -1142,7 +1152,8 @@ def update_estimate(
11421152 )
11431153 request_duration = (
11441154 (request_end_time - request_start_time )
1145- if request_end_time and request_start_time else None
1155+ if request_end_time and request_start_time
1156+ else None
11461157 )
11471158
11481159 # Always track concurrency
@@ -1669,7 +1680,7 @@ def compile(
16691680 estimated_state : EstimatedBenchmarkState ,
16701681 scheduler_state : SchedulerState ,
16711682 profile : Profile ,
1672- requests : Iterable ,
1683+ requests : Iterable , # noqa: ARG003
16731684 backend : BackendInterface ,
16741685 environment : Environment ,
16751686 strategy : SchedulingStrategy ,
@@ -1787,9 +1798,8 @@ def create(
17871798 scenario_data = scenario_data ["args" ]
17881799 constructor_kwargs .update (scenario_data )
17891800
1790- for key , value in kwargs .items ():
1791- if value != cls .get_default (key ):
1792- constructor_kwargs [key ] = value
1801+ # Apply overrides from kwargs
1802+ constructor_kwargs .update (kwargs )
17931803
17941804 return cls .model_validate (constructor_kwargs )
17951805
@@ -1818,13 +1828,19 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
18181828 else :
18191829 return factory ({}) # type: ignore[call-arg] # Confirmed correct at runtime by code above
18201830
1821-
1822-
18231831 model_config = ConfigDict (
18241832 extra = "ignore" ,
18251833 use_enum_values = True ,
18261834 from_attributes = True ,
18271835 arbitrary_types_allowed = True ,
1836+ validate_by_alias = True ,
1837+ validate_by_name = True ,
1838+ alias_generator = AliasGenerator (
1839+ # Support field names with hyphens
1840+ validation_alias = lambda field_name : AliasChoices (
1841+ field_name , field_name .replace ("_" , "-" )
1842+ ),
1843+ ),
18281844 )
18291845
18301846 # Required
@@ -1838,7 +1854,7 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
18381854 profile : StrategyType | ProfileType | Profile = Field (
18391855 default = "sweep" , description = "Benchmark profile or scheduling strategy type"
18401856 )
1841- rate : float | list [float ] | None = Field (
1857+ rate : list [float ] | None = Field (
18421858 default = None , description = "Request rate(s) for rate-based scheduling"
18431859 )
18441860 # Backend configuration
@@ -1871,6 +1887,12 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
18711887 data_request_formatter : DatasetPreprocessor | dict [str , str ] | str = Field (
18721888 default = "chat_completions" ,
18731889 description = "Request formatting preprocessor or template name" ,
1890+ validation_alias = AliasChoices (
1891+ "data_request_formatter" ,
1892+ "data-request-formatter" ,
1893+ "request_type" ,
1894+ "request-type" ,
1895+ ),
18741896 )
18751897 data_collator : Callable | Literal ["generative" ] | None = Field (
18761898 default = "generative" , description = "Data collator for batch processing"
@@ -1931,6 +1953,26 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
19311953 default = None , description = "Maximum global error rate (0-1) before stopping"
19321954 )
19331955
1956+ @field_validator ("data" , "data_args" , "rate" , mode = "wrap" )
1957+ @classmethod
1958+ def single_to_list (
1959+ cls , value : Any , handler : ValidatorFunctionWrapHandler
1960+ ) -> list [Any ]:
1961+ """
1962+ Ensures field is always a list.
1963+
1964+ :param value: Input value for the 'data' field
1965+ :return: List of data sources
1966+ """
1967+ try :
1968+ return handler (value )
1969+ except ValidationError as err :
1970+ # If validation fails, try wrapping the value in a list
1971+ if err .errors ()[0 ]["type" ] == "list_type" :
1972+ return handler ([value ])
1973+ else :
1974+ raise
1975+
19341976 @model_serializer
19351977 def serialize_model (self ):
19361978 """
0 commit comments