diff --git a/src/guidellm/__init__.py b/src/guidellm/__init__.py index 9333860e..f2206e94 100644 --- a/src/guidellm/__init__.py +++ b/src/guidellm/__init__.py @@ -20,7 +20,8 @@ hf_logging.set_verbosity_error() logging.getLogger("transformers").setLevel(logging.ERROR) -from .config import ( +from .logger import configure_logger, logger +from .settings import ( DatasetSettings, Environment, LoggingSettings, @@ -30,7 +31,6 @@ reload_settings, settings, ) -from .logger import configure_logger, logger __all__ = [ "DatasetSettings", diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index 7cba6a7c..f82c19cf 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -13,9 +13,9 @@ ) from guidellm.benchmark.entrypoints import benchmark_with_scenario from guidellm.benchmark.scenario import GenerativeTextScenario, get_builtin_scenarios -from guidellm.config import print_config from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset from guidellm.scheduler import StrategyType +from guidellm.settings import print_config from guidellm.utils import DefaultGroupHandler from guidellm.utils import cli as cli_tools diff --git a/src/guidellm/backend/backend.py b/src/guidellm/backend/backend.py index bf2788a7..ceffdc77 100644 --- a/src/guidellm/backend/backend.py +++ b/src/guidellm/backend/backend.py @@ -7,7 +7,7 @@ from PIL import Image from guidellm.backend.response import ResponseSummary, StreamingTextResponse -from guidellm.config import settings +from guidellm.settings import settings __all__ = [ "Backend", diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index e62e9003..dff807af 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -16,7 +16,7 @@ ResponseSummary, StreamingTextResponse, ) -from guidellm.config import settings +from guidellm.settings import settings __all__ = [ "CHAT_COMPLETIONS", diff --git a/src/guidellm/backend/response.py b/src/guidellm/backend/response.py index bfa738d8..a5d8fe45 100644 --- a/src/guidellm/backend/response.py +++ b/src/guidellm/backend/response.py @@ -2,7 +2,7 @@ from pydantic import computed_field -from guidellm.config import settings +from guidellm.settings import settings from guidellm.utils import StandardBaseModel __all__ = [ diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py index 450b536a..d5bd237e 100644 --- a/src/guidellm/benchmark/aggregator.py +++ b/src/guidellm/benchmark/aggregator.py @@ -21,7 +21,6 @@ GenerativeTextErrorStats, GenerativeTextResponseStats, ) -from guidellm.config import settings from guidellm.request import ( GenerationRequest, GenerativeRequestLoaderDescription, @@ -34,6 +33,7 @@ SchedulerRequestResult, WorkerDescription, ) +from guidellm.settings import settings from guidellm.utils import ( RunningStats, StandardBaseModel, diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py index 225ed2b1..d3fff6c9 100644 --- a/src/guidellm/benchmark/output.py +++ b/src/guidellm/benchmark/output.py @@ -20,10 +20,10 @@ SweepProfile, ThroughputProfile, ) -from guidellm.config import settings from guidellm.presentation import UIDataBuilder from guidellm.presentation.injector import create_report from guidellm.scheduler import strategy_display_str +from guidellm.settings import settings from guidellm.utils import ( Colors, DistributionSummary, diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py index d46f2b16..73c3df90 100644 --- a/src/guidellm/benchmark/profile.py +++ b/src/guidellm/benchmark/profile.py @@ -4,7 +4,6 @@ import numpy as np from pydantic import Field, computed_field -from guidellm.config import settings from guidellm.scheduler import ( AsyncConstantStrategy, AsyncPoissonStrategy, @@ -14,6 +13,7 @@ SynchronousStrategy, ThroughputStrategy, ) +from guidellm.settings import settings from guidellm.utils import StandardBaseModel __all__ = [ diff --git a/src/guidellm/logger.py b/src/guidellm/logger.py index ac235c99..48b41a49 100644 --- a/src/guidellm/logger.py +++ b/src/guidellm/logger.py @@ -41,7 +41,7 @@ from loguru import logger -from guidellm.config import LoggingSettings, settings +from guidellm.settings import LoggingSettings, settings __all__ = ["configure_logger", "logger"] diff --git a/src/guidellm/presentation/injector.py b/src/guidellm/presentation/injector.py index 02d53b1d..bb1fd684 100644 --- a/src/guidellm/presentation/injector.py +++ b/src/guidellm/presentation/injector.py @@ -4,7 +4,7 @@ from loguru import logger -from guidellm.config import settings +from guidellm.settings import settings from guidellm.utils.text import load_text diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index 2eff87d5..e207a2e1 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -11,9 +11,9 @@ from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict from transformers import PreTrainedTokenizerBase # type: ignore[import] -from guidellm.config import settings from guidellm.dataset import ColumnInputTypes, load_dataset from guidellm.request.request import GenerationRequest +from guidellm.settings import settings from guidellm.utils import StandardBaseModel __all__ = [ diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index 37bf1fd5..1ca8fb69 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -1,21 +1,46 @@ -from .result import ( - SchedulerRequestInfo, - SchedulerRequestResult, - SchedulerResult, - SchedulerRunInfo, +from .constraints import ( + Constraint, + ConstraintInitializer, + ConstraintsInitializerFactory, + MaxDurationConstraint, + MaxErrorRateConstraint, + MaxErrorsConstraint, + MaxGlobalErrorRateConstraint, + MaxNumberConstraint, + PydanticConstraintInitializer, + SerializableConstraintInitializer, + UnserializableConstraintInitializer, +) +from .objects import ( + BackendInterface, + BackendT, + MeasuredRequestTimings, + MeasuredRequestTimingsT, + MultiTurnRequestT, + RequestSchedulerTimings, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerState, + SchedulerUpdateAction, + SchedulerUpdateActionProgress, ) from .scheduler import Scheduler from .strategy import ( AsyncConstantStrategy, AsyncPoissonStrategy, ConcurrentStrategy, + ConstantRateRequestTimings, + LastCompletionRequestTimings, + NoDelayRequestTimings, + PoissonRateRequestTimings, + ScheduledRequestTimings, SchedulingStrategy, + StrategyT, StrategyType, SynchronousStrategy, ThroughputStrategy, - strategy_display_str, ) -from .types import RequestT, ResponseT from .worker import ( GenerativeRequestsWorker, GenerativeRequestsWorkerDescription, @@ -29,24 +54,46 @@ __all__ = [ "AsyncConstantStrategy", "AsyncPoissonStrategy", + "BackendInterface", + "BackendT", "ConcurrentStrategy", + "ConstantRateRequestTimings", + "Constraint", + "ConstraintInitializer", + "ConstraintsInitializerFactory", "GenerativeRequestsWorker", "GenerativeRequestsWorkerDescription", + "LastCompletionRequestTimings", + "MaxDurationConstraint", + "MaxErrorRateConstraint", + "MaxErrorsConstraint", + "MaxGlobalErrorRateConstraint", + "MaxNumberConstraint", + "MeasuredRequestTimings", + "MeasuredRequestTimingsT", + "MultiTurnRequestT", + "NoDelayRequestTimings", + "PoissonRateRequestTimings", + "PydanticConstraintInitializer", + "RequestSchedulerTimings", "RequestT", "RequestsWorker", "ResolveStatus", "ResponseT", + "ScheduledRequestInfo", + "ScheduledRequestTimings", "Scheduler", - "SchedulerRequestInfo", - "SchedulerRequestResult", - "SchedulerResult", - "SchedulerRunInfo", + "SchedulerState", + "SchedulerUpdateAction", + "SchedulerUpdateActionProgress", "SchedulingStrategy", + "SerializableConstraintInitializer", + "StrategyT", "StrategyType", "SynchronousStrategy", "ThroughputStrategy", + "UnserializableConstraintInitializer", "WorkerDescription", "WorkerProcessRequest", "WorkerProcessResult", - "strategy_display_str", ] diff --git a/src/guidellm/scheduler/constraints.py b/src/guidellm/scheduler/constraints.py new file mode 100644 index 00000000..fd2f082a --- /dev/null +++ b/src/guidellm/scheduler/constraints.py @@ -0,0 +1,990 @@ +""" +Constraint system for scheduler behavior control and request processing limits. + +Provides flexible constraints for managing scheduler behavior with configurable +thresholds based on time, error rates, and request counts. Constraints evaluate +scheduler state and individual requests to determine whether processing should +continue or stop based on predefined limits. The constraint system enables +sophisticated benchmark stopping criteria through composable constraint types. +""" + +from __future__ import annotations + +import time +from abc import ABC, abstractmethod +from typing import Any, Literal, Protocol, runtime_checkable + +from pydantic import Field, field_validator + +from guidellm.scheduler.objects import ( + ScheduledRequestInfo, + SchedulerState, + SchedulerUpdateAction, + SchedulerUpdateActionProgress, +) +from guidellm.settings import settings +from guidellm.utils import InfoMixin, RegistryMixin, StandardBaseModel + +__all__ = [ + "Constraint", + "ConstraintInitializer", + "ConstraintsInitializerFactory", + "MaxDurationConstraint", + "MaxErrorRateConstraint", + "MaxErrorsConstraint", + "MaxGlobalErrorRateConstraint", + "MaxNumberConstraint", + "PydanticConstraintInitializer", + "SerializableConstraintInitializer", + "UnserializableConstraintInitializer", +] + + +@runtime_checkable +class Constraint(Protocol): + """Protocol for constraint evaluation functions that control scheduler behavior.""" + + def __call__( + self, state: SchedulerState, request: ScheduledRequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against scheduler state and request information. + + :param state: Current scheduler state with metrics and timing information + :param request: Individual request information and metadata + :return: Action indicating whether to continue or stop scheduler operations + """ + + +@runtime_checkable +class ConstraintInitializer(Protocol): + """Protocol for constraint initializer factory functions that create constraints.""" + + def create_constraint(self, **kwargs) -> Constraint: + """ + Create a constraint instance from configuration parameters. + + :param kwargs: Configuration parameters for constraint creation + :return: Configured constraint evaluation function + """ + + +@runtime_checkable +class SerializableConstraintInitializer(Protocol): + """Protocol for serializable constraint initializers supporting persistence.""" + + @classmethod + def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: + """ + Validate and process arguments for constraint creation. + + :param args: Positional arguments for constraint configuration + :param kwargs: Keyword arguments for constraint configuration + :return: Validated parameter dictionary for constraint creation + """ + + @classmethod + def model_validate(cls, **kwargs) -> ConstraintInitializer: + """ + Create validated constraint initializer from configuration. + + :param kwargs: Configuration dictionary for initializer creation + :return: Validated constraint initializer instance + """ + + def model_dump(self) -> dict[str, Any]: + """ + Serialize constraint initializer to dictionary format. + + :return: Dictionary representation of constraint initializer + """ + + def create_constraint(self, **kwargs) -> Constraint: + """ + Create constraint instance from this initializer. + + :param kwargs: Additional configuration parameters + :return: Configured constraint evaluation function + """ + + +class ConstraintsInitializerFactory(RegistryMixin[ConstraintInitializer]): + """ + Registry factory for creating and managing constraint initializers. + + Provides centralized access to registered constraint types with support for + creating constraints from configuration dictionaries, simple values, or + pre-configured instances. Handles constraint resolution and type validation + for the scheduler constraint system. + + Example: + :: + from guidellm.scheduler import ConstraintsInitializerFactory + + # Register new constraint type + @ConstraintsInitializerFactory.register("new_constraint") + class NewConstraint: + def create_constraint(self, **kwargs) -> Constraint: + return lambda state, request: SchedulerUpdateAction() + + # Create and use constraint + constraint = ConstraintsInitializerFactory.create_constraint("new_constraint") + """ + + @classmethod + def create(cls, key: str, *args, **kwargs) -> ConstraintInitializer: + """ + Create a constraint initializer for the specified key. + + :param key: Registered constraint initializer key + :param args: Positional arguments for initializer creation + :param kwargs: Keyword arguments for initializer creation + :return: Configured constraint initializer instance + :raises ValueError: If the key is not registered in the factory + """ + if cls.registry is None or key not in cls.registry: + raise ValueError(f"Unknown constraint initializer key: {key}") + + initializer_class = cls.registry[key] + + return ( + initializer_class(*args, **kwargs) # type: ignore[operator] + if not isinstance(initializer_class, type) + or not issubclass(initializer_class, SerializableConstraintInitializer) + else initializer_class( + **initializer_class.validated_kwargs(*args, **kwargs) # type: ignore[misc] + ) + ) + + @classmethod + def serialize(cls, initializer: ConstraintInitializer) -> dict[str, Any]: + """ + Serialize constraint initializer to dictionary format. + + :param initializer: Constraint initializer to serialize + :return: Dictionary representation or unserializable placeholder + """ + if isinstance(initializer, SerializableConstraintInitializer): + return initializer.model_dump() + else: + unserializable = UnserializableConstraintInitializer( + orig_info=InfoMixin.extract_from_obj(initializer) + ) + return unserializable.model_dump() + + @classmethod + def deserialize( + cls, initializer_dict: dict[str, Any] + ) -> SerializableConstraintInitializer: + """ + Deserialize constraint initializer from dictionary format. + + :param initializer_dict: Dictionary representation of constraint initializer + :return: Reconstructed constraint initializer instance + :raises ValueError: If constraint type is unknown or cannot be deserialized + """ + if initializer_dict.get("type_") == "unserializable": + return UnserializableConstraintInitializer.model_validate(initializer_dict) + + if ( + cls.registry is not None + and initializer_dict.get("type_") + and initializer_dict["type_"] in cls.registry + ): + initializer_class = cls.registry[initializer_dict["type_"]] + if hasattr(initializer_class, "model_validate"): + return initializer_class.model_validate(initializer_dict) # type: ignore[return-value] + else: + return initializer_class(**initializer_dict) # type: ignore[return-value,operator] + + raise ValueError( + f"Cannot deserialize unknown constraint initializer: " + f"{initializer_dict.get('type_', 'unknown')}" + ) + + @classmethod + def create_constraint(cls, key: str, *args, **kwargs) -> Constraint: + """ + Create a constraint instance for the specified key. + + :param key: Registered constraint initializer key + :param args: Positional arguments for constraint creation + :param kwargs: Keyword arguments for constraint creation + :return: Configured constraint function ready for evaluation + :raises ValueError: If the key is not registered in the factory + """ + return cls.create(key, *args, **kwargs).create_constraint() + + @classmethod + def resolve( + cls, + initializers: dict[ + str, + Any | dict[str, Any] | Constraint | ConstraintInitializer, + ], + ) -> dict[str, Constraint]: + """ + Resolve mixed constraint specifications to callable constraints. + + :param initializers: Dictionary mapping constraint keys to specifications + :return: Dictionary mapping constraint keys to callable functions + :raises ValueError: If any key is not registered in the factory + """ + constraints = {} + + for key, val in initializers.items(): + if isinstance(val, Constraint): + constraints[key] = val + elif isinstance(val, ConstraintInitializer): + constraints[key] = val.create_constraint() + elif isinstance(val, dict): + constraints[key] = cls.create_constraint(key, **val) + else: + constraints[key] = cls.create_constraint(key, val) + + return constraints + + @classmethod + def resolve_constraints( + cls, + constraints: dict[str, Any | dict[str, Any] | Constraint], + ) -> dict[str, Constraint]: + """ + Resolve constraints from mixed constraint specifications. + + :param constraints: Dictionary mapping constraint keys to specifications + :return: Dictionary mapping constraint keys to callable functions + :raises ValueError: If any constraint key is not registered + """ + resolved_constraints = {} + + for key, val in constraints.items(): + if isinstance(val, Constraint): + resolved_constraints[key] = val + elif isinstance(val, dict): + resolved_constraints[key] = cls.create_constraint(key, **val) + else: + resolved_constraints[key] = cls.create_constraint(key, val) + + return resolved_constraints + + +class PydanticConstraintInitializer(StandardBaseModel, ABC, InfoMixin): + """ + Abstract base for Pydantic-based constraint initializers. + + Provides standardized serialization, validation, and metadata handling for + constraint initializers using Pydantic models. Subclasses implement specific + constraint creation logic while inheriting validation and persistence support. + """ + + type_: str = Field(description="Type identifier for the constraint initializer") + + @property + def info(self) -> dict[str, Any]: + """ + Extract serializable information from this constraint initializer. + + :return: Dictionary containing constraint configuration and metadata + """ + return self.model_dump() + + @classmethod + @abstractmethod + def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: + """ + Validate and process arguments for constraint creation. + + Must be implemented by subclasses to handle their specific parameter patterns + and validation requirements. + + :param args: Positional arguments passed to the constraint + :param kwargs: Keyword arguments passed to the constraint + :return: Validated dictionary of parameters for constraint creation + :raises NotImplementedError: Must be implemented by subclasses + """ + ... + + @abstractmethod + def create_constraint(self, **kwargs) -> Constraint: + """ + Create a constraint instance. + + Must be implemented by subclasses to return their specific constraint type + with appropriate configuration and validation. + + :param kwargs: Additional keyword arguments (usually unused) + :return: Configured constraint instance + :raises NotImplementedError: Must be implemented by subclasses + """ + ... + + +class UnserializableConstraintInitializer(PydanticConstraintInitializer): + """ + Placeholder for constraints that cannot be serialized or executed. + + Represents constraint initializers that failed serialization or contain + non-serializable components. Cannot be executed and raises errors when + invoked to prevent runtime failures from invalid constraint state. + """ + + type_: Literal["unserializable"] = "unserializable" # type: ignore[assignment] + orig_info: dict[str, Any] = Field( + default_factory=dict, + description="Original constraint information before serialization failure", + ) + + @classmethod + def validated_kwargs( + cls, + orig_info: dict[str, Any] | None = None, + **kwargs, # noqa: ARG003 + ) -> dict[str, Any]: + """ + Validate arguments for unserializable constraint creation. + + :param orig_info: Original constraint information before serialization failure + :param kwargs: Additional arguments (ignored) + :return: Validated parameters for unserializable constraint creation + """ + return {"orig_info": orig_info or {}} + + def create_constraint( + self, + **kwargs, # noqa: ARG002 + ) -> Constraint: + """ + Raise error for unserializable constraint creation attempt. + + :param kwargs: Additional keyword arguments (unused) + :raises RuntimeError: Always raised since unserializable constraints + cannot be executed + """ + raise RuntimeError( + "Cannot create constraint from unserializable constraint instance. " + "This constraint cannot be serialized and therefore cannot be executed." + ) + + def __call__( + self, + state: SchedulerState, # noqa: ARG002 + request: ScheduledRequestInfo, # noqa: ARG002 + ) -> SchedulerUpdateAction: + """ + Raise error since unserializable constraints cannot be invoked. + + :param state: Current scheduler state (unused) + :param request: Individual request information (unused) + :raises RuntimeError: Always raised for unserializable constraints + """ + raise RuntimeError( + "Cannot invoke unserializable constraint instance. " + "This constraint was not properly serialized and cannot be executed." + ) + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["max_number", "max_num", "max_requests", "max_req"] +) +class MaxNumberConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on maximum request counts. + + Stops request queuing when created requests reach the limit and stops local + request processing when processed requests reach the limit. Provides progress + tracking based on remaining requests and completion fraction. + """ + + type_: Literal["max_number"] = "max_number" # type: ignore[assignment] + max_num: int | float | list[int | float] = Field( + description="Maximum number of requests allowed before triggering constraint", + ) + current_index: int = Field( + default=-1, description="Current index for list-based max_num values" + ) + + @classmethod + def validated_kwargs( + cls, max_num: int | float | list[int | float], **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxNumberConstraint creation. + + :param max_num: Maximum number of requests to allow + :param kwargs: Supports max_num, max_number, max_requests, max_req, + and optional type_ + :return: Validated dictionary with max_num and type_ fields + """ + aliases = ["max_number", "max_num", "max_requests", "max_req"] + for alias in aliases: + if max_num is None: + max_num = kwargs.get(alias) + + return {"max_num": max_num, "current_index": kwargs.get("current_index", -1)} + + def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return self.model_copy() # type: ignore[return-value] + + def __call__( + self, + state: SchedulerState, + request_info: ScheduledRequestInfo, # noqa: ARG002 + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current scheduler state and request count. + + :param state: Current scheduler state with request counts + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + current_index = max(0, self.current_index) + max_num = ( + self.max_num + if isinstance(self.max_num, (int, float)) + else self.max_num[min(current_index, len(self.max_num) - 1)] + ) + + create_exceeded = state.created_requests >= max_num + processed_exceeded = state.processed_requests >= max_num + remaining_fraction = min( + max(0.0, 1.0 - state.processed_requests / float(max_num)), 1.0 + ) + remaining_requests = max(0, max_num - state.processed_requests) + + return SchedulerUpdateAction( + request_queuing="stop" if create_exceeded else "continue", + request_processing="stop_local" if processed_exceeded else "continue", + metadata={ + "max_number": max_num, + "create_exceeded": create_exceeded, + "processed_exceeded": processed_exceeded, + "created_requests": state.created_requests, + "processed_requests": state.processed_requests, + "remaining_fraction": remaining_fraction, + "remaining_requests": remaining_requests, + }, + progress=SchedulerUpdateActionProgress( + remaining_fraction=remaining_fraction, + remaining_requests=remaining_requests, + ), + ) + + @field_validator("max_num") + @classmethod + def _validate_max_num( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + f"max_num must be set and truthful, received {value} ({val} failed)" + ) + if not isinstance(val, (int, float)) or val <= 0: + raise ValueError( + f"max_num must be a positive num, received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["max_duration", "max_dur", "max_sec", "max_seconds", "max_min", "max_minutes"] +) +class MaxDurationConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on maximum time duration. + + Stops both request queuing and processing when the elapsed time since scheduler + start exceeds the maximum duration. Provides progress tracking based on + remaining time and completion fraction. + """ + + type_: Literal["max_duration"] = "max_duration" # type: ignore[assignment] + max_duration: int | float | list[int | float] = Field( + description="Maximum duration in seconds before triggering constraint" + ) + current_index: int = Field(default=-1, description="Current index in duration list") + + @classmethod + def validated_kwargs( + cls, max_duration: int | float | list[int | float] | None = None, **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxDurationConstraint creation. + + :param max_duration: Maximum duration in seconds + :param kwargs: Supports max_duration, max_dur, max_sec, max_seconds, + max_min, max_minutes, and optional type_ + :return: Validated dictionary with max_duration and type_ fields + """ + seconds_aliases = ["max_dur", "max_sec", "max_seconds"] + for alias in seconds_aliases: + if max_duration is None: + max_duration = kwargs.get(alias) + minutes_aliases = ["max_min", "max_minutes"] + for alias in minutes_aliases: + minutes = kwargs.get(alias) + if minutes is not None and max_duration is None: + max_duration = minutes * 60 + + return { + "max_duration": max_duration, + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return self.model_copy() # type: ignore[return-value] + + def __call__( + self, + state: SchedulerState, + request_info: ScheduledRequestInfo, # noqa: ARG002 + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current scheduler state and elapsed time. + + :param state: Current scheduler state with start time + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + current_index = max(0, self.current_index) + max_duration = ( + self.max_duration + if isinstance(self.max_duration, (int, float)) + else self.max_duration[min(current_index, len(self.max_duration) - 1)] + ) + + current_time = time.time() + elapsed = current_time - state.start_time + duration_exceeded = elapsed >= max_duration + + return SchedulerUpdateAction( + request_queuing="stop" if duration_exceeded else "continue", + request_processing="stop_local" if duration_exceeded else "continue", + metadata={ + "max_duration": max_duration, + "elapsed_time": elapsed, + "duration_exceeded": duration_exceeded, + "start_time": state.start_time, + "current_time": current_time, + }, + progress=SchedulerUpdateActionProgress( + remaining_fraction=max(0.0, 1.0 - elapsed / float(max_duration)), + remaining_duration=max(0.0, max_duration - elapsed), + ), + ) + + @field_validator("max_duration") + @classmethod + def _validate_max_duration( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_duration must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, (int, float)) or val <= 0: + raise ValueError( + "max_duration must be a positive num," + f"received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["max_errors", "max_err", "max_error", "max_errs"] +) +class MaxErrorsConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on absolute error count. + + Stops both request queuing and all request processing when the total number + of errored requests reaches the maximum threshold. Uses global error tracking + across all requests for immediate constraint evaluation. + """ + + type_: Literal["max_errors"] = "max_errors" # type: ignore[assignment] + max_errors: int | float | list[int | float] = Field( + description="Maximum number of errors allowed before triggering constraint", + ) + current_index: int = Field(default=-1, description="Current index in error list") + + @classmethod + def validated_kwargs( + cls, max_errors: int | float | list[int | float] | None = None, **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxErrorsConstraint creation. + + :param max_errors: Maximum number of errors to allow + :param kwargs: Supports max_errors, max_err, max_error, max_errs, + and optional type_ + :return: Validated dictionary with max_errors and type_ fields + """ + aliases = ["max_errors", "max_err", "max_error", "max_errs"] + for alias in aliases: + if max_errors is None: + max_errors = kwargs.get(alias) + + return { + "max_errors": max_errors, + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return self.model_copy() # type: ignore[return-value] + + def __call__( + self, + state: SchedulerState, + request_info: ScheduledRequestInfo, # noqa: ARG002 + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current error count. + + :param state: Current scheduler state with error counts + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + current_index = max(0, self.current_index) + max_errors = ( + self.max_errors + if isinstance(self.max_errors, (int, float)) + else self.max_errors[min(current_index, len(self.max_errors) - 1)] + ) + errors_exceeded = state.errored_requests >= max_errors + + return SchedulerUpdateAction( + request_queuing="stop" if errors_exceeded else "continue", + request_processing="stop_all" if errors_exceeded else "continue", + metadata={ + "max_errors": max_errors, + "errors_exceeded": errors_exceeded, + "current_errors": state.errored_requests, + }, + ) + + @field_validator("max_errors") + @classmethod + def _validate_max_errors( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_errors must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, (int, float)) or val <= 0: + raise ValueError( + f"max_errors must be a positive num,received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["max_error_rate", "max_err_rate", "max_errors_rate"] +) +class MaxErrorRateConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on sliding window error rate. + + Tracks error status of recent requests in a sliding window and stops all + processing when the error rate exceeds the threshold. Only applies the + constraint after processing enough requests to fill the minimum window size + for statistical significance. + """ + + type_: Literal["max_error_rate"] = "max_error_rate" # type: ignore[assignment] + max_error_rate: int | float | list[int | float] = Field( + description="Maximum error rate allowed (0.0, 1.0)" + ) + window_size: int | float = Field( + default=30, + gt=0, + description="Size of sliding window for calculating error rate", + ) + error_window: list[bool] = Field( + default_factory=list, + description="Sliding window tracking error status of recent requests", + ) + current_index: int = Field( + default=-1, description="Current index in the error window" + ) + + @classmethod + def validated_kwargs( + cls, max_error_rate: int | float | list[int | float], **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxErrorRateConstraint creation. + + :param max_error_rate: Maximum error rate to allow + :param kwargs: Supports max_error_rate, max_err_rate, max_errors_rate, + optional window_size, and optional type_ + :return: Validated dictionary with max_error_rate, window_size, + and type_ fields + """ + aliases = ["max_error_rate", "max_err_rate", "max_errors_rate"] + for alias in aliases: + if max_error_rate is None: + max_error_rate = kwargs.get(alias) + + return { + "max_error_rate": max_error_rate, + "window_size": kwargs.get( + "window_size", settings.constraint_error_window_size + ), + "error_window": kwargs.get("error_window", []), + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 + """ + Create a new instance of MaxErrorRateConstraint (due to stateful window). + + :param kwargs: Additional keyword arguments (unused) + :return: New instance of the constraint + """ + self.current_index += 1 + + return self.model_copy() # type: ignore[return-value] + + def __call__( + self, state: SchedulerState, request_info: ScheduledRequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against sliding window error rate. + + :param state: Current scheduler state with request counts + :param request_info: Individual request with completion status + :return: Action indicating whether to continue or stop operations + """ + current_index = max(0, self.current_index) + max_error_rate = ( + self.max_error_rate + if isinstance(self.max_error_rate, (int, float)) + else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] + ) + + if request_info.status in ["completed", "errored", "cancelled"]: + self.error_window.append(request_info.status == "errored") + if len(self.error_window) > self.window_size: + self.error_window.pop(0) + + error_count = sum(self.error_window) + window_requests = len(self.error_window) + error_rate = ( + error_count / float(window_requests) if window_requests > 0 else 0.0 + ) + exceeded_min_processed = state.processed_requests >= self.window_size + exceeded_error_rate = error_rate >= max_error_rate + + return SchedulerUpdateAction( + request_queuing=( + "stop" if exceeded_min_processed and exceeded_error_rate else "continue" + ), + request_processing=( + "stop_all" + if exceeded_min_processed and exceeded_error_rate + else "continue" + ), + metadata={ + "max_error_rate": max_error_rate, + "window_size": self.window_size, + "error_count": error_count, + "processed_count": state.processed_requests, + "current_window_size": len(self.error_window), + "current_error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + }, + ) + + @field_validator("max_error_rate") + @classmethod + def _validate_max_error_rate( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_error_rate must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, (int, float)) or val <= 0 or val >= 1: + raise ValueError( + "max_error_rate must be a number between 0 and 1," + f"received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["max_global_error_rate", "max_global_err_rate", "max_global_errors_rate"] +) +class MaxGlobalErrorRateConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on global error rate. + + Calculates error rate across all processed requests and stops all processing + when the rate exceeds the threshold. Only applies the constraint after + processing the minimum number of requests to ensure statistical significance + for global error rate calculations. + """ + + type_: Literal["max_global_error_rate"] = "max_global_error_rate" # type: ignore[assignment] + max_error_rate: int | float = Field( + description="Maximum error rate allowed (0.0 to 1.0)" + ) + min_processed: int | float | None = Field( + default=30, + gt=0, + description="Minimum requests processed before applying error rate constraint", + ) + current_index: int = Field( + default=-1, description="Current index for list-based max_error_rate values" + ) + + @classmethod + def validated_kwargs( + cls, max_error_rate: int | float | list[int | float], **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxGlobalErrorRateConstraint creation. + + :param max_error_rate: Maximum error rate to allow + :param kwargs: Supports max_global_error_rate, max_global_err_rate, + max_global_errors_rate, optional min_processed, and optional type_ + :return: Validated dictionary with max_error_rate, min_processed, + and type_ fields + """ + for alias in [ + "max_global_error_rate", + "max_global_err_rate", + "max_global_errors_rate", + ]: + if max_error_rate is None: + max_error_rate = kwargs.get(alias) + + return { + "max_error_rate": max_error_rate, + "min_processed": kwargs.get( + "min_processed", settings.constraint_error_min_processed + ), + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return self.model_copy() # type: ignore[return-value] + + def __call__( + self, + state: SchedulerState, + request_info: ScheduledRequestInfo, # noqa: ARG002 + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against global error rate. + + :param state: Current scheduler state with global request and error counts + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + current_index = max(0, self.current_index) + max_error_rate = ( + self.max_error_rate + if isinstance(self.max_error_rate, (int, float)) + else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] + ) + + exceeded_min_processed = ( + self.min_processed is None or state.processed_requests >= self.min_processed + ) + error_rate = ( + state.errored_requests / float(state.processed_requests) + if state.processed_requests > 0 + else 0.0 + ) + exceeded_error_rate = error_rate >= max_error_rate + should_stop = exceeded_min_processed and exceeded_error_rate + + return SchedulerUpdateAction( + request_queuing="stop" if should_stop else "continue", + request_processing="stop_all" if should_stop else "continue", + metadata={ + "max_error_rate": max_error_rate, + "min_processed": self.min_processed, + "processed_requests": state.processed_requests, + "errored_requests": state.errored_requests, + "error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + }, + ) + + @field_validator("max_error_rate") + @classmethod + def _validate_max_error_rate( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_error_rate must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, (int, float)) or val <= 0 or val >= 1: + raise ValueError( + "max_error_rate must be a number between 0 and 1," + f"received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/objects.py new file mode 100644 index 00000000..2cae2abd --- /dev/null +++ b/src/guidellm/scheduler/objects.py @@ -0,0 +1,444 @@ +""" +Core data structures and interfaces for the GuideLLM scheduler system. + +Provides type-safe abstractions for distributed request processing, timing +measurements, and backend interfaces for benchmarking operations. Central to +the scheduler architecture, enabling request lifecycle tracking, backend +coordination, and state management across distributed worker processes. +""" + +from __future__ import annotations + +import time +import uuid +from collections.abc import AsyncIterator +from typing import ( + Any, + Generic, + Literal, + Protocol, + TypeVar, + Union, +) + +from pydantic import Field, computed_field +from typing_extensions import TypeAliasType, TypedDict + +from guidellm.utils import StandardBaseModel + +__all__ = [ + "BackendInterface", + "BackendT", + "MeasuredRequestTimings", + "MeasuredRequestTimingsT", + "MultiTurnRequestT", + "RequestSchedulerTimings", + "RequestT", + "ResponseT", + "ScheduledRequestInfo", + "SchedulerState", + "SchedulerUpdateAction", + "SchedulerUpdateActionProgress", +] + +RequestT = TypeVar("RequestT") +"""Generic request object type for scheduler processing.""" + +ResponseT = TypeVar("ResponseT") +"""Generic response object type returned by backend processing.""" + +MultiTurnRequestT = TypeAliasType( + "MultiTurnRequestT", + Union[ + list[Union[RequestT, tuple[RequestT, float]]], + tuple[Union[RequestT, tuple[RequestT, float]]], + ], + type_params=(RequestT,), +) +"""Multi-turn request structure supporting conversation history with optional delays.""" + + +class RequestSchedulerTimings(StandardBaseModel): + """ + Scheduler-level timing measurements for request lifecycle tracking. + All timestamps are expected to be in Unix time (seconds since epoch). + """ + + targeted_start: float | None = Field( + default=None, + description="When the request was initially targeted for execution", + ) + queued: float | None = Field( + default=None, + description="When the request was placed into the processing queue", + ) + dequeued: float | None = Field( + default=None, + description="When the request was removed from the queue for processing", + ) + scheduled_at: float | None = Field( + default=None, description="When the request was scheduled for processing" + ) + resolve_start: float | None = Field( + default=None, description="When backend resolution of the request began" + ) + resolve_end: float | None = Field( + default=None, description="When backend resolution of the request completed" + ) + finalized: float | None = Field( + default=None, + description="When the request was processed/acknowledged by the scheduler", + ) + + +class MeasuredRequestTimings(StandardBaseModel): + """ + Base timing measurements for backend request processing. + All timestamps are expected to be in Unix time (seconds since epoch). + """ + + request_start: float | None = Field( + default=None, description="When the backend began processing the request" + ) + request_end: float | None = Field( + default=None, description="When the backend completed processing the request" + ) + + +MeasuredRequestTimingsT = TypeVar( + "MeasuredRequestTimingsT", bound=MeasuredRequestTimings +) +"""Generic timing measurements type for backend-specific request processing.""" + + +class ScheduledRequestInfo(StandardBaseModel, Generic[MeasuredRequestTimingsT]): + """ + Complete request information including status, timings, and metadata. + + Central data structure for tracking request lifecycle from creation through + completion, containing scheduling metadata, timing measurements, and processing + status. Used by scheduler components to coordinate request processing across + distributed worker processes. + + Example: + :: + from guidellm.scheduler.objects import ScheduledRequestInfo + + # Create request info with automatic ID generation + request_info = ScheduledRequestInfo() + request_info.status = "in_progress" + request_info.scheduler_timings.queued = time.time() + + # Check processing completion + if request_info.completed_at: + duration = request_info.completed_at - request_info.started_at + """ + + request_id: str = Field( + description="Unique identifier for the request", + default_factory=lambda: str(uuid.uuid4()), + ) + status: Literal[ + "queued", "pending", "in_progress", "completed", "errored", "cancelled" + ] = Field(description="Current processing status of the request", default="queued") + scheduler_node_id: int = Field( + description="ID/rank of the scheduler node handling the request", + default=-1, + ) + scheduler_process_id: int = Field( + description="ID/rank of the node's scheduler process handling the request", + default=-1, + ) + scheduler_start_time: float = Field( + description="Unix timestamp for the local time when scheduler processing began", + default=-1, + ) + + error: str | None = Field( + default=None, description="Error message if the request.status is 'errored'" + ) + scheduler_timings: RequestSchedulerTimings = Field( + default_factory=RequestSchedulerTimings, + description="Scheduler-level timing measurements for request lifecycle", + ) + request_timings: MeasuredRequestTimingsT | None = Field( + default=None, + description="Backend-specific timing measurements for request processing", + ) + + @computed_field # type: ignore[misc] + @property + def started_at(self) -> float | None: + """ + Get the effective request processing start time. + + :return: Unix timestamp when processing began, or None if not started. + """ + request_start = ( + self.request_timings.request_start if self.request_timings else None + ) + + return request_start or self.scheduler_timings.resolve_start + + @computed_field # type: ignore[misc] + @property + def completed_at(self) -> float | None: + """ + Get the effective request processing completion time. + + :return: Unix timestamp when processing completed, or None if not completed. + """ + request_end = self.request_timings.request_end if self.request_timings else None + + return request_end or self.scheduler_timings.resolve_end + + def model_copy(self, **kwargs) -> ScheduledRequestInfo: # type: ignore[override] # noqa: ARG002 + """ + Create a deep copy of the request info with copied timing objects. + + :return: New ScheduledRequestInfo instance with independent timing objects + """ + return super().model_copy( + update={ + "scheduler_timings": self.scheduler_timings.model_copy(), + "request_timings": ( + self.request_timings.model_copy() if self.request_timings else None + ), + }, + deep=False, + ) + + +class BackendInterface(Protocol, Generic[RequestT, MeasuredRequestTimingsT, ResponseT]): + """ + Abstract interface for request processing backends. + + Defines the contract for backend implementations that process requests within + the scheduler system. Backends handle initialization, validation, processing, + and shutdown lifecycle management. Must ensure all properties are pickleable + before process_startup is invoked for multi-process environments. + + Example: + :: + from guidellm.scheduler.objects import BackendInterface + + class CustomBackend(BackendInterface): + @property + def processes_limit(self) -> int: + return 4 + + async def resolve(self, request, request_info, history=None): + # Process request and yield responses + yield response, updated_request_info + """ + + @property + def processes_limit(self) -> int | None: + """ + :return: Maximum worker processes supported, or None if unlimited + """ + + @property + def requests_limit(self) -> int | None: + """ + :return: Maximum concurrent requests supported, or None if unlimited + """ + + @property + def info(self) -> dict[str, Any]: + """ + :return: Backend metadata including model initialization and configuration + """ + + async def process_startup(self) -> None: + """ + Perform backend initialization and startup procedures. + + :raises: Implementation-specific exceptions for startup failures. + """ + + async def validate(self) -> None: + """ + Validate backend configuration and operational status. + + :raises: Implementation-specific exceptions for validation failures. + """ + + async def process_shutdown(self) -> None: + """ + Perform backend cleanup and shutdown procedures. + + :raises: Implementation-specific exceptions for shutdown failures. + """ + + async def resolve( + self, + request: RequestT, + request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + history: list[tuple[RequestT, ResponseT]] | None = None, + ) -> AsyncIterator[tuple[ResponseT, ScheduledRequestInfo[MeasuredRequestTimingsT]]]: + """ + Process a request and yield incremental response updates. + + :param request: The request object to process + :param request_info: Scheduling metadata and timing information + :param history: Optional conversation history for multi-turn requests + :yield: Tuples of (response, updated_request_info) for each response chunk + :raises: Implementation-specific exceptions for processing failures + """ + + +BackendT = TypeVar("BackendT", bound=BackendInterface) +"""Generic backend interface type for request processing.""" + + +class SchedulerUpdateActionProgress(TypedDict, total=False): + """ + Progress information for a scheduler update action. + + Optional progress tracking data that provides estimates for remaining work + in scheduler operations. Used by constraints and monitoring systems to + track execution progress and make termination decisions. + """ + + remaining_fraction: float | None + remaining_requests: float | None + remaining_duration: float | None + + +class SchedulerUpdateAction(StandardBaseModel): + """ + Scheduler behavior control directives and actions. + + Encapsulates control signals for scheduler operations including request + queuing and processing directives. Used by constraints to communicate + termination conditions and progress information to scheduler components. + + Example: + :: + from guidellm.scheduler.objects import SchedulerUpdateAction + + # Signal to stop queuing but continue processing + action = SchedulerUpdateAction( + request_queuing="stop", + request_processing="continue", + metadata={"reason": "max_requests_reached"} + ) + """ + + request_queuing: Literal["continue", "stop"] = Field( + default="continue", description="Action to take for request queuing operations" + ) + request_processing: Literal["continue", "stop_local", "stop_all"] = Field( + default="continue", + description="Action to take for request processing operations", + ) + metadata: dict[str, Any] = Field( + default_factory=dict, + description="Additional context and data for the scheduler action", + ) + progress: SchedulerUpdateActionProgress = Field( + default_factory=lambda: SchedulerUpdateActionProgress(), + description="Progress information for the scheduler action", + ) + + +class SchedulerState(StandardBaseModel): + """ + Scheduler operation state tracking and statistics. + + Comprehensive state container for tracking scheduler execution progress, + request counts, timing information, and constraint enforcement. Central + to scheduler coordination and provides real-time metrics for monitoring + and decision-making across distributed worker processes. + + Example: + :: + from guidellm.scheduler.objects import SchedulerState + + # Initialize scheduler state + state = SchedulerState(node_id=0, num_processes=4) + + # Track request processing + state.created_requests += 1 + state.queued_requests += 1 + + # Monitor completion progress + completion_rate = state.processed_requests / state.created_requests + """ + + node_id: int = Field( + description="Unique identifier for this scheduler node", default=-1 + ) + num_processes: int = Field( + description="Number of worker processes in this scheduler", default=-1 + ) + start_time: float = Field( + description="Unix timestamp when the scheduler started", + default_factory=time.time, + ) + end_time: float | None = Field( + default=None, description="Unix timestamp when the scheduler stopped" + ) + end_queuing_time: float | None = Field( + default=None, description="When request queuing stopped, if applicable" + ) + end_queuing_constraints: dict[str, SchedulerUpdateAction] = Field( + default_factory=dict, + description="Constraints that triggered queuing termination", + ) + end_processing_time: float | None = Field( + default=None, description="When request processing stopped, if applicable" + ) + end_processing_constraints: dict[str, SchedulerUpdateAction] = Field( + default_factory=dict, + description="Constraints that triggered processing termination", + ) + scheduler_constraints: dict[str, SchedulerUpdateAction] = Field( + default_factory=dict, + description=( + "The latest state from all constraints applied during the scheduler run" + ), + ) + + remaining_fraction: float | None = Field( + default=None, + description=( + "Estimated fraction for the remaining progress of the run, if known" + ), + ) + remaining_requests: int | None = Field( + default=None, + description="Estimated number of requests remaining to be processed, if known", + ) + remaining_duration: float | None = Field( + default=None, + description=( + "Estimated time remaining in seconds for the scheduler run, if known" + ), + ) + + created_requests: int = Field( + default=0, description="Total number of requests created" + ) + queued_requests: int = Field( + default=0, description="Total number of requests queued for processing" + ) + pending_requests: int = Field( + default=0, description="Number of requests currently pending processing" + ) + processing_requests: int = Field( + default=0, description="Number of requests currently being processed" + ) + processed_requests: int = Field( + default=0, description="Total number of requests that completed processing" + ) + successful_requests: int = Field( + default=0, description="Number of requests that completed successfully" + ) + errored_requests: int = Field( + default=0, description="Number of requests that failed with errors" + ) + cancelled_requests: int = Field( + default=0, description="Number of requests that were cancelled" + ) diff --git a/src/guidellm/scheduler/result.py b/src/guidellm/scheduler/result.py deleted file mode 100644 index 0cca530b..00000000 --- a/src/guidellm/scheduler/result.py +++ /dev/null @@ -1,137 +0,0 @@ -from typing import ( - Generic, - Literal, - Optional, -) - -from guidellm.scheduler.strategy import SchedulingStrategy -from guidellm.scheduler.types import RequestT, ResponseT -from guidellm.utils import StandardBaseModel - -__all__ = [ - "SchedulerRequestInfo", - "SchedulerRequestResult", - "SchedulerResult", - "SchedulerRunInfo", -] - - -class SchedulerRunInfo(StandardBaseModel): - """ - Information about the current run of the scheduler. - This class holds metadata about the scheduling run, - including the start and end times, the number of processes, - and the scheduling strategy used. - It also tracks the number of requests created, queued, pending, - and completed during the run. - - :param start_time: The start time of the scheduling run. - :param end_time: The end time of the scheduling run; - if None, then this will be math.inf. - :param end_number: The maximum number of requests to be processed; - if None, then this will be math.inf. - :param processes: The number of processes used in the scheduling run. - :param strategy: The scheduling strategy used in the run. - This should be an instance of SchedulingStrategy. - :param created_requests: The number of requests created during the run. - :param queued_requests: The number of requests queued during the run. - :param scheduled_requests: The number of requests scheduled during the run. - (requests pending being sent to the worker but recieved by a process) - :param processing_requests: The number of requests actively being run. - :param completed_requests: The number of requests completed during the run. - """ - - start_time: float - end_time: float - end_number: float - processes: int - strategy: SchedulingStrategy - - created_requests: int = 0 - queued_requests: int = 0 - scheduled_requests: int = 0 - processing_requests: int = 0 - completed_requests: int = 0 - - -class SchedulerRequestInfo(StandardBaseModel): - """ - Information about a specific request run through the scheduler. - This class holds metadata about the request, including - the targeted start time, queued time, start time, end time, - and the process ID that handled the request. - - :param targeted_start_time: The targeted start time for the request (time.time()). - :param queued_time: The time the request was queued (time.time()). - :param scheduled_time: The time the request was scheduled (time.time()) - (any sleep time before the request was sent to the worker). - :param worker_start: The time the worker started processing request (time.time()). - :param worker_end: The time the worker finished processing request. (time.time()). - :param process_id: The ID of the underlying process that handled the request. - """ - - requested: bool = False - completed: bool = False - errored: bool = False - canceled: bool = False - - targeted_start_time: float = -1 - queued_time: float = -1 - dequeued_time: float = -1 - scheduled_time: float = -1 - worker_start: float = -1 - request_start: float = -1 - request_end: float = -1 - worker_end: float = -1 - process_id: int = -1 - - -class SchedulerResult(StandardBaseModel): - """ - The yielded, iterative result for a scheduler run. - These are triggered on the start and end of the run, - as well as on the start and end of each request. - Depending on the type, it will hold the request and response - along with information and statistics about the request and general run. - - :param type_: The type of the result, which can be one of: - - "run_start": Indicates the start of the run. - - "run_complete": Indicates the completion of the run (teardown happens after). - - "request_start": Indicates the start of a request. - - "request_complete": Indicates the completion of a request. - :param request: The request that was processed. - :param response: The response from the worker for the request. - :param request_info: Information about the request, including - the targeted start time, queued time, start time, end time, - and the process ID that handled the request. - :param run_info: Information about the current run of the scheduler, - including the start and end times, the number of processes, - and the scheduling strategy used. - It also tracks the number of requests created, queued, pending, - and completed during the run. - """ - - pydantic_type: Literal["scheduler_result"] = "scheduler_result" - type_: Literal[ - "run_start", - "run_complete", - "request_scheduled", - "request_start", - "request_complete", - ] - run_info: SchedulerRunInfo - - -class SchedulerRequestResult( - SchedulerResult, - Generic[RequestT, ResponseT], -): - pydantic_type: Literal["scheduler_request_result"] = "scheduler_request_result" # type: ignore[assignment] - type_: Literal[ - "request_scheduled", - "request_start", - "request_complete", - ] - request: RequestT - request_info: SchedulerRequestInfo - response: Optional[ResponseT] = None diff --git a/src/guidellm/scheduler/scheduler.py b/src/guidellm/scheduler/scheduler.py index 06203827..f051a564 100644 --- a/src/guidellm/scheduler/scheduler.py +++ b/src/guidellm/scheduler/scheduler.py @@ -9,24 +9,19 @@ Any, Generic, Optional, - Union, ) from loguru import logger -from guidellm.config import settings -from guidellm.scheduler.result import ( - SchedulerRequestResult, - SchedulerResult, - SchedulerRunInfo, -) +from guidellm.scheduler.objects import RequestT, ResponseT from guidellm.scheduler.strategy import SchedulingStrategy -from guidellm.scheduler.types import RequestT, ResponseT from guidellm.scheduler.worker import ( RequestsWorker, WorkerProcessRequest, WorkerProcessResult, ) +from guidellm.settings import settings +from guidellm.utils import StandardBaseDict __all__ = ["Scheduler"] @@ -70,9 +65,7 @@ async def run( scheduling_strategy: SchedulingStrategy, max_number: Optional[int] = None, max_duration: Optional[float] = None, - ) -> AsyncGenerator[ - Union[SchedulerResult, SchedulerRequestResult[RequestT, ResponseT]], None - ]: + ) -> AsyncGenerator[Any, None]: """ The main method that runs the scheduler. This method is a generator that yields SchedulerResult objects @@ -126,7 +119,7 @@ async def run( run_info, requests_iter, times_iter = self._run_setup( futures, scheduling_strategy, max_number, max_duration ) - yield SchedulerResult( + yield StandardBaseDict( type_="run_start", run_info=run_info, ) @@ -166,7 +159,7 @@ async def run( except Exception as err: raise RuntimeError(f"Scheduler run failed: {err}") from err - yield SchedulerResult( + yield StandardBaseDict( type_="run_complete", run_info=run_info, ) @@ -249,7 +242,7 @@ def _run_setup( scheduling_strategy: SchedulingStrategy, max_number: Optional[int], max_duration: Optional[float], - ) -> tuple[SchedulerRunInfo, Iterator[Any], Iterator[float]]: + ) -> tuple[StandardBaseDict, Iterator[Any], Iterator[float]]: requests_iter = iter(self.request_loader) start_time = time.time() times_iter = iter(scheduling_strategy.request_times()) @@ -270,7 +263,7 @@ def _run_setup( "scheduler will run indefinitely until the request loader is exhausted." ) - info = SchedulerRunInfo( + info = StandardBaseDict( start_time=start_time, end_time=end_time, end_number=end_number, @@ -285,7 +278,7 @@ def _add_requests( requests_iter: Optional[Iterator[Any]], times_iter: Iterator[float], requests_queue: multiprocessing.Queue, - run_info: SchedulerRunInfo, + run_info: StandardBaseDict, ) -> Optional[Iterator[Any]]: if requests_iter is not None: try: @@ -325,8 +318,8 @@ def _add_requests( def _check_result_ready( self, responses_queue: multiprocessing.Queue, - run_info: SchedulerRunInfo, - ) -> Optional[SchedulerRequestResult[RequestT, ResponseT]]: + run_info: StandardBaseDict, + ) -> Optional[StandardBaseDict]: try: process_response: WorkerProcessResult[RequestT, ResponseT] = ( responses_queue.get_nowait() @@ -338,7 +331,7 @@ def _check_result_ready( run_info.queued_requests -= 1 run_info.scheduled_requests += 1 - return SchedulerRequestResult( + return StandardBaseDict( type_="request_scheduled", run_info=run_info, request=process_response.request, @@ -350,7 +343,7 @@ def _check_result_ready( run_info.scheduled_requests -= 1 run_info.processing_requests += 1 - return SchedulerRequestResult( + return StandardBaseDict( type_="request_start", run_info=run_info, request=process_response.request, @@ -362,7 +355,7 @@ def _check_result_ready( run_info.processing_requests -= 1 run_info.completed_requests += 1 - return SchedulerRequestResult( + return StandardBaseDict( type_="request_complete", run_info=run_info, request=process_response.request, diff --git a/src/guidellm/scheduler/strategy.py b/src/guidellm/scheduler/strategy.py index d4c065da..8c791671 100644 --- a/src/guidellm/scheduler/strategy.py +++ b/src/guidellm/scheduler/strategy.py @@ -1,493 +1,700 @@ +""" +Request scheduling strategies for controlling how benchmark requests are processed. + +This module provides timing implementations and concrete strategies that control request +concurrency, timing patterns, and throughput characteristics to simulate real-world +usage scenarios. The scheduling system separates timing logic from strategy constraints, +enabling flexible combination of timing behaviors with process and concurrency limits. +""" + +from __future__ import annotations + import math -import os import random import time -from collections.abc import Generator -from typing import ( - Literal, - Optional, - Union, -) +from abc import ABC, abstractmethod +from typing import Annotated, ClassVar, Literal, TypeVar -from pydantic import Field +from pydantic import Field, PrivateAttr -from guidellm.config import settings -from guidellm.utils import StandardBaseModel +from guidellm.scheduler.objects import ScheduledRequestInfo +from guidellm.utils import InfoMixin, PydanticClassRegistryMixin, StandardBaseModel __all__ = [ "AsyncConstantStrategy", "AsyncPoissonStrategy", "ConcurrentStrategy", + "ConstantRateRequestTimings", + "LastCompletionRequestTimings", + "NoDelayRequestTimings", + "PoissonRateRequestTimings", + "ScheduledRequestTimings", "SchedulingStrategy", + "StrategyT", "StrategyType", "SynchronousStrategy", "ThroughputStrategy", - "strategy_display_str", ] -StrategyType = Literal["synchronous", "concurrent", "throughput", "constant", "poisson"] +StrategyType = Annotated[ + Literal["synchronous", "concurrent", "throughput", "constant", "poisson"], + "Valid strategy type identifiers for scheduling request patterns", +] + +def _exponential_decay_tau(max_progress: float, convergence: float = 0.99) -> float: + """ + Calculate tau value for exponential decay to reach target progress level. -class SchedulingStrategy(StandardBaseModel): + :param max_progress: The max progress value to reach + :param convergence: The target convergence level for reaching max_progress + :return: The calculated tau value for the given max_progress and convergence """ - An abstract base class for scheduling strategies. - This class defines the interface for scheduling requests and provides - a common structure for all scheduling strategies. - Subclasses should implement the `request_times` method to provide - specific scheduling behavior. - - :param type_: The type of scheduling strategy to use. - This should be one of the predefined strategy types. + return max_progress / (-math.log(1 - convergence)) + + +def _exponential_decay_fraction(progress: float, tau: float = 1.0) -> float: """ + Calculate completion fraction based on exponential decay curve. - type_: Literal["strategy"] = Field( - description="The type of scheduling strategy schedule requests with.", + :param progress: The current progress value (>=0) + :param tau: The scale factor for the exponential decay + :return: The fraction of completion based on exponential decay (0 -> 1) + """ + return 1 - math.exp(-progress / tau) + + +class ScheduledRequestTimings(StandardBaseModel, ABC): + """ + Abstract base class for controlling when requests are scheduled. + + Defines the interface for timing implementations that determine request scheduling + behavior. Different implementations provide various patterns like synchronous, + constant-rate, or stochastic scheduling to simulate real-world scenarios. + """ + + @abstractmethod + def next_offset(self) -> float: + """ + Calculate the time offset for the next request to be scheduled. + + :return: The offset in seconds from scheduler start time for next request + """ + + @abstractmethod + def request_completed(self, request_info: ScheduledRequestInfo): + """ + Handle request completion and update internal timing state. + + :param request_info: Information about the completed request including + timing details and completion status + """ + + +class LastCompletionRequestTimings(ScheduledRequestTimings): + """ + Timing implementation for synchronous and concurrent scheduling strategies. + + Schedules the next request immediately after the last request completes, enabling + sequential or limited concurrent processing with completion-based timing control. + """ + + offset: float = Field( + default=0.0, + description="Current time offset in seconds from scheduler start time", + ) + startup_requests: int = Field( + default=0, + description="Number of initial requests to schedule with equal spacing", + ge=0, + ) + startup_requests_delay: float = Field( + default=0.0, + description="Delay in seconds between startup requests", + ge=0, ) + _requests_count: int = PrivateAttr(0) - @property - def processing_mode(self) -> Literal["sync", "async"]: + def next_offset(self) -> float: """ - The processing mode for the scheduling strategy, either 'sync' or 'async'. - This property determines how the worker processes are setup: - either to run synchronously with one request at a time or asynchronously. - This property should be implemented by subclasses to return - the appropriate processing mode. + Get the current offset value and apply startup delay if applicable. - :return: The processing mode for the scheduling strategy, - either 'sync' or 'async'. + :return: The current offset value in seconds from scheduler start time """ - return "async" + self._requests_count += 1 - @property - def processes_limit(self) -> int: + if self._requests_count <= self.startup_requests: + self.offset += self.startup_requests_delay + + return self.offset + + def request_completed(self, request_info: ScheduledRequestInfo): """ - The limit on the number of worker processes for the scheduling strategy. - It determines how many worker processes are created - for the scheduling strategy and must be implemented by subclasses. + Update timing state based on the completed request. - :return: The number of processes for the scheduling strategy. + :param request_info: Information about the completed request """ - cpu_cores = os.cpu_count() or 1 + if ( + self._requests_count > self.startup_requests + and request_info.completed_at is not None + ): + # set the next sync offset to the time when the previous request completed + self.offset = request_info.completed_at - request_info.scheduler_start_time - return min(max(1, cpu_cores - 1), settings.max_worker_processes) - @property - def queued_requests_limit(self) -> Optional[int]: +class NoDelayRequestTimings(ScheduledRequestTimings): + """ + Timing implementation for throughput-maximizing scheduling strategies. + + Schedules requests with minimal delay to achieve maximum throughput, with optional + startup ramping to gradually increase request processing during initialization. + """ + + offset: float = Field( + default=0.0, + description="Base time offset in seconds from scheduler start time", + ge=0, + ) + startup_duration: float = Field( + default=0.0, + description="Duration in seconds for gradual startup ramp", + ge=0, + ) + startup_target_requests: int = Field( + default=1, + description="Target number of requests to converge to during startup", + gt=0, + ) + startup_convergence: float = Field( + default=0.99, + description="Target convergence rate during startup phase", + ) + _start_time: float | None = PrivateAttr(None) + _requests_count: int = PrivateAttr(0) + + def next_offset(self) -> float: """ - The maximum number of queued requests for the scheduling strategy. - It determines how many requests can be queued at one time - for the scheduling strategy and must be implemented by subclasses. + Calculate offset with optional startup adjustment. - :return: The maximum number of queued requests for the scheduling strategy. + :return: Static offset plus any startup adjustment """ - return settings.max_concurrency + if self._start_time is None: + self._start_time = time.time() - @property - def processing_requests_limit(self) -> int: + self._requests_count += 1 + elapsed = time.time() - self._start_time + + if self.startup_duration > 0 and elapsed < self.startup_duration: + startup_percent = _exponential_decay_fraction( + self._requests_count, + _exponential_decay_tau( + self.startup_target_requests, self.startup_convergence + ), + ) + else: + startup_percent = 1.0 + + return self.offset + startup_percent * self.startup_duration + + def request_completed(self, request_info: ScheduledRequestInfo): """ - The maximum number of processing requests for the scheduling strategy. - It determines how many requests can be processed at one time - for the scheduling strategy and must be implemented by subclasses. + Handle request completion (no action needed for throughput strategy). - :return: The maximum number of processing requests for the scheduling strategy. + :param request_info: Information about the completed request (unused) """ - return settings.max_concurrency - def request_times(self) -> Generator[float, None, None]: + +class ConstantRateRequestTimings(ScheduledRequestTimings): + """ + Timing implementation for constant-rate scheduling strategies. + + Schedules requests at a fixed rate with evenly spaced intervals to provide + predictable timing behavior for steady-state load simulation. + """ + + rate: float = Field( + description="Target rate in requests per second", + gt=0, + ) + offset: float = Field( + default=0.0, + description="Base time offset in seconds from scheduler start time", + ge=0, + ) + _requests_count: int = PrivateAttr(0) + + def next_offset(self) -> float: """ - A generator that yields timestamps for when requests should be sent. - This method should be implemented by subclasses to provide specific - scheduling behavior. + Calculate the offset for the next request at a constant rate. - :return: A generator that yields timestamps for request scheduling - or -1 for requests that should be sent immediately. + :return: The offset in seconds for the next request """ - raise NotImplementedError("Subclasses must implement request_times() method.") + num_requests = self._requests_count + self._requests_count += 1 + interval = 1.0 / self.rate + return self.offset + interval * num_requests -class SynchronousStrategy(SchedulingStrategy): + def request_completed(self, request_info: ScheduledRequestInfo): + """ + Handle request completion (no action needed for constant rate strategy). + + :param request_info: Information about the completed request (unused) + """ + + +class PoissonRateRequestTimings(ScheduledRequestTimings): """ - A class representing a synchronous scheduling strategy. - This strategy schedules requests synchronously, one at a time, - with the maximum rate possible. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for synchronous scheduling. - - :param type_: The synchronous StrategyType to schedule requests synchronously. + Timing implementation for Poisson-distributed scheduling strategies. + + Schedules requests following a Poisson process with exponentially distributed + inter-arrival times to simulate realistic traffic patterns with random variance. """ - type_: Literal["synchronous"] = "synchronous" # type: ignore[assignment] + rate: float = Field( + description="Target average rate in requests per second", + gt=0, + ) + random_seed: int = Field( + default=42, + description="Seed for random number generator for reproducible behavior", + ) + offset: float = Field( + default=0.0, + description="Base time offset in seconds from scheduler start time", + ) + _requests_count: int = PrivateAttr(0) + _random: random.Random | None = PrivateAttr(None) + + def next_offset(self) -> float: + """ + Calculate the offset for the next request using Poisson distribution. + + :return: The cumulative offset in seconds for the next request + """ + self._requests_count += 1 + + if self._random is None: + self._random = random.Random(self.random_seed) + else: + next_delay = self._random.expovariate(self.rate) + self.offset += next_delay + + return self.offset + + def request_completed(self, request_info: ScheduledRequestInfo): + """ + Handle request completion (no action needed for Poisson rate strategy). + + :param request_info: Information about the completed request (unused) + """ + + +class SchedulingStrategy(PydanticClassRegistryMixin["SchedulingStrategy"], InfoMixin): + """ + Abstract base class for scheduling strategies controlling request processing. + + Defines the interface for strategies that combine timing implementations with + process and concurrency constraints to enable various benchmark scenarios. + """ + + schema_discriminator: ClassVar[str] = "type_" + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[SchedulingStrategy]: + if cls.__name__ == "SchedulingStrategy": + return cls + + return SchedulingStrategy + + type_: Literal["strategy"] = Field( + description="The type of scheduling strategy to schedule requests with", + ) @property - def processing_mode(self) -> Literal["sync"]: + def processes_limit(self) -> int | None: """ - The processing mode for the scheduling strategy, either 'sync' or 'async'. - This property determines how the worker processes are setup: - either to run synchronously with one request at a time or asynchronously. + Get the maximum number of worker processes supported by this strategy. - :return: 'sync' for synchronous scheduling strategy - for the single worker process. + :return: Maximum number of worker processes, None if unlimited """ - return "sync" + return None @property - def processes_limit(self) -> int: + def requests_limit(self) -> int | None: """ - The limit on the number of worker processes for the scheduling strategy. - It determines how many worker processes are created - for the scheduling strategy and must be implemented by subclasses. + Get the maximum number of concurrent requests supported by this strategy. - :return: 1 for the synchronous scheduling strategy to limit - the worker processes to one. + :return: Maximum number of concurrent requests, None if unlimited """ - return 1 + return None + + def create_request_timings( + self, local_rank: int, local_world_size: int, local_max_concurrency: int + ) -> ScheduledRequestTimings: + """ + Create a timing instance to define scheduling behavior for a worker process. + + :param local_rank: The rank of the worker process within local world size + :param local_world_size: Total number of worker processes in local world + :param local_max_concurrency: Maximum concurrent requests for the worker + :return: A ScheduledRequestTimings instance for the worker process + :raises NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError( + "create_worker_timings method must be implemented by subclasses." + ) + + +StrategyT = TypeVar("StrategyT", bound=SchedulingStrategy) + + +@SchedulingStrategy.register("synchronous") +class SynchronousStrategy(SchedulingStrategy): + """ + Sequential request processing strategy with single-process constraint. + + Processes requests one at a time in strict sequential order, providing predictable + timing behavior ideal for measuring maximum sequential throughput and ensuring + request isolation. + """ + + type_: Literal["synchronous"] = "synchronous" # type: ignore[assignment] + + def __str__(self) -> str: + """ + Return string representation of the strategy. + + :return: String identifier for synchronous strategy + """ + return "synchronous" @property - def queued_requests_limit(self) -> int: + def processes_limit(self) -> int | None: """ - The maximum number of queued requests for the scheduling strategy. - It determines how many requests can be queued at one time - for the scheduling strategy and must be implemented by subclasses. + Get maximum number of worker processes for synchronous scheduling. - :return: 1 for the synchronous scheduling strategy to limit - the queued requests to one that is ready to be processed. + :return: Always returns 1 to enforce single-process constraint """ return 1 @property - def processing_requests_limit(self) -> int: + def requests_limit(self) -> int | None: """ - The maximum number of processing requests for the scheduling strategy. - It determines how many requests can be processed at one time - for the scheduling strategy and must be implemented by subclasses. + Get maximum number of concurrent requests for synchronous scheduling. - :return: 1 for the synchronous scheduling strategy to limit - the processing requests to one that is ready to be processed. + :return: Always returns 1 to enforce single-request constraint """ return 1 - def request_times(self) -> Generator[float, None, None]: + def create_request_timings( + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, # noqa: ARG002 + ) -> ScheduledRequestTimings: """ - A generator that yields time.time() so requests are sent immediately, - while scheduling them synchronously. + Create timing implementation for synchronous request scheduling. - :return: A generator that yields time.time() for immediate request scheduling. + :param local_rank: The rank of the worker process (must be 0) + :param local_world_size: Total number of worker processes (must be 1) + :param local_max_concurrency: Maximum concurrent requests (unused) + :return: LastCompletionRequestTimings instance for sequential processing + :raises ValueError: If multiple workers or non-zero rank specified """ - while True: - yield time.time() + if local_world_size > 1 or local_rank != 0: + raise ValueError( + "SynchronousStrategy can only be used with a single worker process." + ) + return LastCompletionRequestTimings() + +@SchedulingStrategy.register("concurrent") class ConcurrentStrategy(SchedulingStrategy): """ - A class representing a concurrent scheduling strategy. - This strategy schedules requests concurrently with the specified - number of streams. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for concurrent scheduling. - - :param type_: The concurrent StrategyType to schedule requests concurrently. - :param streams: The number of concurrent streams to use for scheduling requests. - Each stream runs synchronously with the maximum rate possible. - This must be a positive integer. + Parallel request processing strategy with controlled concurrency limits. + + Enables concurrent request processing up to a specified number of streams, + providing balanced throughput while maintaining predictable resource usage + and completion-based timing coordination. """ type_: Literal["concurrent"] = "concurrent" # type: ignore[assignment] streams: int = Field( - description=( - "The number of concurrent streams to use for scheduling requests. " - "Each stream runs sychronously with the maximum rate possible. " - "This must be a positive integer." - ), + description="Number of concurrent streams for scheduling requests", gt=0, ) + startup_duration: float = Field( + default=0.0, + description="Duration in seconds for distributing startup requests", + ge=0, + ) - @property - def processing_mode(self) -> Literal["sync"]: + def __str__(self) -> str: """ - The processing mode for the scheduling strategy, either 'sync' or 'async'. - This property determines how the worker processes are setup: - either to run synchronously with one request at a time or asynchronously. + Return string representation of the strategy. - :return: 'sync' for synchronous scheduling strategy - for the multiple worker processes equal to streams. + :return: String identifier with stream count """ - return "sync" + return f"concurrent@{self.streams}" @property def processes_limit(self) -> int: """ - The limit on the number of worker processes for the scheduling strategy. - It determines how many worker processes are created - for the scheduling strategy and must be implemented by subclasses. + Get maximum number of worker processes for concurrent scheduling. - :return: {self.streams} for the concurrent scheduling strategy to limit - the worker processes to the number of streams. + :return: Number of streams as maximum worker processes """ return self.streams @property - def queued_requests_limit(self) -> int: + def requests_limit(self) -> int: """ - The maximum number of queued requests for the scheduling strategy. - It determines how many requests can be queued at one time - for the scheduling strategy and must be implemented by subclasses. + Get maximum number of concurrent requests for concurrent scheduling. - :return: {self.streams} for the concurrent scheduling strategy to limit - the queued requests to the number of streams that are ready to be processed. + :return: Number of streams as maximum concurrent requests """ return self.streams - @property - def processing_requests_limit(self) -> int: - """ - The maximum number of processing requests for the scheduling strategy. - It determines how many requests can be processed at one time - for the scheduling strategy and must be implemented by subclasses. - - :return: {self.streams} for the concurrent scheduling strategy to limit - the processing requests to the number of streams that ready to be processed. - """ - return self.streams - - def request_times(self) -> Generator[float, None, None]: - """ - A generator that yields time.time() so requests are sent - immediately, while scheduling them concurrently with the specified - number of streams. + def create_request_timings( + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, # noqa: ARG002 + ) -> LastCompletionRequestTimings: + """ + Create timing implementation for concurrent request scheduling. + + :param local_rank: The rank of the worker process (must be < streams) + :param local_world_size: Total worker processes (must not exceed streams) + :param local_max_concurrency: Maximum concurrent requests (unused) + :return: LastCompletionRequestTimings instance for stream-based processing + :raises ValueError: If worker configuration exceeds stream limits + """ + if local_world_size > self.streams: + raise ValueError( + "ConcurrentStrategy can only be used with up to " + f"{self.streams} worker processes." + ) + + if local_rank >= self.streams: + raise ValueError( + f"Local rank {local_rank} exceeds the number of streams {self.streams}." + ) + + if self.startup_duration > 0: + # Ensure equal global distribution of the start up for concurrent streams + # Ex: for 10 streams, 2 workers, and 8 seconds start up duration, + # the first worker should start at 0.0, 1.6, 3.2, 4.8, 6.4 + # and the second worker should start at 0.8, 2.4, 4.0, 5.6, 7.2 + delay_per_stream = self.startup_duration / self.streams + streams_per_worker = self.streams // local_world_size + + offset = local_rank * streams_per_worker * delay_per_stream + startup_requests = streams_per_worker + ( + 1 + if local_world_size > 1 and local_rank < self.streams % local_world_size + else 0 + ) + startup_requests_delay = delay_per_stream * local_world_size + else: + offset = 0.0 + startup_requests = 0 + startup_requests_delay = 0.0 - :return: A generator that yields time.time() for immediate request scheduling. - """ - while True: - yield time.time() + return LastCompletionRequestTimings( + offset=offset, + startup_requests=startup_requests, + startup_requests_delay=startup_requests_delay, + ) +@SchedulingStrategy.register("throughput") class ThroughputStrategy(SchedulingStrategy): """ - A class representing a throughput scheduling strategy. - This strategy schedules as many requests asynchronously as possible, - with the maximum rate possible. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for throughput scheduling. - - :param type_: The throughput StrategyType to schedule requests asynchronously. + Maximum throughput strategy with optional concurrency limits. + + Schedules requests to maximize system throughput by allowing unlimited concurrent + processing with optional constraints and startup ramping for controlled ramp-up. """ type_: Literal["throughput"] = "throughput" # type: ignore[assignment] - max_concurrency: Optional[int] = Field( + max_concurrency: int | None = Field( default=None, - description=( - "The maximum number of concurrent requests to schedule. " - "If set to None, the concurrency value from settings will be used. " - "This must be a positive integer greater than 0." - ), + description="Maximum number of concurrent requests to schedule", gt=0, ) + startup_duration: float = Field( + default=0.0, + description="Duration in seconds for startup request distribution", + ge=0, + ) - @property - def processing_mode(self) -> Literal["async"]: + def __str__(self) -> str: """ - The processing mode for the scheduling strategy, either 'sync' or 'async'. - This property determines how the worker processes are setup: - either to run synchronously with one request at a time or asynchronously. + Return string representation of the strategy. - :return: 'async' for asynchronous scheduling strategy - for the multiple worker processes handling requests. + :return: String identifier for throughput strategy """ - return "async" + return "throughput" @property - def queued_requests_limit(self) -> int: + def processes_limit(self) -> int | None: """ - The maximum number of queued requests for the scheduling strategy. - It determines how many requests can be queued at one time - for the scheduling strategy and must be implemented by subclasses. + Get maximum number of worker processes for throughput scheduling. - :return: The processing requests limit to ensure that there are enough - requests even for the worst case scenario where the max concurrent - requests are pulled at once for processing. + :return: The max_concurrency value if set, otherwise None for unlimited """ - return self.processing_requests_limit + return self.max_concurrency @property - def processing_requests_limit(self) -> int: + def requests_limit(self) -> int | None: """ - The maximum number of processing requests for the scheduling strategy. - It determines how many requests can be processed at one time - for the scheduling strategy and must be implemented by subclasses. + Get maximum number of concurrent requests for throughput scheduling. - :return: {self.max_concurrency} for the throughput scheduling strategy to limit - the processing requests to the maximum concurrency. - If max_concurrency is None, then the default processing requests limit - will be used. + :return: The max_concurrency value if set, otherwise None for unlimited """ - return self.max_concurrency or super().processing_requests_limit + return self.max_concurrency - def request_times(self) -> Generator[float, None, None]: + def create_request_timings( + self, local_rank: int, local_world_size: int, local_max_concurrency: int + ) -> ScheduledRequestTimings: """ - A generator that yields the start time.time() so requests are sent - immediately, while scheduling as many asynchronously as possible. + Create timing implementation for throughput request scheduling. - :return: A generator that yields the start time.time() - for immediate request scheduling. + :param local_rank: The rank of the worker process + :param local_world_size: Total number of worker processes + :param local_max_concurrency: Maximum concurrent requests for the worker + :return: NoDelayRequestTimings instance for immediate request scheduling """ - start_time = time.time() + if self.startup_duration > 0: + # Vary offset by up to 5% of the startup duration for a bit of variance + offset = 0.05 * self.startup_duration * (local_rank / local_world_size) + # Use local_max_concurrency as the target requests for startup convergence + startup_target_requests = local_max_concurrency + else: + offset = 0.0 + startup_target_requests = 1 - while True: - yield start_time + return NoDelayRequestTimings( + startup_duration=self.startup_duration, + startup_target_requests=startup_target_requests, + offset=offset, + ) +@SchedulingStrategy.register("constant") class AsyncConstantStrategy(ThroughputStrategy): """ - A class representing an asynchronous constant scheduling strategy. - This strategy schedules requests asynchronously at a constant request rate - in requests per second. - If initial_burst is set, it will send an initial burst of math.floor(rate) - requests to reach the target rate. - This is useful to ensure that the target rate is reached quickly - and then maintained. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for asynchronous constant scheduling. - - :param type_: The constant StrategyType to schedule requests asynchronously. - :param rate: The rate at which to schedule requests asynchronously in - requests per second. This must be a positive float. - :param initial_burst: True to send an initial burst of requests - (math.floor(self.rate)) to reach target rate. - False to not send an initial burst. + Asynchronous constant-rate scheduling strategy for predictable load patterns. + + Schedules requests at a fixed rate distributed evenly across worker processes, + providing predictable timing behavior for steady-state load simulation and + consistent system performance measurement. """ type_: Literal["constant"] = "constant" # type: ignore[assignment] rate: float = Field( - description=( - "The rate at which to schedule requests asynchronously in " - "requests per second. This must be a positive float." - ), + description="Rate for scheduling requests asynchronously in requests/second", gt=0, ) - initial_burst: bool = Field( - default=True, - description=( - "True to send an initial burst of requests (math.floor(self.rate)) " - "to reach target rate. False to not send an initial burst." - ), + startup_duration: float = Field( + default=0.0, + description="Duration in seconds for startup request distribution", + ge=0, ) - def request_times(self) -> Generator[float, None, None]: + def __str__(self) -> str: """ - A generator that yields timestamps for when requests should be sent. - This method schedules requests asynchronously at a constant rate - in requests per second. - If burst_time is set, it will send an initial burst of requests - to reach the target rate. - This is useful to ensure that the target rate is reached quickly - and then maintained. + Return string representation of the strategy. - :return: A generator that yields timestamps for request scheduling. + :return: String identifier with rate value """ - start_time = time.time() - constant_increment = 1.0 / self.rate + return f"constant@{self.rate:.2f}" - # handle bursts first to get to the desired rate - if self.initial_burst is not None: - # send an initial burst equal to the rate - # to reach the target rate - burst_count = math.floor(self.rate) - for _ in range(burst_count): - yield start_time - - start_time += constant_increment + def create_request_timings( + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, # noqa: ARG002 + ) -> ScheduledRequestTimings: + """ + Create timing implementation for constant-rate request scheduling. - counter = 0 + :param local_rank: The rank of the worker process + :param local_world_size: Total number of worker processes for rate division + :param local_max_concurrency: Maximum concurrent requests for the worker + :return: ConstantRateRequestTimings instance with per-worker rate + """ + # Divide the rate evenly across all worker processes + worker_rate = self.rate / local_world_size + # Start each worker with an offset to interleave rates + worker_offset = (1 / self.rate) * local_rank - # continue with constant rate after bursting - while True: - yield start_time + constant_increment * counter - counter += 1 + return ConstantRateRequestTimings( + rate=worker_rate, + offset=worker_offset, + ) +@SchedulingStrategy.register("poisson") class AsyncPoissonStrategy(ThroughputStrategy): """ - A class representing an asynchronous Poisson scheduling strategy. - This strategy schedules requests asynchronously at a Poisson request rate - in requests per second. - If initial_burst is set, it will send an initial burst of math.floor(rate) - requests to reach the target rate. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for asynchronous Poisson scheduling. - - :param type_: The Poisson StrategyType to schedule requests asynchronously. - :param rate: The rate at which to schedule requests asynchronously in - requests per second. This must be a positive float. - :param initial_burst: True to send an initial burst of requests - (math.floor(self.rate)) to reach target rate. - False to not send an initial burst. + Asynchronous Poisson-distributed scheduling strategy for realistic load simulation. + + Schedules requests following a Poisson process with exponentially distributed + inter-arrival times, providing realistic simulation of user behavior and network + traffic patterns with random variance around the target rate. """ type_: Literal["poisson"] = "poisson" # type: ignore[assignment] rate: float = Field( - description=( - "The rate at which to schedule requests asynchronously in " - "requests per second. This must be a positive float." - ), + description="Rate for scheduling requests asynchronously in requests/second", gt=0, ) - initial_burst: bool = Field( - default=True, - description=( - "True to send an initial burst of requests (math.floor(self.rate)) " - "to reach target rate. False to not send an initial burst." - ), + startup_duration: float = Field( + default=0.0, + description="Duration in seconds for startup request distribution", + ge=0, ) random_seed: int = Field( default=42, - description=("The random seed to use for the Poisson distribution. "), + description="Random seed to use for Poisson distribution", ) - def request_times(self) -> Generator[float, None, None]: + def __str__(self) -> str: """ - A generator that yields timestamps for when requests should be sent. - This method schedules requests asynchronously at a Poisson rate - in requests per second. - The inter arrival time between requests is exponentially distributed - based on the rate. + Return string representation of the strategy. - :return: A generator that yields timestamps for request scheduling. + :return: String identifier with rate value """ - start_time = time.time() - - if self.initial_burst is not None: - # send an initial burst equal to the rate - # to reach the target rate - burst_count = math.floor(self.rate) - for _ in range(burst_count): - yield start_time - else: - yield start_time - - # set the random seed for reproducibility - rand = random.Random(self.random_seed) # noqa: S311 + return f"poisson@{self.rate:.2f}" - while True: - inter_arrival_time = rand.expovariate(self.rate) - start_time += inter_arrival_time - yield start_time - - -def strategy_display_str(strategy: Union[StrategyType, SchedulingStrategy]) -> str: - strategy_type = strategy if isinstance(strategy, str) else strategy.type_ - strategy_instance = strategy if isinstance(strategy, SchedulingStrategy) else None + def create_request_timings( + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, # noqa: ARG002 + ) -> ScheduledRequestTimings: + """ + Create timing implementation for Poisson-distributed request scheduling. - if strategy_type == "concurrent": - rate = f"@{strategy_instance.streams}" if strategy_instance else "@##" # type: ignore[attr-defined] - elif strategy_type in ("constant", "poisson"): - rate = f"@{strategy_instance.rate:.2f}" if strategy_instance else "@#.##" # type: ignore[attr-defined] - else: - rate = "" + :param local_rank: The rank of the worker process for seed generation + :param local_world_size: Total number of worker processes for rate division + :param local_max_concurrency: Maximum concurrent requests for the worker + :return: PoissonRateRequestTimings instance with per-worker rate and unique seed + """ + # Divide the rate evenly across all worker processes + worker_rate = self.rate / local_world_size + # Use a different seed for each worker to ensure different sequences + worker_seed = self.random_seed + local_rank + # Start each worker with an offset to interleave rates + worker_offset = (1 / self.rate) * local_rank - return f"{strategy_type}{rate}" + return PoissonRateRequestTimings( + rate=worker_rate, + random_seed=worker_seed, + offset=worker_offset, + ) diff --git a/src/guidellm/scheduler/types.py b/src/guidellm/scheduler/types.py deleted file mode 100644 index 42535d71..00000000 --- a/src/guidellm/scheduler/types.py +++ /dev/null @@ -1,7 +0,0 @@ -from typing import TypeVar - -__all__ = ["RequestT", "ResponseT"] - - -RequestT = TypeVar("RequestT") -ResponseT = TypeVar("ResponseT") diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index ab16e4db..fafb6d69 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -25,9 +25,8 @@ StreamingTextResponse, ) from guidellm.request import GenerationRequest -from guidellm.scheduler.result import SchedulerRequestInfo -from guidellm.scheduler.types import RequestT, ResponseT -from guidellm.utils import StandardBaseModel +from guidellm.scheduler.objects import RequestT, ResponseT +from guidellm.utils import StandardBaseDict, StandardBaseModel __all__ = [ "GenerativeRequestsWorker", @@ -53,7 +52,7 @@ class WorkerProcessResult(Generic[RequestT, ResponseT]): type_: Literal["request_scheduled", "request_start", "request_complete"] request: RequestT response: Optional[ResponseT] - info: SchedulerRequestInfo + info: Any @dataclass @@ -142,7 +141,7 @@ async def resolve_scheduler_request( results_queue: multiprocessing.Queue, process_id: int, ): - info = SchedulerRequestInfo( + info = StandardBaseDict( targeted_start_time=start_time, queued_time=queued_time, dequeued_time=dequeued_time, diff --git a/src/guidellm/config.py b/src/guidellm/settings.py similarity index 98% rename from src/guidellm/config.py rename to src/guidellm/settings.py index beda55fc..efeefa71 100644 --- a/src/guidellm/config.py +++ b/src/guidellm/settings.py @@ -133,6 +133,8 @@ class Settings(BaseSettings): max_concurrency: int = 512 max_worker_processes: int = 10 max_add_requests_per_loop: int = 20 + constraint_error_window_size: float = 30 + constraint_error_min_processed: float = 30 # Data settings dataset: DatasetSettings = DatasetSettings() diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index 6c8561ac..47a517b2 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -29,6 +29,7 @@ InterProcessMessagingPipe, InterProcessMessagingQueue, ) +from .mixins import InfoMixin from .pydantic_utils import ( PydanticClassRegistryMixin, ReloadableBaseModel, @@ -65,6 +66,7 @@ "Encoder", "EncodingTypesAlias", "EndlessTextCreator", + "InfoMixin", "IntegerRangeSampler", "InterProcessMessaging", "InterProcessMessagingManagerQueue", diff --git a/src/guidellm/utils/text.py b/src/guidellm/utils/text.py index beebfe37..6c5adbe4 100644 --- a/src/guidellm/utils/text.py +++ b/src/guidellm/utils/text.py @@ -22,7 +22,7 @@ from loguru import logger from guidellm import data as package_data -from guidellm.config import settings +from guidellm.settings import settings from guidellm.utils.colors import Colors __all__ = [ diff --git a/tests/unit/backend/test_openai_backend.py b/tests/unit/backend/test_openai_backend.py index 0a4c2c38..7123c590 100644 --- a/tests/unit/backend/test_openai_backend.py +++ b/tests/unit/backend/test_openai_backend.py @@ -3,7 +3,7 @@ import pytest from guidellm.backend import OpenAIHTTPBackend, ResponseSummary, StreamingTextResponse -from guidellm.config import settings +from guidellm.settings import settings @pytest.mark.smoke diff --git a/tests/unit/backend/test_openai_backend_custom_configs.py b/tests/unit/backend/test_openai_backend_custom_configs.py index 7f6706ad..5855152d 100644 --- a/tests/unit/backend/test_openai_backend_custom_configs.py +++ b/tests/unit/backend/test_openai_backend_custom_configs.py @@ -1,7 +1,7 @@ import pytest from guidellm.backend import OpenAIHTTPBackend -from guidellm.config import settings +from guidellm.settings import settings @pytest.mark.smoke diff --git a/tests/unit/presentation/test_injector.py b/tests/unit/presentation/test_injector.py index cdaa7619..9d97d021 100644 --- a/tests/unit/presentation/test_injector.py +++ b/tests/unit/presentation/test_injector.py @@ -3,8 +3,8 @@ import pytest from pydantic import BaseModel -from guidellm.config import settings from guidellm.presentation.injector import create_report, inject_data +from guidellm.settings import settings class ExampleModel(BaseModel): diff --git a/tests/unit/scheduler/__init__.py b/tests/unit/scheduler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/scheduler/test_constraints.py b/tests/unit/scheduler/test_constraints.py new file mode 100644 index 00000000..00d279d4 --- /dev/null +++ b/tests/unit/scheduler/test_constraints.py @@ -0,0 +1,1410 @@ +import inspect +import random +import time +from abc import ABC +from typing import Protocol + +import pytest +from pydantic import ValidationError + +from guidellm.scheduler import ( + Constraint, + ConstraintInitializer, + ConstraintsInitializerFactory, + MaxDurationConstraint, + MaxErrorRateConstraint, + MaxErrorsConstraint, + MaxGlobalErrorRateConstraint, + MaxNumberConstraint, + PydanticConstraintInitializer, + ScheduledRequestInfo, + SchedulerState, + SchedulerUpdateAction, + SerializableConstraintInitializer, + UnserializableConstraintInitializer, +) +from guidellm.utils import InfoMixin, StandardBaseModel + + +class TestConstraint: + """Test the Constraint protocol.""" + + @pytest.mark.smoke + def test_is_protocol(self): + """Test that Constraint is a protocol and runtime checkable.""" + assert issubclass(Constraint, Protocol) + assert hasattr(Constraint, "_is_protocol") + assert Constraint._is_protocol is True + assert hasattr(Constraint, "_is_runtime_protocol") + assert Constraint._is_runtime_protocol is True + + @pytest.mark.smoke + def test_protocol_method_signature(self): + """Test that the Constraint protocol has the correct method signature.""" + call_method = Constraint.__call__ + sig = inspect.signature(call_method) + + expected_params = ["self", "state", "request"] + assert list(sig.parameters.keys()) == expected_params + + params = sig.parameters + assert "state" in params + assert "request" in params + + @pytest.mark.smoke + def test_runtime_is_constraint(self): + """Test that Constraint can be checked at runtime using isinstance.""" + + class ValidConstraint: + def __call__( + self, + state: SchedulerState, + request: ScheduledRequestInfo, + ) -> SchedulerUpdateAction: + return SchedulerUpdateAction() + + valid_instance = ValidConstraint() + assert isinstance(valid_instance, Constraint) + + class InvalidConstraint: + pass + + invalid_instance = InvalidConstraint() + assert not isinstance(invalid_instance, Constraint) + + @pytest.mark.smoke + def test_runtime_is_not_intializer(self): + """ + Test that a class not implementing the ConstraintInitializer + protocol is not recognized as such. + """ + + class ValidConstraint: + def __call__( + self, + state: SchedulerState, + request: ScheduledRequestInfo, + ) -> SchedulerUpdateAction: + return SchedulerUpdateAction() + + not_initializer_instance = ValidConstraint() + assert not isinstance(not_initializer_instance, ConstraintInitializer) + + +class TestConstraintInitializer: + """Test the ConstraintInitializer protocol.""" + + @pytest.mark.smoke + def test_is_protocol(self): + """Test that ConstraintInitializer is a protocol and runtime checkable.""" + assert issubclass(ConstraintInitializer, Protocol) + assert hasattr(ConstraintInitializer, "_is_protocol") + assert ConstraintInitializer._is_protocol is True + assert hasattr(ConstraintInitializer, "_is_runtime_protocol") + assert ConstraintInitializer._is_runtime_protocol is True + + @pytest.mark.smoke + def test_protocol_method_signature(self): + """Test that ConstraintInitializer protocol has correct method signature.""" + create_constraint_method = ConstraintInitializer.create_constraint + sig = inspect.signature(create_constraint_method) + + expected_params = ["self", "kwargs"] + assert list(sig.parameters.keys()) == expected_params + kwargs_param = sig.parameters["kwargs"] + assert kwargs_param.kind == kwargs_param.VAR_KEYWORD + + @pytest.mark.smoke + def test_runtime_is_initializer(self): + """Test that ConstraintInitializer can be checked at runtime.""" + + class ValidInitializer: + def create_constraint(self, **kwargs) -> Constraint: + class SimpleConstraint: + def __call__( + self, + state: SchedulerState, + request: ScheduledRequestInfo, + ) -> SchedulerUpdateAction: + return SchedulerUpdateAction() + + return SimpleConstraint() + + valid_instance = ValidInitializer() + assert isinstance(valid_instance, ConstraintInitializer) + + @pytest.mark.smoke + def test_runtime_is_not_constraint(self): + """ + Test that a class not implementing the Constraint protocol + is not recognized as such. + """ + + class ValidInitializer: + def create_constraint(self, **kwargs) -> Constraint: + class SimpleConstraint: + def __call__( + self, + state: SchedulerState, + request: ScheduledRequestInfo, + ) -> SchedulerUpdateAction: + return SchedulerUpdateAction() + + return SimpleConstraint() + + not_constraint_instance = ValidInitializer() + assert not isinstance(not_constraint_instance, Constraint) + + +class TestSerializableConstraintInitializer: + """Test the SerializableConstraintInitializer protocol.""" + + @pytest.mark.smoke + def test_is_protocol(self): + """Test SerializableConstraintInitializer is a protocol and checkable.""" + assert issubclass(SerializableConstraintInitializer, Protocol) + assert hasattr(SerializableConstraintInitializer, "_is_protocol") + assert SerializableConstraintInitializer._is_protocol is True + assert hasattr(SerializableConstraintInitializer, "_is_runtime_protocol") + assert SerializableConstraintInitializer._is_runtime_protocol is True + + @pytest.mark.smoke + def test_protocol_method_signatures(self): + """Test SerializableConstraintInitializer protocol has correct signatures.""" + methods = [ + "validated_kwargs", + "model_validate", + "model_dump", + "create_constraint", + ] + + for method_name in methods: + assert hasattr(SerializableConstraintInitializer, method_name) + + @pytest.mark.smoke + def test_runtime_is_serializable_initializer(self): + """Test that SerializableConstraintInitializer can be checked at runtime.""" + + class ValidSerializableInitializer: + @classmethod + def validated_kwargs(cls, *args, **kwargs): + return kwargs + + @classmethod + def model_validate(cls, **kwargs): + return cls() + + def model_dump(self): + return {} + + def create_constraint(self, **kwargs): + class SimpleConstraint: + def __call__(self, state, request): + return SchedulerUpdateAction() + + return SimpleConstraint() + + valid_instance = ValidSerializableInitializer() + assert isinstance(valid_instance, SerializableConstraintInitializer) + + +class TestPydanticConstraintInitializer: + """Test the PydanticConstraintInitializer implementation.""" + + @pytest.mark.smoke + def test_class_signatures(self): + """Test PydanticConstraintInitializer inheritance and abstract methods.""" + assert issubclass(PydanticConstraintInitializer, StandardBaseModel) + assert issubclass(PydanticConstraintInitializer, ABC) + assert issubclass(PydanticConstraintInitializer, InfoMixin) + + @pytest.mark.smoke + def test_abstract_methods(self): + """Test that PydanticConstraintInitializer has required abstract methods.""" + abstract_methods = PydanticConstraintInitializer.__abstractmethods__ + expected_methods = {"validated_kwargs", "create_constraint"} + assert abstract_methods == expected_methods + + @pytest.mark.sanity + def test_cannot_instantiate_directly(self): + """Test that PydanticConstraintInitializer cannot be instantiated directly.""" + with pytest.raises(TypeError): + PydanticConstraintInitializer(type_="test") + + +class TestUnserializableConstraintInitializer: + """Test the UnserializableConstraintInitializer implementation.""" + + @pytest.fixture( + params=[ + {"orig_info": {}}, + {"orig_info": {"class": "SomeClass", "module": "some.module"}}, + ] + ) + def valid_instances(self, request): + """Fixture providing test data for UnserializableConstraintInitializer.""" + constructor_args = request.param + instance = UnserializableConstraintInitializer(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test UnserializableConstraintInitializer inheritance.""" + assert issubclass( + UnserializableConstraintInitializer, PydanticConstraintInitializer + ) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test UnserializableConstraintInitializer initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, UnserializableConstraintInitializer) + assert instance.type_ == "unserializable" + assert instance.orig_info == constructor_args["orig_info"] + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test validated_kwargs class method.""" + result = UnserializableConstraintInitializer.validated_kwargs( + orig_info={"test": "data"} + ) + assert result == {"orig_info": {"test": "data"}} + + result = UnserializableConstraintInitializer.validated_kwargs() + assert result == {"orig_info": {}} + + @pytest.mark.sanity + def test_create_constraint_raises(self, valid_instances): + """Test that create_constraint raises RuntimeError.""" + instance, _ = valid_instances + with pytest.raises( + RuntimeError, match="Cannot create constraint from unserializable" + ): + instance.create_constraint() + + @pytest.mark.sanity + def test_call_raises(self, valid_instances): + """Test that calling constraint raises RuntimeError.""" + instance, _ = valid_instances + state = SchedulerState() + request = ScheduledRequestInfo() + + with pytest.raises( + RuntimeError, match="Cannot invoke unserializable constraint" + ): + instance(state, request) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test UnserializableConstraintInitializer serialization/deserialization.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert data["type_"] == "unserializable" + assert data["orig_info"] == constructor_args["orig_info"] + + reconstructed = UnserializableConstraintInitializer.model_validate(data) + assert reconstructed.type_ == instance.type_ + assert reconstructed.orig_info == instance.orig_info + + +class TestMaxNumberConstraint: + """Test the MaxNumberConstraint implementation.""" + + @pytest.fixture(params=[{"max_num": 100}, {"max_num": 50.5}, {"max_num": 1}]) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxNumberConstraint(**constructor_args) + + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxNumberConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self, valid_instances): + """Test MaxNumberConstraint satisfies the ConstraintInitializer protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that MaxNumberConstraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxNumberConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxNumberConstraint() + with pytest.raises(ValidationError): + MaxNumberConstraint(max_num=-1) + with pytest.raises(ValidationError): + MaxNumberConstraint(max_num=0) + with pytest.raises(ValidationError): + MaxNumberConstraint(max_num="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions and progress""" + instance, constructor_args = valid_instances + start_time = time.time() + + for num_requests in range(0, int(constructor_args["max_num"]) * 2 + 1, 1): + state = SchedulerState( + start_time=start_time, + created_requests=num_requests, + processed_requests=num_requests, + errored_requests=0, + ) + request_info = ScheduledRequestInfo(request_id="test", status="completed") + + action = instance(state, request_info) + assert isinstance(action, SchedulerUpdateAction) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxNumberConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxNumberConstraint.model_validate(data) + assert reconstructed.max_num == instance.max_num + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + @pytest.mark.smoke + def test_create_constraint_functionality(self, valid_instances): + """Test the constraint initializer functionality.""" + instance, constructor_args = valid_instances + + constraint = instance.create_constraint() + assert isinstance(constraint, MaxNumberConstraint) + assert constraint.max_num == constructor_args["max_num"] + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test MaxNumberConstraint.validated_kwargs class method.""" + result = MaxNumberConstraint.validated_kwargs(max_num=100) + assert result == {"max_num": 100, "current_index": -1} + + result = MaxNumberConstraint.validated_kwargs(50.5) + assert result == {"max_num": 50.5, "current_index": -1} + + @pytest.mark.smoke + def test_create_constraint(self, valid_instances): + """Test MaxNumberConstraint.create_constraint method.""" + instance, constructor_args = valid_instances + original_index = instance.current_index + constraint = instance.create_constraint() + + assert isinstance(constraint, MaxNumberConstraint) + assert constraint is not instance # Should return a copy + assert constraint.max_num == instance.max_num + assert instance.current_index == original_index + 1 # Original is incremented + assert constraint.current_index == original_index + 1 # Copy has incremented + + @pytest.mark.smoke + def test_factory_registration(self): + """Test MaxNumberConstraint is properly registered with expected aliases.""" + expected_aliases = ["max_number", "max_num", "max_requests", "max_req"] + + for alias in expected_aliases: + assert ConstraintsInitializerFactory.is_registered(alias) + registered_class = ConstraintsInitializerFactory.get_registered_object( + alias + ) + assert registered_class == MaxNumberConstraint + + @pytest.mark.smoke + @pytest.mark.parametrize( + "alias", ["max_number", "max_num", "max_requests", "max_req"] + ) + def test_factory_creation_with_aliases(self, alias): + """Test factory creation using different aliases.""" + # Test with dict configuration + constraint = ConstraintsInitializerFactory.create_constraint(alias, max_num=100) + assert isinstance(constraint, MaxNumberConstraint) + assert constraint.max_num == 100 + + # Test with simple value + constraint = ConstraintsInitializerFactory.create_constraint(alias, 50) + assert isinstance(constraint, MaxNumberConstraint) + assert constraint.max_num == 50 + + @pytest.mark.smoke + def test_factory_resolve_methods(self): + """Test factory resolve methods with various input formats.""" + # Test with dict config + resolved = ConstraintsInitializerFactory.resolve( + {"max_number": {"max_num": 200}} + ) + assert isinstance(resolved["max_number"], MaxNumberConstraint) + assert resolved["max_number"].max_num == 200 + + # Test with simple value + resolved = ConstraintsInitializerFactory.resolve({"max_num": 150}) + assert isinstance(resolved["max_num"], MaxNumberConstraint) + assert resolved["max_num"].max_num == 150 + + # Test with instance + instance = MaxNumberConstraint(max_num=75) + resolved = ConstraintsInitializerFactory.resolve({"max_requests": instance}) + assert resolved["max_requests"] is instance + + +class TestMaxDurationConstraint: + """Test the MaxDurationConstraint implementation.""" + + @pytest.fixture( + params=[{"max_duration": 2.0}, {"max_duration": 1}, {"max_duration": 0.5}] + ) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxDurationConstraint(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxDurationConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxDurationConstraint also satisfies + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that MaxDurationConstraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxDurationConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxDurationConstraint() + with pytest.raises(ValidationError): + MaxDurationConstraint(max_duration=-1) + with pytest.raises(ValidationError): + MaxDurationConstraint(max_duration=0) + with pytest.raises(ValidationError): + MaxDurationConstraint(max_duration="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions and progress through a time loop""" + instance, constructor_args = valid_instances + start_time = time.time() + + max_duration = constructor_args["max_duration"] + sleep_interval = max_duration * 0.05 + target_duration = max_duration * 1.5 + + elapsed = 0.0 + step = 0 + + while elapsed <= target_duration: + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=step + 1, + processed_requests=step, + ) + request = ScheduledRequestInfo( + request_id=f"test-{step}", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + + duration_exceeded = elapsed >= max_duration + + if not duration_exceeded: + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + else: + assert action.request_queuing == "stop" + assert action.request_processing == "stop_local" + assert isinstance(action.metadata, dict) + assert action.metadata["max_duration"] == max_duration + assert action.metadata["elapsed_time"] == pytest.approx(elapsed, abs=0.01) + assert action.metadata["duration_exceeded"] == duration_exceeded + assert action.metadata["start_time"] == start_time + assert isinstance(action.progress, dict) + expected_remaining_fraction = max(0.0, 1.0 - elapsed / max_duration) + expected_remaining_duration = max(0.0, max_duration - elapsed) + assert action.progress["remaining_fraction"] == pytest.approx( + expected_remaining_fraction, abs=0.1 + ) + assert action.progress["remaining_duration"] == pytest.approx( + expected_remaining_duration, abs=0.1 + ) + time.sleep(sleep_interval) + elapsed = time.time() - start_time + step += 1 + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxDurationConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxDurationConstraint.model_validate(data) + assert reconstructed.max_duration == instance.max_duration + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + @pytest.mark.smoke + def test_create_constraint_functionality(self, valid_instances): + """Test the constraint initializer functionality.""" + instance, constructor_args = valid_instances + + constraint = instance.create_constraint() + assert isinstance(constraint, MaxDurationConstraint) + assert constraint.max_duration == constructor_args["max_duration"] + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test MaxDurationConstraint.validated_kwargs class method.""" + result = MaxDurationConstraint.validated_kwargs(max_duration=60.0) + assert result == {"max_duration": 60.0, "current_index": -1} + + result = MaxDurationConstraint.validated_kwargs(30) + assert result == {"max_duration": 30, "current_index": -1} + + @pytest.mark.smoke + def test_create_constraint(self, valid_instances): + """Test MaxDurationConstraint.create_constraint method.""" + instance, constructor_args = valid_instances + original_index = instance.current_index + constraint = instance.create_constraint() + + assert isinstance(constraint, MaxDurationConstraint) + assert constraint is not instance # Should return a copy + assert constraint.max_duration == instance.max_duration + assert instance.current_index == original_index + 1 # Original is incremented + assert constraint.current_index == original_index + 1 # Copy has incremented + + @pytest.mark.smoke + def test_factory_registration(self): + """Test MaxDurationConstraint is properly registered with expected aliases.""" + expected_aliases = [ + "max_duration", + "max_dur", + "max_sec", + "max_seconds", + "max_min", + "max_minutes", + ] + + for alias in expected_aliases: + assert ConstraintsInitializerFactory.is_registered(alias) + registered_class = ConstraintsInitializerFactory.get_registered_object( + alias + ) + assert registered_class == MaxDurationConstraint + + @pytest.mark.smoke + @pytest.mark.parametrize( + "alias", + ["max_duration", "max_dur", "max_sec", "max_seconds", "max_min", "max_minutes"], + ) + def test_factory_creation_with_aliases(self, alias): + """Test factory creation using different aliases.""" + # Test with dict configuration + constraint = ConstraintsInitializerFactory.create_constraint( + alias, max_duration=60.0 + ) + assert isinstance(constraint, MaxDurationConstraint) + assert constraint.max_duration == 60.0 + + # Test with simple value + constraint = ConstraintsInitializerFactory.create_constraint(alias, 30.0) + assert isinstance(constraint, MaxDurationConstraint) + assert constraint.max_duration == 30.0 + + @pytest.mark.smoke + def test_factory_resolve_methods(self): + """Test factory resolve methods with various input formats.""" + # Test with dict config + resolved = ConstraintsInitializerFactory.resolve( + {"max_duration": {"max_duration": 120.0}} + ) + assert isinstance(resolved["max_duration"], MaxDurationConstraint) + assert resolved["max_duration"].max_duration == 120.0 + + # Test with simple value + resolved = ConstraintsInitializerFactory.resolve({"max_sec": 90.0}) + assert isinstance(resolved["max_sec"], MaxDurationConstraint) + assert resolved["max_sec"].max_duration == 90.0 + + # Test with instance + instance = MaxDurationConstraint(max_duration=45.0) + resolved = ConstraintsInitializerFactory.resolve({"max_minutes": instance}) + assert resolved["max_minutes"] is instance + + +class TestMaxErrorsConstraint: + """Test the MaxErrorsConstraint implementation.""" + + @pytest.fixture(params=[{"max_errors": 10}, {"max_errors": 5.5}, {"max_errors": 1}]) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxErrorsConstraint(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxErrorsConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxErrorsConstraint also satisfies + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that MaxErrorsConstraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxErrorsConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxErrorsConstraint() + with pytest.raises(ValidationError): + MaxErrorsConstraint(max_errors=-1) + with pytest.raises(ValidationError): + MaxErrorsConstraint(max_errors=0) + with pytest.raises(ValidationError): + MaxErrorsConstraint(max_errors="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions""" + instance, constructor_args = valid_instances + start_time = time.time() + + for num_errors in range(int(constructor_args["max_errors"] * 2)): + created_requests = (num_errors + 1) * 2 + processed_requests = num_errors + 1 + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=created_requests, + processed_requests=processed_requests, + errored_requests=num_errors, + ) + request = ScheduledRequestInfo( + request_id=f"test-{num_errors}", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + errors_exceeded = num_errors >= constructor_args["max_errors"] + if not errors_exceeded: + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + else: + assert action.request_queuing == "stop" + assert action.request_processing == "stop_all" + + assert isinstance(action.metadata, dict) + assert action.metadata == { + "max_errors": constructor_args["max_errors"], + "errors_exceeded": errors_exceeded, + "current_errors": num_errors, + } + assert action.progress == {} + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxErrorsConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxErrorsConstraint.model_validate(data) + assert reconstructed.max_errors == instance.max_errors + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test MaxErrorsConstraint.validated_kwargs class method.""" + result = MaxErrorsConstraint.validated_kwargs(max_errors=10) + assert result == {"max_errors": 10, "current_index": -1} + + result = MaxErrorsConstraint.validated_kwargs(5.5) + assert result == {"max_errors": 5.5, "current_index": -1} + + @pytest.mark.smoke + def test_create_constraint(self, valid_instances): + """Test MaxErrorsConstraint.create_constraint method.""" + instance, constructor_args = valid_instances + original_index = instance.current_index + constraint = instance.create_constraint() + + assert isinstance(constraint, MaxErrorsConstraint) + assert constraint is not instance + assert constraint.max_errors == instance.max_errors + assert instance.current_index == original_index + 1 + assert constraint.current_index == original_index + 1 + + @pytest.mark.smoke + def test_factory_registration(self): + """Test MaxErrorsConstraint is properly registered with expected aliases.""" + expected_aliases = ["max_errors", "max_err", "max_error", "max_errs"] + + for alias in expected_aliases: + assert ConstraintsInitializerFactory.is_registered(alias) + registered_class = ConstraintsInitializerFactory.get_registered_object( + alias + ) + assert registered_class == MaxErrorsConstraint + + @pytest.mark.smoke + @pytest.mark.parametrize( + "alias", ["max_errors", "max_err", "max_error", "max_errs"] + ) + def test_factory_creation_with_aliases(self, alias): + """Test factory creation using different aliases.""" + # Test with dict configuration + constraint = ConstraintsInitializerFactory.create_constraint( + alias, max_errors=10 + ) + assert isinstance(constraint, MaxErrorsConstraint) + assert constraint.max_errors == 10 + + # Test with simple value + constraint = ConstraintsInitializerFactory.create_constraint(alias, 5) + assert isinstance(constraint, MaxErrorsConstraint) + assert constraint.max_errors == 5 + + @pytest.mark.smoke + def test_factory_resolve_methods(self): + """Test factory resolve methods with various input formats.""" + # Test with dict config + resolved = ConstraintsInitializerFactory.resolve( + {"max_errors": {"max_errors": 15}} + ) + assert isinstance(resolved["max_errors"], MaxErrorsConstraint) + assert resolved["max_errors"].max_errors == 15 + + # Test with simple value + resolved = ConstraintsInitializerFactory.resolve({"max_err": 8}) + assert isinstance(resolved["max_err"], MaxErrorsConstraint) + assert resolved["max_err"].max_errors == 8 + + # Test with instance + instance = MaxErrorsConstraint(max_errors=3) + resolved = ConstraintsInitializerFactory.resolve({"max_error": instance}) + assert resolved["max_error"] is instance + + +class TestMaxErrorRateConstraint: + """Test the MaxErrorRateConstraint implementation.""" + + @pytest.fixture( + params=[ + {"max_error_rate": 0.1, "window_size": 40}, + {"max_error_rate": 0.5, "window_size": 50}, + {"max_error_rate": 0.05, "window_size": 55}, + ] + ) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxErrorRateConstraint(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxErrorRateConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxErrorRateConstraint also satisfies + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that MaxErrorRateConstraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxErrorRateConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxErrorRateConstraint() + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate=0) + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate=-1) + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate=1.5) + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate=0.5, window_size=0) + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions with sliding window behavior""" + instance, constructor_args = valid_instances + start_time = time.time() + + max_error_rate = constructor_args["max_error_rate"] + window_size = constructor_args["window_size"] + safety_factor = 1.5 + total_errors = 0 + error_window = [] + + for request_num in range(window_size * 2): + error_probability = max_error_rate * safety_factor + + if random.random() < error_probability: + total_errors += 1 + status = "errored" + error_window.append(1) + else: + status = "completed" + error_window.append(0) + error_window = ( + error_window[-window_size:] + if len(error_window) > window_size + else error_window + ) + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=request_num + 1, + processed_requests=request_num + 1, + ) + request = ScheduledRequestInfo( + request_id=f"test-{request_num}", + status=status, + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + error_count = sum(instance.error_window) + processed_requests = state.processed_requests + exceeded_min_processed = processed_requests >= window_size + current_error_rate = ( + error_count / float(min(processed_requests, window_size)) + if processed_requests > 0 + else 0.0 + ) + exceeded_error_rate = current_error_rate >= max_error_rate + should_stop = exceeded_min_processed and exceeded_error_rate + expected_queuing = "stop" if should_stop else "continue" + expected_processing = "stop_all" if should_stop else "continue" + + assert action.request_queuing == expected_queuing + assert action.request_processing == expected_processing + assert isinstance(action.metadata, dict) + assert action.metadata["max_error_rate"] == max_error_rate + assert action.metadata["window_size"] == window_size + assert action.metadata["error_count"] == error_count + assert action.metadata["current_error_rate"] == current_error_rate + assert action.metadata["exceeded_error_rate"] == exceeded_error_rate + assert action.progress == {} + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxErrorRateConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxErrorRateConstraint.model_validate(data) + assert reconstructed.max_error_rate == instance.max_error_rate + assert reconstructed.window_size == instance.window_size + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test MaxErrorRateConstraint.validated_kwargs class method.""" + result = MaxErrorRateConstraint.validated_kwargs( + max_error_rate=0.1, window_size=50 + ) + assert result == { + "max_error_rate": 0.1, + "window_size": 50, + "error_window": [], + "current_index": -1, + } + + result = MaxErrorRateConstraint.validated_kwargs(0.05) + assert result == { + "max_error_rate": 0.05, + "window_size": 30, + "error_window": [], + "current_index": -1, + } + + @pytest.mark.smoke + def test_create_constraint(self, valid_instances): + """Test MaxErrorRateConstraint.create_constraint method.""" + instance, constructor_args = valid_instances + original_index = instance.current_index + constraint = instance.create_constraint() + + assert isinstance(constraint, MaxErrorRateConstraint) + assert constraint is not instance # Should return a copy + assert constraint.max_error_rate == instance.max_error_rate + assert constraint.window_size == instance.window_size + assert instance.current_index == original_index + 1 # Original is incremented + assert constraint.current_index == original_index + 1 # Copy has incremented + + @pytest.mark.smoke + def test_factory_registration(self): + """Test MaxErrorRateConstraint is properly registered with expected aliases.""" + expected_aliases = ["max_error_rate", "max_err_rate", "max_errors_rate"] + + for alias in expected_aliases: + assert ConstraintsInitializerFactory.is_registered(alias) + registered_class = ConstraintsInitializerFactory.get_registered_object( + alias + ) + assert registered_class == MaxErrorRateConstraint + + @pytest.mark.smoke + @pytest.mark.parametrize( + "alias", ["max_error_rate", "max_err_rate", "max_errors_rate"] + ) + def test_factory_creation_with_aliases(self, alias): + """Test factory creation using different aliases.""" + # Test with dict configuration + constraint = ConstraintsInitializerFactory.create_constraint( + alias, max_error_rate=0.1, window_size=50 + ) + assert isinstance(constraint, MaxErrorRateConstraint) + assert constraint.max_error_rate == 0.1 + assert constraint.window_size == 50 + + # Test with simple value + constraint = ConstraintsInitializerFactory.create_constraint(alias, 0.05) + assert isinstance(constraint, MaxErrorRateConstraint) + assert constraint.max_error_rate == 0.05 + + @pytest.mark.smoke + def test_factory_resolve_methods(self): + """Test factory resolve methods with various input formats.""" + # Test with dict config + resolved = ConstraintsInitializerFactory.resolve( + {"max_error_rate": {"max_error_rate": 0.15, "window_size": 100}} + ) + assert isinstance(resolved["max_error_rate"], MaxErrorRateConstraint) + assert resolved["max_error_rate"].max_error_rate == 0.15 + assert resolved["max_error_rate"].window_size == 100 + + # Test with simple value + resolved = ConstraintsInitializerFactory.resolve({"max_err_rate": 0.08}) + assert isinstance(resolved["max_err_rate"], MaxErrorRateConstraint) + assert resolved["max_err_rate"].max_error_rate == 0.08 + + # Test with instance + instance = MaxErrorRateConstraint(max_error_rate=0.2, window_size=25) + resolved = ConstraintsInitializerFactory.resolve({"max_errors_rate": instance}) + assert resolved["max_errors_rate"] is instance + + +class TestMaxGlobalErrorRateConstraint: + """Test the MaxGlobalErrorRateConstraint implementation.""" + + @pytest.fixture( + params=[ + {"max_error_rate": 0.1, "min_processed": 50}, + {"max_error_rate": 0.2, "min_processed": 100}, + {"max_error_rate": 0.05, "min_processed": 31}, + ] + ) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxGlobalErrorRateConstraint(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxGlobalErrorRateConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxGlobalErrorRateConstraint also satisfies + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """ + Test that MaxGlobalErrorRateConstraint can be initialized + with valid parameters. + """ + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxGlobalErrorRateConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint() + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate=0) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate=-1) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate=1.5) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate=0.5, min_processed=0) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions based on global error rate""" + instance, constructor_args = valid_instances + start_time = time.time() + + max_error_rate = constructor_args["max_error_rate"] + min_processed = constructor_args["min_processed"] + safety_factor = 1.5 + total_requests = min_processed * 2 + total_errors = 0 + + for request_num in range(total_requests): + error_probability = max_error_rate * safety_factor + + if random.random() < error_probability: + total_errors += 1 + status = "errored" + else: + status = "completed" + + processed_requests = request_num + 1 + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=processed_requests + 10, + processed_requests=processed_requests, + errored_requests=total_errors, + ) + request = ScheduledRequestInfo( + request_id=f"test-{request_num}", + status=status, + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + + exceeded_min_processed = processed_requests >= min_processed + error_rate = ( + total_errors / float(processed_requests) + if processed_requests > 0 + else 0.0 + ) + exceeded_error_rate = error_rate >= max_error_rate + should_stop = exceeded_min_processed and exceeded_error_rate + + expected_queuing = "stop" if should_stop else "continue" + expected_processing = "stop_all" if should_stop else "continue" + + assert action.request_queuing == expected_queuing + assert action.request_processing == expected_processing + + assert isinstance(action.metadata, dict) + assert action.metadata == { + "max_error_rate": max_error_rate, + "min_processed": min_processed, + "processed_requests": processed_requests, + "errored_requests": total_errors, + "error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + } + + # Error constraints don't provide progress information + assert action.progress == {} + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxGlobalErrorRateConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxGlobalErrorRateConstraint.model_validate(data) + assert reconstructed.max_error_rate == instance.max_error_rate + assert reconstructed.min_processed == instance.min_processed + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test MaxGlobalErrorRateConstraint.validated_kwargs class method.""" + result = MaxGlobalErrorRateConstraint.validated_kwargs( + max_error_rate=0.1, min_processed=50 + ) + assert result == { + "max_error_rate": 0.1, + "min_processed": 50, + "current_index": -1, + } + + result = MaxGlobalErrorRateConstraint.validated_kwargs(0.05) + assert result == { + "max_error_rate": 0.05, + "min_processed": 30, + "current_index": -1, + } + + @pytest.mark.smoke + def test_create_constraint(self, valid_instances): + """Test MaxGlobalErrorRateConstraint.create_constraint method.""" + instance, constructor_args = valid_instances + original_index = instance.current_index + constraint = instance.create_constraint() + + assert isinstance(constraint, MaxGlobalErrorRateConstraint) + assert constraint is not instance # Should return a copy + assert constraint.max_error_rate == instance.max_error_rate + assert constraint.min_processed == instance.min_processed + assert instance.current_index == original_index + 1 # Original is incremented + assert constraint.current_index == original_index + 1 # Copy has incremented + + @pytest.mark.smoke + def test_factory_registration(self): + """Test MaxGlobalErrorRateConstraint is properly registered with aliases.""" + expected_aliases = [ + "max_global_error_rate", + "max_global_err_rate", + "max_global_errors_rate", + ] + + for alias in expected_aliases: + assert ConstraintsInitializerFactory.is_registered(alias) + registered_class = ConstraintsInitializerFactory.get_registered_object( + alias + ) + assert registered_class == MaxGlobalErrorRateConstraint + + @pytest.mark.smoke + @pytest.mark.parametrize( + "alias", + ["max_global_error_rate", "max_global_err_rate", "max_global_errors_rate"], + ) + def test_factory_creation_with_aliases(self, alias): + """Test factory creation using different aliases.""" + # Test with dict configuration + constraint = ConstraintsInitializerFactory.create_constraint( + alias, max_error_rate=0.1, min_processed=50 + ) + assert isinstance(constraint, MaxGlobalErrorRateConstraint) + assert constraint.max_error_rate == 0.1 + assert constraint.min_processed == 50 + + # Test with simple value + constraint = ConstraintsInitializerFactory.create_constraint(alias, 0.05) + assert isinstance(constraint, MaxGlobalErrorRateConstraint) + assert constraint.max_error_rate == 0.05 + + @pytest.mark.smoke + def test_factory_resolve_methods(self): + """Test factory resolve methods with various input formats.""" + # Test with dict config + resolved = ConstraintsInitializerFactory.resolve( + {"max_global_error_rate": {"max_error_rate": 0.12, "min_processed": 100}} + ) + assert isinstance( + resolved["max_global_error_rate"], MaxGlobalErrorRateConstraint + ) + assert resolved["max_global_error_rate"].max_error_rate == 0.12 + assert resolved["max_global_error_rate"].min_processed == 100 + + # Test with simple value + resolved = ConstraintsInitializerFactory.resolve({"max_global_err_rate": 0.08}) + assert isinstance(resolved["max_global_err_rate"], MaxGlobalErrorRateConstraint) + assert resolved["max_global_err_rate"].max_error_rate == 0.08 + + # Test with instance + instance = MaxGlobalErrorRateConstraint(max_error_rate=0.15, min_processed=75) + resolved = ConstraintsInitializerFactory.resolve( + {"max_global_errors_rate": instance} + ) + assert resolved["max_global_errors_rate"] is instance + + +class TestConstraintsInitializerFactory: + """Test the ConstraintsInitializerFactory implementation.""" + + @pytest.mark.sanity + def test_unregistered_key_fails(self): + """Test that unregistered keys raise ValueError.""" + unregistered_key = "nonexistent_constraint" + assert not ConstraintsInitializerFactory.is_registered(unregistered_key) + + with pytest.raises( + ValueError, match=f"Unknown constraint initializer key: {unregistered_key}" + ): + ConstraintsInitializerFactory.create(unregistered_key) + + with pytest.raises( + ValueError, match=f"Unknown constraint initializer key: {unregistered_key}" + ): + ConstraintsInitializerFactory.create_constraint(unregistered_key) + + @pytest.mark.smoke + def test_resolve_mixed_types(self): + """Test resolve method with mixed constraint types.""" + max_num_constraint = MaxNumberConstraint(max_num=25) + max_duration_initializer = MaxDurationConstraint(max_duration=120.0) + + mixed_spec = { + "max_number": max_num_constraint, + "max_duration": max_duration_initializer, + "max_errors": {"max_errors": 15}, + "max_error_rate": 0.08, + } + + resolved = ConstraintsInitializerFactory.resolve(mixed_spec) + + assert len(resolved) == 4 + assert all(isinstance(c, Constraint) for c in resolved.values()) + assert resolved["max_number"] is max_num_constraint + assert isinstance(resolved["max_duration"], MaxDurationConstraint) + assert isinstance(resolved["max_errors"], MaxErrorsConstraint) + assert isinstance(resolved["max_error_rate"], MaxErrorRateConstraint) + assert resolved["max_error_rate"].max_error_rate == 0.08 + + @pytest.mark.sanity + def test_resolve_with_invalid_key(self): + """Test that resolve raises ValueError for unregistered keys.""" + invalid_spec = { + "max_number": {"max_num": 100}, + "invalid_constraint": {"some_param": 42}, + } + + with pytest.raises( + ValueError, match="Unknown constraint initializer key: invalid_constraint" + ): + ConstraintsInitializerFactory.resolve(invalid_spec) + + @pytest.mark.smoke + def test_functional_constraint_creation(self): + """Test that created constraints are functionally correct.""" + constraint = ConstraintsInitializerFactory.create_constraint( + "max_number", max_num=10 + ) + start_time = time.time() + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=5, + processed_requests=5, + ) + request = ScheduledRequestInfo( + request_id="test-request", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + + state_exceeded = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=15, + processed_requests=15, + ) + action_exceeded = constraint(state_exceeded, request) + assert action_exceeded.request_queuing == "stop" + assert action_exceeded.request_processing == "stop_local" diff --git a/tests/unit/scheduler/test_objects.py b/tests/unit/scheduler/test_objects.py new file mode 100644 index 00000000..dac62da4 --- /dev/null +++ b/tests/unit/scheduler/test_objects.py @@ -0,0 +1,1299 @@ +from __future__ import annotations + +import inspect +import typing +from collections.abc import AsyncIterator +from typing import Any, Optional, TypeVar, Union + +import pytest +from pydantic import ValidationError +from typing_extensions import TypeAliasType + +from guidellm.scheduler import ( + BackendInterface, + BackendT, + MeasuredRequestTimings, + MeasuredRequestTimingsT, + MultiTurnRequestT, + RequestSchedulerTimings, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerState, + SchedulerUpdateAction, + SchedulerUpdateActionProgress, +) +from guidellm.utils import StandardBaseModel + + +def test_request_t(): + """Validate that RequestT is a TypeVar usable for generics and isn't bound.""" + assert isinstance(RequestT, TypeVar) + assert RequestT.__name__ == "RequestT" + assert RequestT.__bound__ is None + assert RequestT.__constraints__ == () + + +def test_response_t(): + """Validate that ResponseT is a TypeVar usable for generics and isn't bound.""" + assert isinstance(ResponseT, TypeVar) + assert ResponseT.__name__ == "ResponseT" + assert ResponseT.__bound__ is None + assert ResponseT.__constraints__ == () + + +def test_request_timings_t(): + """Validate MeasuredRequestTimingsT is a TypeVar bound to MeasuredRequestTimings.""" + assert isinstance(MeasuredRequestTimingsT, TypeVar) + assert MeasuredRequestTimingsT.__name__ == "MeasuredRequestTimingsT" + assert MeasuredRequestTimingsT.__bound__ == MeasuredRequestTimings + assert MeasuredRequestTimingsT.__constraints__ == () + + +def test_backend_t(): + """Validate that BackendT is a TypeVar bound to BackendInterface.""" + assert isinstance(BackendT, TypeVar) + assert BackendT.__name__ == "BackendT" + assert BackendT.__bound__.__name__ == "BackendInterface" + assert BackendT.__constraints__ == () + + +def test_multi_turn_request_t(): + """Validate MultiTurnRequestT is a TypeAliasType for multi-turn requests.""" + assert isinstance(MultiTurnRequestT, TypeAliasType) + assert MultiTurnRequestT.__name__ == "MultiTurnRequestT" + + value = MultiTurnRequestT.__value__ + assert hasattr(value, "__origin__") + assert value.__origin__ is Union + + type_params = getattr(MultiTurnRequestT, "__type_params__", ()) + assert len(type_params) == 1 + assert type_params[0].__name__ == "RequestT" + + +class TestBackendInterface: + """Test the BackendInterface abstract base class.""" + + @pytest.mark.smoke + def test_abstract_methods_defined(self): + """Test that all expected abstract methods are defined.""" + expected_methods = { + "process_startup", + "validate", + "process_shutdown", + "resolve", + } + expected_properties = { + "processes_limit", + "requests_limit", + "info", + } + + for method_name in expected_methods: + assert hasattr(BackendInterface, method_name) + method = getattr(BackendInterface, method_name) + assert inspect.isfunction(method) or inspect.ismethod(method) + + for prop_name in expected_properties: + assert hasattr(BackendInterface, prop_name) + prop = getattr(BackendInterface, prop_name) + assert hasattr(prop, "__get__") + + @pytest.mark.smoke + def test_generic_type_parameters(self): + """Test that BackendInterface has the correct generic type parameters.""" + orig_bases = BackendInterface.__orig_bases__ + protocol_base = None + generic_base = None + + for base in orig_bases: + if hasattr(base, "__origin__"): + if base.__origin__ is typing.Generic: + generic_base = base + elif base.__name__ == "Protocol": + protocol_base = base + + assert protocol_base is not None, "Should inherit from Protocol" + assert generic_base is not None, "Should inherit from Generic" + + if hasattr(generic_base, "__args__"): + type_params = generic_base.__args__ + assert len(type_params) == 3, "Should have 3 type parameters" + param_names = [param.__name__ for param in type_params] + expected_names = ["RequestT", "MeasuredRequestTimingsT", "ResponseT"] + assert param_names == expected_names + + @pytest.mark.smoke + def test_implementation_construction(self): + """Test that a complete concrete implementation can be instantiated.""" + + class ConcreteBackend(BackendInterface[str, MeasuredRequestTimings, str]): + @property + def processes_limit(self) -> int | None: + return 4 + + @property + def requests_limit(self) -> int | None: + return 100 + + @property + def info(self) -> dict[str, Any]: + return {"model": "test", "version": "1.0"} + + async def process_startup(self) -> None: + pass + + async def validate(self) -> None: + pass + + async def process_shutdown(self) -> None: + pass + + async def resolve( + self, + request: str, + request_info: ScheduledRequestInfo[MeasuredRequestTimings], + history: list[tuple[str, str]] | None = None, + ) -> AsyncIterator[ + tuple[str, ScheduledRequestInfo[MeasuredRequestTimings]] + ]: + yield f"Response to: {request}", request_info + + backend = ConcreteBackend() + assert isinstance(backend, BackendInterface) + assert isinstance(backend, ConcreteBackend) + assert backend.processes_limit == 4 + assert backend.requests_limit == 100 + info = backend.info + assert info == {"model": "test", "version": "1.0"} + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_implementation_async_methods(self): # noqa: C901 + """Test that async methods work correctly in concrete implementation.""" + + class AsyncBackend(BackendInterface[dict, MeasuredRequestTimings, dict]): + def __init__(self): + self.startup_called = False + self.validate_called = False + self.shutdown_called = False + + @property + def processes_limit(self) -> int | None: + return None # Unlimited + + @property + def requests_limit(self) -> int | None: + return None # Unlimited + + @property + def info(self) -> dict[str, Any]: + return {"backend": "async_test"} + + async def process_startup(self) -> None: + self.startup_called = True + + async def validate(self) -> None: + self.validate_called = True + + async def process_shutdown(self) -> None: + self.shutdown_called = True + + async def resolve( + self, + request: dict, + request_info: ScheduledRequestInfo[MeasuredRequestTimings], + history: list[tuple[dict, dict]] | None = None, + ) -> AsyncIterator[ + tuple[dict, ScheduledRequestInfo[MeasuredRequestTimings]] + ]: + response = {"result": request.get("input", ""), "status": "success"} + yield response, request_info + + backend = AsyncBackend() + await backend.process_startup() + assert backend.startup_called + + await backend.validate() + assert backend.validate_called + + await backend.process_shutdown() + assert backend.shutdown_called + + request = {"input": "test_request"} + request_info = ScheduledRequestInfo( + request_id="test-123", + status="queued", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=1000.0, + ) + results = [] + async for response, updated_info in backend.resolve(request, request_info): + results.append((response, updated_info)) + + assert len(results) == 1 + response, updated_info = results[0] + assert response == {"result": "test_request", "status": "success"} + assert updated_info == request_info + + @pytest.mark.smoke + def test_method_signatures(self): + """Test that abstract methods have the expected signatures.""" + info_prop = BackendInterface.info + assert isinstance(info_prop, property) + + processes_limit_prop = BackendInterface.processes_limit + assert isinstance(processes_limit_prop, property) + + requests_limit_prop = BackendInterface.requests_limit + assert isinstance(requests_limit_prop, property) + + startup_sig = inspect.signature(BackendInterface.process_startup) + assert len(startup_sig.parameters) == 1 # Only self + assert list(startup_sig.parameters.keys()) == ["self"] + + validate_sig = inspect.signature(BackendInterface.validate) + assert len(validate_sig.parameters) == 1 # Only self + assert list(validate_sig.parameters.keys()) == ["self"] + + shutdown_sig = inspect.signature(BackendInterface.process_shutdown) + assert len(shutdown_sig.parameters) == 1 # Only self + assert list(shutdown_sig.parameters.keys()) == ["self"] + + resolve_sig = inspect.signature(BackendInterface.resolve) + expected_params = ["self", "request", "request_info", "history"] + assert list(resolve_sig.parameters.keys()) == expected_params + + history_param = resolve_sig.parameters["history"] + assert history_param.default is None + + +class TestRequestSchedulerTimings: + """Test the RequestSchedulerTimings model class.""" + + CHECK_KEYS = [ + "targeted_start", + "queued", + "dequeued", + "scheduled_at", + "resolve_start", + "resolve_end", + "finalized", + ] + + @pytest.fixture( + params=[ + {}, + { + "targeted_start": None, + "queued": None, + "dequeued": None, + "scheduled_at": None, + "resolve_start": None, + "resolve_end": None, + "finalized": None, + }, + { + "targeted_start": 1000.0, + "queued": 200.0, + "dequeued": 800.0, + "scheduled_at": 900.0, + "resolve_start": 1000.5, + "resolve_end": 1100.0, + "finalized": 1100.5, + }, + { + "queued": 200.0, + "scheduled_at": 250.0, + "resolve_start": 1000.5, + "resolve_end": 1100.0, + }, + { + "targeted_start": 0.0, + "queued": 0.0, + "dequeued": 0.0, + "scheduled_at": 0.0, + "resolve_start": 0.0, + "resolve_end": 0.0, + "finalized": 0.0, + }, + ], + ids=[ + "default_empty", + "all_none_explicit", + "complete_sequence", + "partial_data", + "zero_timestamps", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of RequestSchedulerTimings.""" + constructor_args = request.param + instance = RequestSchedulerTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test RequestSchedulerTimings inheritance and type relationships.""" + assert issubclass(RequestSchedulerTimings, StandardBaseModel) + assert hasattr(RequestSchedulerTimings, "model_dump") + assert hasattr(RequestSchedulerTimings, "model_validate") + + # Check all expected fields are defined + fields = RequestSchedulerTimings.model_fields + for key in self.CHECK_KEYS: + assert key in fields + field_info = fields[key] + assert field_info.annotation in (Union[float, None], Optional[float]) + assert field_info.default is None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, RequestSchedulerTimings) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args + for field, expected_value in constructor_args.items(): + assert getattr(instance, field) == expected_value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("targeted_start", "invalid_string"), + ("queued", "invalid_string"), + ("dequeued", [1, 2, 3]), + ("scheduled_at", {"key": "value"}), + ("resolve_start", {"key": "value"}), + ("resolve_end", [1, 2, 3]), + ("finalized", object()), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + RequestSchedulerTimings(**kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = RequestSchedulerTimings.model_validate(data) + assert isinstance(reconstructed, RequestSchedulerTimings) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + assert getattr(reconstructed, field) == getattr(instance, field) + + # Validate that the reconstructed instance matches original constructor args + for field, expected_value in constructor_args.items(): + assert getattr(reconstructed, field) == expected_value + + +class TestRequestTimings: + """Test the MeasuredRequestTimings model class.""" + + CHECK_KEYS = [ + "request_start", + "request_end", + ] + + @pytest.fixture( + params=[ + {}, + { + "request_start": None, + "request_end": None, + }, + { + "request_start": 1000.0, + "request_end": 1100.0, + }, + { + "request_start": 1000.0, + }, + { + "request_start": 0.0, + "request_end": 0.0, + }, + ], + ids=[ + "default_empty", + "all_none_explicit", + "complete_sequence", + "partial_data", + "zero_timestamps", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of MeasuredRequestTimings.""" + constructor_args = request.param + instance = MeasuredRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test MeasuredRequestTimings inheritance and type relationships.""" + assert issubclass(MeasuredRequestTimings, StandardBaseModel) + assert hasattr(MeasuredRequestTimings, "model_dump") + assert hasattr(MeasuredRequestTimings, "model_validate") + + # Check all expected fields are defined + fields = MeasuredRequestTimings.model_fields + for key in self.CHECK_KEYS: + assert key in fields + field_info = fields[key] + assert field_info.annotation in (Union[float, None], Optional[float]) + assert field_info.default is None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, MeasuredRequestTimings) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args + for field, expected_value in constructor_args.items(): + assert getattr(instance, field) == expected_value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("request_start", "invalid_string"), + ("request_end", [1, 2, 3]), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + MeasuredRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = MeasuredRequestTimings.model_validate(data) + assert isinstance(reconstructed, MeasuredRequestTimings) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + assert getattr(reconstructed, field) == getattr(instance, field) + + # Validate that the reconstructed instance matches original constructor args + for field, expected_value in constructor_args.items(): + assert getattr(reconstructed, field) == expected_value + + +class TestScheduledRequestInfo: + CHECK_KEYS = [ + "request_id", + "status", + "error", + "scheduler_node_id", + "scheduler_process_id", + "scheduler_start_time", + "scheduler_timings", + "request_timings", + ] + + @pytest.fixture( + params=[ + # Minimal required configuration + { + "request_id": "test-req-123", + "status": "queued", + "scheduler_node_id": 1, + "scheduler_process_id": 0, + "scheduler_start_time": 1000.0, + }, + # Complete configuration with all fields + { + "request_id": "test-req-456", + "status": "completed", + "error": None, + "scheduler_node_id": 2, + "scheduler_process_id": 1, + "scheduler_start_time": 2000.0, + "scheduler_timings": { + "targeted_start": 1900.0, + "queued": 1950.0, + "dequeued": 2000.0, + "resolve_start": 2050.0, + "resolve_end": 2100.0, + "finalized": 2150.0, + }, + "request_timings": { + "request_start": 2060.0, + "request_end": 2110.0, + }, + }, + # Error state configuration + { + "request_id": "test-req-error", + "status": "errored", + "error": "Connection timeout", + "scheduler_node_id": 0, + "scheduler_process_id": 0, + "scheduler_start_time": 3000.0, + }, + # Different status values + { + "request_id": "test-req-pending", + "status": "pending", + "scheduler_node_id": 1, + "scheduler_process_id": 2, + "scheduler_start_time": 4000.0, + }, + { + "request_id": "test-req-in-progress", + "status": "in_progress", + "scheduler_node_id": 2, + "scheduler_process_id": 1, + "scheduler_start_time": 5000.0, + }, + ], + ids=[ + "minimal_required", + "complete_configuration", + "error_state", + "pending_status", + "in_progress_status", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of ScheduledRequestInfo. + + Returns: + tuple: (instance, constructor_args) where instance is the constructed + ScheduledRequestInfo and constructor_args are the kwargs used. + """ + constructor_args = request.param.copy() + + # Handle nested objects + if "scheduler_timings" in constructor_args: + constructor_args["scheduler_timings"] = RequestSchedulerTimings( + **constructor_args["scheduler_timings"] + ) + if "request_timings" in constructor_args: + constructor_args["request_timings"] = MeasuredRequestTimings( + **constructor_args["request_timings"] + ) + + instance = ScheduledRequestInfo(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ScheduledRequestInfo inheritance and type relationships.""" + assert issubclass(ScheduledRequestInfo, StandardBaseModel) + assert issubclass(ScheduledRequestInfo, typing.Generic) + assert hasattr(ScheduledRequestInfo, "model_dump") + assert hasattr(ScheduledRequestInfo, "model_validate") + + # Check computed properties + assert hasattr(ScheduledRequestInfo, "started_at") + assert hasattr(ScheduledRequestInfo, "completed_at") + assert isinstance(ScheduledRequestInfo.started_at, property) + assert isinstance(ScheduledRequestInfo.completed_at, property) + + # Check that it's properly generic + orig_bases = getattr(ScheduledRequestInfo, "__orig_bases__", ()) + generic_base = next( + ( + base + for base in orig_bases + if hasattr(base, "__origin__") and base.__origin__ is typing.Generic + ), + None, + ) + assert generic_base is not None + + # Check required fields + fields = ScheduledRequestInfo.model_fields + for key in self.CHECK_KEYS: + assert key in fields + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ScheduledRequestInfo) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args + for field, expected_value in constructor_args.items(): + if field in ["scheduler_timings", "request_timings"]: + actual_value = getattr(instance, field) + if expected_value is None: + assert actual_value is None or ( + field == "scheduler_timings" + and isinstance(actual_value, RequestSchedulerTimings) + ) + else: + assert isinstance(actual_value, type(expected_value)) + else: + assert getattr(instance, field) == expected_value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("request_id", None), # Required field + ("request_id", 123), # Wrong type + ("status", "invalid_status"), # Invalid literal + ("scheduler_node_id", "not_an_int"), + ("scheduler_process_id", -1.5), + ("scheduler_start_time", "not_a_float"), + ("error", 123), # Should be string or None + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + # Start with valid base config + base_kwargs = { + "request_id": "test-req", + "status": "queued", + "scheduler_node_id": 1, + "scheduler_process_id": 0, + "scheduler_start_time": 1000.0, + } + base_kwargs[field] = value + with pytest.raises(ValidationError): + ScheduledRequestInfo(**base_kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = ScheduledRequestInfo.model_validate(data) + assert isinstance(reconstructed, ScheduledRequestInfo) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + original_value = getattr(instance, field) + reconstructed_value = getattr(reconstructed, field) + + if field in ["scheduler_timings", "request_timings"]: + if original_value is not None and reconstructed_value is not None: + assert ( + original_value.model_dump() == reconstructed_value.model_dump() + ) + else: + assert original_value is None or isinstance( + original_value, + (RequestSchedulerTimings, MeasuredRequestTimings), + ) + assert reconstructed_value is None or isinstance( + reconstructed_value, + (RequestSchedulerTimings, MeasuredRequestTimings), + ) + else: + assert original_value == reconstructed_value + + @pytest.mark.smoke + def test_started_at_property(self): + """Test the started_at property logic.""" + # Test with request_timings.request_start (should take precedence) + instance = ScheduledRequestInfo( + request_id="test-req", + status="completed", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + scheduler_timings=RequestSchedulerTimings(resolve_start=2000.0), + request_timings=MeasuredRequestTimings(request_start=2100.0), + ) + assert instance.started_at == 2100.0 + + # Test with only scheduler_timings.resolve_start + instance = ScheduledRequestInfo( + request_id="test-req", + status="completed", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + scheduler_timings=RequestSchedulerTimings(resolve_start=2000.0), + ) + assert instance.started_at == 2000.0 + + # Test with no timing info + instance = ScheduledRequestInfo( + request_id="test-req", + status="queued", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + ) + assert instance.started_at is None + + @pytest.mark.smoke + def test_completed_at_property(self): + """Test the completed_at property logic.""" + # Test with request_timings.request_end (should take precedence) + instance = ScheduledRequestInfo( + request_id="test-req", + status="completed", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + scheduler_timings=RequestSchedulerTimings(resolve_end=2000.0), + request_timings=MeasuredRequestTimings(request_end=2100.0), + ) + assert instance.completed_at == 2100.0 + + # Test with only scheduler_timings.resolve_end + instance = ScheduledRequestInfo( + request_id="test-req", + status="completed", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + scheduler_timings=RequestSchedulerTimings(resolve_end=2000.0), + ) + assert instance.completed_at == 2000.0 + + # Test with no timing info + instance = ScheduledRequestInfo( + request_id="test-req", + status="queued", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + ) + assert instance.completed_at is None + + +class TestSchedulerState: + CHECK_KEYS = [ + "node_id", + "num_processes", + "start_time", + "end_time", + "end_queuing_time", + "end_queuing_constraints", + "end_processing_time", + "end_processing_constraints", + "scheduler_constraints", + "remaining_fraction", + "remaining_requests", + "remaining_duration", + "created_requests", + "queued_requests", + "pending_requests", + "processing_requests", + "processed_requests", + "successful_requests", + "errored_requests", + "cancelled_requests", + ] + + @pytest.fixture( + params=[ + # Minimal required configuration + { + "node_id": 0, + "num_processes": 1, + "start_time": 1000.0, + }, + # Complete configuration with all fields + { + "node_id": 1, + "num_processes": 4, + "start_time": 2000.0, + "end_time": 3000.0, + "end_queuing_time": 2500.0, + "end_queuing_constraints": { + "time_limit": SchedulerUpdateAction( + request_queuing="stop", metadata={"max_duration": 1500} + ) + }, + "end_processing_time": 2800.0, + "end_processing_constraints": { + "request_limit": SchedulerUpdateAction( + request_processing="stop_all", metadata={"max_requests": 1000} + ) + }, + "scheduler_constraints": { + "rate_limit": SchedulerUpdateAction(metadata={"max_rps": 100}) + }, + "remaining_fraction": 0.25, + "remaining_requests": 50, + "remaining_duration": 300.0, + "created_requests": 200, + "queued_requests": 180, + "pending_requests": 20, + "processing_requests": 10, + "processed_requests": 150, + "successful_requests": 140, + "errored_requests": 8, + "cancelled_requests": 2, + }, + # Partial configuration with some stats + { + "node_id": 2, + "num_processes": 2, + "start_time": 4000.0, + "created_requests": 50, + "processed_requests": 30, + "successful_requests": 28, + "errored_requests": 2, + }, + # Edge case: zero values + { + "node_id": 0, + "num_processes": 1, + "start_time": 0.0, + "created_requests": 0, + "processed_requests": 0, + "successful_requests": 0, + }, + ], + ids=[ + "minimal_required", + "complete_configuration", + "partial_stats", + "zero_values", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of SchedulerState. + + Returns: + tuple: (instance, constructor_args) where instance is the constructed + SchedulerState and constructor_args are the kwargs used. + """ + constructor_args = request.param + instance = SchedulerState(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test SchedulerState inheritance and type relationships.""" + assert issubclass(SchedulerState, StandardBaseModel) + assert hasattr(SchedulerState, "model_dump") + assert hasattr(SchedulerState, "model_validate") + + # Check all expected fields are defined + fields = SchedulerState.model_fields + for key in self.CHECK_KEYS: + assert key in fields + + # Check field defaults for key counters + counter_fields = [ + "created_requests", + "queued_requests", + "pending_requests", + "processing_requests", + "processed_requests", + "successful_requests", + "errored_requests", + "cancelled_requests", + ] + for field in counter_fields: + field_info = fields[field] + assert field_info.default == 0 + + # Check that start_time has a default factory + start_time_field = fields["start_time"] + assert start_time_field.default_factory is not None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, SchedulerState) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args + for field, expected_value in constructor_args.items(): + assert getattr(instance, field) == expected_value + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("field", "value"), + [ + ("node_id", "not_an_int"), + ("start_time", "not_a_float"), + ("end_time", [1, 2, 3]), + ("remaining_fraction", "not_a_float"), + ("created_requests", "not_an_int"), + ("end_queuing_constraints", "not_a_dict"), + ("scheduler_constraints", ["not", "a", "dict"]), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + # Start with valid base config + base_kwargs = { + "node_id": 0, + "num_processes": 1, + "start_time": 1000.0, + } + base_kwargs[field] = value + with pytest.raises(ValidationError): + SchedulerState(**base_kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = SchedulerState.model_validate(data) + assert isinstance(reconstructed, SchedulerState) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + assert getattr(reconstructed, field) == getattr(instance, field) + + # Validate that the reconstructed instance matches original constructor args + for field, expected_value in constructor_args.items(): + assert getattr(reconstructed, field) == expected_value + + +class TestSchedulerUpdateAction: + CHECK_KEYS = [ + "request_queuing", + "request_processing", + "metadata", + "progress", + ] + + @pytest.fixture( + params=[ + # Default configuration + {}, + # All explicit default values + { + "request_queuing": "continue", + "request_processing": "continue", + "metadata": {}, + "progress": {}, + }, + # Stop queuing configuration + { + "request_queuing": "stop", + "request_processing": "continue", + "metadata": {"reason": "rate_limit_exceeded"}, + }, + # Stop local processing configuration + { + "request_queuing": "continue", + "request_processing": "stop_local", + "metadata": {"node_id": 1, "reason": "resource_exhausted"}, + }, + # Stop all processing configuration + { + "request_queuing": "stop", + "request_processing": "stop_all", + "metadata": { + "emergency_stop": True, + "reason": "critical_error", + "error_details": {"code": 500, "message": "Internal server error"}, + }, + }, + # Complex metadata configuration + { + "request_queuing": "continue", + "request_processing": "continue", + "metadata": { + "stats": {"processed": 100, "pending": 50}, + "constraints": {"max_rps": 10, "max_concurrent": 20}, + "config": {"batch_size": 32, "timeout": 30.0}, + }, + }, + # Progress with remaining_fraction only + { + "request_queuing": "continue", + "request_processing": "continue", + "progress": {"remaining_fraction": 0.75}, + }, + # Progress with remaining_requests only + { + "request_queuing": "continue", + "request_processing": "continue", + "progress": {"remaining_requests": 250.0}, + }, + # Progress with remaining_duration only + { + "request_queuing": "continue", + "request_processing": "continue", + "progress": {"remaining_duration": 120.5}, + }, + # Complete progress configuration + { + "request_queuing": "stop", + "request_processing": "stop_all", + "metadata": {"shutdown_reason": "completion"}, + "progress": { + "remaining_fraction": 0.0, + "remaining_requests": 0.0, + "remaining_duration": 0.0, + }, + }, + # Partial progress configuration + { + "request_queuing": "continue", + "request_processing": "continue", + "metadata": {"checkpoint": "mid_benchmark"}, + "progress": { + "remaining_fraction": 0.45, + "remaining_duration": 180.0, + }, + }, + ], + ids=[ + "default_empty", + "explicit_defaults", + "stop_queuing", + "stop_local_processing", + "stop_all_processing", + "complex_metadata", + "progress_fraction_only", + "progress_requests_only", + "progress_duration_only", + "complete_progress", + "partial_progress", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of SchedulerUpdateAction. + + Returns: + tuple: (instance, constructor_args) where instance is the constructed + SchedulerUpdateAction and constructor_args are the kwargs used. + """ + constructor_args = request.param + instance = SchedulerUpdateAction(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test SchedulerUpdateAction inheritance and type relationships.""" + assert issubclass(SchedulerUpdateAction, StandardBaseModel) + assert hasattr(SchedulerUpdateAction, "model_dump") + assert hasattr(SchedulerUpdateAction, "model_validate") + + # Check all expected fields are defined + fields = SchedulerUpdateAction.model_fields + for key in self.CHECK_KEYS: + assert key in fields + + # Check field defaults + assert fields["request_queuing"].default == "continue" + assert fields["request_processing"].default == "continue" + metadata_field = fields["metadata"] + assert metadata_field.default_factory is not None + progress_field = fields["progress"] + assert progress_field.default_factory is not None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, SchedulerUpdateAction) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args or defaults + for field in self.CHECK_KEYS: + if field in constructor_args: + assert getattr(instance, field) == constructor_args[field] + elif field in ["request_queuing", "request_processing"]: + assert getattr(instance, field) == "continue" + elif field in ["metadata", "progress"]: + assert getattr(instance, field) == {} + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("field", "value"), + [ + ("request_queuing", "invalid_action"), + ("request_queuing", 123), + ("request_processing", "invalid_action"), + ("request_processing", ["stop"]), + ("metadata", "not_a_dict"), + ("metadata", [{"key": "value"}]), + ("progress", "not_a_dict"), + ("progress", [{"remaining_fraction": 0.5}]), + ("progress", {"remaining_fraction": "not_a_float"}), + ("progress", {"remaining_requests": "not_a_float"}), + ("progress", {"remaining_duration": "not_a_float"}), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + SchedulerUpdateAction(**kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = SchedulerUpdateAction.model_validate(data) + assert isinstance(reconstructed, SchedulerUpdateAction) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + assert getattr(reconstructed, field) == getattr(instance, field) + + # Validate that the reconstructed instance matches expected values + for field in self.CHECK_KEYS: + if field in constructor_args: + assert getattr(reconstructed, field) == constructor_args[field] + elif field in ["request_queuing", "request_processing"]: + assert getattr(reconstructed, field) == "continue" + elif field in ["metadata", "progress"]: + assert getattr(reconstructed, field) == {} + + @pytest.mark.smoke + def test_progress_field_behavior(self): + """Test the progress field specific behavior and validation.""" + # Test empty progress (default) + instance = SchedulerUpdateAction() + assert instance.progress == {} + assert isinstance(instance.progress, dict) + + # Test progress with all valid fields + progress_data = { + "remaining_fraction": 0.75, + "remaining_requests": 100.0, + "remaining_duration": 30.5, + } + instance = SchedulerUpdateAction(progress=progress_data) + assert instance.progress == progress_data + + # Test progress with partial fields (TypedDict allows partial) + partial_progress = {"remaining_fraction": 0.25} + instance = SchedulerUpdateAction(progress=partial_progress) + assert instance.progress == partial_progress + + # Test progress with zero values + zero_progress = { + "remaining_fraction": 0.0, + "remaining_requests": 0.0, + "remaining_duration": 0.0, + } + instance = SchedulerUpdateAction(progress=zero_progress) + assert instance.progress == zero_progress + + # Test that progress field persists through marshalling + data = instance.model_dump() + assert "progress" in data + assert data["progress"] == zero_progress + + reconstructed = SchedulerUpdateAction.model_validate(data) + assert reconstructed.progress == zero_progress + + @pytest.mark.smoke + @pytest.mark.parametrize( + "progress_value", + [ + {"remaining_fraction": 0.0}, + {"remaining_fraction": 1.0}, + {"remaining_requests": 0.0}, + {"remaining_requests": 1000.0}, + {"remaining_duration": 0.0}, + {"remaining_duration": 3600.0}, + {"remaining_fraction": 0.5, "remaining_requests": 50.0}, + {"remaining_requests": 25.0, "remaining_duration": 120.0}, + {"remaining_fraction": 0.33, "remaining_duration": 45.0}, + ], + ) + def test_progress_valid_combinations(self, progress_value): + """Test various valid combinations of progress field values.""" + instance = SchedulerUpdateAction(progress=progress_value) + assert instance.progress == progress_value + + # Verify marshalling works correctly + data = instance.model_dump() + reconstructed = SchedulerUpdateAction.model_validate(data) + assert reconstructed.progress == progress_value + + @pytest.mark.smoke + def test_scheduler_update_action_progress_typeddict(self): + """Test the SchedulerUpdateActionProgress TypedDict behavior.""" + # Test that SchedulerUpdateActionProgress is a proper TypedDict + # Verify it's a TypedDict (has the special attributes) + assert hasattr(SchedulerUpdateActionProgress, "__annotations__") + assert hasattr(SchedulerUpdateActionProgress, "__total__") + assert hasattr(SchedulerUpdateActionProgress, "__required_keys__") + assert hasattr(SchedulerUpdateActionProgress, "__optional_keys__") + + # Check that all keys are optional (total=False) + expected_keys = { + "remaining_fraction", + "remaining_requests", + "remaining_duration", + } + actual_keys = set(SchedulerUpdateActionProgress.__annotations__.keys()) + assert actual_keys == expected_keys + assert SchedulerUpdateActionProgress.__total__ is False + assert SchedulerUpdateActionProgress.__required_keys__ == frozenset() + assert SchedulerUpdateActionProgress.__optional_keys__ == expected_keys + + # Test that type annotations are correct + annotations = SchedulerUpdateActionProgress.__annotations__ + assert "remaining_fraction" in annotations + assert "remaining_requests" in annotations + assert "remaining_duration" in annotations + + # Test creation of valid TypedDict instances + valid_progress_1: SchedulerUpdateActionProgress = {} + valid_progress_2: SchedulerUpdateActionProgress = {"remaining_fraction": 0.5} + valid_progress_3: SchedulerUpdateActionProgress = { + "remaining_fraction": 0.25, + "remaining_requests": 100.0, + "remaining_duration": 60.0, + } + + # All should be valid dict instances + assert isinstance(valid_progress_1, dict) + assert isinstance(valid_progress_2, dict) + assert isinstance(valid_progress_3, dict) diff --git a/tests/unit/scheduler/test_strategy.py b/tests/unit/scheduler/test_strategy.py new file mode 100644 index 00000000..8cb91d82 --- /dev/null +++ b/tests/unit/scheduler/test_strategy.py @@ -0,0 +1,1154 @@ +from __future__ import annotations + +import inspect +import math +import statistics +import time +from abc import ABC +from typing import Literal, TypeVar + +import pytest +from pydantic import ValidationError + +from guidellm.scheduler import ( + AsyncConstantStrategy, + AsyncPoissonStrategy, + ConcurrentStrategy, + ConstantRateRequestTimings, + LastCompletionRequestTimings, + NoDelayRequestTimings, + PoissonRateRequestTimings, + ScheduledRequestInfo, + ScheduledRequestTimings, + SchedulingStrategy, + StrategyT, + SynchronousStrategy, + ThroughputStrategy, +) +from guidellm.scheduler.strategy import ( + _exponential_decay_fraction, + _exponential_decay_tau, +) + + +def test_strategy_type(): + """Test that StrategyType is defined correctly as a Literal type.""" + # StrategyType is a type alias/literal type, we can't test its runtime value + # but we can test that it exists and is importable + from guidellm.scheduler.strategy import StrategyType + + assert StrategyType is not None + + +def test_strategy_t(): + """Test that StrategyT is filled out correctly as a TypeVar.""" + assert isinstance(StrategyT, type(TypeVar("test"))) + assert StrategyT.__name__ == "StrategyT" + assert StrategyT.__bound__ == SchedulingStrategy + assert StrategyT.__constraints__ == () + + +class TestExponentialDecay: + """Test suite for _exponential_decay_tau function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("max_progress", "convergence", "expected_range"), + [ + (1.0, 0.99, (0.21, 0.22)), + (5.0, 0.99, (1.08, 1.09)), + (10.0, 0.95, (3.33, 3.35)), + ], + ) + def test_tau_invocation(self, max_progress, convergence, expected_range): + """Test exponential decay tau calculation with valid inputs.""" + tau = _exponential_decay_tau(max_progress, convergence) + assert expected_range[0] <= tau <= expected_range[1] + expected_tau = max_progress / (-math.log(1 - convergence)) + assert tau == pytest.approx(expected_tau, rel=1e-10) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("progress", "tau", "expected_min", "expected_max"), + [ + (0.0, 1.0, 0.0, 0.0), # No progress = 0 + (1.0, 1.0, 0.6, 0.7), # 1 tau ≈ 63.2% + (2.0, 1.0, 0.85, 0.87), # 2 tau ≈ 86.5% + (3.0, 1.0, 0.95, 0.96), # 3 tau ≈ 95.0% + ], + ) + def test_exp_decay_invocation(self, progress, tau, expected_min, expected_max): + """Test exponential decay fraction calculation with valid inputs.""" + fraction = _exponential_decay_fraction(progress, tau) + assert expected_min <= fraction <= expected_max + expected_fraction = 1 - math.exp(-progress / tau) + assert fraction == pytest.approx(expected_fraction, rel=1e-10) + + @pytest.mark.smoke + def test_exp_boundary_conditions(self): + """Test boundary conditions for exponential decay fraction.""" + assert _exponential_decay_fraction(0.0, 1.0) == 0.0 + assert _exponential_decay_fraction(0.0, 10.0) == 0.0 + large_progress = 100.0 + fraction = _exponential_decay_fraction(large_progress, 1.0) + assert fraction > 0.99999 + + +class TestScheduledRequestTimings: + @pytest.mark.smoke + def test_signatures(self): + """Test that ScheduledRequestTimings is an abstract base class.""" + assert issubclass(ScheduledRequestTimings, ABC) + assert inspect.isabstract(ScheduledRequestTimings) + + abstract_methods = ScheduledRequestTimings.__abstractmethods__ + expected_methods = {"next_offset", "request_completed"} + assert abstract_methods == expected_methods + + # Validate method signatures + next_offset_method = ScheduledRequestTimings.next_offset + assert callable(next_offset_method) + request_completed_method = ScheduledRequestTimings.request_completed + assert callable(request_completed_method) + + # Check signature parameters using inspect + next_offset_sig = inspect.signature(next_offset_method) + assert len(next_offset_sig.parameters) == 1 + assert str(next_offset_sig.return_annotation) == "float" + request_completed_sig = inspect.signature(request_completed_method) + assert len(request_completed_sig.parameters) == 2 + params = list(request_completed_sig.parameters.values()) + param_annotation = params[1].annotation + assert param_annotation in {ScheduledRequestInfo, "ScheduledRequestInfo"} + + @pytest.mark.sanity + def test_invalid_implementation(self): + """Test that invalid implementations raise TypeError.""" + + class InvalidImplementation(ScheduledRequestTimings): + pass # Missing required abstract methods + + with pytest.raises(TypeError): + InvalidImplementation() + + @pytest.mark.smoke + def test_child_implementation(self): + """Test that concrete implementations can be constructed.""" + + class TestRequestTimings(ScheduledRequestTimings): + offset: float = 0.0 + + def next_offset(self) -> float: + self.offset += 1.0 + return self.offset + + def request_completed(self, request_info: ScheduledRequestInfo): + pass + + timing = TestRequestTimings() + assert isinstance(timing, ScheduledRequestTimings) + + assert timing.next_offset() == 1.0 + assert timing.next_offset() == 2.0 + + mock_request = ScheduledRequestInfo( + request_id="test", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=time.time(), + ) + timing.request_completed(mock_request) + + +class TestLastCompletionRequestTimings: + @pytest.fixture( + params=[ + {}, + {"offset": 10.0}, + {"startup_requests": 5, "startup_requests_delay": 0.5}, + { + "offset": 0.0, + "startup_requests": 0, + "startup_requests_delay": 0.0, + }, + { + "offset": 2.5, + "startup_requests": 3, + "startup_requests_delay": 1.0, + }, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of LastCompletionRequestTimings.""" + constructor_args = request.param + instance = LastCompletionRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization( + self, valid_instances: tuple[LastCompletionRequestTimings, dict] + ): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, LastCompletionRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("startup_requests", -1), + ("startup_requests_delay", -0.5), + ("offset", "invalid"), + ("startup_requests", 1.5), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + LastCompletionRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_lifecycle( + self, valid_instances: tuple[LastCompletionRequestTimings, dict] + ): + """Test the complete lifecycle of next_offset and request_completed calls.""" + instance, constructor_args = valid_instances + initial_offset = instance.offset + startup_requests = constructor_args.get("startup_requests", 0) + startup_delay = constructor_args.get("startup_requests_delay", 0.0) + request_times = [] + + for index in range(max(5, startup_requests + 2)): + offset = instance.next_offset() + assert isinstance(offset, (int, float)) + + if index < startup_requests: + expected_offset = initial_offset + (index + 1) * startup_delay + assert offset == pytest.approx(expected_offset, abs=1e-5) + + completion_time = time.time() + offset + request_times.append(completion_time) + + mock_request: ScheduledRequestInfo = ScheduledRequestInfo( + request_id=f"test-{index}", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=time.time(), + ) + mock_request.scheduler_timings.resolve_end = completion_time + instance.request_completed(mock_request) + + @pytest.mark.smoke + def test_marshalling( + self, valid_instances: tuple[LastCompletionRequestTimings, dict] + ): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = LastCompletionRequestTimings.model_validate(data) + assert isinstance(reconstructed, LastCompletionRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestNoDelayRequestTimings: + @pytest.fixture( + params=[ + {}, + {"offset": 0.2}, + {"startup_duration": 0.3, "startup_target_requests": 5}, + { + "offset": 0.15, + "startup_duration": 0.2, + "startup_target_requests": 20, + "startup_convergence": 0.9, + }, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of NoDelayRequestTimings.""" + constructor_args = request.param + instance = NoDelayRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[NoDelayRequestTimings, dict]): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, NoDelayRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("offset", -1.0), + ("startup_duration", -1.0), + ("startup_target_requests", 0), + ("startup_target_requests", -1), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + NoDelayRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_lifecycle(self, valid_instances: tuple[NoDelayRequestTimings, dict]): + """Test the complete lifecycle of timing methods.""" + instance, constructor_args = valid_instances + startup_duration = constructor_args.get("startup_duration", 0.0) + base_offset = constructor_args.get("offset", 0.0) + start_time = time.time() + min_time = base_offset + startup_duration + 0.2 + end_time = start_time + min_time + last_offset = -1 * math.inf + + while (current_time := time.time()) < end_time: + offset = instance.next_offset() + + if startup_duration > 0 and (current_time - start_time) <= startup_duration: + assert offset < base_offset + startup_duration + assert offset > last_offset + elif startup_duration > 0: + assert offset == base_offset + startup_duration + else: + assert offset == base_offset + + last_offset = offset + time.sleep(0.025) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[NoDelayRequestTimings, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = NoDelayRequestTimings.model_validate(data) + assert isinstance(reconstructed, NoDelayRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestConstantRateRequestTimings: + @pytest.fixture( + params=[ + {"rate": 1.0}, + {"rate": 5.0, "offset": 2.0}, + {"rate": 10.5, "offset": 1.0}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of ConstantRateRequestTimings.""" + constructor_args = request.param + instance = ConstantRateRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization( + self, valid_instances: tuple[ConstantRateRequestTimings, dict] + ): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ConstantRateRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("rate", 0), + ("rate", -1.0), + ("offset", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {"rate": 1.0} + kwargs[field] = value + with pytest.raises(ValidationError): + ConstantRateRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_constant_rate_behavior( + self, valid_instances: tuple[ConstantRateRequestTimings, dict] + ): + """Test that requests are scheduled at constant intervals.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + expected_interval = 1.0 / rate + base_offset = constructor_args.get("offset", 0.0) + num_requests = int(5 * rate) # simulate 5 seconds + + for ind in range(num_requests): + offset = instance.next_offset() + assert offset >= base_offset + assert offset == pytest.approx( + base_offset + ind * expected_interval, rel=1e-2 + ) + + @pytest.mark.smoke + def test_marshalling( + self, valid_instances: tuple[ConstantRateRequestTimings, dict] + ): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = ConstantRateRequestTimings.model_validate(data) + assert isinstance(reconstructed, ConstantRateRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestPoissonRateRequestTimings: + @pytest.fixture( + params=[ + {"rate": 1.0}, + { + "rate": 5.0, + "random_seed": 123, + "offset": 1.0, + }, + { + "rate": 0.5, + }, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of PoissonRateRequestTimings.""" + constructor_args = request.param + instance = PoissonRateRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization( + self, valid_instances: tuple[PoissonRateRequestTimings, dict] + ): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, PoissonRateRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("rate", 0), + ("rate", -1.0), + ("offset", "invalid"), + ("random_seed", "invalid"), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {"rate": 1.0} + kwargs[field] = value + with pytest.raises(ValidationError): + PoissonRateRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_lifecycle(self, valid_instances: tuple[PoissonRateRequestTimings, dict]): + """Test that Poisson timing produces variable intervals.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + base_offset = constructor_args.get("offset", 0.0) + num_requests = 200 + last_offset = 0.0 + intervals = [] + + for index in range(num_requests): + offset = instance.next_offset() + + if index == 0: + assert offset == base_offset + else: + assert offset > last_offset + + intervals.append(offset - last_offset) + last_offset = offset + + expected_mean_interval = 1.0 / rate + actual_mean_interval = statistics.mean(intervals) + tolerance = 0.2 * expected_mean_interval + assert abs(actual_mean_interval - expected_mean_interval) < tolerance + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[PoissonRateRequestTimings, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = PoissonRateRequestTimings.model_validate(data) + assert isinstance(reconstructed, PoissonRateRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestSchedulingStrategy: + @pytest.mark.smoke + def test_class_signatures(self): + """Test SchedulingStrategy inheritance and type relationships.""" + # Inheritance and abstract class properties + assert issubclass(SchedulingStrategy, object) + assert hasattr(SchedulingStrategy, "info") + + # Validate expected methods exist + expected_methods = { + "processes_limit", + "requests_limit", + "create_request_timings", + } + strategy_methods = set(dir(SchedulingStrategy)) + for method in expected_methods: + assert method in strategy_methods + + # validate expected properties + processes_limit_prop = SchedulingStrategy.processes_limit + assert isinstance(processes_limit_prop, property) + requests_limit_prop = SchedulingStrategy.requests_limit + assert isinstance(requests_limit_prop, property) + create_request_timings_method = SchedulingStrategy.create_request_timings + assert callable(create_request_timings_method) + + # Validate method signature + sig = inspect.signature(create_request_timings_method) + params = list(sig.parameters.keys()) + expected_params = [ + "self", + "local_rank", + "local_world_size", + "local_max_concurrency", + ] + assert params == expected_params + + @pytest.mark.sanity + def test_invalid_implementation(self): + """Test that invalid implementations raise NotImplementedError.""" + + class InvalidStrategy(SchedulingStrategy): + type_: Literal["strategy"] = "strategy" # type: ignore[assignment,annotation-unchecked] + + strategy = InvalidStrategy() + with pytest.raises(NotImplementedError): + strategy.create_request_timings(0, 1, 1) + + @pytest.mark.smoke + def test_concrete_implementation(self): + """Test that concrete implementations can be constructed.""" + + class TestStrategy(SchedulingStrategy): + type_: Literal["strategy"] = "strategy" # type: ignore[assignment,annotation-unchecked] + + def create_request_timings( + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, + ): + return LastCompletionRequestTimings() + + strategy = TestStrategy() + assert isinstance(strategy, SchedulingStrategy) + timing = strategy.create_request_timings(0, 1, 1) + assert isinstance(timing, ScheduledRequestTimings) + + +class TestSynchronousStrategy: + @pytest.mark.smoke + def test_initialization(self): + """Test initialization of SynchronousStrategy.""" + strategy = SynchronousStrategy() + assert strategy.type_ == "synchronous" + + @pytest.mark.smoke + def test_limits(self): + """Test that SynchronousStrategy enforces proper limits.""" + strategy = SynchronousStrategy() + assert strategy.processes_limit == 1 + assert strategy.requests_limit == 1 + + @pytest.mark.smoke + def test_create_timings_valid(self): + """Test creating timings with valid parameters.""" + strategy = SynchronousStrategy() + timing = strategy.create_request_timings(0, 1, 1) + assert isinstance(timing, LastCompletionRequestTimings) + + @pytest.mark.sanity + def test_create_timings_invalid(self): + """Test that invalid parameters raise ValueError.""" + strategy = SynchronousStrategy() + + with pytest.raises(ValueError): + strategy.create_request_timings(1, 1, 1) # rank != 0 + + with pytest.raises(ValueError): + strategy.create_request_timings(0, 2, 1) # world_size > 1 + + @pytest.mark.smoke + def test_string_representation(self): + """Test __str__ method for SynchronousStrategy.""" + strategy = SynchronousStrategy() + result = str(strategy) + assert result == "synchronous" + + @pytest.mark.smoke + def test_marshalling(self): + """Test marshalling to/from pydantic dict formats.""" + strategy = SynchronousStrategy() + data = strategy.model_dump() + assert isinstance(data, dict) + assert data["type_"] == "synchronous" + + reconstructed = SynchronousStrategy.model_validate(data) + assert isinstance(reconstructed, SynchronousStrategy) + assert reconstructed.type_ == "synchronous" + + # Test polymorphic reconstruction via base registry class + base_reconstructed = SchedulingStrategy.model_validate(data) + assert isinstance(base_reconstructed, SynchronousStrategy) + assert base_reconstructed.type_ == "synchronous" + + # Test model_validate_json pathway + json_str = strategy.model_dump_json() + json_reconstructed = SynchronousStrategy.model_validate_json(json_str) + assert isinstance(json_reconstructed, SynchronousStrategy) + assert json_reconstructed.type_ == "synchronous" + + # Test polymorphic model_validate_json via base class + base_json_reconstructed = SchedulingStrategy.model_validate_json(json_str) + assert isinstance(base_json_reconstructed, SynchronousStrategy) + assert base_json_reconstructed.type_ == "synchronous" + + +class TestConcurrentStrategy: + @pytest.fixture( + params=[ + {"streams": 1}, + {"streams": 4}, + {"streams": 8, "startup_duration": 2.0}, + {"streams": 2, "startup_duration": 0.0}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of ConcurrentStrategy.""" + constructor_args = request.param + instance = ConcurrentStrategy(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[ConcurrentStrategy, dict]): + """Test initialization of ConcurrentStrategy.""" + instance, constructor_args = valid_instances + assert instance.type_ == "concurrent" + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("streams", 0), + ("streams", -1), + ("startup_duration", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization.""" + kwargs = {"streams": 2} + kwargs[field] = value + with pytest.raises(ValidationError): + ConcurrentStrategy(**kwargs) + + @pytest.mark.smoke + def test_limits(self, valid_instances: tuple[ConcurrentStrategy, dict]): + """Test that ConcurrentStrategy returns correct limits.""" + instance, constructor_args = valid_instances + streams = constructor_args["streams"] + assert instance.processes_limit == streams + assert instance.requests_limit == streams + + @pytest.mark.smoke + def test_create_timings(self, valid_instances: tuple[ConcurrentStrategy, dict]): + """Test creating timings.""" + instance, constructor_args = valid_instances + streams = constructor_args["streams"] + startup_duration = constructor_args.get("startup_duration", 0.0) + + # Test with different rank and world_size combinations + for local_rank in range(min(streams, 2)): + for local_world_size in range(1, min(streams + 1, 3)): + if local_rank < local_world_size: + timing = instance.create_request_timings( + local_rank, local_world_size, streams + ) + assert isinstance(timing, LastCompletionRequestTimings) + + # Verify startup behavior + if startup_duration > 0: + # Check that timing has proper startup configuration + expected_delay_per_stream = startup_duration / streams + streams_per_worker = streams // local_world_size + expected_offset = ( + local_rank * streams_per_worker * expected_delay_per_stream + ) + assert timing.offset == pytest.approx(expected_offset, abs=1e-5) + + @pytest.mark.sanity + def test_create_timings_invalid( + self, valid_instances: tuple[ConcurrentStrategy, dict] + ): + """Test invalid inputs for create request timings.""" + instance, constructor_args = valid_instances + streams = constructor_args["streams"] + + # Test various invalid configurations + invalid_configs = [ + (streams, 1, 1), # rank >= streams + (0, streams + 1, 1), # world_size > streams + ] + + for local_rank, local_world_size, local_max_concurrency in invalid_configs: + if local_rank >= streams or local_world_size > streams: + with pytest.raises(ValueError): + instance.create_request_timings( + local_rank, local_world_size, local_max_concurrency + ) + + @pytest.mark.smoke + def test_string_representation( + self, valid_instances: tuple[ConcurrentStrategy, dict] + ): + """Test __str__ method for ConcurrentStrategy.""" + instance, constructor_args = valid_instances + streams = constructor_args["streams"] + result = str(instance) + assert result == f"concurrent@{streams}" + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[ConcurrentStrategy, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + assert data["type_"] == "concurrent" + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = ConcurrentStrategy.model_validate(data) + assert isinstance(reconstructed, ConcurrentStrategy) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + # Test polymorphic reconstruction via base registry class + base_reconstructed = SchedulingStrategy.model_validate(data) + assert isinstance(base_reconstructed, ConcurrentStrategy) + assert base_reconstructed.type_ == "concurrent" + + for key, value in constructor_args.items(): + assert getattr(base_reconstructed, key) == value + + # Test model_validate_json pathway + json_str = instance.model_dump_json() + json_reconstructed = ConcurrentStrategy.model_validate_json(json_str) + assert isinstance(json_reconstructed, ConcurrentStrategy) + + for key, value in constructor_args.items(): + assert getattr(json_reconstructed, key) == value + + # Test polymorphic model_validate_json via base class + base_json_reconstructed = SchedulingStrategy.model_validate_json(json_str) + assert isinstance(base_json_reconstructed, ConcurrentStrategy) + assert base_json_reconstructed.type_ == "concurrent" + + for key, value in constructor_args.items(): + assert getattr(base_json_reconstructed, key) == value + + +class TestThroughputStrategy: + @pytest.fixture( + params=[ + {}, + {"max_concurrency": 10}, + {"startup_duration": 5.0}, + {"max_concurrency": 5, "startup_duration": 2.0}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of ThroughputStrategy.""" + constructor_args = request.param + instance = ThroughputStrategy(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[ThroughputStrategy, dict]): + """Test initialization of ThroughputStrategy.""" + instance, constructor_args = valid_instances + assert instance.type_ == "throughput" + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("max_concurrency", 0), + ("max_concurrency", -1), + ("startup_duration", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + ThroughputStrategy(**kwargs) + + @pytest.mark.smoke + def test_limits(self, valid_instances: tuple[ThroughputStrategy, dict]): + """Test that ThroughputStrategy returns correct limits.""" + instance, constructor_args = valid_instances + max_concurrency = constructor_args.get("max_concurrency") + assert instance.processes_limit == max_concurrency + assert instance.requests_limit == max_concurrency + + @pytest.mark.smoke + def test_create_timings(self, valid_instances: tuple[ThroughputStrategy, dict]): + """Test creating timings.""" + instance, constructor_args = valid_instances + startup_duration = constructor_args.get("startup_duration", 0.0) + + # Test with different configurations + for local_rank in range(3): + for local_world_size in range(1, 4): + for local_max_concurrency in range(1, 6): + timing = instance.create_request_timings( + local_rank, local_world_size, local_max_concurrency + ) + assert isinstance(timing, NoDelayRequestTimings) + + # Verify startup configuration + if startup_duration > 0: + assert timing.startup_duration == startup_duration + assert timing.startup_target_requests == local_max_concurrency + expected_offset = ( + 0.05 * startup_duration * (local_rank / local_world_size) + ) + assert timing.offset == pytest.approx(expected_offset, abs=1e-5) + else: + assert timing.startup_duration == 0.0 + assert timing.offset == 0.0 + + @pytest.mark.smoke + def test_string_representation( + self, valid_instances: tuple[ThroughputStrategy, dict] + ): + """Test __str__ method for ThroughputStrategy.""" + instance, _ = valid_instances + result = str(instance) + assert result == "throughput" + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[ThroughputStrategy, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + assert data["type_"] == "throughput" + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = ThroughputStrategy.model_validate(data) + assert isinstance(reconstructed, ThroughputStrategy) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + # Test polymorphic reconstruction via base registry class + base_reconstructed = SchedulingStrategy.model_validate(data) + assert isinstance(base_reconstructed, ThroughputStrategy) + assert base_reconstructed.type_ == "throughput" + + for key, value in constructor_args.items(): + assert getattr(base_reconstructed, key) == value + + # Test model_validate_json pathway + json_str = instance.model_dump_json() + json_reconstructed = ThroughputStrategy.model_validate_json(json_str) + assert isinstance(json_reconstructed, ThroughputStrategy) + + for key, value in constructor_args.items(): + assert getattr(json_reconstructed, key) == value + + # Test polymorphic model_validate_json via base class + base_json_reconstructed = SchedulingStrategy.model_validate_json(json_str) + assert isinstance(base_json_reconstructed, ThroughputStrategy) + assert base_json_reconstructed.type_ == "throughput" + + for key, value in constructor_args.items(): + assert getattr(base_json_reconstructed, key) == value + + +class TestAsyncConstantStrategy: + @pytest.fixture( + params=[ + {"rate": 1.0}, + {"rate": 5.0}, + {"rate": 10.3, "max_concurrency": 8}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of AsyncConstantStrategy.""" + constructor_args = request.param + instance = AsyncConstantStrategy(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[AsyncConstantStrategy, dict]): + """Test initialization of AsyncConstantStrategy.""" + instance, constructor_args = valid_instances + assert instance.type_ == "constant" + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("rate", 0), + ("rate", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization.""" + kwargs = {"rate": 1.0} + kwargs[field] = value + with pytest.raises(ValidationError): + AsyncConstantStrategy(**kwargs) + + @pytest.mark.smoke + def test_create_timings(self, valid_instances: tuple[AsyncConstantStrategy, dict]): + """Test creating timings.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + + # Test with different worker configurations + for local_world_size in range(1, 5): + timing = instance.create_request_timings(0, local_world_size, 1) + assert isinstance(timing, ConstantRateRequestTimings) + + # Rate should be distributed across workers + expected_worker_rate = rate / local_world_size + assert timing.rate == pytest.approx(expected_worker_rate, abs=1e-5) + + @pytest.mark.smoke + def test_string_representation( + self, valid_instances: tuple[AsyncConstantStrategy, dict] + ): + """Test __str__ method for AsyncConstantStrategy.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + result = str(instance) + assert result == f"constant@{rate:.2f}" + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[AsyncConstantStrategy, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + assert data["type_"] == "constant" + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = AsyncConstantStrategy.model_validate(data) + assert isinstance(reconstructed, AsyncConstantStrategy) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + # Test polymorphic reconstruction via base registry class + base_reconstructed = SchedulingStrategy.model_validate(data) + assert isinstance(base_reconstructed, AsyncConstantStrategy) + assert base_reconstructed.type_ == "constant" + + for key, value in constructor_args.items(): + assert getattr(base_reconstructed, key) == value + + # Test model_validate_json pathway + json_str = instance.model_dump_json() + json_reconstructed = AsyncConstantStrategy.model_validate_json(json_str) + assert isinstance(json_reconstructed, AsyncConstantStrategy) + + for key, value in constructor_args.items(): + assert getattr(json_reconstructed, key) == value + + # Test polymorphic model_validate_json via base class + base_json_reconstructed = SchedulingStrategy.model_validate_json(json_str) + assert isinstance(base_json_reconstructed, AsyncConstantStrategy) + assert base_json_reconstructed.type_ == "constant" + + for key, value in constructor_args.items(): + assert getattr(base_json_reconstructed, key) == value + + +class TestAsyncPoissonStrategy: + @pytest.fixture( + params=[ + {"rate": 1.0}, + {"rate": 5.0, "random_seed": 123}, + {"rate": 10.3, "random_seed": 456, "max_concurrency": 8}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of AsyncPoissonStrategy.""" + constructor_args = request.param + instance = AsyncPoissonStrategy(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[AsyncPoissonStrategy, dict]): + """Test initialization of AsyncPoissonStrategy.""" + instance, constructor_args = valid_instances + assert instance.type_ == "poisson" + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("rate", 0), + ("rate", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization.""" + kwargs = {"rate": 1.0, "random_seed": 42} + kwargs[field] = value + with pytest.raises(ValidationError): + AsyncPoissonStrategy(**kwargs) + + @pytest.mark.smoke + def test_create_timings(self, valid_instances: tuple[AsyncPoissonStrategy, dict]): + """Test creating timings.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + base_seed = constructor_args.get("random_seed", 42) + + # Test with different worker configurations + for local_rank in range(3): + for local_world_size in range(1, 4): + timing = instance.create_request_timings( + local_rank, local_world_size, 1 + ) + assert isinstance(timing, PoissonRateRequestTimings) + + # Rate should be distributed across workers + expected_worker_rate = rate / local_world_size + assert timing.rate == pytest.approx(expected_worker_rate, abs=1e-5) + + # Each worker should have a unique seed + expected_seed = base_seed + local_rank + assert timing.random_seed == expected_seed + + @pytest.mark.smoke + def test_string_representation( + self, valid_instances: tuple[AsyncPoissonStrategy, dict] + ): + """Test __str__ method for AsyncPoissonStrategy.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + result = str(instance) + assert result == f"poisson@{rate:.2f}" + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[AsyncPoissonStrategy, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + assert data["type_"] == "poisson" + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = AsyncPoissonStrategy.model_validate(data) + assert isinstance(reconstructed, AsyncPoissonStrategy) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + # Test polymorphic reconstruction via base registry class + base_reconstructed = SchedulingStrategy.model_validate(data) + assert isinstance(base_reconstructed, AsyncPoissonStrategy) + assert base_reconstructed.type_ == "poisson" + + for key, value in constructor_args.items(): + assert getattr(base_reconstructed, key) == value + + # Test model_validate_json pathway + json_str = instance.model_dump_json() + json_reconstructed = AsyncPoissonStrategy.model_validate_json(json_str) + assert isinstance(json_reconstructed, AsyncPoissonStrategy) + + for key, value in constructor_args.items(): + assert getattr(json_reconstructed, key) == value + + # Test polymorphic model_validate_json via base class + base_json_reconstructed = SchedulingStrategy.model_validate_json(json_str) + assert isinstance(base_json_reconstructed, AsyncPoissonStrategy) + assert base_json_reconstructed.type_ == "poisson" + + for key, value in constructor_args.items(): + assert getattr(base_json_reconstructed, key) == value diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py index 53e8b664..792c9770 100644 --- a/tests/unit/test_logger.py +++ b/tests/unit/test_logger.py @@ -3,7 +3,7 @@ import pytest from guidellm import configure_logger, logger -from guidellm.config import LoggingSettings +from guidellm.settings import LoggingSettings @pytest.fixture(autouse=True) diff --git a/tests/unit/test_config.py b/tests/unit/test_settings.py similarity index 99% rename from tests/unit/test_config.py rename to tests/unit/test_settings.py index f5d9415c..42c8901d 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_settings.py @@ -1,6 +1,6 @@ import pytest -from guidellm.config import ( +from guidellm.settings import ( DatasetSettings, Environment, LoggingSettings, diff --git a/tests/unit/utils/test_statistics.py b/tests/unit/utils/test_statistics.py index 855bfa5f..fa8cccd0 100644 --- a/tests/unit/utils/test_statistics.py +++ b/tests/unit/utils/test_statistics.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from guidellm.utils import ( +from guidellm.objects import ( DistributionSummary, Percentiles, RunningStats,