Skip to content

Commit 3148467

Browse files
committed
Convert single data to list
Signed-off-by: Samuel Monson <[email protected]>
1 parent 638c277 commit 3148467

File tree

1 file changed

+31
-4
lines changed

1 file changed

+31
-4
lines changed

src/guidellm/benchmark/schemas.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@
2323
from typing import Any, ClassVar, Literal, TypeVar, cast
2424

2525
import yaml
26-
from pydantic import ConfigDict, Field, computed_field, model_serializer
26+
from pydantic import (
27+
ConfigDict,
28+
Field,
29+
ValidationError,
30+
ValidatorFunctionWrapHandler,
31+
computed_field,
32+
field_validator,
33+
model_serializer,
34+
)
2735
from torch.utils.data import Sampler
2836
from transformers import PreTrainedTokenizerBase
2937

@@ -1142,7 +1150,8 @@ def update_estimate(
11421150
)
11431151
request_duration = (
11441152
(request_end_time - request_start_time)
1145-
if request_end_time and request_start_time else None
1153+
if request_end_time and request_start_time
1154+
else None
11461155
)
11471156

11481157
# Always track concurrency
@@ -1818,8 +1827,6 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
18181827
else:
18191828
return factory({}) # type: ignore[call-arg] # Confirmed correct at runtime by code above
18201829

1821-
1822-
18231830
model_config = ConfigDict(
18241831
extra="ignore",
18251832
use_enum_values=True,
@@ -1931,6 +1938,26 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
19311938
default=None, description="Maximum global error rate (0-1) before stopping"
19321939
)
19331940

1941+
@field_validator("data", mode="wrap")
1942+
@classmethod
1943+
def single_to_list(
1944+
cls, value: Any, handler: ValidatorFunctionWrapHandler
1945+
) -> list[Any]:
1946+
"""
1947+
Ensures 'data' field is always a list.
1948+
1949+
:param value: Input value for the 'data' field
1950+
:return: List of data sources
1951+
"""
1952+
try:
1953+
return handler(value)
1954+
except ValidationError as err:
1955+
# If validation fails, try wrapping the value in a list
1956+
if err.errors()[0]["type"] == "list_type":
1957+
return handler([value])
1958+
else:
1959+
raise
1960+
19341961
@model_serializer
19351962
def serialize_model(self):
19361963
"""

0 commit comments

Comments
 (0)