diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index 2b52bbc5b..0a36646c5 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -25,6 +25,7 @@ import asyncio import codecs +import json from pathlib import Path import click @@ -382,10 +383,43 @@ def benchmark(): default=BenchmarkGenerativeTextArgs.get_default("max_global_error_rate"), help="Maximum global error rate across all benchmarks.", ) -def run(**kwargs): +@click.option( + "--over-saturation", + "--detect-saturation", # alias + default=None, + help=( + "Enable over-saturation detection. " + "Use --over-saturation=True for boolean flag, " + "or a JSON dict with configuration " + '(e.g., \'{"enabled": true, "min_seconds": 30}\'). ' + "Defaults to None (disabled)." + ), + type=click.UNPROCESSED, +) +def run(**kwargs): # noqa: C901 # Only set CLI args that differ from click defaults kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs) + # Handle over_saturation parsing (can be bool flag or JSON dict string) + if "over_saturation" in kwargs and kwargs["over_saturation"] is not None: + over_sat = kwargs["over_saturation"] + if isinstance(over_sat, str): + try: + # Try parsing as JSON dict + kwargs["over_saturation"] = json.loads(over_sat) + except (json.JSONDecodeError, ValueError): + # If not valid JSON, treat as bool flag + kwargs["over_saturation"] = over_sat.lower() in ( + "true", + "1", + "yes", + "on", + ) + elif isinstance(over_sat, bool): + # Already a bool, keep as is + pass + # If it's already a dict, keep as is + # Handle remapping for request params request_type = kwargs.pop("request_type", None) request_formatter_kwargs = kwargs.pop("request_formatter_kwargs", None) diff --git a/src/guidellm/benchmark/entrypoints.py b/src/guidellm/benchmark/entrypoints.py index 5b57b22fe..efe3d304d 100644 --- a/src/guidellm/benchmark/entrypoints.py +++ b/src/guidellm/benchmark/entrypoints.py @@ -323,6 +323,7 @@ async def resolve_profile( max_errors: int | None, max_error_rate: float | None, max_global_error_rate: float | None, + over_saturation: bool | dict[str, Any] | None = None, console: Console | None = None, ) -> Profile: """ @@ -343,6 +344,7 @@ async def resolve_profile( :param max_errors: Maximum number of errors before stopping :param max_error_rate: Maximum error rate threshold before stopping :param max_global_error_rate: Maximum global error rate threshold before stopping + :param over_saturation: Over-saturation detection configuration (bool or dict) :param console: Console instance for progress reporting, or None :return: Configured Profile instance ready for benchmarking :raises ValueError: If constraints are provided with a pre-configured Profile @@ -359,6 +361,7 @@ async def resolve_profile( "max_errors": max_errors, "max_error_rate": max_error_rate, "max_global_error_rate": max_global_error_rate, + "over_saturation": over_saturation, }.items(): if val is not None: constraints[key] = val @@ -500,6 +503,7 @@ async def benchmark_generative_text( max_errors=args.max_errors, max_error_rate=args.max_error_rate, max_global_error_rate=args.max_global_error_rate, + over_saturation=args.over_saturation, console=console, ) output_formats = await resolve_output_formats( diff --git a/src/guidellm/benchmark/progress.py b/src/guidellm/benchmark/progress.py index d7372a40c..bf744dd22 100644 --- a/src/guidellm/benchmark/progress.py +++ b/src/guidellm/benchmark/progress.py @@ -12,7 +12,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from datetime import datetime from typing import Any, Generic, Literal from rich.console import Group @@ -37,7 +36,7 @@ GenerativeBenchmarkAccumulator, ) from guidellm.scheduler import SchedulerState, SchedulingStrategy -from guidellm.utils import Colors, format_value_display +from guidellm.utils import Colors, format_value_display, safe_format_timestamp __all__ = ["BenchmarkerProgress", "GenerativeConsoleBenchmarkerProgress"] @@ -390,7 +389,7 @@ def formatted_start_time(self) -> str: if self.start_time < 0.0: return "--:--:--" - return datetime.fromtimestamp(self.start_time).strftime("%H:%M:%S") + return safe_format_timestamp(self.start_time, format_="%H:%M:%S") @property def formatted_progress_status(self) -> str: diff --git a/src/guidellm/benchmark/schemas/generative/entrypoints.py b/src/guidellm/benchmark/schemas/generative/entrypoints.py index a080daa03..2b633ae3f 100644 --- a/src/guidellm/benchmark/schemas/generative/entrypoints.py +++ b/src/guidellm/benchmark/schemas/generative/entrypoints.py @@ -283,6 +283,14 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any: max_global_error_rate: float | None = Field( default=None, description="Maximum global error rate (0-1) before stopping" ) + over_saturation: bool | dict[str, Any] | None = Field( + default=None, + description=( + "Over-saturation detection configuration. Can be a bool to enable/disable " + "with defaults, or a dict with configuration parameters (enabled, " + "min_seconds, max_window_seconds, moe_threshold, etc.)." + ), + ) @field_validator("data", "data_args", "rate", mode="wrap") @classmethod diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index c03410767..ab4aeef7b 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -19,6 +19,8 @@ MaxErrorsConstraint, MaxGlobalErrorRateConstraint, MaxNumberConstraint, + OverSaturationConstraint, + OverSaturationConstraintInitializer, PydanticConstraintInitializer, SerializableConstraintInitializer, UnserializableConstraintInitializer, @@ -66,6 +68,8 @@ "MaxNumberConstraint", "MultiTurnRequestT", "NonDistributedEnvironment", + "OverSaturationConstraint", + "OverSaturationConstraintInitializer", "PydanticConstraintInitializer", "RequestT", "ResponseT", diff --git a/src/guidellm/scheduler/constraints.py b/src/guidellm/scheduler/constraints.py deleted file mode 100644 index 21e0fe967..000000000 --- a/src/guidellm/scheduler/constraints.py +++ /dev/null @@ -1,1037 +0,0 @@ -""" -Constraint system for scheduler behavior control and request processing limits. - -Provides flexible constraints for managing scheduler behavior with configurable -thresholds based on time, error rates, and request counts. Constraints evaluate -scheduler state and individual requests to determine whether processing should -continue or stop based on predefined limits. The constraint system enables -sophisticated benchmark stopping criteria through composable constraint types. -""" - -from __future__ import annotations - -import time -from abc import ABC, abstractmethod -from typing import Any, Literal, Protocol, cast, runtime_checkable - -from pydantic import Field, field_validator - -from guidellm.scheduler.schemas import ( - SchedulerProgress, - SchedulerState, - SchedulerUpdateAction, -) -from guidellm.schemas import RequestInfo, StandardBaseModel -from guidellm.settings import settings -from guidellm.utils import InfoMixin, RegistryMixin - -__all__ = [ - "Constraint", - "ConstraintInitializer", - "ConstraintsInitializerFactory", - "MaxDurationConstraint", - "MaxErrorRateConstraint", - "MaxErrorsConstraint", - "MaxGlobalErrorRateConstraint", - "MaxNumberConstraint", - "PydanticConstraintInitializer", - "RequestsExhaustedConstraint", - "SerializableConstraintInitializer", - "UnserializableConstraintInitializer", -] - - -@runtime_checkable -class Constraint(Protocol): - """Protocol for constraint evaluation functions that control scheduler behavior.""" - - def __call__( - self, state: SchedulerState, request: RequestInfo - ) -> SchedulerUpdateAction: - """ - Evaluate constraint against scheduler state and request information. - - :param state: Current scheduler state with metrics and timing information - :param request: Individual request information and metadata - :return: Action indicating whether to continue or stop scheduler operations - """ - - -@runtime_checkable -class ConstraintInitializer(Protocol): - """Protocol for constraint initializer factory functions that create constraints.""" - - def create_constraint(self, **kwargs) -> Constraint: - """ - Create a constraint instance from configuration parameters. - - :param kwargs: Configuration parameters for constraint creation - :return: Configured constraint evaluation function - """ - - -@runtime_checkable -class SerializableConstraintInitializer(Protocol): - """Protocol for serializable constraint initializers supporting persistence.""" - - @classmethod - def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: - """ - Validate and process arguments for constraint creation. - - :param args: Positional arguments for constraint configuration - :param kwargs: Keyword arguments for constraint configuration - :return: Validated parameter dictionary for constraint creation - """ - - @classmethod - def model_validate(cls, **kwargs) -> ConstraintInitializer: - """ - Create validated constraint initializer from configuration. - - :param kwargs: Configuration dictionary for initializer creation - :return: Validated constraint initializer instance - """ - - def model_dump(self) -> dict[str, Any]: - """ - Serialize constraint initializer to dictionary format. - - :return: Dictionary representation of constraint initializer - """ - - def create_constraint(self, **kwargs) -> Constraint: - """ - Create constraint instance from this initializer. - - :param kwargs: Additional configuration parameters - :return: Configured constraint evaluation function - """ - - -class ConstraintsInitializerFactory(RegistryMixin[ConstraintInitializer]): - """ - Registry factory for creating and managing constraint initializers. - - Provides centralized access to registered constraint types with support for - creating constraints from configuration dictionaries, simple values, or - pre-configured instances. Handles constraint resolution and type validation - for the scheduler constraint system. - - Example: - :: - from guidellm.scheduler import ConstraintsInitializerFactory - - # Register new constraint type - @ConstraintsInitializerFactory.register("new_constraint") - class NewConstraint: - def create_constraint(self, **kwargs) -> Constraint: - return lambda state, request: SchedulerUpdateAction() - - # Create and use constraint - constraint = ConstraintsInitializerFactory.create_constraint("new_constraint") - """ - - @classmethod - def create(cls, key: str, *args, **kwargs) -> ConstraintInitializer: - """ - Create a constraint initializer for the specified key. - - :param key: Registered constraint initializer key - :param args: Positional arguments for initializer creation - :param kwargs: Keyword arguments for initializer creation - :return: Configured constraint initializer instance - :raises ValueError: If the key is not registered in the factory - """ - if cls.registry is None or key not in cls.registry: - raise ValueError(f"Unknown constraint initializer key: {key}") - - initializer_class = cls.registry[key] - - return ( - initializer_class(*args, **kwargs) # type: ignore[operator] - if not isinstance(initializer_class, type) - or not issubclass(initializer_class, SerializableConstraintInitializer) - else initializer_class( - **initializer_class.validated_kwargs(*args, **kwargs) # type: ignore[misc] - ) - ) - - @classmethod - def serialize(cls, initializer: ConstraintInitializer) -> dict[str, Any]: - """ - Serialize constraint initializer to dictionary format. - - :param initializer: Constraint initializer to serialize - :return: Dictionary representation or unserializable placeholder - """ - if isinstance(initializer, SerializableConstraintInitializer): - return initializer.model_dump() - else: - unserializable = UnserializableConstraintInitializer( - orig_info=InfoMixin.extract_from_obj(initializer) - ) - return unserializable.model_dump() - - @classmethod - def deserialize( - cls, initializer_dict: dict[str, Any] - ) -> SerializableConstraintInitializer | UnserializableConstraintInitializer: - """ - Deserialize constraint initializer from dictionary format. - - :param initializer_dict: Dictionary representation of constraint initializer - :return: Reconstructed constraint initializer instance - :raises ValueError: If constraint type is unknown or cannot be deserialized - """ - if initializer_dict.get("type_") == "unserializable": - return UnserializableConstraintInitializer.model_validate(initializer_dict) - - if ( - cls.registry is not None - and initializer_dict.get("type_") - and initializer_dict["type_"] in cls.registry - ): - initializer_class = cls.registry[initializer_dict["type_"]] - if hasattr(initializer_class, "model_validate"): - return initializer_class.model_validate(initializer_dict) # type: ignore[return-value] - else: - return initializer_class(**initializer_dict) # type: ignore[return-value,operator] - - raise ValueError( - f"Cannot deserialize unknown constraint initializer: " - f"{initializer_dict.get('type_', 'unknown')}" - ) - - @classmethod - def create_constraint(cls, key: str, *args, **kwargs) -> Constraint: - """ - Create a constraint instance for the specified key. - - :param key: Registered constraint initializer key - :param args: Positional arguments for constraint creation - :param kwargs: Keyword arguments for constraint creation - :return: Configured constraint function ready for evaluation - :raises ValueError: If the key is not registered in the factory - """ - return cls.create(key, *args, **kwargs).create_constraint() - - @classmethod - def resolve( - cls, - initializers: dict[ - str, - Any | dict[str, Any] | Constraint | ConstraintInitializer, - ], - ) -> dict[str, Constraint]: - """ - Resolve mixed constraint specifications to callable constraints. - - :param initializers: Dictionary mapping constraint keys to specifications - :return: Dictionary mapping constraint keys to callable functions - :raises ValueError: If any key is not registered in the factory - """ - constraints = {} - - for key, val in initializers.items(): - if isinstance(val, Constraint): - constraints[key] = val - elif isinstance(val, ConstraintInitializer): - constraints[key] = val.create_constraint() - elif isinstance(val, dict): - constraints[key] = cls.create_constraint(key, **val) - else: - constraints[key] = cls.create_constraint(key, val) - - return constraints - - @classmethod - def resolve_constraints( - cls, - constraints: dict[str, Any | dict[str, Any] | Constraint], - ) -> dict[str, Constraint]: - """ - Resolve constraints from mixed constraint specifications. - - :param constraints: Dictionary mapping constraint keys to specifications - :return: Dictionary mapping constraint keys to callable functions - :raises ValueError: If any constraint key is not registered - """ - resolved_constraints = {} - - for key, val in constraints.items(): - if isinstance(val, Constraint): - resolved_constraints[key] = val - elif isinstance(val, dict): - resolved_constraints[key] = cls.create_constraint(key, **val) - else: - resolved_constraints[key] = cls.create_constraint(key, val) - - return resolved_constraints - - -class PydanticConstraintInitializer(StandardBaseModel, ABC, InfoMixin): - """ - Abstract base for Pydantic-based constraint initializers. - - Provides standardized serialization, validation, and metadata handling for - constraint initializers using Pydantic models. Subclasses implement specific - constraint creation logic while inheriting validation and persistence support. - """ - - type_: str = Field(description="Type identifier for the constraint initializer") - - @property - def info(self) -> dict[str, Any]: - """ - Extract serializable information from this constraint initializer. - - :return: Dictionary containing constraint configuration and metadata - """ - return self.model_dump() - - @classmethod - @abstractmethod - def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: - """ - Validate and process arguments for constraint creation. - - Must be implemented by subclasses to handle their specific parameter patterns - and validation requirements. - - :param args: Positional arguments passed to the constraint - :param kwargs: Keyword arguments passed to the constraint - :return: Validated dictionary of parameters for constraint creation - :raises NotImplementedError: Must be implemented by subclasses - """ - ... - - @abstractmethod - def create_constraint(self, **kwargs) -> Constraint: - """ - Create a constraint instance. - - Must be implemented by subclasses to return their specific constraint type - with appropriate configuration and validation. - - :param kwargs: Additional keyword arguments (usually unused) - :return: Configured constraint instance - :raises NotImplementedError: Must be implemented by subclasses - """ - ... - - -class UnserializableConstraintInitializer(PydanticConstraintInitializer): - """ - Placeholder for constraints that cannot be serialized or executed. - - Represents constraint initializers that failed serialization or contain - non-serializable components. Cannot be executed and raises errors when - invoked to prevent runtime failures from invalid constraint state. - """ - - type_: Literal["unserializable"] = "unserializable" # type: ignore[assignment] - orig_info: dict[str, Any] = Field( - default_factory=dict, - description="Original constraint information before serialization failure", - ) - - @classmethod - def validated_kwargs( - cls, orig_info: dict[str, Any] | None = None, **_kwargs - ) -> dict[str, Any]: - """ - Validate arguments for unserializable constraint creation. - - :param orig_info: Original constraint information before serialization failure - :param kwargs: Additional arguments (ignored) - :return: Validated parameters for unserializable constraint creation - """ - return {"orig_info": orig_info or {}} - - def create_constraint(self, **_kwargs) -> Constraint: - """ - Raise error for unserializable constraint creation attempt. - - :param kwargs: Additional keyword arguments (unused) - :raises RuntimeError: Always raised since unserializable constraints - cannot be executed - """ - raise RuntimeError( - "Cannot create constraint from unserializable constraint instance. " - "This constraint cannot be serialized and therefore cannot be executed." - ) - - def __call__( - self, state: SchedulerState, request: RequestInfo - ) -> SchedulerUpdateAction: - """ - Raise error since unserializable constraints cannot be invoked. - - :param state: Current scheduler state (unused) - :param request: Individual request information (unused) - :raises RuntimeError: Always raised for unserializable constraints - """ - _ = (state, request) # Unused parameters - raise RuntimeError( - "Cannot invoke unserializable constraint instance. " - "This constraint was not properly serialized and cannot be executed." - ) - - -@ConstraintsInitializerFactory.register( # type: ignore[arg-type] - ["max_number", "max_num", "max_requests", "max_req"] -) -class MaxNumberConstraint(PydanticConstraintInitializer): - """ - Constraint that limits execution based on maximum request counts. - - Stops request queuing when created requests reach the limit and stops local - request processing when processed requests reach the limit. Provides progress - tracking based on remaining requests and completion fraction. - """ - - type_: Literal["max_number"] = "max_number" # type: ignore[assignment] - max_num: int | float | list[int | float] = Field( - description="Maximum number of requests allowed before triggering constraint", - ) - current_index: int = Field( - default=-1, description="Current index for list-based max_num values" - ) - - @classmethod - def validated_kwargs( - cls, max_num: int | float | list[int | float], **kwargs - ) -> dict[str, Any]: - """ - Validate and process arguments for MaxNumberConstraint creation. - - :param max_num: Maximum number of requests to allow - :param kwargs: Supports max_num, max_number, max_requests, max_req, - and optional type_ - :return: Validated dictionary with max_num and type_ fields - """ - aliases = ["max_number", "max_num", "max_requests", "max_req"] - for alias in aliases: - if max_num is None: - max_num = kwargs.get(alias) - - return {"max_num": max_num, "current_index": kwargs.get("current_index", -1)} - - def create_constraint(self, **_kwargs) -> Constraint: - """ - Return self as the constraint instance. - - :param kwargs: Additional keyword arguments (unused) - :return: Self instance as the constraint - """ - self.current_index += 1 - - return cast("Constraint", self.model_copy()) - - def __call__( - self, state: SchedulerState, request_info: RequestInfo - ) -> SchedulerUpdateAction: - """ - Evaluate constraint against current scheduler state and request count. - - :param state: Current scheduler state with request counts - :param request_info: Individual request information (unused) - :return: Action indicating whether to continue or stop operations - """ - _ = request_info # Unused parameters - current_index = max(0, self.current_index) - max_num = ( - self.max_num - if isinstance(self.max_num, int | float) - else self.max_num[min(current_index, len(self.max_num) - 1)] - ) - - create_exceeded = state.created_requests >= max_num - processed_exceeded = state.processed_requests >= max_num - remaining_requests = min(max(0, max_num - state.processed_requests), max_num) - stop_time = ( - None if remaining_requests > 0 else request_info.completed_at or time.time() - ) - - return SchedulerUpdateAction( - request_queuing="stop" if create_exceeded else "continue", - request_processing="stop_local" if processed_exceeded else "continue", - metadata={ - "max_number": max_num, - "create_exceeded": create_exceeded, - "processed_exceeded": processed_exceeded, - "created_requests": state.created_requests, - "processed_requests": state.processed_requests, - "remaining_requests": remaining_requests, - "stop_time": stop_time, - }, - progress=SchedulerProgress( - remaining_requests=remaining_requests, - total_requests=max_num, - stop_time=stop_time, - ), - ) - - @field_validator("max_num") - @classmethod - def _validate_max_num( - cls, value: int | float | list[int | float] - ) -> int | float | list[int | float]: - if not isinstance(value, list): - value = [value] - for val in value: - if not val: - raise ValueError( - f"max_num must be set and truthful, received {value} ({val} failed)" - ) - if not isinstance(val, int | float) or val <= 0: - raise ValueError( - f"max_num must be a positive num, received {value} ({val} failed)" - ) - - return value[0] if isinstance(value, list) and len(value) == 1 else value - - -@ConstraintsInitializerFactory.register( - ["max_duration", "max_dur", "max_sec", "max_seconds", "max_min", "max_minutes"] -) -class MaxDurationConstraint(PydanticConstraintInitializer): - """ - Constraint that limits execution based on maximum time duration. - - Stops both request queuing and processing when the elapsed time since scheduler - start exceeds the maximum duration. Provides progress tracking based on - remaining time and completion fraction. - """ - - type_: Literal["max_duration"] = "max_duration" # type: ignore[assignment] - max_duration: int | float | list[int | float] = Field( - description="Maximum duration in seconds before triggering constraint" - ) - current_index: int = Field(default=-1, description="Current index in duration list") - - @classmethod - def validated_kwargs( - cls, max_duration: int | float | list[int | float] | None = None, **kwargs - ) -> dict[str, Any]: - """ - Validate and process arguments for MaxDurationConstraint creation. - - :param max_duration: Maximum duration in seconds - :param kwargs: Supports max_duration, max_dur, max_sec, max_seconds, - max_min, max_minutes, and optional type_ - :return: Validated dictionary with max_duration and type_ fields - """ - seconds_aliases = ["max_dur", "max_sec", "max_seconds"] - for alias in seconds_aliases: - if max_duration is None: - max_duration = kwargs.get(alias) - minutes_aliases = ["max_min", "max_minutes"] - for alias in minutes_aliases: - minutes = kwargs.get(alias) - if minutes is not None and max_duration is None: - max_duration = minutes * 60 - - return { - "max_duration": max_duration, - "current_index": kwargs.get("current_index", -1), - } - - def create_constraint(self, **_kwargs) -> Constraint: - """ - Return self as the constraint instance. - - :param kwargs: Additional keyword arguments (unused) - :return: Self instance as the constraint - """ - self.current_index += 1 - - return cast("Constraint", self.model_copy()) - - def __call__( - self, state: SchedulerState, request_info: RequestInfo - ) -> SchedulerUpdateAction: - """ - Evaluate constraint against current scheduler state and elapsed time. - - :param state: Current scheduler state with start time - :param request_info: Individual request information (unused) - :return: Action indicating whether to continue or stop operations - """ - _ = request_info # Unused parameters - current_index = max(0, self.current_index) - max_duration = ( - self.max_duration - if isinstance(self.max_duration, int | float) - else self.max_duration[min(current_index, len(self.max_duration) - 1)] - ) - - current_time = time.time() - elapsed = current_time - state.start_time - duration_exceeded = elapsed >= max_duration - remaining_duration = min(max(0.0, max_duration - elapsed), max_duration) - stop_time = None if not duration_exceeded else state.start_time + max_duration - - return SchedulerUpdateAction( - request_queuing="stop" if duration_exceeded else "continue", - request_processing="stop_local" if duration_exceeded else "continue", - metadata={ - "max_duration": max_duration, - "elapsed_time": elapsed, - "duration_exceeded": duration_exceeded, - "start_time": state.start_time, - "current_time": current_time, - "stop_time": stop_time, - }, - progress=SchedulerProgress( - remaining_duration=remaining_duration, - total_duration=max_duration, - stop_time=stop_time, - ), - ) - - @field_validator("max_duration") - @classmethod - def _validate_max_duration( - cls, value: int | float | list[int | float] - ) -> int | float | list[int | float]: - if not isinstance(value, list): - value = [value] - for val in value: - if not val: - raise ValueError( - "max_duration must be set and truthful, " - f"received {value} ({val} failed)" - ) - if not isinstance(val, int | float) or val <= 0: - raise ValueError( - "max_duration must be a positive num," - f"received {value} ({val} failed)" - ) - - return value[0] if isinstance(value, list) and len(value) == 1 else value - - -@ConstraintsInitializerFactory.register( - ["max_errors", "max_err", "max_error", "max_errs"] -) -class MaxErrorsConstraint(PydanticConstraintInitializer): - """ - Constraint that limits execution based on absolute error count. - - Stops both request queuing and all request processing when the total number - of errored requests reaches the maximum threshold. Uses global error tracking - across all requests for immediate constraint evaluation. - """ - - type_: Literal["max_errors"] = "max_errors" # type: ignore[assignment] - max_errors: int | float | list[int | float] = Field( - description="Maximum number of errors allowed before triggering constraint", - ) - current_index: int = Field(default=-1, description="Current index in error list") - - @classmethod - def validated_kwargs( - cls, max_errors: int | float | list[int | float] | None = None, **kwargs - ) -> dict[str, Any]: - """ - Validate and process arguments for MaxErrorsConstraint creation. - - :param max_errors: Maximum number of errors to allow - :param kwargs: Supports max_errors, max_err, max_error, max_errs, - and optional type_ - :return: Validated dictionary with max_errors and type_ fields - """ - aliases = ["max_errors", "max_err", "max_error", "max_errs"] - for alias in aliases: - if max_errors is None: - max_errors = kwargs.get(alias) - - return { - "max_errors": max_errors, - "current_index": kwargs.get("current_index", -1), - } - - def create_constraint(self, **_kwargs) -> Constraint: - """ - Return self as the constraint instance. - - :param kwargs: Additional keyword arguments (unused) - :return: Self instance as the constraint - """ - self.current_index += 1 - - return cast("Constraint", self.model_copy()) - - def __call__( - self, state: SchedulerState, request_info: RequestInfo - ) -> SchedulerUpdateAction: - """ - Evaluate constraint against current error count. - - :param state: Current scheduler state with error counts - :param request_info: Individual request information (unused) - :return: Action indicating whether to continue or stop operations - """ - _ = request_info # Unused parameters - current_index = max(0, self.current_index) - max_errors = ( - self.max_errors - if isinstance(self.max_errors, int | float) - else self.max_errors[min(current_index, len(self.max_errors) - 1)] - ) - errors_exceeded = state.errored_requests >= max_errors - stop_time = ( - None if not errors_exceeded else request_info.completed_at or time.time() - ) - - return SchedulerUpdateAction( - request_queuing="stop" if errors_exceeded else "continue", - request_processing="stop_all" if errors_exceeded else "continue", - metadata={ - "max_errors": max_errors, - "errors_exceeded": errors_exceeded, - "current_errors": state.errored_requests, - "stop_time": stop_time, - }, - progress=SchedulerProgress(stop_time=stop_time), - ) - - @field_validator("max_errors") - @classmethod - def _validate_max_errors( - cls, value: int | float | list[int | float] - ) -> int | float | list[int | float]: - if not isinstance(value, list): - value = [value] - for val in value: - if not val: - raise ValueError( - "max_errors must be set and truthful, " - f"received {value} ({val} failed)" - ) - if not isinstance(val, int | float) or val <= 0: - raise ValueError( - f"max_errors must be a positive num,received {value} ({val} failed)" - ) - - return value[0] if isinstance(value, list) and len(value) == 1 else value - - -@ConstraintsInitializerFactory.register( - ["max_error_rate", "max_err_rate", "max_errors_rate"] -) -class MaxErrorRateConstraint(PydanticConstraintInitializer): - """ - Constraint that limits execution based on sliding window error rate. - - Tracks error status of recent requests in a sliding window and stops all - processing when the error rate exceeds the threshold. Only applies the - constraint after processing enough requests to fill the minimum window size - for statistical significance. - """ - - type_: Literal["max_error_rate"] = "max_error_rate" # type: ignore[assignment] - max_error_rate: int | float | list[int | float] = Field( - description="Maximum error rate allowed (0.0, 1.0)" - ) - window_size: int | float = Field( - default=30, - gt=0, - description="Size of sliding window for calculating error rate", - ) - error_window: list[bool] = Field( - default_factory=list, - description="Sliding window tracking error status of recent requests", - ) - current_index: int = Field( - default=-1, description="Current index in the error window" - ) - - @classmethod - def validated_kwargs( - cls, max_error_rate: int | float | list[int | float], **kwargs - ) -> dict[str, Any]: - """ - Validate and process arguments for MaxErrorRateConstraint creation. - - :param max_error_rate: Maximum error rate to allow - :param kwargs: Supports max_error_rate, max_err_rate, max_errors_rate, - optional window_size, and optional type_ - :return: Validated dictionary with max_error_rate, window_size, - and type_ fields - """ - aliases = ["max_error_rate", "max_err_rate", "max_errors_rate"] - for alias in aliases: - if max_error_rate is None: - max_error_rate = kwargs.get(alias) - - return { - "max_error_rate": max_error_rate, - "window_size": kwargs.get( - "window_size", settings.constraint_error_window_size - ), - "error_window": kwargs.get("error_window", []), - "current_index": kwargs.get("current_index", -1), - } - - def create_constraint(self, **_kwargs) -> Constraint: - """ - Create a new instance of MaxErrorRateConstraint (due to stateful window). - - :param kwargs: Additional keyword arguments (unused) - :return: New instance of the constraint - """ - self.current_index += 1 - - return cast("Constraint", self.model_copy()) - - def __call__( - self, state: SchedulerState, request_info: RequestInfo - ) -> SchedulerUpdateAction: - """ - Evaluate constraint against sliding window error rate. - - :param state: Current scheduler state with request counts - :param request_info: Individual request with completion status - :return: Action indicating whether to continue or stop operations - """ - current_index = max(0, self.current_index) - max_error_rate = ( - self.max_error_rate - if isinstance(self.max_error_rate, int | float) - else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] - ) - - if request_info.status in ["completed", "errored", "cancelled"]: - self.error_window.append(request_info.status == "errored") - if len(self.error_window) > self.window_size: - self.error_window.pop(0) - - error_count = sum(self.error_window) - window_requests = len(self.error_window) - error_rate = ( - error_count / float(window_requests) if window_requests > 0 else 0.0 - ) - exceeded_min_processed = state.processed_requests >= self.window_size - exceeded_error_rate = error_rate >= max_error_rate - exceeded = exceeded_min_processed and exceeded_error_rate - stop_time = None if not exceeded else request_info.completed_at or time.time() - - return SchedulerUpdateAction( - request_queuing="stop" if exceeded else "continue", - request_processing="stop_all" if exceeded else "continue", - metadata={ - "max_error_rate": max_error_rate, - "window_size": self.window_size, - "error_count": error_count, - "processed_count": state.processed_requests, - "current_window_size": len(self.error_window), - "current_error_rate": error_rate, - "exceeded_min_processed": exceeded_min_processed, - "exceeded_error_rate": exceeded_error_rate, - "exceeded": exceeded, - "stop_time": stop_time, - }, - ) - - @field_validator("max_error_rate") - @classmethod - def _validate_max_error_rate( - cls, value: int | float | list[int | float] - ) -> int | float | list[int | float]: - if not isinstance(value, list): - value = [value] - for val in value: - if not val: - raise ValueError( - "max_error_rate must be set and truthful, " - f"received {value} ({val} failed)" - ) - if not isinstance(val, int | float) or val <= 0 or val >= 1: - raise ValueError( - "max_error_rate must be a number between 0 and 1," - f"received {value} ({val} failed)" - ) - - return value[0] if isinstance(value, list) and len(value) == 1 else value - - -@ConstraintsInitializerFactory.register( - ["max_global_error_rate", "max_global_err_rate", "max_global_errors_rate"] -) -class MaxGlobalErrorRateConstraint(PydanticConstraintInitializer): - """ - Constraint that limits execution based on global error rate. - - Calculates error rate across all processed requests and stops all processing - when the rate exceeds the threshold. Only applies the constraint after - processing the minimum number of requests to ensure statistical significance - for global error rate calculations. - """ - - type_: Literal["max_global_error_rate"] = "max_global_error_rate" # type: ignore[assignment] - max_error_rate: int | float = Field( - description="Maximum error rate allowed (0.0 to 1.0)" - ) - min_processed: int | float | None = Field( - default=30, - gt=0, - description="Minimum requests processed before applying error rate constraint", - ) - current_index: int = Field( - default=-1, description="Current index for list-based max_error_rate values" - ) - - @classmethod - def validated_kwargs( - cls, max_error_rate: int | float | list[int | float], **kwargs - ) -> dict[str, Any]: - """ - Validate and process arguments for MaxGlobalErrorRateConstraint creation. - - :param max_error_rate: Maximum error rate to allow - :param kwargs: Supports max_global_error_rate, max_global_err_rate, - max_global_errors_rate, optional min_processed, and optional type_ - :return: Validated dictionary with max_error_rate, min_processed, - and type_ fields - """ - for alias in [ - "max_global_error_rate", - "max_global_err_rate", - "max_global_errors_rate", - ]: - if max_error_rate is None: - max_error_rate = kwargs.get(alias) - - return { - "max_error_rate": max_error_rate, - "min_processed": kwargs.get( - "min_processed", settings.constraint_error_min_processed - ), - "current_index": kwargs.get("current_index", -1), - } - - def create_constraint(self, **_kwargs) -> Constraint: - """ - Return self as the constraint instance. - - :param kwargs: Additional keyword arguments (unused) - :return: Self instance as the constraint - """ - self.current_index += 1 - - return cast("Constraint", self.model_copy()) - - def __call__( - self, state: SchedulerState, request_info: RequestInfo - ) -> SchedulerUpdateAction: - """ - Evaluate constraint against global error rate. - - :param state: Current scheduler state with global request and error counts - :param request_info: Individual request information (unused) - :return: Action indicating whether to continue or stop operations - """ - _ = request_info # Unused parameters - current_index = max(0, self.current_index) - max_error_rate = ( - self.max_error_rate - if isinstance(self.max_error_rate, int | float) - else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] - ) - - exceeded_min_processed = ( - self.min_processed is None or state.processed_requests >= self.min_processed - ) - error_rate = ( - state.errored_requests / float(state.processed_requests) - if state.processed_requests > 0 - else 0.0 - ) - exceeded_error_rate = error_rate >= max_error_rate - exceeded = exceeded_min_processed and exceeded_error_rate - stop_time = None if not exceeded else request_info.completed_at or time.time() - - return SchedulerUpdateAction( - request_queuing="stop" if exceeded else "continue", - request_processing="stop_all" if exceeded else "continue", - metadata={ - "max_error_rate": max_error_rate, - "min_processed": self.min_processed, - "processed_requests": state.processed_requests, - "errored_requests": state.errored_requests, - "error_rate": error_rate, - "exceeded_min_processed": exceeded_min_processed, - "exceeded_error_rate": exceeded_error_rate, - "exceeded": exceeded, - "stop_time": stop_time, - }, - progress=SchedulerProgress(stop_time=stop_time), - ) - - @field_validator("max_error_rate") - @classmethod - def _validate_max_error_rate( - cls, value: int | float | list[int | float] - ) -> int | float | list[int | float]: - if not isinstance(value, list): - value = [value] - for val in value: - if not val: - raise ValueError( - "max_error_rate must be set and truthful, " - f"received {value} ({val} failed)" - ) - if not isinstance(val, int | float) or val <= 0 or val >= 1: - raise ValueError( - "max_error_rate must be a number between 0 and 1," - f"received {value} ({val} failed)" - ) - - return value[0] if isinstance(value, list) and len(value) == 1 else value - - -class RequestsExhaustedConstraint(StandardBaseModel, InfoMixin): - type_: Literal["requests_exhausted"] = "requests_exhausted" # type: ignore[assignment] - num_requests: int - - @property - def info(self) -> dict[str, Any]: - """ - Extract serializable information from this constraint initializer. - - :return: Dictionary containing constraint configuration and metadata - """ - return self.model_dump() - - def __call__( - self, state: SchedulerState, request: RequestInfo - ) -> SchedulerUpdateAction: - _ = request # Unused parameter - create_exceeded = state.created_requests >= self.num_requests - processed_exceeded = state.processed_requests >= self.num_requests - remaining_requests = max(0, self.num_requests - state.processed_requests) - stop_time = ( - None if remaining_requests > 0 else request.completed_at or time.time() - ) - - return SchedulerUpdateAction( - request_queuing="stop" if create_exceeded else "continue", - request_processing="stop_local" if processed_exceeded else "continue", - metadata={ - "num_requests": self.num_requests, - "create_exceeded": create_exceeded, - "processed_exceeded": processed_exceeded, - "created_requests": state.created_requests, - "processed_requests": state.processed_requests, - "remaining_requests": remaining_requests, - "stop_time": stop_time, - }, - progress=SchedulerProgress( - remaining_requests=remaining_requests, - total_requests=self.num_requests, - stop_time=stop_time, - ), - ) diff --git a/src/guidellm/scheduler/constraints/__init__.py b/src/guidellm/scheduler/constraints/__init__.py new file mode 100644 index 000000000..1f5343a93 --- /dev/null +++ b/src/guidellm/scheduler/constraints/__init__.py @@ -0,0 +1,49 @@ +""" +Constraint system for scheduler behavior control and request processing limits. + +Provides flexible constraints for managing scheduler behavior with configurable +thresholds based on time, error rates, and request counts. Constraints evaluate +scheduler state and individual requests to determine whether processing should +continue or stop based on predefined limits. The constraint system enables +sophisticated benchmark stopping criteria through composable constraint types. +""" + +from .constraint import ( + Constraint, + ConstraintInitializer, + PydanticConstraintInitializer, + SerializableConstraintInitializer, + UnserializableConstraintInitializer, +) +from .error import ( + MaxErrorRateConstraint, + MaxErrorsConstraint, + MaxGlobalErrorRateConstraint, +) +from .factory import ConstraintsInitializerFactory +from .request import ( + MaxDurationConstraint, + MaxNumberConstraint, + RequestsExhaustedConstraint, +) +from .saturation import ( + OverSaturationConstraint, + OverSaturationConstraintInitializer, +) + +__all__ = [ + "Constraint", + "ConstraintInitializer", + "ConstraintsInitializerFactory", + "MaxDurationConstraint", + "MaxErrorRateConstraint", + "MaxErrorsConstraint", + "MaxGlobalErrorRateConstraint", + "MaxNumberConstraint", + "OverSaturationConstraint", + "OverSaturationConstraintInitializer", + "PydanticConstraintInitializer", + "RequestsExhaustedConstraint", + "SerializableConstraintInitializer", + "UnserializableConstraintInitializer", +] diff --git a/src/guidellm/scheduler/constraints/constraint.py b/src/guidellm/scheduler/constraints/constraint.py new file mode 100644 index 000000000..dd901acfa --- /dev/null +++ b/src/guidellm/scheduler/constraints/constraint.py @@ -0,0 +1,325 @@ +""" +Core constraint system protocols and base classes. + +Defines the fundamental protocols and base classes that form the foundation of the +constraint system. Constraints control scheduler behavior by evaluating scheduler +state and individual requests to determine whether processing should continue or +stop based on predefined limits. The constraint system enables sophisticated +benchmark stopping criteria through composable constraint types with support for +serialization, validation, and dynamic instantiation. + +The module provides: +- Protocols defining the constraint interface contract + (Constraint, ConstraintInitializer) +- Base classes for Pydantic-based constraint initializers with serialization support +- Placeholder classes for handling unserializable constraint states + +Example: +:: + from guidellm.scheduler.constraints import ( + Constraint, + PydanticConstraintInitializer, + ) + + class MyConstraint(PydanticConstraintInitializer): + type_: str = "my_constraint" + + def create_constraint(self) -> Constraint: + def evaluate(state, request): + return SchedulerUpdateAction(request_queuing="continue") + return evaluate +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Literal, Protocol, runtime_checkable + +from pydantic import Field + +from guidellm.scheduler.schemas import SchedulerState, SchedulerUpdateAction +from guidellm.schemas import RequestInfo, StandardBaseModel +from guidellm.utils import InfoMixin + +__all__ = [ + "Constraint", + "ConstraintInitializer", + "PydanticConstraintInitializer", + "SerializableConstraintInitializer", + "UnserializableConstraintInitializer", +] + + +@runtime_checkable +class Constraint(Protocol): + """ + Protocol for constraint evaluation functions that control scheduler behavior. + + Defines the interface that all constraint implementations must follow. Constraints + are callable objects that evaluate scheduler state and request information to + determine whether processing should continue or stop. The protocol enables type + checking and runtime validation of constraint implementations while allowing + flexible implementation approaches (functions, classes, closures). + + Example: + :: + def my_constraint( + state: SchedulerState, request: RequestInfo + ) -> SchedulerUpdateAction: + if state.processing_requests > 100: + return SchedulerUpdateAction(request_queuing="stop") + return SchedulerUpdateAction(request_queuing="continue") + """ + + def __call__( + self, state: SchedulerState, request: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against scheduler state and request information. + + :param state: Current scheduler state with metrics and timing information + :param request: Individual request information and metadata + :return: Action indicating whether to continue or stop scheduler operations + """ + + +@runtime_checkable +class ConstraintInitializer(Protocol): + """ + Protocol for constraint initializer factory functions that create constraints. + + Defines the interface for factory objects that create constraint instances from + configuration parameters. Constraint initializers enable dynamic constraint + creation and configuration, supporting both simple boolean flags and complex + parameter dictionaries. The protocol allows type checking while maintaining + flexibility for different initialization patterns. + + Example: + :: + class MaxRequestsInitializer: + def __init__(self, max_requests: int): + self.max_requests = max_requests + + def create_constraint(self) -> Constraint: + def evaluate(state, request): + if state.total_requests >= self.max_requests: + return SchedulerUpdateAction(request_queuing="stop") + return SchedulerUpdateAction(request_queuing="continue") + return evaluate + """ + + def create_constraint(self, **kwargs) -> Constraint: + """ + Create a constraint instance from configuration parameters. + + :param kwargs: Configuration parameters for constraint creation + :return: Configured constraint evaluation function + """ + + +@runtime_checkable +class SerializableConstraintInitializer(Protocol): + """ + Protocol for serializable constraint initializers supporting persistence. + + Extends ConstraintInitializer with serialization capabilities, enabling constraint + configurations to be saved, loaded, and transmitted. Serializable initializers + support validation, model-based configuration, and dictionary-based serialization + for integration with configuration systems and persistence layers. + + Example: + :: + class SerializableInitializer: + @classmethod + def validated_kwargs(cls, **kwargs) -> dict[str, Any]: + return {"max_requests": kwargs.get("max_requests", 100)} + + @classmethod + def model_validate(cls, data: dict) -> ConstraintInitializer: + return cls(**cls.validated_kwargs(**data)) + + def model_dump(self) -> dict[str, Any]: + return {"type_": "max_requests", "max_requests": self.max_requests} + + def create_constraint(self) -> Constraint: + # ... create constraint + """ + + @classmethod + def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: + """ + Validate and process arguments for constraint creation. + + :param args: Positional arguments for constraint configuration + :param kwargs: Keyword arguments for constraint configuration + :return: Validated parameter dictionary for constraint creation + """ + + @classmethod + def model_validate(cls, **kwargs) -> ConstraintInitializer: + """ + Create validated constraint initializer from configuration. + + :param kwargs: Configuration dictionary for initializer creation + :return: Validated constraint initializer instance + """ + + def model_dump(self) -> dict[str, Any]: + """ + Serialize constraint initializer to dictionary format. + + :return: Dictionary representation of constraint initializer + """ + + def create_constraint(self, **kwargs) -> Constraint: + """ + Create constraint instance from this initializer. + + :param kwargs: Additional configuration parameters + :return: Configured constraint evaluation function + """ + + +class PydanticConstraintInitializer(StandardBaseModel, ABC, InfoMixin): + """ + Abstract base for Pydantic-based constraint initializers. + + Provides standardized serialization, validation, and metadata handling for + constraint initializers using Pydantic models. Subclasses implement specific + constraint creation logic while inheriting validation and persistence support. + Integrates with the constraint factory system for dynamic instantiation and + configuration management. + + Example: + :: + @ConstraintsInitializerFactory.register("max_duration") + class MaxDurationConstraintInitializer(PydanticConstraintInitializer): + type_: str = "max_duration" + max_seconds: float = Field(description="Maximum duration in seconds") + + def create_constraint(self) -> Constraint: + def evaluate(state, request): + if time.time() - state.start_time > self.max_seconds: + return SchedulerUpdateAction(request_queuing="stop") + return SchedulerUpdateAction(request_queuing="continue") + return evaluate + + :cvar type_: Type identifier for the constraint initializer + """ + + type_: str = Field(description="Type identifier for the constraint initializer") + + @property + def info(self) -> dict[str, Any]: + """ + Extract serializable information from this constraint initializer. + + :return: Dictionary containing constraint configuration and metadata + """ + return self.model_dump() + + @classmethod + @abstractmethod + def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: + """ + Validate and process arguments for constraint creation. + + Must be implemented by subclasses to handle their specific parameter patterns + and validation requirements. This method processes raw input (booleans, dicts, + etc.) and converts them into validated parameter dictionaries suitable for + constraint initialization. + + :param args: Positional arguments passed to the constraint + :param kwargs: Keyword arguments passed to the constraint + :return: Validated dictionary of parameters for constraint creation + :raises NotImplementedError: Must be implemented by subclasses + """ + ... + + @abstractmethod + def create_constraint(self, **kwargs) -> Constraint: + """ + Create a constraint instance. + + Must be implemented by subclasses to return their specific constraint type + with appropriate configuration and validation. The returned constraint should + be ready for evaluation against scheduler state and requests. + + :param kwargs: Additional keyword arguments (usually unused) + :return: Configured constraint instance + :raises NotImplementedError: Must be implemented by subclasses + """ + ... + + +class UnserializableConstraintInitializer(PydanticConstraintInitializer): + """ + Placeholder for constraints that cannot be serialized or executed. + + Represents constraint initializers that failed serialization or contain + non-serializable components. Cannot be executed and raises errors when + invoked to prevent runtime failures from invalid constraint state. Used + by the factory system to preserve constraint information even when full + serialization is not possible. + + Example: + :: + # Created automatically by factory when serialization fails + unserializable = UnserializableConstraintInitializer( + orig_info={"type_": "custom", "data": non_serializable_object} + ) + + # Attempting to use it raises RuntimeError + constraint = unserializable.create_constraint() # Raises RuntimeError + + :cvar type_: Always "unserializable" to identify placeholder constraints + :cvar orig_info: Original constraint information before serialization failure + """ + + type_: Literal["unserializable"] = "unserializable" # type: ignore[assignment] + orig_info: dict[str, Any] = Field( + default_factory=dict, + description="Original constraint information before serialization failure", + ) + + @classmethod + def validated_kwargs( + cls, orig_info: dict[str, Any] | None = None, **_kwargs + ) -> dict[str, Any]: + """ + Validate arguments for unserializable constraint creation. + + :param orig_info: Original constraint information before serialization failure + :param kwargs: Additional arguments (ignored) + :return: Validated parameters for unserializable constraint creation + """ + return {"orig_info": orig_info or {}} + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Raise error for unserializable constraint creation attempt. + + :param kwargs: Additional keyword arguments (unused) + :raises RuntimeError: Always raised since unserializable constraints + cannot be executed + """ + raise RuntimeError( + "Cannot create constraint from unserializable constraint instance. " + "This constraint cannot be serialized and therefore cannot be executed." + ) + + def __call__( + self, state: SchedulerState, request: RequestInfo + ) -> SchedulerUpdateAction: + """ + Raise error since unserializable constraints cannot be invoked. + + :param state: Current scheduler state (unused) + :param request: Individual request information (unused) + :raises RuntimeError: Always raised for unserializable constraints + """ + _ = (state, request) # Unused parameters + raise RuntimeError( + "Cannot invoke unserializable constraint instance. " + "This constraint was not properly serialized and cannot be executed." + ) diff --git a/src/guidellm/scheduler/constraints/error.py b/src/guidellm/scheduler/constraints/error.py new file mode 100644 index 000000000..d9ed7ca95 --- /dev/null +++ b/src/guidellm/scheduler/constraints/error.py @@ -0,0 +1,411 @@ +""" +Error-based constraint implementations. + +Provides constraint types for limiting benchmark execution based on error rates +and error counts. These constraints monitor request error status to determine +when to stop benchmark execution due to excessive errors. +""" + +from __future__ import annotations + +import time +from typing import Any, Literal, cast + +from pydantic import Field, field_validator + +from guidellm.scheduler.schemas import ( + SchedulerProgress, + SchedulerState, + SchedulerUpdateAction, +) +from guidellm.schemas import RequestInfo +from guidellm.settings import settings + +from .constraint import Constraint, PydanticConstraintInitializer +from .factory import ConstraintsInitializerFactory + +__all__ = [ + "MaxErrorRateConstraint", + "MaxErrorsConstraint", + "MaxGlobalErrorRateConstraint", +] + + +@ConstraintsInitializerFactory.register( + ["max_errors", "max_err", "max_error", "max_errs"] +) +class MaxErrorsConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on absolute error count. + + Stops both request queuing and all request processing when the total number + of errored requests reaches the maximum threshold. Uses global error tracking + across all requests for immediate constraint evaluation. + """ + + type_: Literal["max_errors"] = "max_errors" # type: ignore[assignment] + max_errors: int | float | list[int | float] = Field( + description="Maximum number of errors allowed before triggering constraint", + ) + current_index: int = Field(default=-1, description="Current index in error list") + + @classmethod + def validated_kwargs( + cls, max_errors: int | float | list[int | float] | None = None, **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxErrorsConstraint creation. + + :param max_errors: Maximum number of errors to allow + :param kwargs: Supports max_errors, max_err, max_error, max_errs, + and optional type_ + :return: Validated dictionary with max_errors and type_ fields + """ + aliases = ["max_errors", "max_err", "max_error", "max_errs"] + for alias in aliases: + if max_errors is None: + max_errors = kwargs.get(alias) + + return { + "max_errors": max_errors, + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return cast("Constraint", self.model_copy()) + + def __call__( + self, state: SchedulerState, request_info: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current error count. + + :param state: Current scheduler state with error counts + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + _ = request_info # Unused parameters + current_index = max(0, self.current_index) + max_errors = ( + self.max_errors + if isinstance(self.max_errors, int | float) + else self.max_errors[min(current_index, len(self.max_errors) - 1)] + ) + errors_exceeded = state.errored_requests >= max_errors + stop_time = ( + None if not errors_exceeded else request_info.completed_at or time.time() + ) + + return SchedulerUpdateAction( + request_queuing="stop" if errors_exceeded else "continue", + request_processing="stop_all" if errors_exceeded else "continue", + metadata={ + "max_errors": max_errors, + "errors_exceeded": errors_exceeded, + "current_errors": state.errored_requests, + "stop_time": stop_time, + }, + progress=SchedulerProgress(stop_time=stop_time), + ) + + @field_validator("max_errors") + @classmethod + def _validate_max_errors( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_errors must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, int | float) or val <= 0: + raise ValueError( + f"max_errors must be a positive num,received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( + ["max_error_rate", "max_err_rate", "max_errors_rate"] +) +class MaxErrorRateConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on sliding window error rate. + + Tracks error status of recent requests in a sliding window and stops all + processing when the error rate exceeds the threshold. Only applies the + constraint after processing enough requests to fill the minimum window size + for statistical significance. + """ + + type_: Literal["max_error_rate"] = "max_error_rate" # type: ignore[assignment] + max_error_rate: int | float | list[int | float] = Field( + description="Maximum error rate allowed (0.0, 1.0)" + ) + window_size: int | float = Field( + default=30, + gt=0, + description="Size of sliding window for calculating error rate", + ) + error_window: list[bool] = Field( + default_factory=list, + description="Sliding window tracking error status of recent requests", + ) + current_index: int = Field( + default=-1, description="Current index in the error window" + ) + + @classmethod + def validated_kwargs( + cls, max_error_rate: int | float | list[int | float], **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxErrorRateConstraint creation. + + :param max_error_rate: Maximum error rate to allow + :param kwargs: Supports max_error_rate, max_err_rate, max_errors_rate, + optional window_size, and optional type_ + :return: Validated dictionary with max_error_rate, window_size, + and type_ fields + """ + aliases = ["max_error_rate", "max_err_rate", "max_errors_rate"] + for alias in aliases: + if max_error_rate is None: + max_error_rate = kwargs.get(alias) + + return { + "max_error_rate": max_error_rate, + "window_size": kwargs.get( + "window_size", settings.constraint_error_window_size + ), + "error_window": kwargs.get("error_window", []), + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Create a new instance of MaxErrorRateConstraint (due to stateful window). + + :param kwargs: Additional keyword arguments (unused) + :return: New instance of the constraint + """ + self.current_index += 1 + + return cast("Constraint", self.model_copy()) + + def __call__( + self, state: SchedulerState, request_info: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against sliding window error rate. + + :param state: Current scheduler state with request counts + :param request_info: Individual request with completion status + :return: Action indicating whether to continue or stop operations + """ + current_index = max(0, self.current_index) + max_error_rate = ( + self.max_error_rate + if isinstance(self.max_error_rate, int | float) + else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] + ) + + if request_info.status in ["completed", "errored", "cancelled"]: + self.error_window.append(request_info.status == "errored") + if len(self.error_window) > self.window_size: + self.error_window.pop(0) + + error_count = sum(self.error_window) + window_requests = len(self.error_window) + error_rate = ( + error_count / float(window_requests) if window_requests > 0 else 0.0 + ) + exceeded_min_processed = state.processed_requests >= self.window_size + exceeded_error_rate = error_rate >= max_error_rate + exceeded = exceeded_min_processed and exceeded_error_rate + stop_time = None if not exceeded else request_info.completed_at or time.time() + + return SchedulerUpdateAction( + request_queuing="stop" if exceeded else "continue", + request_processing="stop_all" if exceeded else "continue", + metadata={ + "max_error_rate": max_error_rate, + "window_size": self.window_size, + "error_count": error_count, + "processed_count": state.processed_requests, + "current_window_size": len(self.error_window), + "current_error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + "exceeded": exceeded, + "stop_time": stop_time, + }, + ) + + @field_validator("max_error_rate") + @classmethod + def _validate_max_error_rate( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_error_rate must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, int | float) or val <= 0 or val >= 1: + raise ValueError( + "max_error_rate must be a number between 0 and 1," + f"received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( + ["max_global_error_rate", "max_global_err_rate", "max_global_errors_rate"] +) +class MaxGlobalErrorRateConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on global error rate. + + Calculates error rate across all processed requests and stops all processing + when the rate exceeds the threshold. Only applies the constraint after + processing the minimum number of requests to ensure statistical significance + for global error rate calculations. + """ + + type_: Literal["max_global_error_rate"] = "max_global_error_rate" # type: ignore[assignment] + max_error_rate: int | float = Field( + description="Maximum error rate allowed (0.0 to 1.0)" + ) + min_processed: int | float | None = Field( + default=30, + gt=0, + description="Minimum requests processed before applying error rate constraint", + ) + current_index: int = Field( + default=-1, description="Current index for list-based max_error_rate values" + ) + + @classmethod + def validated_kwargs( + cls, max_error_rate: int | float | list[int | float], **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxGlobalErrorRateConstraint creation. + + :param max_error_rate: Maximum error rate to allow + :param kwargs: Supports max_global_error_rate, max_global_err_rate, + max_global_errors_rate, optional min_processed, and optional type_ + :return: Validated dictionary with max_error_rate, min_processed, + and type_ fields + """ + for alias in [ + "max_global_error_rate", + "max_global_err_rate", + "max_global_errors_rate", + ]: + if max_error_rate is None: + max_error_rate = kwargs.get(alias) + + return { + "max_error_rate": max_error_rate, + "min_processed": kwargs.get( + "min_processed", settings.constraint_error_min_processed + ), + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return cast("Constraint", self.model_copy()) + + def __call__( + self, state: SchedulerState, request_info: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against global error rate. + + :param state: Current scheduler state with global request and error counts + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + _ = request_info # Unused parameters + current_index = max(0, self.current_index) + max_error_rate = ( + self.max_error_rate + if isinstance(self.max_error_rate, int | float) + else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] + ) + + exceeded_min_processed = ( + self.min_processed is None or state.processed_requests >= self.min_processed + ) + error_rate = ( + state.errored_requests / float(state.processed_requests) + if state.processed_requests > 0 + else 0.0 + ) + exceeded_error_rate = error_rate >= max_error_rate + exceeded = exceeded_min_processed and exceeded_error_rate + stop_time = None if not exceeded else request_info.completed_at or time.time() + + return SchedulerUpdateAction( + request_queuing="stop" if exceeded else "continue", + request_processing="stop_all" if exceeded else "continue", + metadata={ + "max_error_rate": max_error_rate, + "min_processed": self.min_processed, + "processed_requests": state.processed_requests, + "errored_requests": state.errored_requests, + "error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + "exceeded": exceeded, + "stop_time": stop_time, + }, + progress=SchedulerProgress(stop_time=stop_time), + ) + + @field_validator("max_error_rate") + @classmethod + def _validate_max_error_rate( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_error_rate must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, int | float) or val <= 0 or val >= 1: + raise ValueError( + "max_error_rate must be a number between 0 and 1," + f"received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value diff --git a/src/guidellm/scheduler/constraints/factory.py b/src/guidellm/scheduler/constraints/factory.py new file mode 100644 index 000000000..0b77b145b --- /dev/null +++ b/src/guidellm/scheduler/constraints/factory.py @@ -0,0 +1,183 @@ +""" +Factory for creating and managing constraint initializers. + +Provides centralized access to registered constraint types with support for +creating constraints from configuration dictionaries, simple values, or +pre-configured instances. +""" + +from __future__ import annotations + +from typing import Any + +from guidellm.utils import InfoMixin, RegistryMixin + +from .constraint import ( + Constraint, + ConstraintInitializer, + SerializableConstraintInitializer, + UnserializableConstraintInitializer, +) + +__all__ = ["ConstraintsInitializerFactory"] + + +class ConstraintsInitializerFactory(RegistryMixin[ConstraintInitializer]): + """ + Registry factory for creating and managing constraint initializers. + + Provides centralized access to registered constraint types with support for + creating constraints from configuration dictionaries, simple values, or + pre-configured instances. Handles constraint resolution and type validation + for the scheduler constraint system. + + Example: + :: + from guidellm.scheduler import ConstraintsInitializerFactory + + # Register new constraint type + @ConstraintsInitializerFactory.register("new_constraint") + class NewConstraint: + def create_constraint(self, **kwargs) -> Constraint: + return lambda state, request: SchedulerUpdateAction() + + # Create and use constraint + constraint = ConstraintsInitializerFactory.create_constraint("new_constraint") + """ + + @classmethod + def create(cls, key: str, *args, **kwargs) -> ConstraintInitializer: + """ + Create a constraint initializer for the specified key. + + :param key: Registered constraint initializer key + :param args: Positional arguments for initializer creation + :param kwargs: Keyword arguments for initializer creation + :return: Configured constraint initializer instance + :raises ValueError: If the key is not registered in the factory + """ + if cls.registry is None or key not in cls.registry: + raise ValueError(f"Unknown constraint initializer key: {key}") + + initializer_class = cls.registry[key] + + return ( + initializer_class(*args, **kwargs) # type: ignore[operator] + if not isinstance(initializer_class, type) + or not issubclass(initializer_class, SerializableConstraintInitializer) + else initializer_class( + **initializer_class.validated_kwargs(*args, **kwargs) # type: ignore[misc] + ) + ) + + @classmethod + def serialize(cls, initializer: ConstraintInitializer) -> dict[str, Any]: + """ + Serialize constraint initializer to dictionary format. + + :param initializer: Constraint initializer to serialize + :return: Dictionary representation or unserializable placeholder + """ + if isinstance(initializer, SerializableConstraintInitializer): + return initializer.model_dump() + else: + unserializable = UnserializableConstraintInitializer( + orig_info=InfoMixin.extract_from_obj(initializer) + ) + return unserializable.model_dump() + + @classmethod + def deserialize( + cls, initializer_dict: dict[str, Any] + ) -> SerializableConstraintInitializer | UnserializableConstraintInitializer: + """ + Deserialize constraint initializer from dictionary format. + + :param initializer_dict: Dictionary representation of constraint initializer + :return: Reconstructed constraint initializer instance + :raises ValueError: If constraint type is unknown or cannot be deserialized + """ + if initializer_dict.get("type_") == "unserializable": + return UnserializableConstraintInitializer.model_validate(initializer_dict) + + if ( + cls.registry is not None + and initializer_dict.get("type_") + and initializer_dict["type_"] in cls.registry + ): + initializer_class = cls.registry[initializer_dict["type_"]] + if hasattr(initializer_class, "model_validate"): + return initializer_class.model_validate(initializer_dict) # type: ignore[return-value] + else: + return initializer_class(**initializer_dict) # type: ignore[return-value,operator] + + raise ValueError( + f"Cannot deserialize unknown constraint initializer: " + f"{initializer_dict.get('type_', 'unknown')}" + ) + + @classmethod + def create_constraint(cls, key: str, *args, **kwargs) -> Constraint: + """ + Create a constraint instance for the specified key. + + :param key: Registered constraint initializer key + :param args: Positional arguments for constraint creation + :param kwargs: Keyword arguments for constraint creation + :return: Configured constraint function ready for evaluation + :raises ValueError: If the key is not registered in the factory + """ + return cls.create(key, *args, **kwargs).create_constraint() + + @classmethod + def resolve( + cls, + initializers: dict[ + str, + Any | dict[str, Any] | Constraint | ConstraintInitializer, + ], + ) -> dict[str, Constraint]: + """ + Resolve mixed constraint specifications to callable constraints. + + :param initializers: Dictionary mapping constraint keys to specifications + :return: Dictionary mapping constraint keys to callable functions + :raises ValueError: If any key is not registered in the factory + """ + constraints = {} + + for key, val in initializers.items(): + if isinstance(val, Constraint): + constraints[key] = val + elif isinstance(val, ConstraintInitializer): + constraints[key] = val.create_constraint() + elif isinstance(val, dict): + constraints[key] = cls.create_constraint(key, **val) + else: + constraints[key] = cls.create_constraint(key, val) + + return constraints + + @classmethod + def resolve_constraints( + cls, + constraints: dict[str, Any | dict[str, Any] | Constraint], + ) -> dict[str, Constraint]: + """ + Resolve constraints from mixed constraint specifications. + + :param constraints: Dictionary mapping constraint keys to specifications + :return: Dictionary mapping constraint keys to callable functions + :raises ValueError: If any constraint key is not registered + """ + resolved_constraints = {} + + for key, val in constraints.items(): + if isinstance(val, Constraint): + resolved_constraints[key] = val + elif isinstance(val, dict): + resolved_constraints[key] = cls.create_constraint(key, **val) + else: + resolved_constraints[key] = cls.create_constraint(key, val) + + return resolved_constraints diff --git a/src/guidellm/scheduler/constraints/request.py b/src/guidellm/scheduler/constraints/request.py new file mode 100644 index 000000000..7ef5cab58 --- /dev/null +++ b/src/guidellm/scheduler/constraints/request.py @@ -0,0 +1,309 @@ +""" +Request-based constraint implementations. + +Provides constraint types for limiting benchmark execution based on request counts +and time duration. These constraints monitor request creation, processing, and +elapsed time to determine when to stop benchmark execution. +""" + +from __future__ import annotations + +import time +from typing import Any, Literal, cast + +from pydantic import Field, field_validator + +from guidellm.scheduler.schemas import ( + SchedulerProgress, + SchedulerState, + SchedulerUpdateAction, +) +from guidellm.schemas import RequestInfo, StandardBaseModel +from guidellm.utils import InfoMixin + +from .constraint import Constraint, PydanticConstraintInitializer +from .factory import ConstraintsInitializerFactory + +__all__ = [ + "MaxDurationConstraint", + "MaxNumberConstraint", + "RequestsExhaustedConstraint", +] + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["max_number", "max_num", "max_requests", "max_req"] +) +class MaxNumberConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on maximum request counts. + + Stops request queuing when created requests reach the limit and stops local + request processing when processed requests reach the limit. Provides progress + tracking based on remaining requests and completion fraction. + """ + + type_: Literal["max_number"] = "max_number" # type: ignore[assignment] + max_num: int | float | list[int | float] = Field( + description="Maximum number of requests allowed before triggering constraint", + ) + current_index: int = Field( + default=-1, description="Current index for list-based max_num values" + ) + + @classmethod + def validated_kwargs( + cls, max_num: int | float | list[int | float], **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxNumberConstraint creation. + + :param max_num: Maximum number of requests to allow + :param kwargs: Supports max_num, max_number, max_requests, max_req, + and optional type_ + :return: Validated dictionary with max_num and type_ fields + """ + aliases = ["max_number", "max_num", "max_requests", "max_req"] + for alias in aliases: + if max_num is None: + max_num = kwargs.get(alias) + + return {"max_num": max_num, "current_index": kwargs.get("current_index", -1)} + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return cast("Constraint", self.model_copy()) + + def __call__( + self, state: SchedulerState, request_info: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current scheduler state and request count. + + :param state: Current scheduler state with request counts + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + _ = request_info # Unused parameters + current_index = max(0, self.current_index) + max_num = ( + self.max_num + if isinstance(self.max_num, int | float) + else self.max_num[min(current_index, len(self.max_num) - 1)] + ) + + create_exceeded = state.created_requests >= max_num + processed_exceeded = state.processed_requests >= max_num + remaining_requests = min(max(0, max_num - state.processed_requests), max_num) + stop_time = ( + None if remaining_requests > 0 else request_info.completed_at or time.time() + ) + + return SchedulerUpdateAction( + request_queuing="stop" if create_exceeded else "continue", + request_processing="stop_local" if processed_exceeded else "continue", + metadata={ + "max_number": max_num, + "create_exceeded": create_exceeded, + "processed_exceeded": processed_exceeded, + "created_requests": state.created_requests, + "processed_requests": state.processed_requests, + "remaining_requests": remaining_requests, + "stop_time": stop_time, + }, + progress=SchedulerProgress( + remaining_requests=remaining_requests, + total_requests=max_num, + stop_time=stop_time, + ), + ) + + @field_validator("max_num") + @classmethod + def _validate_max_num( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + f"max_num must be set and truthful, received {value} ({val} failed)" + ) + if not isinstance(val, int | float) or val <= 0: + raise ValueError( + f"max_num must be a positive num, received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( + ["max_duration", "max_dur", "max_sec", "max_seconds", "max_min", "max_minutes"] +) +class MaxDurationConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on maximum time duration. + + Stops both request queuing and processing when the elapsed time since scheduler + start exceeds the maximum duration. Provides progress tracking based on + remaining time and completion fraction. + """ + + type_: Literal["max_duration"] = "max_duration" # type: ignore[assignment] + max_duration: int | float | list[int | float] = Field( + description="Maximum duration in seconds before triggering constraint" + ) + current_index: int = Field(default=-1, description="Current index in duration list") + + @classmethod + def validated_kwargs( + cls, max_duration: int | float | list[int | float] | None = None, **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxDurationConstraint creation. + + :param max_duration: Maximum duration in seconds + :param kwargs: Supports max_duration, max_dur, max_sec, max_seconds, + max_min, max_minutes, and optional type_ + :return: Validated dictionary with max_duration and type_ fields + """ + seconds_aliases = ["max_dur", "max_sec", "max_seconds"] + for alias in seconds_aliases: + if max_duration is None: + max_duration = kwargs.get(alias) + minutes_aliases = ["max_min", "max_minutes"] + for alias in minutes_aliases: + minutes = kwargs.get(alias) + if minutes is not None and max_duration is None: + max_duration = minutes * 60 + + return { + "max_duration": max_duration, + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return cast("Constraint", self.model_copy()) + + def __call__( + self, state: SchedulerState, request_info: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current scheduler state and elapsed time. + + :param state: Current scheduler state with start time + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + _ = request_info # Unused parameters + current_index = max(0, self.current_index) + max_duration = ( + self.max_duration + if isinstance(self.max_duration, int | float) + else self.max_duration[min(current_index, len(self.max_duration) - 1)] + ) + + current_time = time.time() + elapsed = current_time - state.start_time + duration_exceeded = elapsed >= max_duration + remaining_duration = min(max(0.0, max_duration - elapsed), max_duration) + stop_time = None if not duration_exceeded else state.start_time + max_duration + + return SchedulerUpdateAction( + request_queuing="stop" if duration_exceeded else "continue", + request_processing="stop_local" if duration_exceeded else "continue", + metadata={ + "max_duration": max_duration, + "elapsed_time": elapsed, + "duration_exceeded": duration_exceeded, + "start_time": state.start_time, + "current_time": current_time, + "stop_time": stop_time, + }, + progress=SchedulerProgress( + remaining_duration=remaining_duration, + total_duration=max_duration, + stop_time=stop_time, + ), + ) + + @field_validator("max_duration") + @classmethod + def _validate_max_duration( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_duration must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, int | float) or val <= 0: + raise ValueError( + "max_duration must be a positive num," + f"received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +class RequestsExhaustedConstraint(StandardBaseModel, InfoMixin): + type_: Literal["requests_exhausted"] = "requests_exhausted" # type: ignore[assignment] + num_requests: int + + @property + def info(self) -> dict[str, Any]: + """ + Extract serializable information from this constraint initializer. + + :return: Dictionary containing constraint configuration and metadata + """ + return self.model_dump() + + def __call__( + self, state: SchedulerState, request: RequestInfo + ) -> SchedulerUpdateAction: + _ = request # Unused parameter + create_exceeded = state.created_requests >= self.num_requests + processed_exceeded = state.processed_requests >= self.num_requests + remaining_requests = max(0, self.num_requests - state.processed_requests) + stop_time = ( + None if remaining_requests > 0 else request.completed_at or time.time() + ) + + return SchedulerUpdateAction( + request_queuing="stop" if create_exceeded else "continue", + request_processing="stop_local" if processed_exceeded else "continue", + metadata={ + "num_requests": self.num_requests, + "create_exceeded": create_exceeded, + "processed_exceeded": processed_exceeded, + "created_requests": state.created_requests, + "processed_requests": state.processed_requests, + "remaining_requests": remaining_requests, + "stop_time": stop_time, + }, + progress=SchedulerProgress( + remaining_requests=remaining_requests, + total_requests=self.num_requests, + stop_time=stop_time, + ), + ) diff --git a/src/guidellm/scheduler/constraints/saturation.py b/src/guidellm/scheduler/constraints/saturation.py new file mode 100644 index 000000000..25fb115fb --- /dev/null +++ b/src/guidellm/scheduler/constraints/saturation.py @@ -0,0 +1,666 @@ +""" +Over-saturation detection constraint implementation. + +This module implements the Over-Saturation Detection (OSD) algorithm for detecting +when a model becomes over-saturated during benchmarking. Over-saturation occurs when +the response rate doesn't keep up with the request rate, leading to degraded +performance. + +Algorithm Overview: +------------------- +The OSD algorithm uses statistical slope detection to identify over-saturation: + +1. **Slope Detection**: The algorithm tracks two key metrics over time: + - Concurrent requests: Number of requests being processed simultaneously + - Time-to-first-token (TTFT): Latency for the first token of each response + +2. **Statistical Analysis**: For each metric, the algorithm: + - Maintains a sliding window of recent data points + - Calculates the linear regression slope using online statistics + - Computes the margin of error (MOE) using t-distribution confidence intervals + - Detects positive slopes with low MOE, indicating degradation + +3. **Detection Criteria**: Over-saturation is detected when: + - Both concurrent requests and TTFT show statistically significant positive slopes + - The minimum duration threshold has been met + - Sufficient data points are available for reliable slope estimation + +4. **Window Management**: The algorithm maintains bounded memory by: + - Limiting window size by time (maximum_window_seconds) + - Limiting window size by ratio of total requests (maximum_window_ratio) + - Automatically pruning old data points + +5. **Constraint Integration**: When over-saturation is detected, the constraint: + - Stops request queuing to prevent further degradation + - Stops processing of existing requests (if enabled) + - Provides detailed metadata about detection state + +Key Parameters: +--------------- +- minimum_duration: Minimum seconds before checking for over-saturation (default: 30.0) +- minimum_ttft: Minimum TTFT threshold for violation counting (default: 2.5) +- maximum_window_seconds: Maximum time window for data retention (default: 120.0) +- moe_threshold: Margin of error threshold for slope detection (default: 2.0) +- maximum_window_ratio: Maximum window size as ratio of total requests (default: 0.75) +- minimum_window_size: Minimum data points required for slope estimation (default: 5) +- confidence: Statistical confidence level for t-distribution (default: 0.95) + +The constraint integrates with the scheduler by evaluating each request update and +providing scheduler actions (continue/stop) based on the current over-saturation state. +""" + +from __future__ import annotations + +import math +import time +from typing import Any, Literal + +from pydantic import Field + +from guidellm.scheduler.schemas import ( + SchedulerState, + SchedulerUpdateAction, +) +from guidellm.schemas import RequestInfo + +from .constraint import Constraint, PydanticConstraintInitializer +from .factory import ConstraintsInitializerFactory + +__all__ = [ + "OverSaturationConstraint", + "OverSaturationConstraintInitializer", + "SlopeChecker", + "approx_t_ppf", +] + + +def approx_t_ppf(p: float, df: float) -> float: + """ + Approximate the percent point function (PPF) for the t-distribution. + + Provides a fast approximation of the t-distribution PPF using numerical + methods from Abramowitz & Stegun. This function is significantly faster + than scipy.stats.t.ppf while providing sufficient accuracy for statistical + slope detection in over-saturation detection. Used internally by SlopeChecker + for calculating confidence intervals and margin of error. + + Reference: + Milton Abramowitz and Irene A. Stegun (Eds.). (1965). + Handbook of Mathematical Functions: with Formulas, Graphs, + and Mathematical Tables. Dover Publications. + + An electronic version of this book is available at: + https://personal.math.ubc.ca/~cbm/aands/. + + :param p: The probability value (e.g., 0.975 for a 95% confidence interval) + :param df: The degrees of freedom for the t-distribution + :return: Approximate t-distribution PPF value, or NaN if df <= 0 + """ + dof = df + if dof <= 0: + return float("nan") + + # 1. Approximate the PPF of the Normal distribution (z-score) + # Uses Abramowitz & Stegun formula 26.2.23. + c = [2.515517, 0.802853, 0.010328] + d = [1.432788, 0.189269, 0.001308] + + numerical_stability_threshold = 0.5 + if p < numerical_stability_threshold: + t = math.sqrt(-2.0 * math.log(p)) + z = -( + t + - ((c[2] * t + c[1]) * t + c[0]) + / (((d[2] * t + d[1]) * t + d[0]) * t + 1.0) + ) + else: + t = math.sqrt(-2.0 * math.log(1.0 - p)) + z = t - ((c[2] * t + c[1]) * t + c[0]) / ( + ((d[2] * t + d[1]) * t + d[0]) * t + 1.0 + ) + + # 2. Convert the z-score to a t-score + # Uses the Cornish-Fisher expansion (first few terms). + z2 = z * z + z3 = z2 * z + z4 = z3 * z + + g1 = (z3 + z) / 4.0 + g2 = (5.0 * z4 + 16.0 * z3 + 3.0 * z2) / 96.0 + + # Adjust z using the degrees of freedom (dof) + return z + g1 / dof + g2 / (dof * dof) + + +class SlopeChecker: + """ + Helper class for online slope detection using linear regression. + + Maintains running statistics for efficient O(1) updates and provides + statistical slope detection with margin of error calculation. Uses online + algorithms to compute linear regression statistics incrementally without + storing all data points, enabling memory-efficient slope detection for + over-saturation detection. Supports adding and removing data points + dynamically while maintaining accurate statistical measures. + + Example: + :: + checker = SlopeChecker(moe_threshold=2.0, confidence=0.95) + checker.add_data_point(1.0, 2.0) + checker.add_data_point(2.0, 3.0) + checker.add_data_point(3.0, 4.0) + is_positive = checker.check_slope(3.0) # True for positive slope + """ + + def __init__( + self, moe_threshold: float = 1.0, confidence: float = 0.95, eps: float = 1e-12 + ) -> None: + """ + Initialize slope checker with statistical parameters. + + :param moe_threshold: Maximum margin of error threshold for slope detection + :param confidence: Statistical confidence level for t-distribution (0-1) + :param eps: Epsilon value for numerical stability in calculations + """ + self.n = 0 + self.sum_x = 0.0 + self.sum_y = 0.0 + self.sum_xy = 0.0 + self.sum_x2 = 0.0 + self.sum_y2 = 0.0 + self.moe_threshold = moe_threshold + self.eps = eps + self.confidence = confidence + self.slope: float | None = None + self.margin_of_error: float | None = None + + def add_data_point(self, x_new: float, y_new: float) -> None: + """ + Integrate a new data point into the accumulated statistics. + + Updates running sums for linear regression calculation in O(1) time. + The data point is incorporated into the statistical model without + storing the individual value, enabling memory-efficient slope detection. + + :param x_new: The new x-coordinate (typically time or duration) + :param y_new: The new y-coordinate (typically metric value like TTFT + or concurrent requests) + """ + self.n += 1 + self.sum_x += x_new + self.sum_y += y_new + self.sum_xy += x_new * y_new + self.sum_x2 += x_new**2 + self.sum_y2 += y_new**2 + + def remove_data_point(self, x_old: float, y_old: float) -> None: + """ + Remove a data point from the accumulated statistics. + + Updates running sums by subtracting the specified data point in O(1) time. + Used for window management when pruning old data points to maintain + bounded memory usage while preserving statistical accuracy. + + :param x_old: The x-coordinate to remove (typically time or duration) + :param y_old: The y-coordinate to remove (typically metric value) + """ + self.n -= 1 + self.sum_x -= x_old + self.sum_y -= y_old + self.sum_xy -= x_old * y_old + self.sum_x2 -= x_old**2 + self.sum_y2 -= y_old**2 + + def check_slope(self, effective_n: float) -> bool: + """ + Check if there is a statistically significant positive slope. + + Calculates linear regression slope and margin of error using online + statistics. Returns True if the slope is positive and the margin of + error is below the threshold, indicating statistically significant + degradation. Updates internal slope and margin_of_error attributes + for external inspection. + + :param effective_n: Effective sample size for slope estimation (may differ + from actual n for correlation adjustment) + :return: True if positive slope detected with margin of error below threshold + """ + minimal_n_for_slope_estimation = 3 + if effective_n < minimal_n_for_slope_estimation: + return False + + # Calculate sums of squares and cross-products + # These formulas are numerically stable for online calculation. + centered_sum_xx = self.sum_x2 - (self.sum_x**2) / self.n + centered_sum_xy = self.sum_xy - (self.sum_x * self.sum_y) / self.n + centered_sum_yy = self.sum_y2 - (self.sum_y**2) / self.n + + # Safeguard against division by zero for SS_xx + centered_sum_xx_safe = max(centered_sum_xx, self.eps) + + slope = centered_sum_xy / centered_sum_xx_safe + + # Calculate Residual Sum of Squares (RSS) + # This is a direct calculation using the sums of squares. + residual_sum_of_squares = centered_sum_yy - ( + centered_sum_xy**2 / centered_sum_xx_safe + ) + + # Ensure RSS is non-negative due to potential floating point inaccuracies + residual_sum_of_squares = max(residual_sum_of_squares, 0.0) + + # Degrees of freedom for standard error (n - 2 for simple linear regression) + dof = effective_n - 2 + + residual_variance = residual_sum_of_squares / dof + standard_error = (residual_variance / centered_sum_xx_safe) ** 0.5 + + # t-critical value + alpha = 1 - self.confidence + t_crit = approx_t_ppf(1 - alpha / 2, df=dof) + + # Margin Of Error + margin_of_error = t_crit * standard_error / max(slope, self.eps) + + self.slope = slope + self.margin_of_error = margin_of_error + return (slope > 0) and (margin_of_error < self.moe_threshold) + + +class OverSaturationConstraint: # type: ignore[misc] + """ + Constraint that detects and stops execution when over-saturation is detected. + + This constraint implements the Over-Saturation Detection (OSD) algorithm to + identify when a model becomes over-saturated (response rate doesn't keep up with + request rate). When over-saturation is detected, the constraint stops request + queuing and optionally stops processing of existing requests. + + The constraint maintains internal state for tracking concurrent requests and + time-to-first-token (TTFT) metrics, using statistical slope detection to identify + performance degradation patterns. + """ + + def __init__( + self, + minimum_duration: float = 30.0, + minimum_ttft: float = 2.5, + maximum_window_seconds: float = 120.0, + moe_threshold: float = 2.0, + maximum_window_ratio: float = 0.75, + minimum_window_size: int = 5, + confidence: float = 0.95, + eps: float = 1e-12, + enabled: bool = True, + ) -> None: # noqa: PLR0913 + """ + Initialize the over-saturation constraint. + + Creates a new constraint instance with specified detection parameters. + The constraint will track concurrent requests and TTFT metrics, using + statistical slope detection to identify when the model becomes + over-saturated. All parameters have sensible defaults suitable for + most benchmarking scenarios. + + :param minimum_duration: Minimum seconds before checking for over-saturation + (default: 30.0) + :param minimum_ttft: Minimum TTFT threshold in seconds for violation counting + (default: 2.5) + :param maximum_window_seconds: Maximum time window in seconds for data retention + (default: 120.0) + :param moe_threshold: Margin of error threshold for slope detection + (default: 2.0) + :param maximum_window_ratio: Maximum window size as ratio of total requests + (default: 0.75) + :param minimum_window_size: Minimum data points required for slope estimation + (default: 5) + :param confidence: Statistical confidence level for t-distribution (0-1) + (default: 0.95) + :param eps: Epsilon for numerical stability in calculations + (default: 1e-12) + :param enabled: Whether to actually stop when over-saturation is detected + (default: True) + """ + self.minimum_duration = minimum_duration + self.minimum_ttft = minimum_ttft + self.maximum_window_seconds = maximum_window_seconds + self.maximum_window_ratio = maximum_window_ratio + self.minimum_window_size = minimum_window_size + self.moe_threshold = moe_threshold + self.confidence = confidence + self.eps = eps + self.enabled = enabled + self.reset() + + def reset(self) -> None: + """ + Reset all internal state to initial values. + + Clears all tracked requests, resets counters, and reinitializes slope + checkers. Useful for reusing constraint instances across multiple + benchmark runs or resetting state after configuration changes. + """ + self.duration = 0.0 + self.started_requests: list[dict[str, Any]] = [] + self.finished_requests: list[dict[str, Any]] = [] + self.ttft_violations_counter = 0 + self.total_finished_ever = 0 + self.total_started_ever = 0 + self.concurrent_slope_checker = SlopeChecker( + moe_threshold=self.moe_threshold, confidence=self.confidence, eps=self.eps + ) + self.ttft_slope_checker = SlopeChecker( + moe_threshold=self.moe_threshold, confidence=self.confidence, eps=self.eps + ) + + def _add_finished(self, request: dict[str, Any]) -> None: + """Add a finished request to tracking.""" + ttft = request["ttft"] + duration = request["duration"] + if ttft is not None: + self.total_finished_ever += 1 + self.finished_requests.append(request) + if ttft > self.minimum_ttft: + self.ttft_violations_counter += 1 + self.ttft_slope_checker.add_data_point(duration, ttft) + + def _remove_finished(self, request: dict[str, Any]) -> None: + """Remove a finished request from tracking.""" + del self.finished_requests[0] + ttft = request["ttft"] + duration = request["duration"] + if ttft > self.minimum_ttft: + self.ttft_violations_counter -= 1 + self.ttft_slope_checker.remove_data_point(duration, ttft) + + def _add_started(self, request: dict[str, Any]) -> None: + """Add a started request to tracking.""" + concurrent = request["concurrent_requests"] + duration = request["duration"] + if concurrent is not None: + self.total_started_ever += 1 + self.started_requests.append(request) + self.concurrent_slope_checker.add_data_point(duration, concurrent) + + def _remove_started(self, request: dict[str, Any]) -> None: + """Remove a started request from tracking.""" + del self.started_requests[0] + concurrent = request["concurrent_requests"] + duration = request["duration"] + self.concurrent_slope_checker.remove_data_point(duration, concurrent) + + def _update_duration(self, duration: float) -> None: + """Update duration and prune old data points.""" + self.duration = duration + + maximum_finished_window_size = int( + self.total_finished_ever * self.maximum_window_ratio + ) + while len(self.finished_requests) > maximum_finished_window_size: + self._remove_finished(self.finished_requests[0]) + + while (len(self.finished_requests) > 0) and ( + ( + time_since_earliest_request := duration + - self.finished_requests[0]["duration"] + ) + > self.maximum_window_seconds + ): + self._remove_finished(self.finished_requests[0]) + + maximum_started_window_size = int( + self.total_started_ever * self.maximum_window_ratio + ) + while len(self.started_requests) > maximum_started_window_size: + self._remove_started(self.started_requests[0]) + + while (len(self.started_requests) > 0) and ( + ( + time_since_earliest_request := duration # noqa: F841 + - self.started_requests[0]["duration"] + ) + > self.maximum_window_seconds + ): + self._remove_started(self.started_requests[0]) + + def _check_alert(self) -> bool: + """ + Check if over-saturation is currently detected. + + Returns: + True if over-saturation is detected, False otherwise. + """ + # Use duration as the maximum n value since requests from the + # same second are highly correlated, this is simple and good enough + # given that the MOE has a custom threshold anyway. + concurrent_n = min(self.duration, self.concurrent_slope_checker.n) + ttft_n = min(self.duration, self.ttft_slope_checker.n) + + if ( + (self.duration < self.minimum_duration) + or (self.ttft_slope_checker.n > self.ttft_violations_counter * 2) + or (self.duration < self.minimum_ttft) + or (concurrent_n < self.minimum_window_size) + ): + return False + + is_concurrent_slope_positive = self.concurrent_slope_checker.check_slope( + concurrent_n + ) + + if ttft_n < self.minimum_window_size: + return is_concurrent_slope_positive + + is_ttft_slope_positive = self.ttft_slope_checker.check_slope(ttft_n) + + return is_concurrent_slope_positive and is_ttft_slope_positive + + def __call__( + self, state: SchedulerState, request_info: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current scheduler state. + + :param state: Current scheduler state. + :param request_info: Individual request information. + :return: Action indicating whether to continue or stop operations. + """ + duration = time.time() - state.start_time + + if request_info.status == "in_progress": + concurrent_requests = state.processing_requests + self._add_started( + {"concurrent_requests": concurrent_requests, "duration": duration} + ) + elif ( + request_info.status == "completed" + and request_info.timings + and request_info.timings.first_token_iteration + and request_info.timings.request_start + ): + ttft = ( + request_info.timings.first_token_iteration + - request_info.timings.request_start + ) + self._add_finished({"ttft": ttft, "duration": duration}) + + self._update_duration(duration) + is_over_saturated = self._check_alert() + + ttft_slope = self.ttft_slope_checker.slope + ttft_slope_moe = self.ttft_slope_checker.margin_of_error + ttft_n = self.ttft_slope_checker.n + ttft_violations = self.ttft_violations_counter + concurrent_slope = self.concurrent_slope_checker.slope + concurrent_slope_moe = self.concurrent_slope_checker.margin_of_error + concurrent_n = self.concurrent_slope_checker.n + + should_stop = is_over_saturated and self.enabled + return SchedulerUpdateAction( + request_queuing="stop" if should_stop else "continue", + request_processing="stop_all" if should_stop else "continue", + metadata={ + "ttft_slope": ttft_slope, + "ttft_slope_moe": ttft_slope_moe, + "ttft_n": ttft_n, + "ttft_violations": ttft_violations, + "concurrent_slope": concurrent_slope, + "concurrent_slope_moe": concurrent_slope_moe, + "concurrent_n": concurrent_n, + "is_over_saturated": is_over_saturated, + }, + ) + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["over_saturation", "detect_saturation"] +) +class OverSaturationConstraintInitializer(PydanticConstraintInitializer): + """ + Factory for creating OverSaturationConstraint instances from configuration. + + Provides a Pydantic-based initializer for over-saturation detection constraints + with support for flexible configuration patterns. Supports both simple boolean + flags and detailed configuration dictionaries, enabling easy integration with + CLI arguments, configuration files, and programmatic constraint creation. + + Example: + :: + # Simple boolean configuration + initializer = OverSaturationConstraintInitializer(enabled=True) + constraint = initializer.create_constraint() + + # Detailed configuration + initializer = OverSaturationConstraintInitializer( + enabled=True, + min_seconds=60.0, + max_window_seconds=300.0, + moe_threshold=1.5 + ) + constraint = initializer.create_constraint() + + :cvar type_: Always "over_saturation" to identify this constraint type + :cvar enabled: Whether to stop the benchmark if over-saturation is detected + :cvar min_seconds: Minimum seconds before checking for over-saturation + :cvar max_window_seconds: Maximum time window for data retention + :cvar moe_threshold: Margin of error threshold for slope detection + :cvar minimum_ttft: Minimum TTFT threshold for violation counting + :cvar maximum_window_ratio: Maximum window size as ratio of total requests + :cvar minimum_window_size: Minimum data points required for slope estimation + :cvar confidence: Statistical confidence level for t-distribution + """ + + type_: Literal["over_saturation"] = "over_saturation" # type: ignore[assignment] + enabled: bool = Field( + default=True, + description="Whether to stop the benchmark if the model is over-saturated", + ) + min_seconds: int | float = Field( + default=30.0, + ge=0, + description="Minimum seconds before checking for over-saturation", + ) + max_window_seconds: int | float = Field( + default=120.0, + ge=0, + description="Maximum over-saturation checking window size in seconds", + ) + moe_threshold: float = Field( + default=2.0, + ge=0, + description="Margin of error threshold for slope detection", + ) + minimum_ttft: float = Field( + default=2.5, + ge=0, + description="Minimum TTFT threshold for violation counting", + ) + maximum_window_ratio: float = Field( + default=0.75, + ge=0, + le=1.0, + description="Maximum window size as ratio of total requests", + ) + minimum_window_size: int = Field( + default=5, + ge=0, + description="Minimum data points required for slope estimation", + ) + confidence: float = Field( + default=0.95, + ge=0, + le=1.0, + description="Statistical confidence level for t-distribution", + ) + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Create an OverSaturationConstraint instance from this initializer. + + Constructs a new OverSaturationConstraint with the configuration parameters + specified in this initializer. The constraint will be ready for evaluation + against scheduler state and requests. + + :param _kwargs: Additional keyword arguments (unused) + :return: Configured OverSaturationConstraint instance ready for use + """ + return OverSaturationConstraint( # type: ignore[return-value] + minimum_duration=self.min_seconds, + minimum_ttft=self.minimum_ttft, + maximum_window_seconds=self.max_window_seconds, + moe_threshold=self.moe_threshold, + maximum_window_ratio=self.maximum_window_ratio, + minimum_window_size=self.minimum_window_size, + confidence=self.confidence, + enabled=self.enabled, + ) + + @classmethod + def validated_kwargs( + cls, over_saturation: bool | dict[str, Any] | None = None, **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for OverSaturationConstraint creation. + + Processes flexible input formats to create validated constraint configuration. + Supports boolean flags for simple enable/disable, dictionary inputs for detailed + configuration, and alias parameters for compatibility. Handles parameter + normalization and default value application. + + :param over_saturation: Boolean to enable/disable with defaults, or dictionary + with configuration parameters (min_seconds, max_window_seconds, etc.) + :param kwargs: Additional keyword arguments supporting aliases like + "detect_saturation" for compatibility + :return: Validated dictionary with constraint configuration ready for + initializer creation + """ + # Check for aliases in kwargs + aliases = ["over_saturation", "detect_saturation"] + result: bool | dict[str, Any] | None = over_saturation + + for alias in aliases: + alias_value = kwargs.get(alias) + if alias_value is not None: + result = alias_value + break + + if result is None: + return {} + + if isinstance(result, bool): + return {"enabled": result} + elif isinstance(result, dict): + # Extract configuration from dict + return { + "enabled": result.get("enabled", True), + "min_seconds": result.get("min_seconds", 30.0), + "max_window_seconds": result.get("max_window_seconds", 120.0), + "moe_threshold": result.get("moe_threshold", 2.0), + "minimum_ttft": result.get("minimum_ttft", 2.5), + "maximum_window_ratio": result.get("maximum_window_ratio", 0.75), + "minimum_window_size": result.get("minimum_window_size", 5), + "confidence": result.get("confidence", 0.95), + } + else: + # Convert to bool if it's truthy + return {"enabled": bool(result)} diff --git a/src/guidellm/settings.py b/src/guidellm/settings.py index 293416d7c..651f2ed13 100644 --- a/src/guidellm/settings.py +++ b/src/guidellm/settings.py @@ -154,6 +154,10 @@ class Settings(BaseSettings): constraint_error_window_size: float = 30 constraint_error_min_processed: float = 30 + # Constraint settings + constraint_over_saturation_min_seconds: float = 30.0 + constraint_over_saturation_max_window_seconds: float = 120.0 + # Data settings dataset: DatasetSettings = DatasetSettings() diff --git a/tests/e2e/test_over_saturated_benchmark.py b/tests/e2e/test_over_saturated_benchmark.py index 368e2c0f2..eb518decb 100644 --- a/tests/e2e/test_over_saturated_benchmark.py +++ b/tests/e2e/test_over_saturated_benchmark.py @@ -33,7 +33,6 @@ def server(): server.stop() # Teardown: Stop the server after tests are done -@pytest.mark.skip(reason="Skipping future feature test") @pytest.mark.timeout(60) def test_over_saturated_benchmark(server: VllmSimServer): """ @@ -50,7 +49,7 @@ def test_over_saturated_benchmark(server: VllmSimServer): client.start_benchmark( rate=rate, max_seconds=20, - stop_over_saturated=True, + over_saturation=True, extra_env={ "GUIDELLM__CONSTRAINT_OVER_SATURATION_MIN_SECONDS": "0", "GOMAXPROCS": "1", @@ -69,7 +68,53 @@ def test_over_saturated_benchmark(server: VllmSimServer): # Check that the max duration constraint was triggered assert_constraint_triggered( - benchmark, "stop_over_saturated", {"is_over_saturated": True} + benchmark, "over_saturation", {"is_over_saturated": True} + ) + + cleanup_report_file(report_path) + + +@pytest.mark.timeout(60) +def test_over_saturated_benchmark_with_dict_config(server: VllmSimServer): + """ + Test over-saturation detection with dictionary configuration instead of boolean. + """ + report_path = Path("tests/e2e/over_saturated_benchmarks_dict.json") + rate = 100 + + # Create and configure the guidellm client + client = GuidellmClient(target=server.get_url(), output_path=report_path) + + cleanup_report_file(report_path) + # Start the benchmark with dictionary configuration for over-saturation + client.start_benchmark( + rate=rate, + max_seconds=20, + over_saturation={ + "enabled": True, + "min_seconds": 0, + "max_window_seconds": 120.0, + "moe_threshold": 2.0, + "minimum_window_size": 5, + }, + extra_env={ + "GOMAXPROCS": "1", + }, + ) + + # Wait for the benchmark to complete + client.wait_for_completion(timeout=55) + + # Assert no Python exceptions occurred + assert_no_python_exceptions(client.stderr) + + # Load and validate the report + report = load_benchmark_report(report_path) + benchmark = report["benchmarks"][0] + + # Check that the over-saturation constraint was triggered + assert_constraint_triggered( + benchmark, "over_saturation", {"is_over_saturated": True} ) cleanup_report_file(report_path) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index e63587e4e..663f22915 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -5,6 +5,7 @@ import sys import time from pathlib import Path +from typing import Any from loguru import logger @@ -44,7 +45,7 @@ def start_benchmark( max_seconds: int | None = None, max_requests: int | None = None, max_error_rate: float | None = None, - stop_over_saturated: bool | None = False, + over_saturation: bool | dict[str, Any] | None = None, data: str = "prompt_tokens=256,output_tokens=128", processor: str = "gpt2", additional_args: str = "", @@ -53,16 +54,18 @@ def start_benchmark( """ Start a guidellm benchmark command. - :param rate_type: Type of rate control (constant, etc.) + :param profile: Type of rate control (constant, etc.) :param rate: Request rate :param max_seconds: Maximum duration in seconds :param max_requests: Maximum number of requests :param max_error_rate: Maximum error rate before stopping - :param stop_over_saturated: Whether to stop the benchmark if the model is - over-saturated. + :param over_saturation: Over-saturation detection configuration (bool or dict). + When bool is True, passes --over-saturation=True to avoid Click parsing + issues. :param data: Data configuration string :param processor: Processor/tokenizer to use :param additional_args: Additional command line arguments + :param extra_env: Additional environment variables to set """ guidellm_exe = get_guidellm_executable() @@ -70,7 +73,7 @@ def start_benchmark( cmd_parts = [ *([f"{k}={v}" for k, v in extra_env.items()] if extra_env else []), "HF_HOME=/tmp/huggingface_cache", - f"{guidellm_exe} benchmark", + f"{guidellm_exe} benchmark run", f'--target "{self.target}"', f"--profile {profile}", f"--rate {rate}", @@ -85,8 +88,14 @@ def start_benchmark( if max_error_rate is not None: cmd_parts.append(f"--max-error-rate {max_error_rate}") - if stop_over_saturated: - cmd_parts.append("--stop-over-saturated") + if over_saturation is not None: + if isinstance(over_saturation, bool): + if over_saturation: + cmd_parts.append("--over-saturation=True") + elif isinstance(over_saturation, dict): + import json + + cmd_parts.append(f"--over-saturation '{json.dumps(over_saturation)}'") cmd_parts.extend( [ diff --git a/tests/unit/scheduler/test_over_saturation.py b/tests/unit/scheduler/test_over_saturation.py new file mode 100644 index 000000000..f87c40633 --- /dev/null +++ b/tests/unit/scheduler/test_over_saturation.py @@ -0,0 +1,578 @@ +"""Unit tests for over-saturation constraint implementation.""" + +import inspect +import time + +import pytest +from pydantic import ValidationError + +from guidellm.scheduler import ( + Constraint, + ConstraintInitializer, + ConstraintsInitializerFactory, + OverSaturationConstraint, + OverSaturationConstraintInitializer, + PydanticConstraintInitializer, + SchedulerState, + SchedulerUpdateAction, + SerializableConstraintInitializer, +) +from guidellm.schemas import RequestInfo, RequestTimings + + +class TestOverSaturationConstraintInternal: + """Test the OverSaturationConstraint internal functionality.""" + + @pytest.fixture( + params=[ + {"minimum_duration": 30.0, "maximum_window_seconds": 120.0}, + {"minimum_duration": 10.0, "maximum_window_seconds": 60.0}, + {"minimum_duration": 60.0, "maximum_window_seconds": 240.0}, + ] + ) + def valid_instances(self, request): + """Create OverSaturationConstraint instances with valid parameters.""" + constructor_args = request.param + instance = OverSaturationConstraint(**constructor_args, enabled=True) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test OverSaturationConstraint initialization with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.smoke + def test_initialization_defaults(self): + """Test that OverSaturationConstraint has correct default values.""" + constraint = OverSaturationConstraint(enabled=True) + + assert constraint.minimum_duration == 30.0 + assert constraint.minimum_ttft == 2.5 + assert constraint.maximum_window_seconds == 120.0 + assert constraint.moe_threshold == 2.0 + assert constraint.maximum_window_ratio == 0.75 + assert constraint.minimum_window_size == 5 + assert constraint.confidence == 0.95 + assert constraint.eps == 1e-12 + + @pytest.mark.smoke + def test_reset(self, valid_instances): + """Test that reset method properly initializes constraint state.""" + constraint, _ = valid_instances + constraint.reset() + + assert constraint.duration == 0.0 + assert constraint.started_requests == [] + assert constraint.finished_requests == [] + assert constraint.ttft_violations_counter == 0 + assert constraint.total_finished_ever == 0 + assert constraint.total_started_ever == 0 + assert hasattr(constraint, "concurrent_slope_checker") + assert hasattr(constraint, "ttft_slope_checker") + + @pytest.mark.sanity + def test_window_management_through_constraint(self): + """Test that constraint properly manages window sizes through usage.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + maximum_window_seconds=100.0, + maximum_window_ratio=0.5, + enabled=True, + ) + start_time = time.time() + + # Add many requests through constraint calls + for i in range(100): + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time - i, + processing_requests=i, + ) + request = RequestInfo( + request_id=f"test-{i}", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time - i, + ) + constraint(state, request) + + # Check that window management is working (through internal state) + # The constraint should have pruned old requests + assert len(constraint.started_requests) <= 50 # Should be limited by ratio + + +class TestOverSaturationConstraint: + """Test the OverSaturationConstraint implementation.""" + + @pytest.fixture + def constraint(self): + """Create a constraint for testing.""" + return OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=3, enabled=True + ) + + @pytest.fixture( + params=[ + {"enabled": True}, + {"enabled": False}, + ] + ) + def valid_instances(self, request): + """Create OverSaturationConstraint instances with valid parameters.""" + constructor_args = request.param + instance = OverSaturationConstraint( + minimum_duration=0.0, + minimum_window_size=3, + **constructor_args, + ) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that OverSaturationConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.smoke + def test_protocol_method_signature(self): + """Test that OverSaturationConstraint has the correct method signature.""" + constraint = OverSaturationConstraint(enabled=True) + call_method = constraint.__call__ + sig = inspect.signature(call_method) + + expected_params = ["state", "request_info"] + assert list(sig.parameters.keys()) == expected_params + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test OverSaturationConstraint initialization with valid parameters.""" + constraint, constructor_args = valid_instances + + assert constraint.enabled == constructor_args["enabled"] + + @pytest.mark.sanity + def test_constraint_returns_continue_when_not_saturated(self, constraint): + """Test constraint returns continue when not over-saturated.""" + start_time = time.time() + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=5, + ) + + request = RequestInfo( + request_id="test-1", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + assert isinstance(action.metadata, dict) + assert "is_over_saturated" in action.metadata + + @pytest.mark.sanity + def test_constraint_with_completed_request(self, constraint): + """Test constraint with completed request including timings.""" + start_time = time.time() + + # Create timings with first_iteration + timings = RequestTimings( + request_start=start_time + 0.1, first_iteration=start_time + 0.2 + ) + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=5, + ) + + request = RequestInfo( + request_id="test-1", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + timings=timings, + ) + + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + assert "ttft_slope" in action.metadata + assert "ttft_n" in action.metadata + + @pytest.mark.sanity + def test_constraint_stops_when_over_saturated(self, constraint): + """Test constraint stops when over-saturated and flag is enabled.""" + start_time = time.time() + + # Simulate over-saturation by creating positive slopes through constraint calls + # Add many started requests with increasing concurrent count + for i in range(20): + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time - i, + processing_requests=i * 2, + ) + request = RequestInfo( + request_id=f"test-{i}", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time - i, + ) + constraint(state, request) + + # Add finished requests with increasing TTFT + for i in range(20): + timings = RequestTimings( + request_start=start_time - i - 10.0, + first_iteration=start_time - i - 10.0 + (1.0 + i * 0.1), + ) + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time - i - 10.0, + processing_requests=5, + ) + request = RequestInfo( + request_id=f"test-finished-{i}", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time - i - 10.0, + timings=timings, + ) + constraint(state, request) + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=40, + ) + + request = RequestInfo( + request_id="test-1", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + # If over-saturated, should stop (but depends on slope detection) + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + # The exact action depends on whether detection triggers + assert action.request_queuing in ["continue", "stop"] + assert "is_over_saturated" in action.metadata + + @pytest.mark.sanity + def test_constraint_never_stops_when_flag_disabled(self): + """Test constraint never stops when enabled is False.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + minimum_window_size=3, + enabled=False, + ) + start_time = time.time() + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=100, # High concurrent requests + ) + + request = RequestInfo( + request_id="test-1", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + # Even if over-saturated, should continue when flag is False + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + + +class TestOverSaturationConstraintInitializer: + """Test the OverSaturationConstraintInitializer implementation.""" + + @pytest.fixture( + params=[ + {"enabled": True}, + {"enabled": False}, + { + "enabled": True, + "min_seconds": 10.0, + "max_window_seconds": 60.0, + }, + ] + ) + def valid_instances(self, request): + """Create OverSaturationConstraintInitializer with valid parameters.""" + constructor_args = request.param + instance = OverSaturationConstraintInitializer(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_pydantic_constraint_initializer(self, valid_instances): + """Test that initializer is a PydanticConstraintInitializer.""" + instance, _ = valid_instances + assert isinstance(instance, PydanticConstraintInitializer) + assert isinstance(instance, SerializableConstraintInitializer) + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self, valid_instances): + """Test that initializer satisfies ConstraintInitializer protocol.""" + instance, _ = valid_instances + assert isinstance(instance, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that initializer can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + assert instance.type_ == "over_saturation" + assert instance.enabled == constructor_args["enabled"] + + if "min_seconds" in constructor_args: + assert instance.min_seconds == constructor_args["min_seconds"] + if "max_window_seconds" in constructor_args: + assert instance.max_window_seconds == constructor_args["max_window_seconds"] + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that initializer rejects invalid parameters.""" + # Invalid type for enabled + with pytest.raises(ValidationError): + OverSaturationConstraintInitializer(enabled="invalid") + + # Invalid type for min_seconds + with pytest.raises(ValidationError): + OverSaturationConstraintInitializer(enabled=True, min_seconds="invalid") + + @pytest.mark.smoke + def test_create_constraint(self, valid_instances): + """Test that create_constraint returns OverSaturationConstraint.""" + instance, _ = valid_instances + constraint = instance.create_constraint() + + assert isinstance(constraint, OverSaturationConstraint) + assert constraint.enabled == instance.enabled + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test validated_kwargs method with various inputs.""" + result = OverSaturationConstraintInitializer.validated_kwargs( + over_saturation=True + ) + assert result == {"enabled": True} + + result = OverSaturationConstraintInitializer.validated_kwargs( + over_saturation=False + ) + assert result == {"enabled": False} + + # Test with dict input + result = OverSaturationConstraintInitializer.validated_kwargs( + over_saturation={"enabled": True, "min_seconds": 20.0} + ) + assert result["enabled"] is True + assert "min_seconds" in result + + # Test with aliases + result = OverSaturationConstraintInitializer.validated_kwargs( + detect_saturation=True + ) + assert result == {"enabled": True} + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that initializer can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert data["type_"] == "over_saturation" + assert data["enabled"] == constructor_args["enabled"] + + reconstructed = OverSaturationConstraintInitializer.model_validate(data) + assert reconstructed.enabled == instance.enabled + + @pytest.mark.smoke + def test_factory_registration(self): + """Test that initializer is properly registered with expected aliases.""" + expected_aliases = [ + "over_saturation", + "detect_saturation", + ] + + for alias in expected_aliases: + assert ConstraintsInitializerFactory.is_registered(alias) + registered_class = ConstraintsInitializerFactory.get_registered_object( + alias + ) + assert registered_class == OverSaturationConstraintInitializer + + @pytest.mark.smoke + @pytest.mark.parametrize("alias", ["over_saturation", "detect_saturation"]) + def test_factory_creation_with_aliases(self, alias): + """Test factory creation using different aliases.""" + # Test with dict configuration + constraint = ConstraintsInitializerFactory.create_constraint( + alias, enabled=True + ) + assert isinstance(constraint, OverSaturationConstraint) + assert constraint.enabled is True + + # Test with simple boolean value + constraint = ConstraintsInitializerFactory.create_constraint(alias, True) + assert isinstance(constraint, OverSaturationConstraint) + assert constraint.enabled is True + + constraint = ConstraintsInitializerFactory.create_constraint(alias, False) + assert isinstance(constraint, OverSaturationConstraint) + assert constraint.enabled is False + + @pytest.mark.smoke + def test_factory_resolve_methods(self): + """Test factory resolve methods with various input formats.""" + # Test with dict config + resolved = ConstraintsInitializerFactory.resolve( + {"over_saturation": {"enabled": True}} + ) + assert isinstance(resolved["over_saturation"], OverSaturationConstraint) + assert resolved["over_saturation"].enabled is True + + # Test with simple value + resolved = ConstraintsInitializerFactory.resolve({"detect_saturation": True}) + assert isinstance(resolved["detect_saturation"], OverSaturationConstraint) + assert resolved["detect_saturation"].enabled is True + + # Test with instance + instance = OverSaturationConstraintInitializer(enabled=False) + constraint_instance = instance.create_constraint() + resolved = ConstraintsInitializerFactory.resolve( + {"over_saturation": constraint_instance} + ) + assert resolved["over_saturation"] is constraint_instance + + @pytest.mark.smoke + def test_functional_constraint_creation(self): + """Test that created constraints are functionally correct.""" + constraint = ConstraintsInitializerFactory.create_constraint( + "over_saturation", enabled=True + ) + start_time = time.time() + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=5, + processed_requests=5, + processing_requests=3, + ) + request = RequestInfo( + request_id="test-request", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + # Should continue when not over-saturated + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + assert "is_over_saturated" in action.metadata + + +class TestSlopeChecker: + """Test the SlopeChecker implementation used by OverSaturationDetector.""" + + @pytest.fixture + def slope_checker(self): + """Create a SlopeChecker instance for testing.""" + from guidellm.scheduler.constraints.saturation import ( + SlopeChecker, + ) + + return SlopeChecker(moe_threshold=1.0, confidence=0.95) + + @pytest.mark.smoke + def test_initialization(self, slope_checker): + """Test SlopeChecker initialization.""" + assert slope_checker.n == 0 + assert slope_checker.sum_x == 0.0 + assert slope_checker.sum_y == 0.0 + assert slope_checker.moe_threshold == 1.0 + assert slope_checker.confidence == 0.95 + + @pytest.mark.sanity + def test_add_and_remove_data_points(self, slope_checker): + """Test adding and removing data points.""" + # Add data points + slope_checker.add_data_point(1.0, 2.0) + slope_checker.add_data_point(2.0, 4.0) + slope_checker.add_data_point(3.0, 6.0) + + assert slope_checker.n == 3 + assert slope_checker.sum_x == 6.0 + assert slope_checker.sum_y == 12.0 + + # Remove data point + slope_checker.remove_data_point(1.0, 2.0) + + assert slope_checker.n == 2 + assert slope_checker.sum_x == 5.0 + assert slope_checker.sum_y == 10.0 + + @pytest.mark.sanity + def test_check_slope_with_positive_slope(self, slope_checker): + """Test check_slope with clear positive slope.""" + # Create data with clear positive slope + for i in range(10): + slope_checker.add_data_point(float(i), float(i * 2)) + + result = slope_checker.check_slope(10.0) + assert result is True + assert slope_checker.slope is not None + assert slope_checker.slope > 0 + assert slope_checker.margin_of_error is not None + + @pytest.mark.sanity + def test_check_slope_requires_minimum_samples(self, slope_checker): + """Test that check_slope requires minimum samples.""" + # Not enough samples + slope_checker.add_data_point(1.0, 2.0) + result = slope_checker.check_slope(1.0) + assert result is False + + # Still not enough with 2 points + slope_checker.add_data_point(2.0, 4.0) + result = slope_checker.check_slope(2.0) + assert result is False + + # Should work with 3+ points + slope_checker.add_data_point(3.0, 6.0) + result = slope_checker.check_slope(3.0) + # Might be True or False depending on confidence intervals diff --git a/tests/unit/scheduler/test_over_saturation_comprehensive.py b/tests/unit/scheduler/test_over_saturation_comprehensive.py new file mode 100644 index 000000000..0914f6967 --- /dev/null +++ b/tests/unit/scheduler/test_over_saturation_comprehensive.py @@ -0,0 +1,870 @@ +"""Comprehensive unit tests for over-saturation constraint implementation. + +This module provides thorough testing to validate that over-saturation detection +and stopping features work correctly under various conditions and edge cases. +""" + +import math +import time +from unittest.mock import patch + +import pytest + +from guidellm.scheduler import ( + OverSaturationConstraint, + OverSaturationConstraintInitializer, + SchedulerState, + SchedulerUpdateAction, +) +from guidellm.scheduler.constraints.saturation import ( + SlopeChecker, + approx_t_ppf, +) +from guidellm.schemas import RequestInfo, RequestTimings + + +class TestSlopeCheckerStatisticalAccuracy: + """Test the statistical accuracy of SlopeChecker implementation.""" + + @pytest.mark.sanity + def test_approx_t_ppf_accuracy(self): + """Test that approx_t_ppf produces reasonable approximations.""" + # Test known values for t-distribution + # For df=10, p=0.975 (95% confidence, two-tailed), t ≈ 2.228 + result = approx_t_ppf(0.975, 10) + assert 2.0 < result < 2.5, f"Expected ~2.228, got {result}" + + # For df=30, p=0.975, t ≈ 2.042 + result = approx_t_ppf(0.975, 30) + assert 1.9 < result < 2.2, f"Expected ~2.042, got {result}" + + # For large df, should approach normal distribution (z=1.96) + result = approx_t_ppf(0.975, 1000) + assert 1.8 < result < 2.1, f"Expected ~1.96, got {result}" + + @pytest.mark.sanity + def test_approx_t_ppf_edge_cases(self): + """Test approx_t_ppf with edge cases.""" + # Very small df + result = approx_t_ppf(0.975, 1) + assert result > 5.0, "t-value should be large for df=1" + + # Invalid df should return NaN + result = approx_t_ppf(0.975, 0) + assert math.isnan(result) + + result = approx_t_ppf(0.975, -1) + assert math.isnan(result) + + @pytest.mark.smoke + def test_slope_calculation_perfect_line(self): + """Test slope calculation with perfect linear data.""" + checker = SlopeChecker(moe_threshold=0.1, confidence=0.95) + + # Perfect line: y = 2x + 1 + for i in range(10): + x = float(i) + y = 2.0 * x + 1.0 + checker.add_data_point(x, y) + + result = checker.check_slope(10.0) + assert result is True + assert abs(checker.slope - 2.0) < 0.001, ( + f"Expected slope ~2.0, got {checker.slope}" + ) + + @pytest.mark.smoke + def test_slope_calculation_zero_slope(self): + """Test slope calculation with horizontal line.""" + checker = SlopeChecker(moe_threshold=0.1, confidence=0.95) + + # Horizontal line: y = 5 + for i in range(10): + x = float(i) + y = 5.0 + checker.add_data_point(x, y) + + result = checker.check_slope(10.0) + # Should not detect positive slope + if result: + assert checker.slope <= 0.1, f"Slope should be ~0, got {checker.slope}" + + @pytest.mark.sanity + def test_slope_calculation_negative_slope(self): + """Test slope calculation with negative slope.""" + checker = SlopeChecker(moe_threshold=0.1, confidence=0.95) + + # Negative slope: y = -1.5x + 10 + for i in range(10): + x = float(i) + y = -1.5 * x + 10.0 + checker.add_data_point(x, y) + + result = checker.check_slope(10.0) + # Should not detect positive slope + assert result is False or checker.slope <= 0 + + @pytest.mark.sanity + def test_slope_calculation_with_noise(self): + """Test slope calculation with noisy data.""" + import random + + random.seed(42) # Reproducible results + + checker = SlopeChecker(moe_threshold=1.0, confidence=0.90) + + # Positive slope with noise: y = 1.5x + noise + for i in range(50): + x = float(i) + noise = random.uniform(-2.0, 2.0) + y = 1.5 * x + noise + checker.add_data_point(x, y) + + result = checker.check_slope(50.0) + if result: + assert 1.0 < checker.slope < 2.0, ( + f"Expected slope ~1.5, got {checker.slope}" + ) + + @pytest.mark.sanity + def test_margin_of_error_calculation(self): + """Test that margin of error is calculated correctly.""" + checker = SlopeChecker(moe_threshold=0.5, confidence=0.95) + + # Add data with known properties + for i in range(20): + x = float(i) + y = 2.0 * x + 1.0 + checker.add_data_point(x, y) + + result = checker.check_slope(20.0) + assert result is True + assert checker.margin_of_error is not None + assert checker.margin_of_error >= 0 + # For perfect data, margin of error should be very small + assert checker.margin_of_error < 0.1 + + +class TestOverSaturationConstraintRobustness: + """Test the robustness of OverSaturationConstraint under various conditions.""" + + @pytest.mark.sanity + def test_constraint_with_empty_data(self): + """Test constraint behavior with no data.""" + constraint = OverSaturationConstraint(minimum_duration=0.0, enabled=True) + + # Should not alert with no data + assert constraint._check_alert() is False + + # Should handle update_duration gracefully + constraint._update_duration(100.0) + assert constraint._check_alert() is False + + @pytest.mark.sanity + def test_constraint_with_single_request(self): + """Test constraint behavior with single request.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=1, enabled=True + ) + + constraint._add_started({"concurrent_requests": 5, "duration": 1.0}) + constraint._add_finished({"ttft": 2.0, "duration": 2.0}) + constraint._update_duration(10.0) + + # Should not alert with insufficient data + assert constraint._check_alert() is False + + @pytest.mark.sanity + def test_constraint_with_identical_values(self): + """Test constraint with identical values (zero variance).""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=3, enabled=True + ) + + # Add identical values + for i in range(10): + constraint._add_started({"concurrent_requests": 5, "duration": float(i)}) + constraint._add_finished({"ttft": 1.0, "duration": float(i)}) + + constraint._update_duration(20.0) + result = constraint._check_alert() + + # Should not alert for flat data + assert result is False + + @pytest.mark.sanity + def test_constraint_extreme_values(self): + """Test constraint with extreme values.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=3, enabled=True + ) + + # Add extreme values + values = [0.1, 1000.0, 0.01, 5000.0, 0.001] + for i, val in enumerate(values): + constraint._add_started( + {"concurrent_requests": int(val), "duration": float(i)} + ) + constraint._add_finished({"ttft": val, "duration": float(i)}) + + constraint._update_duration(20.0) + # Should handle without crashing + result = constraint._check_alert() + assert result in [True, False] + + @pytest.mark.sanity + def test_constraint_precision_edge_cases(self): + """Test constraint with floating point precision edge cases.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=3, enabled=True + ) + + # Very small increments + base = 1e10 + for i in range(10): + constraint._add_started( + {"concurrent_requests": 5, "duration": base + i * 1e-10} + ) + constraint._add_finished({"ttft": 1.0, "duration": base + i * 1e-10}) + + constraint._update_duration(base + 100.0) + # Should handle without numerical issues + result = constraint._check_alert() + assert result in [True, False] + + @pytest.mark.sanity + def test_constraint_window_management_stress(self): + """Test constraint window management under stress.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + maximum_window_seconds=10.0, + minimum_window_size=5, + enabled=True, + ) + + # Add many requests over time + for i in range(1000): + duration = float(i * 0.1) # 100 seconds total + constraint._add_started( + {"concurrent_requests": i % 50, "duration": duration} + ) + constraint._add_finished({"ttft": (i % 100) * 0.01, "duration": duration}) + + # Periodic window updates + if i % 100 == 0: + constraint._update_duration(duration + 5.0) + + # Should maintain reasonable window size + assert len(constraint.started_requests) <= 200 # Should be pruned + assert len(constraint.finished_requests) <= 200 + + +class TestOverSaturationConstraintRealisticScenarios: + """Test detector with realistic request patterns.""" + + @pytest.mark.sanity + def test_gradual_performance_degradation(self): + """Test detection of gradual performance degradation.""" + constraint = OverSaturationConstraint( + minimum_duration=5.0, + minimum_window_size=10, + moe_threshold=1.5, + enabled=True, + ) + + # Simulate gradual degradation + for i in range(50): + # Gradually increasing concurrent requests + concurrent = 10 + i * 0.5 + # Gradually increasing TTFT + ttft = 1.0 + i * 0.1 + duration = float(i) + + constraint._add_started( + {"concurrent_requests": int(concurrent), "duration": duration} + ) + constraint._add_finished({"ttft": ttft, "duration": duration}) + + constraint._update_duration(60.0) + result = constraint._check_alert() + + # Should detect the degradation + assert result is True, "Should detect gradual performance degradation" + + @pytest.mark.sanity + def test_sudden_load_spike(self): + """Test detection of sudden load spike.""" + constraint = OverSaturationConstraint( + minimum_duration=5.0, + minimum_window_size=10, + moe_threshold=1.0, + enabled=True, + ) + + # Normal operations first + for i in range(20): + constraint._add_started({"concurrent_requests": 5, "duration": float(i)}) + constraint._add_finished({"ttft": 1.0, "duration": float(i)}) + + # Sudden spike + for i in range(20, 40): + constraint._add_started({"concurrent_requests": 50, "duration": float(i)}) + constraint._add_finished({"ttft": 5.0, "duration": float(i)}) + + constraint._update_duration(50.0) + result = constraint._check_alert() + + # Should detect the spike + assert result is True, "Should detect sudden load spike" + + @pytest.mark.sanity + def test_variable_but_stable_performance(self): + """Test that variable but stable performance doesn't trigger false positives.""" + constraint = OverSaturationConstraint( + minimum_duration=5.0, + minimum_window_size=10, + moe_threshold=2.0, + enabled=True, + ) + + import random + + random.seed(123) # Reproducible + + # Variable but centered around stable values + for i in range(100): + concurrent = 15 + random.randint(-5, 5) # 10-20 range + ttft = 2.0 + random.uniform(-0.5, 0.5) # 1.5-2.5 range + duration = float(i) + + constraint._add_started( + {"concurrent_requests": concurrent, "duration": duration} + ) + constraint._add_finished({"ttft": ttft, "duration": duration}) + + constraint._update_duration(120.0) + result = constraint._check_alert() + + # Should not trigger false positive + assert result is False, ( + "Should not trigger false positive for stable performance" + ) + + @pytest.mark.sanity + def test_recovery_after_degradation(self): + """Test that detector handles recovery after degradation.""" + constraint = OverSaturationConstraint( + minimum_duration=5.0, + minimum_window_size=10, + maximum_window_seconds=30.0, + enabled=True, + ) + + # Initial degradation + for i in range(20): + concurrent = 10 + i * 2 # Increasing load + ttft = 1.0 + i * 0.2 # Increasing TTFT + constraint._add_started( + {"concurrent_requests": concurrent, "duration": float(i)} + ) + constraint._add_finished({"ttft": ttft, "duration": float(i)}) + + constraint._update_duration(25.0) + degradation_result = constraint._check_alert() + + # Add recovery period - improved performance + for i in range(40, 60): + constraint._add_started({"concurrent_requests": 5, "duration": float(i)}) + constraint._add_finished({"ttft": 0.8, "duration": float(i)}) + + constraint._update_duration(65.0) + recovery_result = constraint._check_alert() + + # Should detect degradation initially, then not alert during recovery + # (depending on window management) + assert degradation_result in [True, False] # Could go either way + # After recovery with window management, should be less likely to alert + if len(constraint.finished_requests) < 15: # If old data was purged + assert recovery_result is False, "Should not alert after recovery" + + +class TestOverSaturationConstraintIntegration: + """Test integration between constraint and detector with complex scenarios.""" + + def create_realistic_constraint(self) -> OverSaturationConstraint: + """Create a constraint with realistic settings.""" + return OverSaturationConstraint( + minimum_duration=10.0, + minimum_window_size=5, + maximum_window_seconds=60.0, + moe_threshold=1.5, + confidence=0.90, + enabled=True, + ) + + @pytest.mark.sanity + def test_constraint_metadata_completeness(self): + """Test that constraint provides complete metadata.""" + constraint = self.create_realistic_constraint() + start_time = time.time() + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=10, + ) + + request = RequestInfo( + request_id="test-request", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + + # Verify metadata completeness + required_fields = [ + "is_over_saturated", + "concurrent_slope", + "concurrent_n", + "ttft_slope", + "ttft_n", + "ttft_violations", # Correct field name + # Note: total_started_ever, total_finished_ever, + # window sizes not in metadata + ] + + for field in required_fields: + assert field in action.metadata, f"Missing metadata field: {field}" + + @pytest.mark.sanity + def test_constraint_with_realistic_request_flow(self): + """Test constraint with realistic request flow.""" + constraint = self.create_realistic_constraint() + start_time = time.time() + actions = [] + + # Simulate 60 seconds of requests + for second in range(60): + current_time = start_time + second + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=10 + second, # Gradually increasing load + ) + + # Mix of request statuses + for req_num in range(3): # 3 requests per second + request_id = f"req-{second}-{req_num}" + + if req_num == 0: # Completed request + timings = RequestTimings( + request_start=current_time - 2.0, + first_iteration=current_time + - 2.0 + + (second * 0.05), # Gradually slower + ) + request = RequestInfo( + request_id=request_id, + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + timings=timings, + ) + else: # In progress request + request = RequestInfo( + request_id=request_id, + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + actions.append((second, action)) + + # Analyze results + stop_actions = [a for s, a in actions if a.request_queuing == "stop"] + + # Should eventually detect over-saturation + if len(stop_actions) > 0: + first_stop_second = min( + s for s, a in actions if a.request_queuing == "stop" + ) + assert first_stop_second >= 10, "Should not stop before minimum duration" + + @pytest.mark.sanity + def test_constraint_disabled_never_stops(self): + """Test that disabled constraint never stops regardless of load.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + minimum_window_size=3, + enabled=False, # Disabled + ) + + # Add obviously over-saturated data + for i in range(50): + constraint._add_started( + {"concurrent_requests": i * 10, "duration": float(i)} + ) + constraint._add_finished({"ttft": i * 2.0, "duration": float(i)}) + + constraint._update_duration(60.0) + + start_time = time.time() + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=500, # Very high load + ) + + request = RequestInfo( + request_id="test-request", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + + # Should continue despite over-saturation + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + assert action.metadata["is_over_saturated"] in [True, False] # Could be either + + +class TestOverSaturationConstraintPerformance: + """Test performance characteristics of the constraint.""" + + @pytest.mark.sanity + def test_detector_memory_usage(self): + """Test that detector manages memory properly.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + maximum_window_seconds=10.0, + minimum_window_size=5, + enabled=True, + ) + + # Add many requests + for i in range(10000): + duration = float(i * 0.01) # 100 seconds total + constraint._add_started({"concurrent_requests": 10, "duration": duration}) + constraint._add_finished({"ttft": 1.0, "duration": duration}) + + if i % 1000 == 0: + constraint._update_duration(duration + 5.0) + + # Memory should be bounded due to window management + assert len(constraint.started_requests) < 2000, "Started requests not bounded" + assert len(constraint.finished_requests) < 2000, "Finished requests not bounded" + + @pytest.mark.sanity + def test_constraint_computational_efficiency(self): + """Test that constraint operations remain efficient.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=10, enabled=True + ) + + # Add baseline data + for i in range(100): + constraint._add_started({"concurrent_requests": 10, "duration": float(i)}) + constraint._add_finished({"ttft": 1.0, "duration": float(i)}) + + constraint._update_duration(120.0) + + # Time multiple check_alert calls + start_time = time.time() + for _ in range(100): + constraint._check_alert() + elapsed = time.time() - start_time + + # Should complete quickly (< 1 second for 100 calls) + assert elapsed < 1.0, f"Detection too slow: {elapsed:.3f}s for 100 calls" + + +class TestOverSaturationConstraintInitializerRobustness: + """Test robustness of the constraint initializer.""" + + @pytest.mark.smoke + def test_initializer_parameter_validation(self): + """Test parameter validation in initializer.""" + # Valid parameters + initializer = OverSaturationConstraintInitializer( + enabled=True, + min_seconds=5.0, + max_window_seconds=30.0, + moe_threshold=1.5, + confidence=0.95, + ) + + constraint = initializer.create_constraint() + assert constraint.enabled is True + assert constraint.minimum_duration == 5.0 + assert constraint.maximum_window_seconds == 30.0 + + @pytest.mark.smoke + def test_initializer_with_extreme_parameters(self): + """Test initializer with extreme but valid parameters.""" + # Very permissive settings - only test parameters actually supported + initializer = OverSaturationConstraintInitializer( + enabled=True, + min_seconds=0.1, + max_window_seconds=3600.0, # 1 hour + ) + + constraint = initializer.create_constraint() + + assert constraint.minimum_duration == 0.1 + assert constraint.maximum_window_seconds == 3600.0 + # Note: moe_threshold and confidence may have default values + + @pytest.mark.smoke + def test_initializer_alias_precedence(self): + """Test alias precedence in validated_kwargs.""" + # Multiple aliases provided - should use the explicit one + result = OverSaturationConstraintInitializer.validated_kwargs( + over_saturation=False, # Explicit parameter + detect_saturation=True, # Alias + ) + + # detect_saturation should override over_saturation=False + assert result == {"enabled": True} + + @pytest.mark.smoke + def test_constraint_creation_with_mock_constraint(self): + """Test constraint creation with mocked constraint for isolation.""" + constraint = OverSaturationConstraint(enabled=True) + # Set up constraint state to simulate over-saturation + constraint.ttft_slope_checker.slope = 1.5 + constraint.ttft_slope_checker.margin_of_error = 0.3 + constraint.ttft_slope_checker.n = 10 + constraint.concurrent_slope_checker.slope = 2.0 + constraint.concurrent_slope_checker.margin_of_error = 0.5 + constraint.concurrent_slope_checker.n = 15 + constraint.ttft_violations_counter = 5 + constraint.duration = 30.0 # Set duration to pass minimum check + + start_time = time.time() + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=10, + ) + + request = RequestInfo( + request_id="test-request", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + + # Should provide metadata about saturation state + assert "is_over_saturated" in action.metadata + + +class TestOverSaturationEdgeCasesAndRegression: + """Test edge cases and regression scenarios.""" + + @pytest.mark.sanity + def test_detector_with_malformed_request_data(self): + """Test detector requires proper request data structure.""" + constraint = OverSaturationConstraint(minimum_duration=0.0, enabled=True) + + # Missing fields should raise KeyError + with pytest.raises(KeyError): + constraint._add_started({}) # Missing required fields + + with pytest.raises(KeyError): + constraint._add_finished({}) + + with pytest.raises(KeyError): + constraint._add_started({"concurrent_requests": 5}) # Missing duration + + with pytest.raises(KeyError): + constraint._add_finished({"ttft": 1.0}) # Missing duration + + # Valid data should work + constraint._add_started({"concurrent_requests": 5, "duration": 1.0}) + constraint._add_finished({"ttft": 1.0, "duration": 1.0}) + + constraint._update_duration(10.0) + result = constraint._check_alert() + assert result in [True, False] + + @pytest.mark.sanity + def test_constraint_with_missing_timings_data(self): + """Test constraint handles missing timings data gracefully.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + enabled=True, + ) + + start_time = time.time() + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=5, + ) + + # Create request without timings (in_progress status) + request = RequestInfo( + request_id="test-request", + status="in_progress", # No timings expected for in_progress + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + # Should not crash + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + + @pytest.mark.sanity + def test_detector_concurrent_modification_safety(self): + """Test detector behavior under concurrent-like modifications.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=3, enabled=True + ) + + # Add requests + requests = [] + for i in range(20): + req = {"concurrent_requests": i, "duration": float(i)} + constraint._add_started(req) + requests.append(req) + + # Remove some while iterating (simulating concurrent access pattern) + for i in range(0, 10, 2): # Remove every other early request + constraint._remove_started(requests[i]) + + # Should still function + constraint._update_duration(25.0) + result = constraint._check_alert() + assert result in [True, False] + + @pytest.mark.sanity + def test_slope_checker_numerical_stability(self): + """Test SlopeChecker numerical stability with challenging data.""" + checker = SlopeChecker(moe_threshold=0.1, confidence=0.95) + + # Add data that could cause numerical instability + base = 1e15 # Very large numbers + for i in range(10): + x = base + i + y = base + i * 1e-10 # Very small slope relative to magnitude + checker.add_data_point(x, y) + + # Should handle without overflow/underflow + result = checker.check_slope(10.0) + assert result in [True, False] + + if checker.slope is not None: + assert not math.isnan(checker.slope) + assert not math.isinf(checker.slope) + + @pytest.mark.sanity + def test_detector_reset_clears_all_state(self): + """Test that detector reset completely clears state.""" + constraint = OverSaturationConstraint(minimum_duration=0.0, enabled=True) + + # Add data and trigger computation + for i in range(20): + constraint._add_started({"concurrent_requests": i, "duration": float(i)}) + constraint._add_finished({"ttft": i * 0.1, "duration": float(i)}) + + constraint._update_duration(25.0) + constraint._check_alert() # Populate computed values + + # Verify state exists + assert len(constraint.started_requests) > 0 + assert len(constraint.finished_requests) > 0 + assert constraint.total_started_ever > 0 + assert constraint.total_finished_ever > 0 + + # Reset + constraint.reset() + + # Verify complete reset + assert len(constraint.started_requests) == 0 + assert len(constraint.finished_requests) == 0 + assert constraint.total_started_ever == 0 + assert constraint.total_finished_ever == 0 + assert constraint.ttft_violations_counter == 0 + assert constraint.duration == 0.0 + + # Slope checkers should be reset too + assert constraint.concurrent_slope_checker.n == 0 + assert constraint.ttft_slope_checker.n == 0 + + @pytest.mark.sanity + @patch("time.time") + def test_constraint_time_calculation_accuracy(self, mock_time): + """Test that constraint calculates durations accurately.""" + # Mock time to control duration calculation + start_time = 1000.0 + current_time = 1030.0 # 30 seconds later + mock_time.return_value = current_time + + constraint = OverSaturationConstraint( + minimum_duration=25.0, enabled=True + ) # Should be met + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=5, + ) + + request = RequestInfo( + request_id="test-request", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + # Call constraint - should update detector duration + constraint(state, request) + + # Verify duration was calculated correctly + assert abs(constraint.duration - 30.0) < 0.001, ( + f"Expected duration ~30.0, got {constraint.duration}" + ) + + @pytest.mark.sanity + def test_ttft_violation_counting_accuracy(self): + """Test TTFT violation counting is accurate.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + minimum_ttft=2.0, # Threshold + enabled=True, + ) + + # Add requests with known TTFT values + ttft_values = [1.0, 3.0, 1.5, 4.0, 2.1, 0.5, 5.0, 1.9] + expected_violations = sum( + 1 for ttft in ttft_values if ttft > 2.0 + ) # Should be 4 + + for i, ttft in enumerate(ttft_values): + constraint._add_finished({"ttft": ttft, "duration": float(i)}) + + assert constraint.ttft_violations_counter == expected_violations, ( + f"Expected {expected_violations} violations, " + f"got {constraint.ttft_violations_counter}" + )