Skip to content
3 changes: 2 additions & 1 deletion src/guidellm/benchmark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from collections.abc import Iterable
from pathlib import Path
from typing import Any, TypeAliasType
from typing import Any

from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
from transformers import ( # type: ignore[import]
PreTrainedTokenizerBase,
)
from typing_extensions import TypeAliasType

from guidellm.benchmark.aggregator import (
Aggregator,
Expand Down
2 changes: 2 additions & 0 deletions src/guidellm/scheduler/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Protocol,
TypeVar,
Union,
runtime_checkable,
)

from pydantic import Field, computed_field
Expand Down Expand Up @@ -232,6 +233,7 @@ def model_copy(self, **kwargs) -> ScheduledRequestInfo: # type: ignore[override
)


@runtime_checkable
class BackendInterface(Protocol, Generic[RequestT, ResponseT]):
"""
Abstract interface for request processing backends.
Expand Down
11 changes: 11 additions & 0 deletions src/guidellm/utils/pydantic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,17 @@ class DatabaseConfig(BaseConfig):

schema_discriminator: ClassVar[str] = "model_type"

def __new__(cls, *args, **kwargs): # noqa: ARG004
"""
Prevent direct instantiation of base classes that use this mixin.

Only allows instantiation of concrete subclasses, not the base class.
"""
base_type = cls.__pydantic_schema_base_type__()
if cls is base_type:
raise TypeError(f"only children of '{cls.__name__}' may be instantiated")
return super().__new__(cls)

@classmethod
def register_decorator(
cls, clazz: RegisterClassT, name: str | list[str] | None = None
Expand Down
25 changes: 13 additions & 12 deletions src/guidellm/utils/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,9 @@ def from_request_times(
"""
if distribution_type == "concurrency":
# convert to delta changes based on when requests were running
time_deltas: dict[float, int] = defaultdict(int)
for start, end in requests:
time_deltas[start] += 1
time_deltas[end] -= 1

# convert to the events over time measuring concurrency changes
events = []
active = 0

for time, delta in sorted(time_deltas.items()):
active += delta
events.append((time, active))
events = [(start, 1) for start, _ in requests] + [
(end, -1) for _, end in requests
]
elif distribution_type == "rate":
# convert to events for when requests finished
global_start = min(start for start, _ in requests) if requests else 0
Expand All @@ -313,6 +304,16 @@ def from_request_times(
else:
flattened_events.append((time, val))

if distribution_type == "concurrency":
# convert to the events over time measuring concurrency changes
events_over_time: list[tuple[float, float]] = []
active = 0
for time, delta in flattened_events:
active += delta # type: ignore [assignment]
events_over_time.append((time, active))

flattened_events = events_over_time

# convert to value distribution function
distribution: dict[float, float] = defaultdict(float)

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def default_model(self) -> str | None:
def test_class_signatures(self):
"""Test Backend inheritance and type relationships."""
assert issubclass(Backend, RegistryMixin)
assert issubclass(Backend, BackendInterface)
assert isinstance(Backend, BackendInterface)
assert hasattr(Backend, "create")
assert hasattr(Backend, "register")
assert hasattr(Backend, "get_registered_object")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ def valid_instances(self, request):
def test_class_signatures(self):
"""Test GenerationRequestTimings inheritance and type relationships."""
assert issubclass(GenerationRequestTimings, MeasuredRequestTimings)
assert issubclass(GenerationRequestTimings, StandardBaseModel)
assert hasattr(GenerationRequestTimings, "model_dump")
assert hasattr(GenerationRequestTimings, "model_validate")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async def test_info(self):
target="http://test", model="test-model", timeout=30.0
)

info = backend.info()
info = backend.info

assert info["target"] == "http://test"
assert info["model"] == "test-model"
Expand Down Expand Up @@ -1074,7 +1074,7 @@ def test_get_chat_message_media_item_jpeg_file(self):
mock_image = Mock(spec=Image.Image)
mock_image.tobytes.return_value = b"fake_jpeg_data"

with patch("guidellm.backend.openai.Image.open", return_value=mock_image):
with patch("guidellm.backends.openai.Image.open", return_value=mock_image):
result = backend._get_chat_message_media_item(mock_jpeg_path)

expected_data = base64.b64encode(b"fake_jpeg_data").decode("utf-8")
Expand Down
Empty file removed tests/unit/objects/__init__.py
Empty file.
43 changes: 0 additions & 43 deletions tests/unit/objects/test_pydantic.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/unit/scheduler/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,11 @@ def test_create_constraint_raises(self, valid_instances):
def test_call_raises(self, valid_instances):
"""Test that calling constraint raises RuntimeError."""
instance, _ = valid_instances
state = SchedulerState(node_id="test_node", num_processes=1, start_time=0.0)
state = SchedulerState(node_id=0, num_processes=1, start_time=0.0)
request = ScheduledRequestInfo(
request_id="test_request",
status="pending",
scheduler_node_id="test_node",
scheduler_node_id=0,
scheduler_process_id=1,
scheduler_start_time=0.0,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/scheduler/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async def test_sync_run_start(self, valid_instances, mock_time, delay, expected)

with (
patch("time.time", return_value=mock_time),
patch("guidellm.scheduler.environment.settings") as mock_settings,
patch("guidellm.scheduler.environments.settings") as mock_settings,
):
mock_settings.scheduler_start_delay_non_distributed = delay
start_time = await instance.sync_run_start()
Expand Down
62 changes: 35 additions & 27 deletions tests/unit/scheduler/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import typing
from collections.abc import AsyncIterator
from typing import Any, Optional, TypeVar, Union
from typing import Any, Literal, Optional, TypeVar, Union

import pytest
from pydantic import ValidationError
Expand All @@ -25,6 +25,13 @@
from guidellm.utils import StandardBaseModel


@MeasuredRequestTimings.register("test_request_timings")
class ConcreteMeasuredRequestTimings(MeasuredRequestTimings):
"""Concrete test implementation of MeasuredRequestTimings for testing."""

timings_type: Literal["test_request_timings"] = "test_request_timings"


def test_request_t():
"""Validate that RequestT is a TypeVar usable for generics and isn't bound."""
assert isinstance(RequestT, TypeVar)
Expand Down Expand Up @@ -110,7 +117,7 @@ def test_generic_type_parameters(self):

if hasattr(generic_base, "__args__"):
type_params = generic_base.__args__
assert len(type_params) == 3, "Should have 3 type parameters"
assert len(type_params) == 2, "Should have 2 type parameters"
param_names = [param.__name__ for param in type_params]
expected_names = ["RequestT", "ResponseT"]
assert param_names == expected_names
Expand All @@ -119,7 +126,7 @@ def test_generic_type_parameters(self):
def test_implementation_construction(self):
"""Test that a complete concrete implementation can be instantiated."""

class ConcreteBackend(BackendInterface[str, MeasuredRequestTimings, str]):
class ConcreteBackend(BackendInterface[str, str]):
@property
def processes_limit(self) -> int | None:
return 4
Expand Down Expand Up @@ -162,7 +169,7 @@ async def resolve(
async def test_implementation_async_methods(self): # noqa: C901
"""Test that async methods work correctly in concrete implementation."""

class AsyncBackend(BackendInterface[dict, MeasuredRequestTimings, dict]):
class AsyncBackend(BackendInterface[dict, dict]):
def __init__(self):
self.startup_called = False
self.validate_called = False
Expand Down Expand Up @@ -400,19 +407,23 @@ class TestRequestTimings:

@pytest.fixture(
params=[
{},
{"timings_type": "test_request_timings"},
{
"timings_type": "test_request_timings",
"request_start": None,
"request_end": None,
},
{
"timings_type": "test_request_timings",
"request_start": 1000.0,
"request_end": 1100.0,
},
{
"timings_type": "test_request_timings",
"request_start": 1000.0,
},
{
"timings_type": "test_request_timings",
"request_start": 0.0,
"request_end": 0.0,
},
Expand All @@ -428,13 +439,12 @@ class TestRequestTimings:
def valid_instances(self, request):
"""Creates various valid configurations of MeasuredRequestTimings."""
constructor_args = request.param
instance = MeasuredRequestTimings(**constructor_args)
instance = MeasuredRequestTimings.model_validate(constructor_args)
return instance, constructor_args

@pytest.mark.smoke
def test_class_signatures(self):
"""Test MeasuredRequestTimings inheritance and type relationships."""
assert issubclass(MeasuredRequestTimings, StandardBaseModel)
assert hasattr(MeasuredRequestTimings, "model_dump")
assert hasattr(MeasuredRequestTimings, "model_validate")

Expand All @@ -447,7 +457,13 @@ def test_class_signatures(self):
assert field_info.default is None

@pytest.mark.smoke
def test_initialization(self, valid_instances):
def test_initialization(self):
"""Base class initialization should fail."""
with pytest.raises(TypeError):
MeasuredRequestTimings()

@pytest.mark.smoke
def test_validation(self, valid_instances):
"""Test initialization with valid configurations."""
instance, constructor_args = valid_instances
assert isinstance(instance, MeasuredRequestTimings)
Expand All @@ -468,9 +484,9 @@ def test_initialization(self, valid_instances):
)
def test_invalid_initialization(self, field, value):
"""Test invalid initialization scenarios."""
kwargs = {field: value}
kwargs = {"timings_type": "test_request_timings", field: value}
with pytest.raises(ValidationError):
MeasuredRequestTimings(**kwargs)
MeasuredRequestTimings.model_validate(kwargs)

@pytest.mark.smoke
def test_marshalling(self, valid_instances):
Expand Down Expand Up @@ -534,6 +550,7 @@ class TestScheduledRequestInfo:
"finalized": 2150.0,
},
"request_timings": {
"timings_type": "test_request_timings",
"request_start": 2060.0,
"request_end": 2110.0,
},
Expand Down Expand Up @@ -586,8 +603,8 @@ def valid_instances(self, request):
**constructor_args["scheduler_timings"]
)
if "request_timings" in constructor_args:
constructor_args["request_timings"] = MeasuredRequestTimings(
**constructor_args["request_timings"]
constructor_args["request_timings"] = MeasuredRequestTimings.model_validate(
constructor_args["request_timings"]
)

instance = ScheduledRequestInfo(**constructor_args)
Expand All @@ -597,7 +614,6 @@ def valid_instances(self, request):
def test_class_signatures(self):
"""Test ScheduledRequestInfo inheritance and type relationships."""
assert issubclass(ScheduledRequestInfo, StandardBaseModel)
assert issubclass(ScheduledRequestInfo, typing.Generic)
assert hasattr(ScheduledRequestInfo, "model_dump")
assert hasattr(ScheduledRequestInfo, "model_validate")

Expand All @@ -607,18 +623,6 @@ def test_class_signatures(self):
assert isinstance(ScheduledRequestInfo.started_at, property)
assert isinstance(ScheduledRequestInfo.completed_at, property)

# Check that it's properly generic
orig_bases = getattr(ScheduledRequestInfo, "__orig_bases__", ())
generic_base = next(
(
base
for base in orig_bases
if hasattr(base, "__origin__") and base.__origin__ is typing.Generic
),
None,
)
assert generic_base is not None

# Check required fields
fields = ScheduledRequestInfo.model_fields
for key in self.CHECK_KEYS:
Expand Down Expand Up @@ -720,7 +724,9 @@ def test_started_at_property(self):
scheduler_process_id=0,
scheduler_start_time=1000.0,
scheduler_timings=RequestSchedulerTimings(resolve_start=2000.0),
request_timings=MeasuredRequestTimings(request_start=2100.0),
request_timings=MeasuredRequestTimings.model_validate(
{"timings_type": "test_request_timings", "request_start": 2100.0}
),
)
assert instance.started_at == 2100.0

Expand Down Expand Up @@ -756,7 +762,9 @@ def test_completed_at_property(self):
scheduler_process_id=0,
scheduler_start_time=1000.0,
scheduler_timings=RequestSchedulerTimings(resolve_end=2000.0),
request_timings=MeasuredRequestTimings(request_end=2100.0),
request_timings=MeasuredRequestTimings.model_validate(
{"timings_type": "test_request_timings", "request_end": 2100.0}
),
)
assert instance.completed_at == 2100.0

Expand Down
Loading
Loading