diff --git a/src/guidellm/benchmark/scenario.py b/src/guidellm/benchmark/scenario.py index af43e426..042b25b1 100644 --- a/src/guidellm/benchmark/scenario.py +++ b/src/guidellm/benchmark/scenario.py @@ -12,7 +12,7 @@ from guidellm.backend.backend import BackendType from guidellm.benchmark.profile import ProfileType from guidellm.objects.pydantic import StandardBaseModel -from guidellm.scheduler.strategy import StrategyType +from guidellm.scheduler.strategies import StrategyType __ALL__ = ["Scenario", "GenerativeTextScenario", "get_builtin_scenarios"] diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index d3aa0aab..64647424 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -1,47 +1,90 @@ -from .result import ( - SchedulerRequestInfo, - SchedulerRequestResult, - SchedulerResult, - SchedulerRunInfo, +from .constraints import ( + Constraint, + ConstraintInitializer, + ConstraintsInitializerFactory, + MaxDurationConstraint, + MaxErrorRateConstraint, + MaxErrorsConstraint, + MaxGlobalErrorRateConstraint, + MaxNumberConstraint, + PydanticConstraintInitializer, + SerializableConstraintInitializer, + UnserializableConstraintInitializer, +) +from .environments import Environment, NonDistributedEnvironment +from .objects import ( + BackendInterface, + BackendT, + MeasuredRequestTimings, + MultiTurnRequestT, + RequestSchedulerTimings, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerMessagingPydanticRegistry, + SchedulerState, + SchedulerUpdateAction, + SchedulerUpdateActionProgress, ) from .scheduler import Scheduler -from .strategy import ( +from .strategies import ( AsyncConstantStrategy, AsyncPoissonStrategy, ConcurrentStrategy, + ConstantRateRequestTimings, + LastCompletionRequestTimings, + NoDelayRequestTimings, + PoissonRateRequestTimings, + ScheduledRequestTimings, SchedulingStrategy, + StrategyT, StrategyType, SynchronousStrategy, ThroughputStrategy, - strategy_display_str, -) -from .worker import ( - GenerativeRequestsWorker, - GenerativeRequestsWorkerDescription, - RequestsWorker, - ResolveStatus, - WorkerDescription, - WorkerProcessResult, ) +from .worker import WorkerProcess +from .worker_group import WorkerProcessGroup __all__ = [ "AsyncConstantStrategy", "AsyncPoissonStrategy", + "BackendInterface", + "BackendT", "ConcurrentStrategy", - "GenerativeRequestsWorker", - "GenerativeRequestsWorkerDescription", - "RequestsWorker", - "ResolveStatus", + "ConstantRateRequestTimings", + "Constraint", + "ConstraintInitializer", + "ConstraintsInitializerFactory", + "Environment", + "LastCompletionRequestTimings", + "MaxDurationConstraint", + "MaxErrorRateConstraint", + "MaxErrorsConstraint", + "MaxGlobalErrorRateConstraint", + "MaxNumberConstraint", + "MeasuredRequestTimings", + "MultiTurnRequestT", + "NoDelayRequestTimings", + "NonDistributedEnvironment", + "PoissonRateRequestTimings", + "PydanticConstraintInitializer", + "RequestSchedulerTimings", + "RequestT", + "ResponseT", + "ScheduledRequestInfo", + "ScheduledRequestTimings", "Scheduler", - "SchedulerRequestInfo", - "SchedulerRequestResult", - "SchedulerResult", - "SchedulerRunInfo", + "SchedulerMessagingPydanticRegistry", + "SchedulerState", + "SchedulerUpdateAction", + "SchedulerUpdateActionProgress", "SchedulingStrategy", + "SerializableConstraintInitializer", + "StrategyT", "StrategyType", "SynchronousStrategy", "ThroughputStrategy", - "WorkerDescription", - "WorkerProcessResult", - "strategy_display_str", + "UnserializableConstraintInitializer", + "WorkerProcess", + "WorkerProcessGroup", ] diff --git a/src/guidellm/scheduler/constraints.py b/src/guidellm/scheduler/constraints.py new file mode 100644 index 00000000..c724a74a --- /dev/null +++ b/src/guidellm/scheduler/constraints.py @@ -0,0 +1,1035 @@ +""" +Constraint system for scheduler behavior control and request processing limits. + +Provides flexible constraints for managing scheduler behavior with configurable +thresholds based on time, error rates, and request counts. Constraints evaluate +scheduler state and individual requests to determine whether processing should +continue or stop based on predefined limits. The constraint system enables +sophisticated benchmark stopping criteria through composable constraint types. +""" + +from __future__ import annotations + +import time +from abc import ABC, abstractmethod +from typing import Any, Literal, Protocol, runtime_checkable + +from pydantic import Field, field_validator + +from guidellm.scheduler.objects import ( + ScheduledRequestInfo, + SchedulerState, + SchedulerUpdateAction, + SchedulerUpdateActionProgress, +) +from guidellm.settings import settings +from guidellm.utils import InfoMixin, RegistryMixin, StandardBaseModel + +__all__ = [ + "Constraint", + "ConstraintInitializer", + "ConstraintsInitializerFactory", + "MaxDurationConstraint", + "MaxErrorRateConstraint", + "MaxErrorsConstraint", + "MaxGlobalErrorRateConstraint", + "MaxNumberConstraint", + "PydanticConstraintInitializer", + "RequestsExhaustedConstraint", + "SerializableConstraintInitializer", + "UnserializableConstraintInitializer", +] + + +@runtime_checkable +class Constraint(Protocol): + """Protocol for constraint evaluation functions that control scheduler behavior.""" + + def __call__( + self, state: SchedulerState, request: ScheduledRequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against scheduler state and request information. + + :param state: Current scheduler state with metrics and timing information + :param request: Individual request information and metadata + :return: Action indicating whether to continue or stop scheduler operations + """ + + +@runtime_checkable +class ConstraintInitializer(Protocol): + """Protocol for constraint initializer factory functions that create constraints.""" + + def create_constraint(self, **kwargs) -> Constraint: + """ + Create a constraint instance from configuration parameters. + + :param kwargs: Configuration parameters for constraint creation + :return: Configured constraint evaluation function + """ + + +@runtime_checkable +class SerializableConstraintInitializer(Protocol): + """Protocol for serializable constraint initializers supporting persistence.""" + + @classmethod + def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: + """ + Validate and process arguments for constraint creation. + + :param args: Positional arguments for constraint configuration + :param kwargs: Keyword arguments for constraint configuration + :return: Validated parameter dictionary for constraint creation + """ + + @classmethod + def model_validate(cls, **kwargs) -> ConstraintInitializer: + """ + Create validated constraint initializer from configuration. + + :param kwargs: Configuration dictionary for initializer creation + :return: Validated constraint initializer instance + """ + + def model_dump(self) -> dict[str, Any]: + """ + Serialize constraint initializer to dictionary format. + + :return: Dictionary representation of constraint initializer + """ + + def create_constraint(self, **kwargs) -> Constraint: + """ + Create constraint instance from this initializer. + + :param kwargs: Additional configuration parameters + :return: Configured constraint evaluation function + """ + + +class ConstraintsInitializerFactory(RegistryMixin[ConstraintInitializer]): + """ + Registry factory for creating and managing constraint initializers. + + Provides centralized access to registered constraint types with support for + creating constraints from configuration dictionaries, simple values, or + pre-configured instances. Handles constraint resolution and type validation + for the scheduler constraint system. + + Example: + :: + from guidellm.scheduler import ConstraintsInitializerFactory + + # Register new constraint type + @ConstraintsInitializerFactory.register("new_constraint") + class NewConstraint: + def create_constraint(self, **kwargs) -> Constraint: + return lambda state, request: SchedulerUpdateAction() + + # Create and use constraint + constraint = ConstraintsInitializerFactory.create_constraint("new_constraint") + """ + + @classmethod + def create(cls, key: str, *args, **kwargs) -> ConstraintInitializer: + """ + Create a constraint initializer for the specified key. + + :param key: Registered constraint initializer key + :param args: Positional arguments for initializer creation + :param kwargs: Keyword arguments for initializer creation + :return: Configured constraint initializer instance + :raises ValueError: If the key is not registered in the factory + """ + if cls.registry is None or key not in cls.registry: + raise ValueError(f"Unknown constraint initializer key: {key}") + + initializer_class = cls.registry[key] + + return ( + initializer_class(*args, **kwargs) # type: ignore[operator] + if not isinstance(initializer_class, type) + or not issubclass(initializer_class, SerializableConstraintInitializer) + else initializer_class( + **initializer_class.validated_kwargs(*args, **kwargs) # type: ignore[misc] + ) + ) + + @classmethod + def serialize(cls, initializer: ConstraintInitializer) -> dict[str, Any]: + """ + Serialize constraint initializer to dictionary format. + + :param initializer: Constraint initializer to serialize + :return: Dictionary representation or unserializable placeholder + """ + if isinstance(initializer, SerializableConstraintInitializer): + return initializer.model_dump() + else: + unserializable = UnserializableConstraintInitializer( + orig_info=InfoMixin.extract_from_obj(initializer) + ) + return unserializable.model_dump() + + @classmethod + def deserialize( + cls, initializer_dict: dict[str, Any] + ) -> SerializableConstraintInitializer: + """ + Deserialize constraint initializer from dictionary format. + + :param initializer_dict: Dictionary representation of constraint initializer + :return: Reconstructed constraint initializer instance + :raises ValueError: If constraint type is unknown or cannot be deserialized + """ + if initializer_dict.get("type_") == "unserializable": + return UnserializableConstraintInitializer.model_validate(initializer_dict) + + if ( + cls.registry is not None + and initializer_dict.get("type_") + and initializer_dict["type_"] in cls.registry + ): + initializer_class = cls.registry[initializer_dict["type_"]] + if hasattr(initializer_class, "model_validate"): + return initializer_class.model_validate(initializer_dict) # type: ignore[return-value] + else: + return initializer_class(**initializer_dict) # type: ignore[return-value,operator] + + raise ValueError( + f"Cannot deserialize unknown constraint initializer: " + f"{initializer_dict.get('type_', 'unknown')}" + ) + + @classmethod + def create_constraint(cls, key: str, *args, **kwargs) -> Constraint: + """ + Create a constraint instance for the specified key. + + :param key: Registered constraint initializer key + :param args: Positional arguments for constraint creation + :param kwargs: Keyword arguments for constraint creation + :return: Configured constraint function ready for evaluation + :raises ValueError: If the key is not registered in the factory + """ + return cls.create(key, *args, **kwargs).create_constraint() + + @classmethod + def resolve( + cls, + initializers: dict[ + str, + Any | dict[str, Any] | Constraint | ConstraintInitializer, + ], + ) -> dict[str, Constraint]: + """ + Resolve mixed constraint specifications to callable constraints. + + :param initializers: Dictionary mapping constraint keys to specifications + :return: Dictionary mapping constraint keys to callable functions + :raises ValueError: If any key is not registered in the factory + """ + constraints = {} + + for key, val in initializers.items(): + if isinstance(val, Constraint): + constraints[key] = val + elif isinstance(val, ConstraintInitializer): + constraints[key] = val.create_constraint() + elif isinstance(val, dict): + constraints[key] = cls.create_constraint(key, **val) + else: + constraints[key] = cls.create_constraint(key, val) + + return constraints + + @classmethod + def resolve_constraints( + cls, + constraints: dict[str, Any | dict[str, Any] | Constraint], + ) -> dict[str, Constraint]: + """ + Resolve constraints from mixed constraint specifications. + + :param constraints: Dictionary mapping constraint keys to specifications + :return: Dictionary mapping constraint keys to callable functions + :raises ValueError: If any constraint key is not registered + """ + resolved_constraints = {} + + for key, val in constraints.items(): + if isinstance(val, Constraint): + resolved_constraints[key] = val + elif isinstance(val, dict): + resolved_constraints[key] = cls.create_constraint(key, **val) + else: + resolved_constraints[key] = cls.create_constraint(key, val) + + return resolved_constraints + + +class PydanticConstraintInitializer(StandardBaseModel, ABC, InfoMixin): + """ + Abstract base for Pydantic-based constraint initializers. + + Provides standardized serialization, validation, and metadata handling for + constraint initializers using Pydantic models. Subclasses implement specific + constraint creation logic while inheriting validation and persistence support. + """ + + type_: str = Field(description="Type identifier for the constraint initializer") + + @property + def info(self) -> dict[str, Any]: + """ + Extract serializable information from this constraint initializer. + + :return: Dictionary containing constraint configuration and metadata + """ + return self.model_dump() + + @classmethod + @abstractmethod + def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: + """ + Validate and process arguments for constraint creation. + + Must be implemented by subclasses to handle their specific parameter patterns + and validation requirements. + + :param args: Positional arguments passed to the constraint + :param kwargs: Keyword arguments passed to the constraint + :return: Validated dictionary of parameters for constraint creation + :raises NotImplementedError: Must be implemented by subclasses + """ + ... + + @abstractmethod + def create_constraint(self, **kwargs) -> Constraint: + """ + Create a constraint instance. + + Must be implemented by subclasses to return their specific constraint type + with appropriate configuration and validation. + + :param kwargs: Additional keyword arguments (usually unused) + :return: Configured constraint instance + :raises NotImplementedError: Must be implemented by subclasses + """ + ... + + +class UnserializableConstraintInitializer(PydanticConstraintInitializer): + """ + Placeholder for constraints that cannot be serialized or executed. + + Represents constraint initializers that failed serialization or contain + non-serializable components. Cannot be executed and raises errors when + invoked to prevent runtime failures from invalid constraint state. + """ + + type_: Literal["unserializable"] = "unserializable" # type: ignore[assignment] + orig_info: dict[str, Any] = Field( + default_factory=dict, + description="Original constraint information before serialization failure", + ) + + @classmethod + def validated_kwargs( + cls, + orig_info: dict[str, Any] | None = None, + **kwargs, # noqa: ARG003 + ) -> dict[str, Any]: + """ + Validate arguments for unserializable constraint creation. + + :param orig_info: Original constraint information before serialization failure + :param kwargs: Additional arguments (ignored) + :return: Validated parameters for unserializable constraint creation + """ + return {"orig_info": orig_info or {}} + + def create_constraint( + self, + **kwargs, # noqa: ARG002 + ) -> Constraint: + """ + Raise error for unserializable constraint creation attempt. + + :param kwargs: Additional keyword arguments (unused) + :raises RuntimeError: Always raised since unserializable constraints + cannot be executed + """ + raise RuntimeError( + "Cannot create constraint from unserializable constraint instance. " + "This constraint cannot be serialized and therefore cannot be executed." + ) + + def __call__( + self, + state: SchedulerState, # noqa: ARG002 + request: ScheduledRequestInfo, # noqa: ARG002 + ) -> SchedulerUpdateAction: + """ + Raise error since unserializable constraints cannot be invoked. + + :param state: Current scheduler state (unused) + :param request: Individual request information (unused) + :raises RuntimeError: Always raised for unserializable constraints + """ + raise RuntimeError( + "Cannot invoke unserializable constraint instance. " + "This constraint was not properly serialized and cannot be executed." + ) + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["max_number", "max_num", "max_requests", "max_req"] +) +class MaxNumberConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on maximum request counts. + + Stops request queuing when created requests reach the limit and stops local + request processing when processed requests reach the limit. Provides progress + tracking based on remaining requests and completion fraction. + """ + + type_: Literal["max_number"] = "max_number" # type: ignore[assignment] + max_num: int | float | list[int | float] = Field( + description="Maximum number of requests allowed before triggering constraint", + ) + current_index: int = Field( + default=-1, description="Current index for list-based max_num values" + ) + + @classmethod + def validated_kwargs( + cls, max_num: int | float | list[int | float], **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxNumberConstraint creation. + + :param max_num: Maximum number of requests to allow + :param kwargs: Supports max_num, max_number, max_requests, max_req, + and optional type_ + :return: Validated dictionary with max_num and type_ fields + """ + aliases = ["max_number", "max_num", "max_requests", "max_req"] + for alias in aliases: + if max_num is None: + max_num = kwargs.get(alias) + + return {"max_num": max_num, "current_index": kwargs.get("current_index", -1)} + + def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return self.model_copy() # type: ignore[return-value] + + def __call__( + self, + state: SchedulerState, + request_info: ScheduledRequestInfo, # noqa: ARG002 + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current scheduler state and request count. + + :param state: Current scheduler state with request counts + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + current_index = max(0, self.current_index) + max_num = ( + self.max_num + if isinstance(self.max_num, (int, float)) + else self.max_num[min(current_index, len(self.max_num) - 1)] + ) + + create_exceeded = state.created_requests >= max_num + processed_exceeded = state.processed_requests >= max_num + remaining_requests = min(max(0, max_num - state.processed_requests), max_num) + remaining_fraction = remaining_requests / float(max_num) + + return SchedulerUpdateAction( + request_queuing="stop" if create_exceeded else "continue", + request_processing="stop_local" if processed_exceeded else "continue", + metadata={ + "max_number": max_num, + "create_exceeded": create_exceeded, + "processed_exceeded": processed_exceeded, + "created_requests": state.created_requests, + "processed_requests": state.processed_requests, + "remaining_fraction": remaining_fraction, + "remaining_requests": remaining_requests, + }, + progress=SchedulerUpdateActionProgress( + remaining_fraction=remaining_fraction, + remaining_requests=remaining_requests, + ), + ) + + @field_validator("max_num") + @classmethod + def _validate_max_num( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + f"max_num must be set and truthful, received {value} ({val} failed)" + ) + if not isinstance(val, (int, float)) or val <= 0: + raise ValueError( + f"max_num must be a positive num, received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["max_duration", "max_dur", "max_sec", "max_seconds", "max_min", "max_minutes"] +) +class MaxDurationConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on maximum time duration. + + Stops both request queuing and processing when the elapsed time since scheduler + start exceeds the maximum duration. Provides progress tracking based on + remaining time and completion fraction. + """ + + type_: Literal["max_duration"] = "max_duration" # type: ignore[assignment] + max_duration: int | float | list[int | float] = Field( + description="Maximum duration in seconds before triggering constraint" + ) + current_index: int = Field(default=-1, description="Current index in duration list") + + @classmethod + def validated_kwargs( + cls, max_duration: int | float | list[int | float] | None = None, **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxDurationConstraint creation. + + :param max_duration: Maximum duration in seconds + :param kwargs: Supports max_duration, max_dur, max_sec, max_seconds, + max_min, max_minutes, and optional type_ + :return: Validated dictionary with max_duration and type_ fields + """ + seconds_aliases = ["max_dur", "max_sec", "max_seconds"] + for alias in seconds_aliases: + if max_duration is None: + max_duration = kwargs.get(alias) + minutes_aliases = ["max_min", "max_minutes"] + for alias in minutes_aliases: + minutes = kwargs.get(alias) + if minutes is not None and max_duration is None: + max_duration = minutes * 60 + + return { + "max_duration": max_duration, + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return self.model_copy() # type: ignore[return-value] + + def __call__( + self, + state: SchedulerState, + request_info: ScheduledRequestInfo, # noqa: ARG002 + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current scheduler state and elapsed time. + + :param state: Current scheduler state with start time + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + current_index = max(0, self.current_index) + max_duration = ( + self.max_duration + if isinstance(self.max_duration, (int, float)) + else self.max_duration[min(current_index, len(self.max_duration) - 1)] + ) + + current_time = time.time() + elapsed = current_time - state.start_time + duration_exceeded = elapsed >= max_duration + remaining_duration = min(max(0.0, max_duration - elapsed), max_duration) + remaining_fraction = remaining_duration / float(max_duration) + + return SchedulerUpdateAction( + request_queuing="stop" if duration_exceeded else "continue", + request_processing="stop_local" if duration_exceeded else "continue", + metadata={ + "max_duration": max_duration, + "elapsed_time": elapsed, + "duration_exceeded": duration_exceeded, + "start_time": state.start_time, + "current_time": current_time, + }, + progress=SchedulerUpdateActionProgress( + remaining_fraction=remaining_fraction, + remaining_duration=remaining_duration, + ), + ) + + @field_validator("max_duration") + @classmethod + def _validate_max_duration( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_duration must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, (int, float)) or val <= 0: + raise ValueError( + "max_duration must be a positive num," + f"received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["max_errors", "max_err", "max_error", "max_errs"] +) +class MaxErrorsConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on absolute error count. + + Stops both request queuing and all request processing when the total number + of errored requests reaches the maximum threshold. Uses global error tracking + across all requests for immediate constraint evaluation. + """ + + type_: Literal["max_errors"] = "max_errors" # type: ignore[assignment] + max_errors: int | float | list[int | float] = Field( + description="Maximum number of errors allowed before triggering constraint", + ) + current_index: int = Field(default=-1, description="Current index in error list") + + @classmethod + def validated_kwargs( + cls, max_errors: int | float | list[int | float] | None = None, **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxErrorsConstraint creation. + + :param max_errors: Maximum number of errors to allow + :param kwargs: Supports max_errors, max_err, max_error, max_errs, + and optional type_ + :return: Validated dictionary with max_errors and type_ fields + """ + aliases = ["max_errors", "max_err", "max_error", "max_errs"] + for alias in aliases: + if max_errors is None: + max_errors = kwargs.get(alias) + + return { + "max_errors": max_errors, + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return self.model_copy() # type: ignore[return-value] + + def __call__( + self, + state: SchedulerState, + request_info: ScheduledRequestInfo, # noqa: ARG002 + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current error count. + + :param state: Current scheduler state with error counts + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + current_index = max(0, self.current_index) + max_errors = ( + self.max_errors + if isinstance(self.max_errors, (int, float)) + else self.max_errors[min(current_index, len(self.max_errors) - 1)] + ) + errors_exceeded = state.errored_requests >= max_errors + + return SchedulerUpdateAction( + request_queuing="stop" if errors_exceeded else "continue", + request_processing="stop_all" if errors_exceeded else "continue", + metadata={ + "max_errors": max_errors, + "errors_exceeded": errors_exceeded, + "current_errors": state.errored_requests, + }, + ) + + @field_validator("max_errors") + @classmethod + def _validate_max_errors( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_errors must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, (int, float)) or val <= 0: + raise ValueError( + f"max_errors must be a positive num,received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["max_error_rate", "max_err_rate", "max_errors_rate"] +) +class MaxErrorRateConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on sliding window error rate. + + Tracks error status of recent requests in a sliding window and stops all + processing when the error rate exceeds the threshold. Only applies the + constraint after processing enough requests to fill the minimum window size + for statistical significance. + """ + + type_: Literal["max_error_rate"] = "max_error_rate" # type: ignore[assignment] + max_error_rate: int | float | list[int | float] = Field( + description="Maximum error rate allowed (0.0, 1.0)" + ) + window_size: int | float = Field( + default=30, + gt=0, + description="Size of sliding window for calculating error rate", + ) + error_window: list[bool] = Field( + default_factory=list, + description="Sliding window tracking error status of recent requests", + ) + current_index: int = Field( + default=-1, description="Current index in the error window" + ) + + @classmethod + def validated_kwargs( + cls, max_error_rate: int | float | list[int | float], **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxErrorRateConstraint creation. + + :param max_error_rate: Maximum error rate to allow + :param kwargs: Supports max_error_rate, max_err_rate, max_errors_rate, + optional window_size, and optional type_ + :return: Validated dictionary with max_error_rate, window_size, + and type_ fields + """ + aliases = ["max_error_rate", "max_err_rate", "max_errors_rate"] + for alias in aliases: + if max_error_rate is None: + max_error_rate = kwargs.get(alias) + + return { + "max_error_rate": max_error_rate, + "window_size": kwargs.get( + "window_size", settings.constraint_error_window_size + ), + "error_window": kwargs.get("error_window", []), + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 + """ + Create a new instance of MaxErrorRateConstraint (due to stateful window). + + :param kwargs: Additional keyword arguments (unused) + :return: New instance of the constraint + """ + self.current_index += 1 + + return self.model_copy() # type: ignore[return-value] + + def __call__( + self, state: SchedulerState, request_info: ScheduledRequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against sliding window error rate. + + :param state: Current scheduler state with request counts + :param request_info: Individual request with completion status + :return: Action indicating whether to continue or stop operations + """ + current_index = max(0, self.current_index) + max_error_rate = ( + self.max_error_rate + if isinstance(self.max_error_rate, (int, float)) + else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] + ) + + if request_info.status in ["completed", "errored", "cancelled"]: + self.error_window.append(request_info.status == "errored") + if len(self.error_window) > self.window_size: + self.error_window.pop(0) + + error_count = sum(self.error_window) + window_requests = len(self.error_window) + error_rate = ( + error_count / float(window_requests) if window_requests > 0 else 0.0 + ) + exceeded_min_processed = state.processed_requests >= self.window_size + exceeded_error_rate = error_rate >= max_error_rate + + return SchedulerUpdateAction( + request_queuing=( + "stop" if exceeded_min_processed and exceeded_error_rate else "continue" + ), + request_processing=( + "stop_all" + if exceeded_min_processed and exceeded_error_rate + else "continue" + ), + metadata={ + "max_error_rate": max_error_rate, + "window_size": self.window_size, + "error_count": error_count, + "processed_count": state.processed_requests, + "current_window_size": len(self.error_window), + "current_error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + }, + ) + + @field_validator("max_error_rate") + @classmethod + def _validate_max_error_rate( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_error_rate must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, (int, float)) or val <= 0 or val >= 1: + raise ValueError( + "max_error_rate must be a number between 0 and 1," + f"received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["max_global_error_rate", "max_global_err_rate", "max_global_errors_rate"] +) +class MaxGlobalErrorRateConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on global error rate. + + Calculates error rate across all processed requests and stops all processing + when the rate exceeds the threshold. Only applies the constraint after + processing the minimum number of requests to ensure statistical significance + for global error rate calculations. + """ + + type_: Literal["max_global_error_rate"] = "max_global_error_rate" # type: ignore[assignment] + max_error_rate: int | float = Field( + description="Maximum error rate allowed (0.0 to 1.0)" + ) + min_processed: int | float | None = Field( + default=30, + gt=0, + description="Minimum requests processed before applying error rate constraint", + ) + current_index: int = Field( + default=-1, description="Current index for list-based max_error_rate values" + ) + + @classmethod + def validated_kwargs( + cls, max_error_rate: int | float | list[int | float], **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxGlobalErrorRateConstraint creation. + + :param max_error_rate: Maximum error rate to allow + :param kwargs: Supports max_global_error_rate, max_global_err_rate, + max_global_errors_rate, optional min_processed, and optional type_ + :return: Validated dictionary with max_error_rate, min_processed, + and type_ fields + """ + for alias in [ + "max_global_error_rate", + "max_global_err_rate", + "max_global_errors_rate", + ]: + if max_error_rate is None: + max_error_rate = kwargs.get(alias) + + return { + "max_error_rate": max_error_rate, + "min_processed": kwargs.get( + "min_processed", settings.constraint_error_min_processed + ), + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return self.model_copy() # type: ignore[return-value] + + def __call__( + self, + state: SchedulerState, + request_info: ScheduledRequestInfo, # noqa: ARG002 + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against global error rate. + + :param state: Current scheduler state with global request and error counts + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + current_index = max(0, self.current_index) + max_error_rate = ( + self.max_error_rate + if isinstance(self.max_error_rate, (int, float)) + else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] + ) + + exceeded_min_processed = ( + self.min_processed is None or state.processed_requests >= self.min_processed + ) + error_rate = ( + state.errored_requests / float(state.processed_requests) + if state.processed_requests > 0 + else 0.0 + ) + exceeded_error_rate = error_rate >= max_error_rate + should_stop = exceeded_min_processed and exceeded_error_rate + + return SchedulerUpdateAction( + request_queuing="stop" if should_stop else "continue", + request_processing="stop_all" if should_stop else "continue", + metadata={ + "max_error_rate": max_error_rate, + "min_processed": self.min_processed, + "processed_requests": state.processed_requests, + "errored_requests": state.errored_requests, + "error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + }, + ) + + @field_validator("max_error_rate") + @classmethod + def _validate_max_error_rate( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_error_rate must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, (int, float)) or val <= 0 or val >= 1: + raise ValueError( + "max_error_rate must be a number between 0 and 1," + f"received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +class RequestsExhaustedConstraint(StandardBaseModel, InfoMixin): + type_: Literal["requests_exhausted"] = "requests_exhausted" # type: ignore[assignment] + num_requests: int + + @property + def info(self) -> dict[str, Any]: + """ + Extract serializable information from this constraint initializer. + + :return: Dictionary containing constraint configuration and metadata + """ + return self.model_dump() + + def __call__( + self, + state: SchedulerState, + request_info: ScheduledRequestInfo, # noqa: ARG002 + ) -> SchedulerUpdateAction: + create_exceeded = state.created_requests >= self.num_requests + processed_exceeded = state.processed_requests >= self.num_requests + remaining_fraction = min( + max(0.0, 1.0 - state.processed_requests / float(self.num_requests)), 1.0 + ) + remaining_requests = max(0, self.num_requests - state.processed_requests) + + return SchedulerUpdateAction( + request_queuing="stop" if create_exceeded else "continue", + request_processing="stop_local" if processed_exceeded else "continue", + metadata={ + "num_requests": self.num_requests, + "create_exceeded": create_exceeded, + "processed_exceeded": processed_exceeded, + "created_requests": state.created_requests, + "processed_requests": state.processed_requests, + "remaining_fraction": remaining_fraction, + "remaining_requests": remaining_requests, + }, + progress=SchedulerUpdateActionProgress( + remaining_fraction=remaining_fraction, + remaining_requests=remaining_requests, + ), + ) diff --git a/src/guidellm/scheduler/environments.py b/src/guidellm/scheduler/environments.py new file mode 100644 index 00000000..6234f8f6 --- /dev/null +++ b/src/guidellm/scheduler/environments.py @@ -0,0 +1,273 @@ +""" +Environment abstractions for coordinating scheduler execution across distributed nodes. + +Provides environment abstractions that handle synchronization, timing coordination, +error propagation, and lifecycle management for scheduler execution across single +or multiple nodes. The Environment protocol defines the interface for distributed +coordination while NonDistributedEnvironment provides a minimal implementation +for single-node execution. + +Environment Execution Flow: +1. sync_run_params() - Distribute workload and synchronize parameters across nodes +2. sync_run_start() - Coordinate synchronized start time for all nodes +3. update_run_iteration() - Update state after each request (called per iteration) +4. sync_run_error() - Handle and propagate errors across nodes +5. sync_run_end() - Aggregate results and cleanup at completion +""" + +from __future__ import annotations + +import time +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Iterable +from typing import ( + Generic, +) + +from guidellm.scheduler.constraints import Constraint +from guidellm.scheduler.objects import ( + MultiTurnRequestT, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerState, +) +from guidellm.scheduler.strategies import SchedulingStrategy +from guidellm.settings import settings +from guidellm.utils import InfoMixin + +__all__ = ["Environment", "NonDistributedEnvironment"] + + +class Environment(ABC, Generic[RequestT, ResponseT], InfoMixin): + """ + Abstract base for coordinating scheduler execution across distributed nodes. + + Defines the interface for managing distributed scheduler execution including + parameter synchronization, timing coordination, state updates, error propagation, + and result aggregation. Implementations handle the complexity of distributed + coordination while providing a unified interface for scheduler orchestration. + """ + + @abstractmethod + async def sync_run_params( + self, + requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + strategy: SchedulingStrategy, + constraints: dict[str, Constraint], + ) -> tuple[ + Iterable[RequestT | MultiTurnRequestT[RequestT]], + SchedulingStrategy, + dict[str, Constraint], + ]: + """ + Synchronize execution parameters across nodes and resolve local scope. + + Coordinates parameter distribution and validation across active nodes. + In distributed environments, handles node assignment and workload partitioning. + In non-distributed environments, typically returns parameters unchanged. + + :param requests: Complete set of requests to process across all nodes + :param strategy: Scheduling strategy to apply during execution + :param constraints: Runtime constraints to enforce during execution + :return: Tuple of (local_requests, strategy, constraints) for this node + :raises Exception: If parameter synchronization fails or nodes inconsistent + """ + ... + + @abstractmethod + async def sync_run_start(self) -> float: + """ + Coordinate synchronized start time across all nodes. + + Ensures all nodes begin processing simultaneously for accurate benchmarking + and consistent timing measurements across distributed execution. + + :return: Unix timestamp when all nodes should begin processing + :raises Exception: If startup synchronization fails across nodes + """ + ... + + @abstractmethod + async def update_run_iteration( + self, + response: ResponseT | None, + request: RequestT, + request_info: ScheduledRequestInfo, + state: SchedulerState, + ): + """ + Update environment state with completed request iteration results. + + Called after each request processing to update execution progress and + synchronize any required state across nodes in distributed environments. + Generally, distributed is expected to store the iteration updates until + all nodes have processed and sync_run_end is called to retrieve them. + + :param response: Response generated for the request, if successful + :param request: The processed request + :param request_info: Metadata about request processing including timings + :param state: Current scheduler state with metrics and progress + :raises Exception: If state update fails or indicates critical errors + """ + ... + + @abstractmethod + async def sync_run_error(self, err: list[Exception] | Exception): + """ + Handle and propagate errors across all active nodes. + + Coordinates error handling when failures occur, ensuring all nodes are + notified for appropriate cleanup or shutdown procedures. + + :param err: The exception(s) that occurred during execution + """ + ... + + @abstractmethod + async def sync_run_end( + self, + ) -> AsyncIterator[ + tuple[ + ResponseT, + RequestT | MultiTurnRequestT[RequestT], + ScheduledRequestInfo, + SchedulerState, + ] + ]: + """ + Finalize execution and aggregate results from all nodes. + + Handles cleanup, result synchronization, and error propagation at execution + completion. Collects and yields results from worker nodes in distributed + environments. + + :return: Iterator of (response, request, request_info, state) tuples from + remote nodes in distributed environments, empty for non-distributed + :raises Exception: Any errors that occurred during execution + """ + ... + + +class NonDistributedEnvironment(Environment): + """ + Single-node scheduler execution environment with minimal coordination overhead. + + Simplified environment for running schedulers on a single node without distributed + coordination requirements. Implements the Environment interface with no-op + synchronization for local testing, development, and single-machine benchmarking. + + Example: + :: + from guidellm.scheduler import ( + MaxNumberConstraint, + NonDistributedEnvironment, + ScheduledRequestInfo, + SchedulerState, + SynchronousStrategy, + ) + + + # Definitions + requests = [f"req_{ind}" for ind in range(5)] + strategy = SynchronousStrategy() + constraints = {"max_num": MaxNumberConstraint(max_num=5)} + state = SchedulerState() + + # Run environment + local_req, local_strat, local_const = await env.sync_run_params( + requests, strategy, constraints + ) + start_time = await env.sync_run_start() + for req in local_req: + state.processed_requests += 1 + await env.update_run_iteration( + f"resp_{req}", req, ScheduledRequestInfo(), state + ) + async for nonlocal_req in env.sync_run_end(): + state.processed_requests += 1 + """ + + def __init__(self): + """Initialize with empty error storage for single-node execution.""" + self.run_errors: list[Exception] = [] + + async def sync_run_params( + self, + requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + strategy: SchedulingStrategy, + constraints: dict[str, Constraint], + ) -> tuple[ + Iterable[RequestT | MultiTurnRequestT[RequestT]], + SchedulingStrategy, + dict[str, Constraint], + ]: + """ + Return parameters unchanged for single-node execution. + + :param requests: Requests to process locally + :param strategy: Scheduling strategy to apply during execution + :param constraints: Runtime constraints to enforce during execution + :return: Tuple containing the original (requests, strategy, constraints) + """ + return requests, strategy, constraints + + async def sync_run_start(self) -> float: + """ + Return current time plus configured delay for single-node startup. + + :return: Unix timestamp for when the run should start + """ + return time.time() + settings.scheduler_start_delay_non_distributed + + async def update_run_iteration( + self, + response: ResponseT | None, + request: RequestT, + request_info: ScheduledRequestInfo, + state: SchedulerState, + ): + """ + No-op for single-node execution with no distributed state synchronization. + + :param response: Response generated for the request, if successful + :param request: The request that was processed + :param request_info: Metadata about request processing including timings + :param state: Current scheduler state with metrics and progress + """ + + async def sync_run_error(self, err: Exception): + """ + Store error for later propagation during run finalization. + + :param err: The exception(s) that occurred during execution + """ + err = [err] if not isinstance(err, list) else err + self.run_errors.extend(err) + + async def sync_run_end( + self, + ) -> AsyncIterator[ + tuple[ + ResponseT, + RequestT | MultiTurnRequestT[RequestT], + ScheduledRequestInfo, + SchedulerState, + ] + ]: + """ + Finalize single-node execution and propagate any stored errors. + + :return: Empty iterator since there are no remote nodes + :raises Exception: Any error stored during execution via sync_run_error + """ + if self.run_errors: + if len(self.run_errors) == 1: + raise self.run_errors[0] + else: + raise RuntimeError( + f"Errors occurred during execution: {self.run_errors}" + ) + + return + yield # needed to force generator compilation diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/objects.py new file mode 100644 index 00000000..b7f2efc3 --- /dev/null +++ b/src/guidellm/scheduler/objects.py @@ -0,0 +1,468 @@ +""" +Core data structures and interfaces for the GuideLLM scheduler system. + +Provides type-safe abstractions for distributed request processing, timing +measurements, and backend interfaces for benchmarking operations. Central to +the scheduler architecture, enabling request lifecycle tracking, backend +coordination, and state management across distributed worker processes. +""" + +from __future__ import annotations + +import time +import uuid +from collections.abc import AsyncIterator +from typing import ( + Any, + ClassVar, + Generic, + Literal, + Protocol, + TypeVar, + Union, +) + +from pydantic import Field, computed_field +from typing_extensions import TypeAliasType, TypedDict + +from guidellm.utils import ( + PydanticClassRegistryMixin, + RegistryMixin, + StandardBaseModel, +) +from guidellm.utils.registry import RegistryObjT + +__all__ = [ + "BackendInterface", + "BackendT", + "MeasuredRequestTimings", + "MultiTurnRequestT", + "RequestSchedulerTimings", + "RequestT", + "ResponseT", + "ScheduledRequestInfo", + "SchedulerMessagingPydanticRegistry", + "SchedulerState", + "SchedulerUpdateAction", + "SchedulerUpdateActionProgress", +] + +RequestT = TypeVar("RequestT") +"""Generic request object type for scheduler processing.""" + +ResponseT = TypeVar("ResponseT") +"""Generic response object type returned by backend processing.""" + +MultiTurnRequestT = TypeAliasType( + "MultiTurnRequestT", + Union[ + list[Union[RequestT, tuple[RequestT, float]]], + tuple[Union[RequestT, tuple[RequestT, float]]], + ], + type_params=(RequestT,), +) +"""Multi-turn request structure supporting conversation history with optional delays.""" + + +class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]): + """ + Registry for enabling a generic interface to define the pydantic class types used + for inter-process messaging within the scheduler. + """ + + +@SchedulerMessagingPydanticRegistry.register() +class RequestSchedulerTimings(StandardBaseModel): + """ + Scheduler-level timing measurements for request lifecycle tracking. + All timestamps are expected to be in Unix time (seconds since epoch). + """ + + targeted_start: float | None = Field( + default=None, + description="When the request was initially targeted for execution", + ) + queued: float | None = Field( + default=None, + description="When the request was placed into the processing queue", + ) + dequeued: float | None = Field( + default=None, + description="When the request was removed from the queue for processing", + ) + scheduled_at: float | None = Field( + default=None, description="When the request was scheduled for processing" + ) + resolve_start: float | None = Field( + default=None, description="When backend resolution of the request began" + ) + resolve_end: float | None = Field( + default=None, description="When backend resolution of the request completed" + ) + finalized: float | None = Field( + default=None, + description="When the request was processed/acknowledged by the scheduler", + ) + + +@SchedulerMessagingPydanticRegistry.register() +class MeasuredRequestTimings(PydanticClassRegistryMixin["MeasuredRequestTimings"]): + """ + Base timing measurements for backend request processing. + All timestamps are expected to be in Unix time (seconds since epoch). + """ + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[MeasuredRequestTimings]: + if cls.__name__ == "MeasuredRequestTimings": + return cls + + return MeasuredRequestTimings + + schema_discriminator: ClassVar[str] = "timings_type" + + timings_type: Literal["measured_request_timings"] = Field( + default="measured_request_timings", + description="Type identifier for the timing measurement", + ) + request_start: float | None = Field( + default=None, description="When the backend began processing the request" + ) + request_end: float | None = Field( + default=None, description="When the backend completed processing the request" + ) + + +@SchedulerMessagingPydanticRegistry.register() +class ScheduledRequestInfo(StandardBaseModel): + """ + Complete request information including status, timings, and metadata. + + Central data structure for tracking request lifecycle from creation through + completion, containing scheduling metadata, timing measurements, and processing + status. Used by scheduler components to coordinate request processing across + distributed worker processes. + + Example: + :: + from guidellm.scheduler.objects import ScheduledRequestInfo + + # Create request info with automatic ID generation + request_info = ScheduledRequestInfo() + request_info.status = "in_progress" + request_info.scheduler_timings.queued = time.time() + + # Check processing completion + if request_info.completed_at: + duration = request_info.completed_at - request_info.started_at + """ + + request_id: str = Field( + description="Unique identifier for the request", + default_factory=lambda: str(uuid.uuid4()), + ) + status: Literal[ + "queued", "pending", "in_progress", "completed", "errored", "cancelled" + ] = Field(description="Current processing status of the request", default="queued") + scheduler_node_id: int = Field( + description="ID/rank of the scheduler node handling the request", + default=-1, + ) + scheduler_process_id: int = Field( + description="ID/rank of the node's scheduler process handling the request", + default=-1, + ) + scheduler_start_time: float = Field( + description="Unix timestamp for the local time when scheduler processing began", + default=-1, + ) + + error: str | None = Field( + default=None, description="Error message if the request.status is 'errored'" + ) + scheduler_timings: RequestSchedulerTimings = Field( + default_factory=RequestSchedulerTimings, + description="Scheduler-level timing measurements for request lifecycle", + ) + request_timings: MeasuredRequestTimings | None = Field( + default=None, + description="Backend-specific timing measurements for request processing", + ) + + @computed_field # type: ignore[misc] + @property + def started_at(self) -> float | None: + """ + Get the effective request processing start time. + + :return: Unix timestamp when processing began, or None if not started. + """ + request_start = ( + self.request_timings.request_start if self.request_timings else None + ) + + return request_start or self.scheduler_timings.resolve_start + + @computed_field # type: ignore[misc] + @property + def completed_at(self) -> float | None: + """ + Get the effective request processing completion time. + + :return: Unix timestamp when processing completed, or None if not completed. + """ + request_end = self.request_timings.request_end if self.request_timings else None + + return request_end or self.scheduler_timings.resolve_end + + def model_copy(self, **kwargs) -> ScheduledRequestInfo: # type: ignore[override] # noqa: ARG002 + """ + Create a deep copy of the request info with copied timing objects. + + :return: New ScheduledRequestInfo instance with independent timing objects + """ + return super().model_copy( + update={ + "scheduler_timings": self.scheduler_timings.model_copy(), + "request_timings": ( + self.request_timings.model_copy() if self.request_timings else None + ), + }, + deep=False, + ) + + +class BackendInterface(Protocol, Generic[RequestT, ResponseT]): + """ + Abstract interface for request processing backends. + + Defines the contract for backend implementations that process requests within + the scheduler system. Backends handle initialization, validation, processing, + and shutdown lifecycle management. Must ensure all properties are pickleable + before process_startup is invoked for multi-process environments. + + Example: + :: + from guidellm.scheduler.objects import BackendInterface + + class CustomBackend(BackendInterface): + @property + def processes_limit(self) -> int: + return 4 + + async def resolve(self, request, request_info, history=None): + # Process request and yield responses + yield response, updated_request_info + """ + + @property + def processes_limit(self) -> int | None: + """ + :return: Maximum worker processes supported, or None if unlimited + """ + + @property + def requests_limit(self) -> int | None: + """ + :return: Maximum concurrent requests supported, or None if unlimited + """ + + @property + def info(self) -> dict[str, Any]: + """ + :return: Backend metadata including model initialization and configuration + """ + + async def process_startup(self) -> None: + """ + Perform backend initialization and startup procedures. + + :raises: Implementation-specific exceptions for startup failures. + """ + + async def validate(self) -> None: + """ + Validate backend configuration and operational status. + + :raises: Implementation-specific exceptions for validation failures. + """ + + async def process_shutdown(self) -> None: + """ + Perform backend cleanup and shutdown procedures. + + :raises: Implementation-specific exceptions for shutdown failures. + """ + + async def resolve( + self, + request: RequestT, + request_info: ScheduledRequestInfo, + history: list[tuple[RequestT, ResponseT]] | None = None, + ) -> AsyncIterator[tuple[ResponseT, ScheduledRequestInfo]]: + """ + Process a request and yield incremental response updates. + + :param request: The request object to process + :param request_info: Scheduling metadata and timing information + :param history: Optional conversation history for multi-turn requests + :yield: Tuples of (response, updated_request_info) for each response chunk + :raises: Implementation-specific exceptions for processing failures + """ + + +BackendT = TypeVar("BackendT", bound=BackendInterface) +"""Generic backend interface type for request processing.""" + + +class SchedulerUpdateActionProgress(TypedDict, total=False): + """ + Progress information for a scheduler update action. + + Optional progress tracking data that provides estimates for remaining work + in scheduler operations. Used by constraints and monitoring systems to + track execution progress and make termination decisions. + """ + + remaining_fraction: float | None + remaining_requests: float | None + remaining_duration: float | None + + +class SchedulerUpdateAction(StandardBaseModel): + """ + Scheduler behavior control directives and actions. + + Encapsulates control signals for scheduler operations including request + queuing and processing directives. Used by constraints to communicate + termination conditions and progress information to scheduler components. + + Example: + :: + from guidellm.scheduler.objects import SchedulerUpdateAction + + # Signal to stop queuing but continue processing + action = SchedulerUpdateAction( + request_queuing="stop", + request_processing="continue", + metadata={"reason": "max_requests_reached"} + ) + """ + + request_queuing: Literal["continue", "stop"] = Field( + default="continue", description="Action to take for request queuing operations" + ) + request_processing: Literal["continue", "stop_local", "stop_all"] = Field( + default="continue", + description="Action to take for request processing operations", + ) + metadata: dict[str, Any] = Field( + default_factory=dict, + description="Additional context and data for the scheduler action", + ) + progress: SchedulerUpdateActionProgress = Field( + default_factory=lambda: SchedulerUpdateActionProgress(), + description="Progress information for the scheduler action", + ) + + +class SchedulerState(StandardBaseModel): + """ + Scheduler operation state tracking and statistics. + + Comprehensive state container for tracking scheduler execution progress, + request counts, timing information, and constraint enforcement. Central + to scheduler coordination and provides real-time metrics for monitoring + and decision-making across distributed worker processes. + + Example: + :: + from guidellm.scheduler.objects import SchedulerState + + # Initialize scheduler state + state = SchedulerState(node_id=0, num_processes=4) + + # Track request processing + state.created_requests += 1 + state.queued_requests += 1 + + # Monitor completion progress + completion_rate = state.processed_requests / state.created_requests + """ + + node_id: int = Field( + description="Unique identifier for this scheduler node", default=-1 + ) + num_processes: int = Field( + description="Number of worker processes in this scheduler", default=-1 + ) + start_time: float = Field( + description="Unix timestamp when the scheduler started", + default_factory=time.time, + ) + end_time: float | None = Field( + default=None, description="Unix timestamp when the scheduler stopped" + ) + end_queuing_time: float | None = Field( + default=None, description="When request queuing stopped, if applicable" + ) + end_queuing_constraints: dict[str, SchedulerUpdateAction] = Field( + default_factory=dict, + description="Constraints that triggered queuing termination", + ) + end_processing_time: float | None = Field( + default=None, description="When request processing stopped, if applicable" + ) + end_processing_constraints: dict[str, SchedulerUpdateAction] = Field( + default_factory=dict, + description="Constraints that triggered process ing termination", + ) + scheduler_constraints: dict[str, SchedulerUpdateAction] = Field( + default_factory=dict, + description=( + "The latest state from all constraints applied during the scheduler run" + ), + ) + + remaining_fraction: float | None = Field( + default=None, + description=( + "Estimated fraction for the remaining progress of the run, if known" + ), + ) + remaining_requests: float | None = Field( + default=None, + description="Estimated number of requests remaining to be processed, if known", + ) + remaining_duration: float | None = Field( + default=None, + description=( + "Estimated time remaining in seconds for the scheduler run, if known" + ), + ) + + created_requests: int = Field( + default=0, description="Total number of requests created" + ) + queued_requests: int = Field( + default=0, description="Total number of requests queued for processing" + ) + pending_requests: int = Field( + default=0, + description="Total number of requests pending processing within a worker", + ) + processing_requests: int = Field( + default=0, description="Number of requests currently being processed" + ) + processed_requests: int = Field( + default=0, description="Total number of requests that completed processing" + ) + successful_requests: int = Field( + default=0, description="Number of requests that completed successfully" + ) + errored_requests: int = Field( + default=0, description="Number of requests that failed with errors" + ) + cancelled_requests: int = Field( + default=0, description="Number of requests that were cancelled" + ) diff --git a/src/guidellm/scheduler/queues.py b/src/guidellm/scheduler/queues.py deleted file mode 100644 index 6ccc6704..00000000 --- a/src/guidellm/scheduler/queues.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -Helper module for importing the correct queue types. -""" - -from dataclasses import dataclass -from queue import Empty as QueueEmpty -from queue import Full as QueueFull -from queue import Queue -from typing import Generic - -from guidellm.request.types import RequestT, ResponseT -from guidellm.scheduler.result import WorkerProcessRequest, WorkerProcessResult - -__all__ = [ - "MPQueues", - "Queue", - "QueueEmpty", - "QueueFull", -] - - -@dataclass -class MPQueues(Generic[RequestT, ResponseT]): - requests: Queue[WorkerProcessRequest[RequestT, ResponseT]] - responses: Queue[WorkerProcessResult[RequestT, ResponseT]] diff --git a/src/guidellm/scheduler/result.py b/src/guidellm/scheduler/result.py deleted file mode 100644 index 04fbf931..00000000 --- a/src/guidellm/scheduler/result.py +++ /dev/null @@ -1,155 +0,0 @@ -from dataclasses import dataclass -from typing import ( - Generic, - Literal, - Optional, -) - -from guidellm.objects import StandardBaseModel -from guidellm.request.types import RequestT, ResponseT -from guidellm.scheduler.strategy import SchedulingStrategy - -__all__ = [ - "SchedulerRequestInfo", - "SchedulerRequestResult", - "SchedulerResult", - "SchedulerRunInfo", - "WorkerProcessRequest", - "WorkerProcessResult", -] - - -class SchedulerRunInfo(StandardBaseModel): - """ - Information about the current run of the scheduler. - This class holds metadata about the scheduling run, - including the start and end times, the number of processes, - and the scheduling strategy used. - It also tracks the number of requests created, queued, pending, - and completed during the run. - - :param start_time: The start time of the scheduling run. - :param end_time: The end time of the scheduling run; - if None, then this will be math.inf. - :param end_number: The maximum number of requests to be processed; - if None, then this will be math.inf. - :param processes: The number of processes used in the scheduling run. - :param strategy: The scheduling strategy used in the run. - This should be an instance of SchedulingStrategy. - :param created_requests: The number of requests created during the run. - :param queued_requests: The number of requests queued during the run. - :param scheduled_requests: The number of requests scheduled during the run. - (requests pending being sent to the worker but recieved by a process) - :param processing_requests: The number of requests actively being run. - :param completed_requests: The number of requests completed during the run. - """ - - start_time: float - end_time: float - end_number: float - processes: int - strategy: SchedulingStrategy - - created_requests: int = 0 - queued_requests: int = 0 - scheduled_requests: int = 0 - processing_requests: int = 0 - completed_requests: int = 0 - - -class SchedulerRequestInfo(StandardBaseModel): - """ - Information about a specific request run through the scheduler. - This class holds metadata about the request, including - the targeted start time, queued time, start time, end time, - and the process ID that handled the request. - - :param targeted_start_time: The targeted start time for the request (time.time()). - :param queued_time: The time the request was queued (time.time()). - :param scheduled_time: The time the request was scheduled (time.time()) - (any sleep time before the request was sent to the worker). - :param worker_start: The time the worker started processing request (time.time()). - :param worker_end: The time the worker finished processing request. (time.time()). - :param process_id: The ID of the underlying process that handled the request. - """ - - requested: bool = False - completed: bool = False - errored: bool = False - canceled: bool = False - - targeted_start_time: float = -1 - queued_time: float = -1 - dequeued_time: float = -1 - scheduled_time: float = -1 - worker_start: float = -1 - request_start: float = -1 - request_end: float = -1 - worker_end: float = -1 - process_id: int = -1 - - -class SchedulerResult(StandardBaseModel): - """ - The yielded, iterative result for a scheduler run. - These are triggered on the start and end of the run, - as well as on the start and end of each request. - Depending on the type, it will hold the request and response - along with information and statistics about the request and general run. - - :param type_: The type of the result, which can be one of: - - "run_start": Indicates the start of the run. - - "run_complete": Indicates the completion of the run (teardown happens after). - - "request_start": Indicates the start of a request. - - "request_complete": Indicates the completion of a request. - :param request: The request that was processed. - :param response: The response from the worker for the request. - :param request_info: Information about the request, including - the targeted start time, queued time, start time, end time, - and the process ID that handled the request. - :param run_info: Information about the current run of the scheduler, - including the start and end times, the number of processes, - and the scheduling strategy used. - It also tracks the number of requests created, queued, pending, - and completed during the run. - """ - - pydantic_type: Literal["scheduler_result"] = "scheduler_result" - type_: Literal[ - "run_start", - "run_complete", - "request_scheduled", - "request_start", - "request_complete", - ] - run_info: SchedulerRunInfo - - -class SchedulerRequestResult( - SchedulerResult, - Generic[RequestT, ResponseT], -): - pydantic_type: Literal["scheduler_request_result"] = "scheduler_request_result" # type: ignore[assignment] - type_: Literal[ - "request_scheduled", - "request_start", - "request_complete", - ] - request: RequestT - request_info: SchedulerRequestInfo - response: Optional[ResponseT] = None - - -@dataclass -class WorkerProcessRequest(Generic[RequestT, ResponseT]): - request: RequestT - timeout_time: float - queued_time: float - - -@dataclass -class WorkerProcessResult(Generic[RequestT, ResponseT]): - type_: Literal["request_scheduled", "request_start", "request_complete"] - request: RequestT - response: Optional[ResponseT] - info: SchedulerRequestInfo diff --git a/src/guidellm/scheduler/scheduler.py b/src/guidellm/scheduler/scheduler.py index 11e1102a..de0660e2 100644 --- a/src/guidellm/scheduler/scheduler.py +++ b/src/guidellm/scheduler/scheduler.py @@ -1,390 +1,165 @@ -import asyncio -import math -import time -from collections.abc import AsyncGenerator, Iterable, Iterator -from concurrent.futures import ProcessPoolExecutor -from multiprocessing import Manager -from threading import Event -from typing import ( - Any, - Generic, - Optional, - Union, -) +""" +Thread-safe singleton scheduler for distributed load generation workload coordination. + +Provides the core orchestration engine that coordinates request processing across +worker processes and distributed environments. Manages timing synchronization, +resource allocation, constraint enforcement, and result aggregation for +load generation operations. Integrates with backends, environments, and strategies +to enable scalable load testing across various scenarios including LLM inference. +""" -from loguru import logger +from __future__ import annotations -from guidellm.request.types import ( +from collections.abc import AsyncIterator, Iterable +from typing import Any, Generic + +from guidellm.scheduler.constraints import ( + Constraint, + ConstraintsInitializerFactory, +) +from guidellm.scheduler.environments import Environment, NonDistributedEnvironment +from guidellm.scheduler.objects import ( + BackendInterface, + MultiTurnRequestT, RequestT, ResponseT, + ScheduledRequestInfo, + SchedulerState, ) -from guidellm.scheduler.queues import MPQueues, Queue, QueueEmpty -from guidellm.scheduler.result import ( - SchedulerRequestResult, - SchedulerResult, - SchedulerRunInfo, - WorkerProcessRequest, - WorkerProcessResult, -) -from guidellm.scheduler.strategy import SchedulingStrategy -from guidellm.scheduler.worker import ( - RequestsWorker, -) -from guidellm.settings import settings +from guidellm.scheduler.strategies import SchedulingStrategy +from guidellm.scheduler.worker_group import WorkerProcessGroup +from guidellm.utils.singleton import ThreadSafeSingletonMixin __all__ = ["Scheduler"] -class Scheduler(Generic[RequestT, ResponseT]): +class Scheduler( + Generic[RequestT, ResponseT], + ThreadSafeSingletonMixin, +): """ - A class that handles the scheduling of requests to a worker. - This class is responsible for managing the lifecycle of the requests, - including their creation, queuing, and processing. - It uses a multiprocessing approach to handle requests concurrently - and efficiently, based on the specified scheduling strategy. - The Scheduler class is designed to work with a RequestsWorker, - which is an abstract base class that defines the interface for a worker - that can resolve requests asynchronously or synchronously. - The Scheduler class also supports different scheduling strategies, - including synchronous, throughput, and concurrent strategies. - - :param worker: The worker that will process the requests. - This should be an instance of RequestsWorker. - :param request_loader: An iterable that generates requests. - This can be a list, generator, or any other iterable. - The requests will be processed by the worker. + Thread-safe singleton scheduler for distributed benchmarking workload coordination. + + Orchestrates request processing across worker processes with distributed timing + coordination, constraint enforcement, and result aggregation. Provides a unified + interface for executing benchmarking operations while abstracting the complexity + of multi-process coordination, environment synchronization, and resource management. + Implements singleton pattern to ensure consistent execution state across concurrent + benchmark operations. + + Example: + :: + from guidellm.scheduler import Scheduler + from guidellm.backend import OpenAIBackend + from guidellm.scheduler import NonDistributedEnvironment, SynchronousStrategy + + scheduler = Scheduler() + async for response, request, info, state in scheduler.run( + requests=request_list, + backend=backend, + strategy=SynchronousStrategy(), + env=NonDistributedEnvironment(), + max_requests=1000 + ): + print(f"Processed: {request} with info: {info} and response: {response}") """ - def __init__( - self, - worker: RequestsWorker[RequestT, ResponseT], - request_loader: Iterable[RequestT], - ): - if not isinstance(worker, RequestsWorker): - raise ValueError(f"Invalid worker: {worker}") - - if not isinstance(request_loader, Iterable): - raise ValueError(f"Invalid request_loader: {request_loader}") - - self.worker = worker - self.request_loader = request_loader - async def run( self, - scheduling_strategy: SchedulingStrategy, - max_number: Optional[int] = None, - max_duration: Optional[float] = None, - ) -> AsyncGenerator[ - Union[SchedulerResult, SchedulerRequestResult[RequestT, ResponseT]], None + requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + backend: BackendInterface[RequestT, ResponseT], + strategy: SchedulingStrategy, + env: Environment | None, + **constraints: dict[str, Any | dict[str, Any] | Constraint], + ) -> AsyncIterator[ + tuple[ + ResponseT | None, + RequestT, + ScheduledRequestInfo, + SchedulerState, + ] ]: """ - The main method that runs the scheduler. - This method is a generator that yields SchedulerResult objects - at the start and end of the run, as well as at the start and end - of each request. - It uses multiprocessing to handle requests concurrently - and efficiently, based on the specified scheduling strategy. - The method also handles the lifecycle of the requests, - including their creation, queuing, and processing. - The method is designed to be used as an asynchronous generator, - allowing it to be used with asyncio and other asynchronous frameworks. - - :param scheduling_strategy: The scheduling strategy to use. - Specifies the times at which requests will be sent as well how many - worker processes are used and if requests are scheduled sync or async. - This can be one of the following: - - "synchronous": Requests are sent synchronously. - - "throughput": Requests are sent at the maximum rate possible. - - An instance of SchedulingStrategy. - :param max_number: The maximum number of requests to process. - If None, then no limit is set and either the iterator must be exhaustible - or the max_duration must be set. - :param max_duration: The maximum duration for the scheduling run. - If None, then no limit is set and either the iterator must be exhaustible - or the max_number must be set. - :return: An asynchronous generator that yields SchedulerResult objects. - Each SchedulerResult object contains information about the request, - the response, and the run information. + Execute distributed request processing with coordinated timing and constraints. + + Orchestrates the complete benchmarking workflow across worker processes with + environment synchronization, constraint enforcement, and error handling. + Manages resource lifecycle from initialization through cleanup while yielding + real-time processing updates for monitoring and aggregation. + + :param requests: Request collection to process. Supports single requests or + multi-turn sequences with optional inter-request delays + :param backend: Backend interface for request processing and response generation + :param strategy: Scheduling strategy controlling request timing and distribution + :param env: Environment interface for distributed coordination and + synchronization + :param constraints: Runtime constraints for execution control (max_requests, + max_duration, max_error_rate, etc.). Values can be primitives, dictionaries, + or constraint instances + :yields: Requests udpates as (response, request, request_info, scheduler_state) + tuples. Each request will generate three ordered updates: + queued, in_progress, completed | errored | cancelled. + :raises Exception: Worker process errors, environment synchronization failures, + or constraint evaluation errors are propagated after cleanup """ - if scheduling_strategy is None or not isinstance( - scheduling_strategy, SchedulingStrategy - ): - raise ValueError(f"Invalid scheduling strategy: {scheduling_strategy}") - - if max_number is not None and max_number < 1: - raise ValueError(f"Invalid max_number: {max_number}") + with self.thread_lock: + if env is None: + env = NonDistributedEnvironment() - if max_duration is not None and max_duration < 0: - raise ValueError(f"Invalid max_duration: {max_duration}") - - with ( - Manager() as manager, - ProcessPoolExecutor( - max_workers=scheduling_strategy.processes_limit - ) as executor, - ): - requests_iter: Optional[Iterator[Any]] = None - scheduling_strategy.start_time = ( - time.time() + settings.scheduler_start_delay - ) # Add a small delay to allow processes to start - futures, queues, stop_event = await self._start_processes( - manager, executor, scheduling_strategy - ) - run_info, requests_iter, times_iter = self._run_setup( - futures, scheduling_strategy, max_number, max_duration - ) - - # Add some initial requests to the queue - requests_iter = self._add_requests( - requests_iter, - queues.requests, - times_iter, - run_info, - ) - # Wait for the test to start - await asyncio.sleep(time.time() - scheduling_strategy.start_time) - yield SchedulerResult( - type_="run_start", - run_info=run_info, - ) + worker_group: WorkerProcessGroup[RequestT, ResponseT] | None = None + # Any issues during the run will raise an error (local or remote), + # be caught and passed to the environment, + # and will ensure clean up before raising the error. try: - while True: - # check errors and raise them - for future in futures: - if future.done() and (err := future.exception()) is not None: - raise err - - if ( - requests_iter is None - and run_info.processing_requests <= 0 - and ( # Ensure we have met one of the end conditions - time.time() >= run_info.end_time - or run_info.completed_requests >= run_info.end_number - ) - ): - # we've exhausted all requests we've wanted to run - # and yielded all responses - break - - requests_iter = self._add_requests( - requests_iter, - queues.requests, - times_iter, - run_info, - ) - await asyncio.sleep(0) # enable requests to start - - iter_result = self._check_result_ready( - queues.responses, - run_info, - ) - if iter_result is not None: - yield iter_result - - # yield control to the event loop - await asyncio.sleep(settings.default_async_loop_sleep) - except Exception as err: - raise RuntimeError(f"Scheduler run failed: {err}") from err - - yield SchedulerResult( - type_="run_complete", - run_info=run_info, - ) - - await self._stop_processes(futures, stop_event) - - async def _start_processes( - self, - manager, - executor: ProcessPoolExecutor, - scheduling_strategy: SchedulingStrategy, - ) -> tuple[ - list[asyncio.Future], - MPQueues[RequestT, ResponseT], - Event, - ]: - await self.worker.prepare_multiprocessing() - queues: MPQueues[RequestT, ResponseT] = MPQueues( - requests=manager.Queue( - maxsize=scheduling_strategy.processing_requests_limit - ), - responses=manager.Queue(), - ) - stop_event = manager.Event() - - num_processes = min( - scheduling_strategy.processes_limit, - scheduling_strategy.processing_requests_limit, - ) - requests_limit_split = ( - scheduling_strategy.processing_requests_limit - // scheduling_strategy.processes_limit - ) - requests_limit_remain = ( - scheduling_strategy.processing_requests_limit - % scheduling_strategy.processes_limit - ) - process_ids = (id_ for id_ in range(num_processes)) - process_requests_limits = ( - requests_limit_split + 1 - if i < requests_limit_remain - else requests_limit_split - for i in range(num_processes) - ) - - futures = [] - loop = asyncio.get_event_loop() - for id_, requests_limit in zip(process_ids, process_requests_limits): - futures.append( - loop.run_in_executor( - executor, - self.worker.process_loop_asynchronous, - queues, - scheduling_strategy, - stop_event, - requests_limit, - id_, - num_processes, + # Setup local run parameters, sync with the environment + constraints = ConstraintsInitializerFactory.resolve_constraints( + constraints ) - ) - - await asyncio.sleep(0.1) # give time for processes to start - - return futures, queues, stop_event - - def _run_setup( - self, - processes: list[asyncio.Future], - scheduling_strategy: SchedulingStrategy, - max_number: Optional[int], - max_duration: Optional[float], - ) -> tuple[SchedulerRunInfo, Iterator[Any], Iterator[float]]: - requests_iter = iter(self.request_loader) - times_iter = iter(scheduling_strategy.request_times()) - end_time = scheduling_strategy.start_time + (max_duration or math.inf) - end_number = max_number or math.inf - - try: - # update end number if the request loader is finite and less than max - iter_length = len(self.request_loader) # type: ignore[arg-type] - if 0 < iter_length < end_number: - end_number = iter_length - except Exception: # noqa: BLE001, S110 - pass - - if end_number == math.inf and end_time is None: - logger.warning( - "No end number or end time set, " - "scheduler will run indefinitely until the request loader is exhausted." - ) - - info = SchedulerRunInfo( - start_time=scheduling_strategy.start_time, - end_time=end_time, - end_number=end_number, - processes=len(processes), - strategy=scheduling_strategy, - ) - - return info, requests_iter, times_iter - - def _add_requests( - self, - requests_iter: Optional[Iterator[Any]], - requests_queue: Queue[WorkerProcessRequest[RequestT, ResponseT]], - times_iter: Iterator[float], - run_info: SchedulerRunInfo, - ) -> Optional[Iterator[Any]]: - if requests_iter is not None: - try: - added_count = 0 - - while not requests_queue.full() and added_count < ( - run_info.strategy.queued_requests_limit - or settings.min_queued_requests - ): - if run_info.created_requests >= run_info.end_number: - raise StopIteration - - if ( - next(times_iter) >= run_info.end_time - or time.time() >= run_info.end_time - ): - raise StopIteration - - work_req = WorkerProcessRequest[RequestT, ResponseT]( - request=next(requests_iter), - timeout_time=run_info.end_time, - queued_time=time.time(), + ( + local_requests, + local_strategy, + local_constraints, + ) = await env.sync_run_params(requests, strategy, constraints) + + # Setup the worker group, sync start with the environment + worker_group = WorkerProcessGroup[RequestT, ResponseT]( + requests=None, + cycle_requests=local_requests, + backend=backend, + strategy=local_strategy, + constraints=local_constraints, + ) + await worker_group.create_processes() + local_start_time = await env.sync_run_start() + await worker_group.start(local_start_time) + + # Yield any updates and sync with the environment for non-local updates + async for ( + response, + request, + request_info, + state, + ) in worker_group.request_updates(): + await env.update_run_iteration( + response, request, request_info, state ) - requests_queue.put(work_req) - - run_info.created_requests += 1 - run_info.queued_requests += 1 - added_count += 1 - except StopIteration: - # we've reached the limit number, limit time, or exhausted the requests - # set to None to stop adding more and tell the loop no more requests - requests_iter = None - - return requests_iter - - def _check_result_ready( - self, - responses_queue: Queue[WorkerProcessResult[RequestT, ResponseT]], - run_info: SchedulerRunInfo, - ) -> Optional[SchedulerRequestResult[RequestT, ResponseT]]: - try: - process_response: WorkerProcessResult[RequestT, ResponseT] = ( - responses_queue.get_nowait() - ) - except QueueEmpty: - return None - - if process_response.type_ == "request_scheduled": - run_info.queued_requests -= 1 - run_info.scheduled_requests += 1 - - return SchedulerRequestResult( - type_="request_scheduled", - run_info=run_info, - request=process_response.request, - request_info=process_response.info, - response=None, - ) - - if process_response.type_ == "request_start": - run_info.scheduled_requests -= 1 - run_info.processing_requests += 1 - - return SchedulerRequestResult( - type_="request_start", - run_info=run_info, - request=process_response.request, - request_info=process_response.info, - response=None, - ) - - if process_response.type_ == "request_complete": - run_info.processing_requests -= 1 - run_info.completed_requests += 1 - - return SchedulerRequestResult( - type_="request_complete", - run_info=run_info, - request=process_response.request, - request_info=process_response.info, - response=process_response.response, - ) - raise ValueError(f"Invalid process response type: {process_response}") - - async def _stop_processes( - self, - futures: list[asyncio.Future], - stop_event: Event, - ): - # stop all processes - stop_event.set() - - await asyncio.gather(*futures) + yield response, request, request_info, state + except Exception as err: # noqa: BLE001 + await env.sync_run_error(err) + finally: + # Ensure all worker processes are cleaned up for error or completion + if worker_group is not None: + err = await worker_group.shutdown() + if err is not None: + await env.sync_run_error(err) + + # Ensure any errors are raised and all responses + # are yielded for aggregation on the primary node + async for ( + response, + request, + request_info, + state, + ) in env.sync_run_end(): + yield response, request, request_info, state diff --git a/src/guidellm/scheduler/strategies.py b/src/guidellm/scheduler/strategies.py new file mode 100644 index 00000000..8c791671 --- /dev/null +++ b/src/guidellm/scheduler/strategies.py @@ -0,0 +1,700 @@ +""" +Request scheduling strategies for controlling how benchmark requests are processed. + +This module provides timing implementations and concrete strategies that control request +concurrency, timing patterns, and throughput characteristics to simulate real-world +usage scenarios. The scheduling system separates timing logic from strategy constraints, +enabling flexible combination of timing behaviors with process and concurrency limits. +""" + +from __future__ import annotations + +import math +import random +import time +from abc import ABC, abstractmethod +from typing import Annotated, ClassVar, Literal, TypeVar + +from pydantic import Field, PrivateAttr + +from guidellm.scheduler.objects import ScheduledRequestInfo +from guidellm.utils import InfoMixin, PydanticClassRegistryMixin, StandardBaseModel + +__all__ = [ + "AsyncConstantStrategy", + "AsyncPoissonStrategy", + "ConcurrentStrategy", + "ConstantRateRequestTimings", + "LastCompletionRequestTimings", + "NoDelayRequestTimings", + "PoissonRateRequestTimings", + "ScheduledRequestTimings", + "SchedulingStrategy", + "StrategyT", + "StrategyType", + "SynchronousStrategy", + "ThroughputStrategy", +] + + +StrategyType = Annotated[ + Literal["synchronous", "concurrent", "throughput", "constant", "poisson"], + "Valid strategy type identifiers for scheduling request patterns", +] + + +def _exponential_decay_tau(max_progress: float, convergence: float = 0.99) -> float: + """ + Calculate tau value for exponential decay to reach target progress level. + + :param max_progress: The max progress value to reach + :param convergence: The target convergence level for reaching max_progress + :return: The calculated tau value for the given max_progress and convergence + """ + return max_progress / (-math.log(1 - convergence)) + + +def _exponential_decay_fraction(progress: float, tau: float = 1.0) -> float: + """ + Calculate completion fraction based on exponential decay curve. + + :param progress: The current progress value (>=0) + :param tau: The scale factor for the exponential decay + :return: The fraction of completion based on exponential decay (0 -> 1) + """ + return 1 - math.exp(-progress / tau) + + +class ScheduledRequestTimings(StandardBaseModel, ABC): + """ + Abstract base class for controlling when requests are scheduled. + + Defines the interface for timing implementations that determine request scheduling + behavior. Different implementations provide various patterns like synchronous, + constant-rate, or stochastic scheduling to simulate real-world scenarios. + """ + + @abstractmethod + def next_offset(self) -> float: + """ + Calculate the time offset for the next request to be scheduled. + + :return: The offset in seconds from scheduler start time for next request + """ + + @abstractmethod + def request_completed(self, request_info: ScheduledRequestInfo): + """ + Handle request completion and update internal timing state. + + :param request_info: Information about the completed request including + timing details and completion status + """ + + +class LastCompletionRequestTimings(ScheduledRequestTimings): + """ + Timing implementation for synchronous and concurrent scheduling strategies. + + Schedules the next request immediately after the last request completes, enabling + sequential or limited concurrent processing with completion-based timing control. + """ + + offset: float = Field( + default=0.0, + description="Current time offset in seconds from scheduler start time", + ) + startup_requests: int = Field( + default=0, + description="Number of initial requests to schedule with equal spacing", + ge=0, + ) + startup_requests_delay: float = Field( + default=0.0, + description="Delay in seconds between startup requests", + ge=0, + ) + _requests_count: int = PrivateAttr(0) + + def next_offset(self) -> float: + """ + Get the current offset value and apply startup delay if applicable. + + :return: The current offset value in seconds from scheduler start time + """ + self._requests_count += 1 + + if self._requests_count <= self.startup_requests: + self.offset += self.startup_requests_delay + + return self.offset + + def request_completed(self, request_info: ScheduledRequestInfo): + """ + Update timing state based on the completed request. + + :param request_info: Information about the completed request + """ + if ( + self._requests_count > self.startup_requests + and request_info.completed_at is not None + ): + # set the next sync offset to the time when the previous request completed + self.offset = request_info.completed_at - request_info.scheduler_start_time + + +class NoDelayRequestTimings(ScheduledRequestTimings): + """ + Timing implementation for throughput-maximizing scheduling strategies. + + Schedules requests with minimal delay to achieve maximum throughput, with optional + startup ramping to gradually increase request processing during initialization. + """ + + offset: float = Field( + default=0.0, + description="Base time offset in seconds from scheduler start time", + ge=0, + ) + startup_duration: float = Field( + default=0.0, + description="Duration in seconds for gradual startup ramp", + ge=0, + ) + startup_target_requests: int = Field( + default=1, + description="Target number of requests to converge to during startup", + gt=0, + ) + startup_convergence: float = Field( + default=0.99, + description="Target convergence rate during startup phase", + ) + _start_time: float | None = PrivateAttr(None) + _requests_count: int = PrivateAttr(0) + + def next_offset(self) -> float: + """ + Calculate offset with optional startup adjustment. + + :return: Static offset plus any startup adjustment + """ + if self._start_time is None: + self._start_time = time.time() + + self._requests_count += 1 + elapsed = time.time() - self._start_time + + if self.startup_duration > 0 and elapsed < self.startup_duration: + startup_percent = _exponential_decay_fraction( + self._requests_count, + _exponential_decay_tau( + self.startup_target_requests, self.startup_convergence + ), + ) + else: + startup_percent = 1.0 + + return self.offset + startup_percent * self.startup_duration + + def request_completed(self, request_info: ScheduledRequestInfo): + """ + Handle request completion (no action needed for throughput strategy). + + :param request_info: Information about the completed request (unused) + """ + + +class ConstantRateRequestTimings(ScheduledRequestTimings): + """ + Timing implementation for constant-rate scheduling strategies. + + Schedules requests at a fixed rate with evenly spaced intervals to provide + predictable timing behavior for steady-state load simulation. + """ + + rate: float = Field( + description="Target rate in requests per second", + gt=0, + ) + offset: float = Field( + default=0.0, + description="Base time offset in seconds from scheduler start time", + ge=0, + ) + _requests_count: int = PrivateAttr(0) + + def next_offset(self) -> float: + """ + Calculate the offset for the next request at a constant rate. + + :return: The offset in seconds for the next request + """ + num_requests = self._requests_count + self._requests_count += 1 + interval = 1.0 / self.rate + + return self.offset + interval * num_requests + + def request_completed(self, request_info: ScheduledRequestInfo): + """ + Handle request completion (no action needed for constant rate strategy). + + :param request_info: Information about the completed request (unused) + """ + + +class PoissonRateRequestTimings(ScheduledRequestTimings): + """ + Timing implementation for Poisson-distributed scheduling strategies. + + Schedules requests following a Poisson process with exponentially distributed + inter-arrival times to simulate realistic traffic patterns with random variance. + """ + + rate: float = Field( + description="Target average rate in requests per second", + gt=0, + ) + random_seed: int = Field( + default=42, + description="Seed for random number generator for reproducible behavior", + ) + offset: float = Field( + default=0.0, + description="Base time offset in seconds from scheduler start time", + ) + _requests_count: int = PrivateAttr(0) + _random: random.Random | None = PrivateAttr(None) + + def next_offset(self) -> float: + """ + Calculate the offset for the next request using Poisson distribution. + + :return: The cumulative offset in seconds for the next request + """ + self._requests_count += 1 + + if self._random is None: + self._random = random.Random(self.random_seed) + else: + next_delay = self._random.expovariate(self.rate) + self.offset += next_delay + + return self.offset + + def request_completed(self, request_info: ScheduledRequestInfo): + """ + Handle request completion (no action needed for Poisson rate strategy). + + :param request_info: Information about the completed request (unused) + """ + + +class SchedulingStrategy(PydanticClassRegistryMixin["SchedulingStrategy"], InfoMixin): + """ + Abstract base class for scheduling strategies controlling request processing. + + Defines the interface for strategies that combine timing implementations with + process and concurrency constraints to enable various benchmark scenarios. + """ + + schema_discriminator: ClassVar[str] = "type_" + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[SchedulingStrategy]: + if cls.__name__ == "SchedulingStrategy": + return cls + + return SchedulingStrategy + + type_: Literal["strategy"] = Field( + description="The type of scheduling strategy to schedule requests with", + ) + + @property + def processes_limit(self) -> int | None: + """ + Get the maximum number of worker processes supported by this strategy. + + :return: Maximum number of worker processes, None if unlimited + """ + return None + + @property + def requests_limit(self) -> int | None: + """ + Get the maximum number of concurrent requests supported by this strategy. + + :return: Maximum number of concurrent requests, None if unlimited + """ + return None + + def create_request_timings( + self, local_rank: int, local_world_size: int, local_max_concurrency: int + ) -> ScheduledRequestTimings: + """ + Create a timing instance to define scheduling behavior for a worker process. + + :param local_rank: The rank of the worker process within local world size + :param local_world_size: Total number of worker processes in local world + :param local_max_concurrency: Maximum concurrent requests for the worker + :return: A ScheduledRequestTimings instance for the worker process + :raises NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError( + "create_worker_timings method must be implemented by subclasses." + ) + + +StrategyT = TypeVar("StrategyT", bound=SchedulingStrategy) + + +@SchedulingStrategy.register("synchronous") +class SynchronousStrategy(SchedulingStrategy): + """ + Sequential request processing strategy with single-process constraint. + + Processes requests one at a time in strict sequential order, providing predictable + timing behavior ideal for measuring maximum sequential throughput and ensuring + request isolation. + """ + + type_: Literal["synchronous"] = "synchronous" # type: ignore[assignment] + + def __str__(self) -> str: + """ + Return string representation of the strategy. + + :return: String identifier for synchronous strategy + """ + return "synchronous" + + @property + def processes_limit(self) -> int | None: + """ + Get maximum number of worker processes for synchronous scheduling. + + :return: Always returns 1 to enforce single-process constraint + """ + return 1 + + @property + def requests_limit(self) -> int | None: + """ + Get maximum number of concurrent requests for synchronous scheduling. + + :return: Always returns 1 to enforce single-request constraint + """ + return 1 + + def create_request_timings( + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, # noqa: ARG002 + ) -> ScheduledRequestTimings: + """ + Create timing implementation for synchronous request scheduling. + + :param local_rank: The rank of the worker process (must be 0) + :param local_world_size: Total number of worker processes (must be 1) + :param local_max_concurrency: Maximum concurrent requests (unused) + :return: LastCompletionRequestTimings instance for sequential processing + :raises ValueError: If multiple workers or non-zero rank specified + """ + if local_world_size > 1 or local_rank != 0: + raise ValueError( + "SynchronousStrategy can only be used with a single worker process." + ) + + return LastCompletionRequestTimings() + + +@SchedulingStrategy.register("concurrent") +class ConcurrentStrategy(SchedulingStrategy): + """ + Parallel request processing strategy with controlled concurrency limits. + + Enables concurrent request processing up to a specified number of streams, + providing balanced throughput while maintaining predictable resource usage + and completion-based timing coordination. + """ + + type_: Literal["concurrent"] = "concurrent" # type: ignore[assignment] + streams: int = Field( + description="Number of concurrent streams for scheduling requests", + gt=0, + ) + startup_duration: float = Field( + default=0.0, + description="Duration in seconds for distributing startup requests", + ge=0, + ) + + def __str__(self) -> str: + """ + Return string representation of the strategy. + + :return: String identifier with stream count + """ + return f"concurrent@{self.streams}" + + @property + def processes_limit(self) -> int: + """ + Get maximum number of worker processes for concurrent scheduling. + + :return: Number of streams as maximum worker processes + """ + return self.streams + + @property + def requests_limit(self) -> int: + """ + Get maximum number of concurrent requests for concurrent scheduling. + + :return: Number of streams as maximum concurrent requests + """ + return self.streams + + def create_request_timings( + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, # noqa: ARG002 + ) -> LastCompletionRequestTimings: + """ + Create timing implementation for concurrent request scheduling. + + :param local_rank: The rank of the worker process (must be < streams) + :param local_world_size: Total worker processes (must not exceed streams) + :param local_max_concurrency: Maximum concurrent requests (unused) + :return: LastCompletionRequestTimings instance for stream-based processing + :raises ValueError: If worker configuration exceeds stream limits + """ + if local_world_size > self.streams: + raise ValueError( + "ConcurrentStrategy can only be used with up to " + f"{self.streams} worker processes." + ) + + if local_rank >= self.streams: + raise ValueError( + f"Local rank {local_rank} exceeds the number of streams {self.streams}." + ) + + if self.startup_duration > 0: + # Ensure equal global distribution of the start up for concurrent streams + # Ex: for 10 streams, 2 workers, and 8 seconds start up duration, + # the first worker should start at 0.0, 1.6, 3.2, 4.8, 6.4 + # and the second worker should start at 0.8, 2.4, 4.0, 5.6, 7.2 + delay_per_stream = self.startup_duration / self.streams + streams_per_worker = self.streams // local_world_size + + offset = local_rank * streams_per_worker * delay_per_stream + startup_requests = streams_per_worker + ( + 1 + if local_world_size > 1 and local_rank < self.streams % local_world_size + else 0 + ) + startup_requests_delay = delay_per_stream * local_world_size + else: + offset = 0.0 + startup_requests = 0 + startup_requests_delay = 0.0 + + return LastCompletionRequestTimings( + offset=offset, + startup_requests=startup_requests, + startup_requests_delay=startup_requests_delay, + ) + + +@SchedulingStrategy.register("throughput") +class ThroughputStrategy(SchedulingStrategy): + """ + Maximum throughput strategy with optional concurrency limits. + + Schedules requests to maximize system throughput by allowing unlimited concurrent + processing with optional constraints and startup ramping for controlled ramp-up. + """ + + type_: Literal["throughput"] = "throughput" # type: ignore[assignment] + max_concurrency: int | None = Field( + default=None, + description="Maximum number of concurrent requests to schedule", + gt=0, + ) + startup_duration: float = Field( + default=0.0, + description="Duration in seconds for startup request distribution", + ge=0, + ) + + def __str__(self) -> str: + """ + Return string representation of the strategy. + + :return: String identifier for throughput strategy + """ + return "throughput" + + @property + def processes_limit(self) -> int | None: + """ + Get maximum number of worker processes for throughput scheduling. + + :return: The max_concurrency value if set, otherwise None for unlimited + """ + return self.max_concurrency + + @property + def requests_limit(self) -> int | None: + """ + Get maximum number of concurrent requests for throughput scheduling. + + :return: The max_concurrency value if set, otherwise None for unlimited + """ + return self.max_concurrency + + def create_request_timings( + self, local_rank: int, local_world_size: int, local_max_concurrency: int + ) -> ScheduledRequestTimings: + """ + Create timing implementation for throughput request scheduling. + + :param local_rank: The rank of the worker process + :param local_world_size: Total number of worker processes + :param local_max_concurrency: Maximum concurrent requests for the worker + :return: NoDelayRequestTimings instance for immediate request scheduling + """ + if self.startup_duration > 0: + # Vary offset by up to 5% of the startup duration for a bit of variance + offset = 0.05 * self.startup_duration * (local_rank / local_world_size) + # Use local_max_concurrency as the target requests for startup convergence + startup_target_requests = local_max_concurrency + else: + offset = 0.0 + startup_target_requests = 1 + + return NoDelayRequestTimings( + startup_duration=self.startup_duration, + startup_target_requests=startup_target_requests, + offset=offset, + ) + + +@SchedulingStrategy.register("constant") +class AsyncConstantStrategy(ThroughputStrategy): + """ + Asynchronous constant-rate scheduling strategy for predictable load patterns. + + Schedules requests at a fixed rate distributed evenly across worker processes, + providing predictable timing behavior for steady-state load simulation and + consistent system performance measurement. + """ + + type_: Literal["constant"] = "constant" # type: ignore[assignment] + rate: float = Field( + description="Rate for scheduling requests asynchronously in requests/second", + gt=0, + ) + startup_duration: float = Field( + default=0.0, + description="Duration in seconds for startup request distribution", + ge=0, + ) + + def __str__(self) -> str: + """ + Return string representation of the strategy. + + :return: String identifier with rate value + """ + return f"constant@{self.rate:.2f}" + + def create_request_timings( + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, # noqa: ARG002 + ) -> ScheduledRequestTimings: + """ + Create timing implementation for constant-rate request scheduling. + + :param local_rank: The rank of the worker process + :param local_world_size: Total number of worker processes for rate division + :param local_max_concurrency: Maximum concurrent requests for the worker + :return: ConstantRateRequestTimings instance with per-worker rate + """ + # Divide the rate evenly across all worker processes + worker_rate = self.rate / local_world_size + # Start each worker with an offset to interleave rates + worker_offset = (1 / self.rate) * local_rank + + return ConstantRateRequestTimings( + rate=worker_rate, + offset=worker_offset, + ) + + +@SchedulingStrategy.register("poisson") +class AsyncPoissonStrategy(ThroughputStrategy): + """ + Asynchronous Poisson-distributed scheduling strategy for realistic load simulation. + + Schedules requests following a Poisson process with exponentially distributed + inter-arrival times, providing realistic simulation of user behavior and network + traffic patterns with random variance around the target rate. + """ + + type_: Literal["poisson"] = "poisson" # type: ignore[assignment] + rate: float = Field( + description="Rate for scheduling requests asynchronously in requests/second", + gt=0, + ) + startup_duration: float = Field( + default=0.0, + description="Duration in seconds for startup request distribution", + ge=0, + ) + random_seed: int = Field( + default=42, + description="Random seed to use for Poisson distribution", + ) + + def __str__(self) -> str: + """ + Return string representation of the strategy. + + :return: String identifier with rate value + """ + return f"poisson@{self.rate:.2f}" + + def create_request_timings( + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, # noqa: ARG002 + ) -> ScheduledRequestTimings: + """ + Create timing implementation for Poisson-distributed request scheduling. + + :param local_rank: The rank of the worker process for seed generation + :param local_world_size: Total number of worker processes for rate division + :param local_max_concurrency: Maximum concurrent requests for the worker + :return: PoissonRateRequestTimings instance with per-worker rate and unique seed + """ + # Divide the rate evenly across all worker processes + worker_rate = self.rate / local_world_size + # Use a different seed for each worker to ensure different sequences + worker_seed = self.random_seed + local_rank + # Start each worker with an offset to interleave rates + worker_offset = (1 / self.rate) * local_rank + + return PoissonRateRequestTimings( + rate=worker_rate, + random_seed=worker_seed, + offset=worker_offset, + ) diff --git a/src/guidellm/scheduler/strategy.py b/src/guidellm/scheduler/strategy.py deleted file mode 100644 index 81ff6558..00000000 --- a/src/guidellm/scheduler/strategy.py +++ /dev/null @@ -1,495 +0,0 @@ -import math -import random -import time -from collections.abc import Generator -from typing import ( - Literal, - Optional, - Union, -) - -from pydantic import Field - -from guidellm.objects import StandardBaseModel -from guidellm.settings import settings - -__all__ = [ - "AsyncConstantStrategy", - "AsyncPoissonStrategy", - "ConcurrentStrategy", - "SchedulingStrategy", - "StrategyType", - "SynchronousStrategy", - "ThroughputStrategy", - "strategy_display_str", -] - - -StrategyType = Literal["synchronous", "concurrent", "throughput", "constant", "poisson"] - - -class SchedulingStrategy(StandardBaseModel): - """ - An abstract base class for scheduling strategies. - This class defines the interface for scheduling requests and provides - a common structure for all scheduling strategies. - Subclasses should implement the `request_times` method to provide - specific scheduling behavior. - - :param type_: The type of scheduling strategy to use. - This should be one of the predefined strategy types. - """ - - type_: Literal["strategy"] = Field( - description="The type of scheduling strategy schedule requests with.", - ) - start_time: float = Field( - default_factory=time.time, - description="The start time for the scheduling strategy.", - ) - - @property - def processing_mode(self) -> Literal["sync", "async"]: - """ - The processing mode for the scheduling strategy, either 'sync' or 'async'. - This property determines how the worker processes are setup: - either to run synchronously with one request at a time or asynchronously. - This property should be implemented by subclasses to return - the appropriate processing mode. - - :return: The processing mode for the scheduling strategy, - either 'sync' or 'async'. - """ - return "async" - - @property - def processes_limit(self) -> int: - """ - The limit on the number of worker processes for the scheduling strategy. - It determines how many worker processes are created - for the scheduling strategy and must be implemented by subclasses. - - :return: The number of processes for the scheduling strategy. - """ - return settings.max_worker_processes - - @property - def queued_requests_limit(self) -> Optional[int]: - """ - The maximum number of queued requests for the scheduling strategy. - It determines how many requests can be queued at one time - for the scheduling strategy and must be implemented by subclasses. - - :return: The maximum number of queued requests for the scheduling strategy. - """ - return settings.max_concurrency - - @property - def processing_requests_limit(self) -> int: - """ - The maximum number of processing requests for the scheduling strategy. - It determines how many requests can be processed at one time - for the scheduling strategy and must be implemented by subclasses. - - :return: The maximum number of processing requests for the scheduling strategy. - """ - return settings.max_concurrency - - def request_times(self) -> Generator[float, None, None]: - """ - A generator that yields timestamps for when requests should be sent. - This method should be implemented by subclasses to provide specific - scheduling behavior. - - :return: A generator that yields timestamps for request scheduling - or -1 for requests that should be sent immediately. - """ - raise NotImplementedError("Subclasses must implement request_times() method.") - - -class SynchronousStrategy(SchedulingStrategy): - """ - A class representing a synchronous scheduling strategy. - This strategy schedules requests synchronously, one at a time, - with the maximum rate possible. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for synchronous scheduling. - - :param type_: The synchronous StrategyType to schedule requests synchronously. - """ - - type_: Literal["synchronous"] = "synchronous" # type: ignore[assignment] - - @property - def processing_mode(self) -> Literal["sync"]: - """ - The processing mode for the scheduling strategy, either 'sync' or 'async'. - This property determines how the worker processes are setup: - either to run synchronously with one request at a time or asynchronously. - - :return: 'sync' for synchronous scheduling strategy - for the single worker process. - """ - return "sync" - - @property - def processes_limit(self) -> int: - """ - The limit on the number of worker processes for the scheduling strategy. - It determines how many worker processes are created - for the scheduling strategy and must be implemented by subclasses. - - :return: 1 for the synchronous scheduling strategy to limit - the worker processes to one. - """ - return 1 - - @property - def queued_requests_limit(self) -> int: - """ - The maximum number of queued requests for the scheduling strategy. - It determines how many requests can be queued at one time - for the scheduling strategy and must be implemented by subclasses. - - :return: 1 for the synchronous scheduling strategy to limit - the queued requests to one that is ready to be processed. - """ - return 1 - - @property - def processing_requests_limit(self) -> int: - """ - The maximum number of processing requests for the scheduling strategy. - It determines how many requests can be processed at one time - for the scheduling strategy and must be implemented by subclasses. - - :return: 1 for the synchronous scheduling strategy to limit - the processing requests to one that is ready to be processed. - """ - return 1 - - def request_times(self) -> Generator[float, None, None]: - """ - A generator that yields time.time() so requests are sent immediately, - while scheduling them synchronously. - - :return: A generator that yields time.time() for immediate request scheduling. - """ - init_time = self.start_time - while True: - yield max(init_time, time.time()) - - -class ConcurrentStrategy(SchedulingStrategy): - """ - A class representing a concurrent scheduling strategy. - This strategy schedules requests concurrently with the specified - number of streams. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for concurrent scheduling. - - :param type_: The concurrent StrategyType to schedule requests concurrently. - :param streams: The number of concurrent streams to use for scheduling requests. - Each stream runs synchronously with the maximum rate possible. - This must be a positive integer. - """ - - type_: Literal["concurrent"] = "concurrent" # type: ignore[assignment] - streams: int = Field( - description=( - "The number of concurrent streams to use for scheduling requests. " - "Each stream runs sychronously with the maximum rate possible. " - "This must be a positive integer." - ), - gt=0, - ) - - @property - def processing_mode(self) -> Literal["sync"]: - """ - The processing mode for the scheduling strategy, either 'sync' or 'async'. - This property determines how the worker processes are setup: - either to run synchronously with one request at a time or asynchronously. - - :return: 'sync' for synchronous scheduling strategy - for the multiple worker processes equal to streams. - """ - return "sync" - - @property - def processes_limit(self) -> int: - """ - The limit on the number of worker processes for the scheduling strategy. - It determines how many worker processes are created - for the scheduling strategy and must be implemented by subclasses. - - :return: {self.streams} for the concurrent scheduling strategy to limit - the worker processes to the number of streams. - """ - - return min(self.streams, settings.max_worker_processes) - - @property - def queued_requests_limit(self) -> int: - """ - The maximum number of queued requests for the scheduling strategy. - It determines how many requests can be queued at one time - for the scheduling strategy and must be implemented by subclasses. - - :return: {self.streams} for the concurrent scheduling strategy to limit - the queued requests to the number of streams that are ready to be processed. - """ - return self.streams - - @property - def processing_requests_limit(self) -> int: - """ - The maximum number of processing requests for the scheduling strategy. - It determines how many requests can be processed at one time - for the scheduling strategy and must be implemented by subclasses. - - :return: {self.streams} for the concurrent scheduling strategy to limit - the processing requests to the number of streams that ready to be processed. - """ - return self.streams - - def request_times(self) -> Generator[float, None, None]: - """ - A generator that yields time.time() so requests are sent - immediately, while scheduling them concurrently with the specified - number of streams. - - :return: A generator that yields time.time() for immediate request scheduling. - """ - init_time = self.start_time - while True: - yield max(init_time, time.time()) - - -class ThroughputStrategy(SchedulingStrategy): - """ - A class representing a throughput scheduling strategy. - This strategy schedules as many requests asynchronously as possible, - with the maximum rate possible. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for throughput scheduling. - - :param type_: The throughput StrategyType to schedule requests asynchronously. - """ - - type_: Literal["throughput"] = "throughput" # type: ignore[assignment] - max_concurrency: Optional[int] = Field( - default=None, - description=( - "The maximum number of concurrent requests to schedule. " - "If set to None, the concurrency value from settings will be used. " - "This must be a positive integer greater than 0." - ), - gt=0, - ) - - @property - def processing_mode(self) -> Literal["async"]: - """ - The processing mode for the scheduling strategy, either 'sync' or 'async'. - This property determines how the worker processes are setup: - either to run synchronously with one request at a time or asynchronously. - - :return: 'async' for asynchronous scheduling strategy - for the multiple worker processes handling requests. - """ - return "async" - - @property - def queued_requests_limit(self) -> int: - """ - The maximum number of queued requests for the scheduling strategy. - It determines how many requests can be queued at one time - for the scheduling strategy and must be implemented by subclasses. - - :return: The processing requests limit to ensure that there are enough - requests even for the worst case scenario where the max concurrent - requests are pulled at once for processing. - """ - return self.processing_requests_limit - - @property - def processing_requests_limit(self) -> int: - """ - The maximum number of processing requests for the scheduling strategy. - It determines how many requests can be processed at one time - for the scheduling strategy and must be implemented by subclasses. - - :return: {self.max_concurrency} for the throughput scheduling strategy to limit - the processing requests to the maximum concurrency. - If max_concurrency is None, then the default processing requests limit - will be used. - """ - return self.max_concurrency or super().processing_requests_limit - - def request_times(self) -> Generator[float, None, None]: - """ - A generator that yields the start time.time() so requests are sent - immediately, while scheduling as many asynchronously as possible. - - :return: A generator that yields the start time.time() - for immediate request scheduling. - """ - init_time = self.start_time - while True: - yield init_time - - -class AsyncConstantStrategy(ThroughputStrategy): - """ - A class representing an asynchronous constant scheduling strategy. - This strategy schedules requests asynchronously at a constant request rate - in requests per second. - If initial_burst is set, it will send an initial burst of math.floor(rate) - requests to reach the target rate. - This is useful to ensure that the target rate is reached quickly - and then maintained. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for asynchronous constant scheduling. - - :param type_: The constant StrategyType to schedule requests asynchronously. - :param rate: The rate at which to schedule requests asynchronously in - requests per second. This must be a positive float. - :param initial_burst: True to send an initial burst of requests - (math.floor(self.rate)) to reach target rate. - False to not send an initial burst. - """ - - type_: Literal["constant"] = "constant" # type: ignore[assignment] - rate: float = Field( - description=( - "The rate at which to schedule requests asynchronously in " - "requests per second. This must be a positive float." - ), - gt=0, - ) - initial_burst: bool = Field( - default=True, - description=( - "True to send an initial burst of requests (math.floor(self.rate)) " - "to reach target rate. False to not send an initial burst." - ), - ) - - def request_times(self) -> Generator[float, None, None]: - """ - A generator that yields timestamps for when requests should be sent. - This method schedules requests asynchronously at a constant rate - in requests per second. - If burst_time is set, it will send an initial burst of requests - to reach the target rate. - This is useful to ensure that the target rate is reached quickly - and then maintained. - - :return: A generator that yields timestamps for request scheduling. - """ - constant_increment = 1.0 / self.rate - - init_time = self.start_time - # handle bursts first to get to the desired rate - if self.initial_burst is not None: - # send an initial burst equal to the rate - # to reach the target rate - burst_count = math.floor(self.rate) - for _ in range(burst_count): - yield init_time - - init_time += constant_increment - - counter = 0 - - # continue with constant rate after bursting - while True: - yield init_time + constant_increment * counter - counter += 1 - - -class AsyncPoissonStrategy(ThroughputStrategy): - """ - A class representing an asynchronous Poisson scheduling strategy. - This strategy schedules requests asynchronously at a Poisson request rate - in requests per second. - If initial_burst is set, it will send an initial burst of math.floor(rate) - requests to reach the target rate. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for asynchronous Poisson scheduling. - - :param type_: The Poisson StrategyType to schedule requests asynchronously. - :param rate: The rate at which to schedule requests asynchronously in - requests per second. This must be a positive float. - :param initial_burst: True to send an initial burst of requests - (math.floor(self.rate)) to reach target rate. - False to not send an initial burst. - """ - - type_: Literal["poisson"] = "poisson" # type: ignore[assignment] - rate: float = Field( - description=( - "The rate at which to schedule requests asynchronously in " - "requests per second. This must be a positive float." - ), - gt=0, - ) - initial_burst: bool = Field( - default=True, - description=( - "True to send an initial burst of requests (math.floor(self.rate)) " - "to reach target rate. False to not send an initial burst." - ), - ) - random_seed: int = Field( - default=42, - description=("The random seed to use for the Poisson distribution. "), - ) - - def request_times(self) -> Generator[float, None, None]: - """ - A generator that yields timestamps for when requests should be sent. - This method schedules requests asynchronously at a Poisson rate - in requests per second. - The inter arrival time between requests is exponentially distributed - based on the rate. - - :return: A generator that yields timestamps for request scheduling. - """ - init_time = self.start_time - if self.initial_burst is not None: - # send an initial burst equal to the rate - # to reach the target rate - burst_count = math.floor(self.rate) - for _ in range(burst_count): - yield init_time - else: - yield init_time - - # set the random seed for reproducibility - rand = random.Random(self.random_seed) # noqa: S311 - - while True: - inter_arrival_time = rand.expovariate(self.rate) - init_time += inter_arrival_time - yield init_time - - -def strategy_display_str(strategy: Union[StrategyType, SchedulingStrategy]) -> str: - strategy_type = strategy if isinstance(strategy, str) else strategy.type_ - strategy_instance = strategy if isinstance(strategy, SchedulingStrategy) else None - - if strategy_type == "concurrent": - rate = f"@{strategy_instance.streams}" if strategy_instance else "@##" # type: ignore[attr-defined] - elif strategy_type in ("constant", "poisson"): - rate = f"@{strategy_instance.rate:.2f}" if strategy_instance else "@#.##" # type: ignore[attr-defined] - else: - rate = "" - - return f"{strategy_type}{rate}" diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index ba36559e..5f2fb74b 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -1,472 +1,389 @@ -import asyncio -import math -import time -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator -from dataclasses import dataclass -from itertools import islice -from threading import Event -from typing import ( - Any, - Generic, - Literal, - Optional, - Union, -) +""" +Individual worker process management for multi-process request execution. -from loguru import logger -from pydantic import Field +Manages worker processes that handle request scheduling, backend processing, and +coordination in distributed benchmark environments. Workers consume requests from +queues, apply timing strategies, process requests through backends, and publish +status updates while maintaining synchronization across the process group. +""" -from guidellm.backend import ( - Backend, - BackendType, - RequestArgs, - ResponseSummary, - StreamingTextResponse, +from __future__ import annotations + +import asyncio +import time +from multiprocessing.synchronize import Barrier as ProcessingBarrier +from multiprocessing.synchronize import Event as ProcessingEvent +from typing import Annotated, Generic, Literal + +try: + import uvloop + + HAS_UVLOOP: Annotated[ + bool, "Flag indicating if uvloop is available for event loop optimization" + ] = True +except ImportError: + uvloop = None + + HAS_UVLOOP: Annotated[ + bool, "Flag indicating if uvloop is available for event loop optimization" + ] = False + + +from guidellm.scheduler.objects import ( + BackendInterface, + MultiTurnRequestT, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerMessagingPydanticRegistry, ) -from guidellm.objects import StandardBaseModel -from guidellm.request import GenerationRequest -from guidellm.request.types import RequestT, ResponseT -from guidellm.scheduler.queues import MPQueues, Queue, QueueEmpty -from guidellm.scheduler.result import ( - SchedulerRequestInfo, - WorkerProcessRequest, - WorkerProcessResult, +from guidellm.scheduler.strategies import ScheduledRequestTimings +from guidellm.utils import ( + InterProcessMessaging, + wait_for_sync_barrier, + wait_for_sync_event, + wait_for_sync_objects, ) -from guidellm.scheduler.strategy import SchedulingStrategy - -__all__ = [ - "GenerativeRequestsWorker", - "GenerativeRequestsWorkerDescription", - "RequestsWorker", - "ResolveStatus", - "WorkerDescription", -] - - -@dataclass -class ResolveStatus: - requested: bool - completed: bool - errored: bool - canceled: bool - request_start: float - request_end: float +__all__ = ["WorkerProcess"] -class WorkerDescription(StandardBaseModel): - type_: Literal["worker"] = "worker" - - -class RequestsWorker(ABC, Generic[RequestT, ResponseT]): +class WorkerProcess(Generic[RequestT, ResponseT]): """ - An abstract base class for a worker that processes requests. - This class defines the interface for a worker that can resolve requests - asynchronously or synchronously within the Scheduler class. - Subclasses must implement the `resolve` method, - which takes a request directly given from the load generator, - along with the desired start_time for the request and a timeout_time. - The `resolve` method should return the response from the backend. + Individual worker process for distributed request execution and coordination. + + Manages the complete request lifecycle from queue consumption through backend + processing and status publication. Coordinates with other workers through + barriers and events while maintaining configurable concurrency limits and + timing strategies for request scheduling. + + Example: + :: + worker = WorkerProcess( + messaging=messaging_interface, + async_limit=10, + startup_barrier=barrier, + shutdown_event=shutdown, + error_event=error, + backend=backend_instance, + request_timings=timing_strategy + ) + worker.run() """ - @property - @abstractmethod - def description(self) -> WorkerDescription: + def __init__( + self, + messaging: InterProcessMessaging[ + tuple[ + ResponseT | None, + RequestT | MultiTurnRequestT[RequestT], + ScheduledRequestInfo, + ], + ], + backend: BackendInterface[RequestT, ResponseT], + request_timings: ScheduledRequestTimings, + async_limit: int, + startup_barrier: ProcessingBarrier, + requests_generated_event: ProcessingEvent, + constraint_reached_event: ProcessingEvent, + shutdown_event: ProcessingEvent, + error_event: ProcessingEvent, + ): """ - An abstract property that must be implemented by subclasses. - This property should return a Serializable class representing the information - about the worker instance. + Initialize worker process instance. + + :param messaging: Inter-process communication interface for request coordination + :param backend: Backend instance for processing requests + :param request_timings: Timing strategy for request scheduling + :param async_limit: Maximum concurrent requests this worker can handle + :param startup_barrier: Multiprocessing barrier for coordinated startup + :param requests_generated_event: Event signaling when request generation is + complete + :param constraint_reached_event: Event signaling when processing constraints + are met + :param shutdown_event: Event for signaling graceful shutdown + :param error_event: Event for signaling error conditions across processes """ - ... + self.messaging = messaging + self.backend = backend + self.request_timings = request_timings + self.async_limit = async_limit + self.startup_barrier = startup_barrier + self.requests_generated_event = requests_generated_event + self.constraint_reached_event = constraint_reached_event + self.shutdown_event = shutdown_event + self.error_event = error_event + + # Internal states + self.startup_completed = False + self.backend_started = False + self.messaging_started = False + + def run(self): + """ + Main entry point for worker process execution. - @abstractmethod - async def prepare_multiprocessing(self): + Initializes asyncio event loop with optional uvloop optimization and starts + worker async operations. Handles event loop cleanup for forked processes. + + :raises RuntimeError: If worker encounters unrecoverable error during execution """ - An abstract method that must be implemented by subclasses. - This is useful for workers that have instance state that can not - be shared across processes and should be cleared out and re-initialized - for each new process. + try: + if HAS_UVLOOP: + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + asyncio.run(self.run_async()) + except Exception as err: + self.error_event.set() + raise RuntimeError( + f"Worker process {self.messaging.worker_index} encountered an " + f"error: {err}" + ) from err + + async def run_async(self): """ - ... + Execute main asynchronous worker process logic. - @abstractmethod - async def resolve( - self, - request: RequestT, - timeout_time: float, - ) -> tuple[ResolveStatus, ResponseT]: - """ - An abstract method that must be implemented by subclasses. - This method should handle the resolution of a request through asyncio, - including any necessary backend processing and response handling. - - :param request: The request to be resolved generated by the load generator. - :param timeout_time: The timeout time for the request, if there is no timeout - given, then this will be math.inf. - :return: The response from the worker. + Orchestrates concurrent execution of request processing and shutdown monitoring + tasks. Handles task cleanup, error propagation, and cancellation coordination + when any task completes or fails. + + :raises RuntimeError: If worker tasks encounter unrecoverable errors + :raises asyncio.CancelledError: If worker process was cancelled """ - ... + stop_task = asyncio.create_task(self._stop_monitor()) + request_proc_task = asyncio.create_task(self._process_requests()) + caller_cancelled = False - async def send_result( - self, - results_queue: Queue[WorkerProcessResult[RequestT, ResponseT]], - result: WorkerProcessResult[RequestT, ResponseT], - ): - await asyncio.to_thread(results_queue.put, result) # type: ignore[attr-defined] + try: + await asyncio.wait( + [stop_task, request_proc_task], + return_when=asyncio.FIRST_COMPLETED, + ) + except asyncio.CancelledError: + caller_cancelled = True - async def resolve_scheduler_request( - self, - process_request: WorkerProcessRequest[RequestT, ResponseT], - dequeued_time: float, - start_time: float, - results_queue: Queue[WorkerProcessResult[RequestT, ResponseT]], - process_id: int, - ): - request = process_request.request - timeout_time = process_request.timeout_time - queued_time = process_request.queued_time - - info = SchedulerRequestInfo( - targeted_start_time=start_time, - queued_time=queued_time, - dequeued_time=dequeued_time, - scheduled_time=time.time(), - process_id=process_id, - ) - result: WorkerProcessResult[RequestT, ResponseT] = WorkerProcessResult( - type_="request_scheduled", - request=request, - response=None, - info=info, - ) - asyncio.create_task(self.send_result(results_queue, result)) + stop_task.cancel() + request_proc_task.cancel() - if (wait_time := start_time - time.time()) > 0: - await asyncio.sleep(wait_time) + try: + # Ensure all child tasks cancel correctly + await asyncio.wait( + [stop_task, request_proc_task], return_when=asyncio.ALL_COMPLETED + ) + except asyncio.CancelledError: + caller_cancelled = True + + if ( + task_err := ( + request_proc_task.exception() + if not request_proc_task.cancelled() + else stop_task.exception() + if not stop_task.cancelled() + else None + ) + ) is not None: + raise RuntimeError( + f"Worker process {self.messaging.worker_index} encountered an " + f"error: {task_err}" + ) from task_err - info.worker_start = time.time() - result = WorkerProcessResult( - type_="request_start", - request=request, - response=None, - info=info, - ) - asyncio.create_task(self.send_result(results_queue, result)) - - status, response = await self.resolve(request, timeout_time) - info.worker_end = time.time() - info.requested = status.requested - info.completed = status.completed - info.errored = status.errored - info.canceled = status.canceled - info.request_start = status.request_start - info.request_end = status.request_end - result = WorkerProcessResult( - type_="request_complete", - request=request, - response=response, - info=info, - ) - asyncio.create_task(self.send_result(results_queue, result)) + if caller_cancelled: + raise asyncio.CancelledError("Worker process was cancelled") - def process_loop_asynchronous( + async def _stop_monitor( self, - queues: MPQueues[RequestT, ResponseT], - strategy: SchedulingStrategy, - stop_event: Event, - max_concurrency: int, - process_id: int, - num_processes: int, - ): - async def _process_runner(): - lock = asyncio.Semaphore(max_concurrency) - times_iter = islice( - strategy.request_times(), - process_id, - None, - num_processes, - ) + ) -> Literal["error_event", "shutdown_event"]: + exit_key = await wait_for_sync_objects( + { + "error_event": self.error_event, + "shutdown_event": self.shutdown_event, + }, + poll_interval=self.messaging.poll_interval, + ) - start_time = None - while not stop_event.is_set(): - if start_time is None: - start_time = next(times_iter) - - # Yield control to the event loop. Sleep if we are way ahead - await asyncio.sleep(start_time - time.time() - 1) - await lock.acquire() - - try: - process_request = queues.requests.get_nowait() - dequeued_time = time.time() - except QueueEmpty: - lock.release() - continue - - def _request_callback( - _: asyncio.Future[WorkerProcessRequest[RequestT, ResponseT]], - ): - nonlocal lock - lock.release() - - task = asyncio.create_task( - self.resolve_scheduler_request( - process_request=process_request, - dequeued_time=dequeued_time, - start_time=start_time, - results_queue=queues.responses, - process_id=process_id, - ) - ) - task.add_done_callback(_request_callback) - start_time = None + if exit_key == "error_event": + raise RuntimeError( + f"Worker process {self.messaging.worker_index} received error signal." + ) + async def _process_requests(self): try: - asyncio.run(_process_runner()) - except Exception as exc: # noqa: BLE001 - logger.error( - f"Error in worker process {process_id}: {exc}", - exc_info=True, - stack_info=True, + # 1. Start up synchronization (backend, messaging, and other processes) + # 2. Messaging startup, receive requests until requests_generated event + await self._processing_startup() + + # 3. Run process requests loop until constraint_reached event + processing_task = asyncio.create_task(self._process_requests_loop()) + await wait_for_sync_event( + self.constraint_reached_event, + poll_interval=self.messaging.poll_interval, ) + processing_task.cancel() + + # 4. Cancel pending requests until proc canceled (manual, shutdown, error) + await self._cancel_requests_loop() + finally: + # 5. On cancel, shut down event, error event, or internal error: + # attempt to shut down this worker cleanly (stop backend and messaging) + await self._processing_shutdown() + + async def _processing_startup(self): + # Get backend ready + await self.backend.process_startup() + self.backend_started = True + await self.backend.validate() + + # Get messaging system ready + await self.messaging.start( + receive_stop_criteria=[self.requests_generated_event], + pydantic_models=list(SchedulerMessagingPydanticRegistry.registry.values()), + ) + self.messaging_started = True + # Wait for all processes to be ready + await wait_for_sync_barrier( + self.startup_barrier, + poll_interval=self.messaging.poll_interval, + ) -class GenerativeRequestsWorkerDescription(WorkerDescription): - type_: Literal["generative_requests_worker"] = "generative_requests_worker" # type: ignore[assignment] - backend_type: BackendType - backend_target: str - backend_model: str - backend_info: dict[str, Any] = Field( - default_factory=dict, - ) - - -class GenerativeRequestsWorker(RequestsWorker[GenerationRequest, ResponseSummary]): - """ - A class that handles the execution of requests using a backend. - This class is responsible for sending requests to the backend, - handling responses, and managing errors. + self.startup_completed = True - :param backend: The backend to use for handling requests. - This should be an instance of Backend such as an OpenAIHTTPBackend. - """ + async def _processing_shutdown(self): + if self.backend_started: + await self.backend.process_shutdown() + self.backend_started = False - def __init__(self, backend: Backend): - self.backend = backend + if self.messaging_started: + await self.messaging.stop() + self.messaging_started = False - @property - def description(self) -> GenerativeRequestsWorkerDescription: - """ - Get the description of the worker. - :return: The description of the worker. - """ - return GenerativeRequestsWorkerDescription( - backend_type=self.backend.type_, - backend_target=self.backend.target, - backend_model=self.backend.model or "None", - backend_info=self.backend.info, - ) + self.startup_completed = False - async def prepare_multiprocessing(self): - """ - Prepare the worker for multiprocessing. - This is useful for workers that have instance state that can not - be shared across processes and should be cleared out and re-initialized - for each new process. - """ - await self.backend.prepare_multiprocessing() + async def _process_requests_loop(self): + try: + # Run request processing + async_semaphore = asyncio.Semaphore(self.async_limit) + pending_tasks: set[asyncio.Task] = set() + + def _task_done(task): + pending_tasks.discard(task) + async_semaphore.release() + + if not task.cancelled() and (exception := task.exception()): + raise exception + + # Main loop; loop until canceled + while True: + await async_semaphore.acquire() + request_task = asyncio.create_task(self._process_next_request()) + pending_tasks.add(request_task) + request_task.add_done_callback(_task_done) + except asyncio.CancelledError as err: + for task in pending_tasks: + task.cancel() + await asyncio.gather(*pending_tasks, return_exceptions=True) + + raise err + + async def _cancel_requests_loop(self): + while True: + try: + request: RequestT + request_info: ScheduledRequestInfo + request, request_info = await self.messaging.get( + timeout=self.messaging.poll_interval + ) + except asyncio.TimeoutError: + continue - def process_loop_asynchronous( - self, - queues: MPQueues[GenerationRequest, ResponseSummary], - strategy: SchedulingStrategy, - stop_event: Event, - max_concurrency: int, - process_id: int, - num_processes: int, - ): - asyncio.run(self.backend.validate()) - super().process_loop_asynchronous( - queues=queues, - strategy=strategy, - stop_event=stop_event, - max_concurrency=max_concurrency, - process_id=process_id, - num_processes=num_processes, - ) + request_info.scheduler_node_id = self.messaging.worker_index + request_info.error = "Request was cancelled" + request_info.scheduler_timings.resolve_end = time.time() + self._send_update("cancelled", None, request, request_info) - async def resolve( - self, - request: GenerationRequest, - timeout_time: float, - ) -> tuple[ResolveStatus, ResponseSummary]: - """ - Resolve a request by sending it to the backend and handling the response. - This method sends the request to the backend, waits for a response, - and handles any errors that may occur during the process. - - :param request: The request to resolve. - :param timeout_time: The time to wait for a response before timing out. - If timeout_time is math.inf, the request will not timeout. - :return: A ResponseSummary object containing the response from the backend. - If an error occurs, the ResponseSummary will contain the error message. - """ - resolve_start_time = time.time() - response = None - error: Optional[str] = None - status = ResolveStatus( - requested=False, - completed=False, - errored=False, - canceled=False, - request_start=-1, - request_end=-1, - ) + async def _process_next_request(self): + request: RequestT | MultiTurnRequestT[RequestT] | None = None + request_info: ScheduledRequestInfo | None = None + response: ResponseT | None = None try: - if timeout_time < time.time(): - raise asyncio.TimeoutError( - "The timeout time has already passed." - ) # exit early - - status.requested = True - request_func, request_kwargs = self._create_request_func_kwargs(request) - - async def _runner(): - # wrap function so we can enforce timeout and - # still return the latest state from the backend - async for resp in request_func(**request_kwargs): # type: ignore[operator] - nonlocal response - response = resp - - await asyncio.wait_for( - _runner(), - timeout=timeout_time - time.time() if timeout_time < math.inf else None, - ) + # Pull request from the queue + request, request_info = await self.messaging.get() - if not response: - raise ValueError( - f"No response received for request: {request} " - f"and backend: {self.backend}" - ) - if not isinstance(response, ResponseSummary): - raise ValueError( - f"Received no ResponseSummary for request: {request} " - f"and backend: {self.backend}, received: {response}" - ) + if isinstance(request, (list, tuple)): + raise NotImplementedError("Multi-turn requests are not yet supported") - status.completed = True - except asyncio.TimeoutError: - error = "TimeoutError: The request timed out before completing." - status.errored = True - status.canceled = True + # Calculate targeted start and set pending state for request + request_info.scheduler_node_id = self.messaging.worker_index + request_info.scheduler_timings.dequeued = time.time() + target_start = ( + request_info.scheduler_start_time + self.request_timings.next_offset() + ) + request_info.scheduler_timings.targeted_start = target_start + self._send_update("pending", response, request, request_info) + + # Schedule the request + current_time = time.time() + request_info.scheduler_timings.scheduled_at = current_time + if target_start > current_time: + await asyncio.sleep(target_start - current_time) + # Adapt delay so that scheduled at reflects the sleep time + request_info.scheduler_timings.scheduled_at = target_start + + # Process the request with the backend + request_info.scheduler_timings.resolve_start = time.time() + self._send_update("in_progress", response, request, request_info) + async for resp, info in self.backend.resolve(request, request_info, None): + response = resp + request_info = info + + # Complete the request + request_info.scheduler_timings.resolve_end = time.time() + self._send_update("completed", response, request, request_info) + + response = request = request_info = None + except asyncio.CancelledError: + # Handle cancellation + if request is not None and request_info is not None: + request_info.error = "Request was cancelled" + request_info.scheduler_timings.resolve_end = time.time() + self._send_update("cancelled", response, request, request_info) + raise except Exception as exc: # noqa: BLE001 - error = str(exc) - status.errored = True - - return self._handle_response( - status=status, - request=request, - response=response, - error=error, - resolve_start_time=resolve_start_time, - ) + if request is not None and request_info is not None: + request_info.error = str(exc) + request_info.scheduler_timings.resolve_end = time.time() + self._send_update("errored", response, request, request_info) - def _create_request_func_kwargs( + def _send_update( self, - request: GenerationRequest, - ) -> tuple[ - AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None], - dict[str, Any], - ]: - request_func: AsyncGenerator[ - Union[StreamingTextResponse, ResponseSummary], None - ] - request_kwargs: dict[str, Any] - - if request.request_type == "text_completions": - request_func = self.backend.text_completions # type: ignore[assignment] - request_kwargs = { - "prompt": request.content, - "request_id": request.request_id, - "prompt_token_count": request.stats.get("prompt_tokens", None), - "output_token_count": request.constraints.get("output_tokens", None), - **request.params, - } - elif request.request_type == "chat_completions": - request_func = self.backend.chat_completions # type: ignore[assignment] - request_kwargs = { - "content": request.content, - "request_id": request.request_id, - "prompt_token_count": request.stats.get("prompt_tokens", None), - "output_token_count": request.constraints.get("output_tokens", None), - **request.params, - } - else: - raise ValueError( - f"Invalid request type: {request.request_type} for {request}" - ) + new_status: Literal[ + "pending", "in_progress", "completed", "errored", "cancelled" + ], + response: ResponseT | None, + request: RequestT | MultiTurnRequestT[RequestT], + request_info: ScheduledRequestInfo, + ): + prev_status = request_info.status - return request_func, request_kwargs + if new_status == prev_status: + # already sent this update, don't send again + return - def _handle_response( - self, - status: ResolveStatus, - request: GenerationRequest, - response: Any, - error: Optional[str], - resolve_start_time: float, - ) -> tuple[ResolveStatus, ResponseSummary]: - if response is None or not isinstance( - response, (ResponseSummary, StreamingTextResponse) - ): - # nothing received or invalid response, fill in defaults for error - if response: - error = str( - ValueError( - f"Invalid response: {type(response)} for request: {request}; " - ) - ) + (error or "") - - response = ResponseSummary( - value="", - request_args=RequestArgs( - target=self.backend.target, - headers={}, - params={}, - payload={}, - ), - start_time=resolve_start_time, - end_time=status.request_end, - first_iter_time=None, - last_iter_time=None, - request_id=request.request_id, - error=error or "Unknown error", + try: + request_info.status = new_status + request_info = ( + request_info.model_copy() + if new_status not in {"completed", "errored", "cancelled"} + else request_info # last update, don't need to copy ) - elif isinstance(response, StreamingTextResponse): - response = ResponseSummary( - value=response.value, - request_args=RequestArgs( - target=self.backend.target, - headers={}, - params={}, - payload={}, - ), - start_time=response.start_time, - end_time=time.time(), - first_iter_time=response.first_iter_time, - last_iter_time=response.time if response.iter_count > 0 else None, - request_prompt_tokens=request.stats.get("prompt_tokens", None), - request_output_tokens=request.constraints.get("output_tokens", None), - response_prompt_tokens=None, - response_output_tokens=response.iter_count, - request_id=request.request_id, - error=error or "Unknown error", + self.messaging.put_sync( + (response, request, request_info), + timeout=-1, ) - - response.error = error - status.request_start = response.start_time - status.request_end = response.end_time - - return status, response + prev_status = new_status + except Exception as exc: + # Reset status to last one that succeeded or started function with + # Calling logic can retry after handling error, if possible + request_info.status = prev_status + raise exc diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py new file mode 100644 index 00000000..c1d516f1 --- /dev/null +++ b/src/guidellm/scheduler/worker_group.py @@ -0,0 +1,681 @@ +""" +Multi-process worker group orchestration for distributed request scheduling. + +Provides infrastructure for coordinating worker processes with shared state +management, inter-process communication, and lifecycle coordination. Handles +dynamic scaling, load balancing, constraint evaluation, and graceful shutdown +across distributed workers processing concurrent requests. +""" + +from __future__ import annotations + +import asyncio +import math +import threading +import time +import uuid +from collections.abc import AsyncIterator, Generator, Iterable, Iterator +from multiprocessing import get_context +from multiprocessing.context import BaseContext +from multiprocessing.managers import BaseManager +from multiprocessing.process import BaseProcess +from multiprocessing.synchronize import Barrier, Event +from typing import Generic, NamedTuple + +from guidellm.scheduler.constraints import Constraint, RequestsExhaustedConstraint +from guidellm.scheduler.objects import ( + BackendInterface, + MultiTurnRequestT, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerMessagingPydanticRegistry, + SchedulerState, + SchedulerUpdateAction, +) +from guidellm.scheduler.strategies import SchedulingStrategy +from guidellm.scheduler.worker import WorkerProcess +from guidellm.settings import settings +from guidellm.utils import ( + InterProcessMessaging, + InterProcessMessagingManagerQueue, + InterProcessMessagingPipe, + InterProcessMessagingQueue, + wait_for_sync_objects, +) + +__all__ = ["WorkerGroupState", "WorkerProcessGroup"] + + +class WorkerProcessGroup(Generic[RequestT, ResponseT]): + """ + Orchestrates multiple worker processes for distributed request processing. + + Manages process lifecycle, request distribution, response collection, and state + synchronization across workers. Handles dynamic scaling, load balancing, and + constraint evaluation with graceful shutdown coordination for high-throughput + request processing workloads. + + Example: + :: + from guidellm.scheduler.worker_group import WorkerProcessGroup + + group = WorkerProcessGroup( + requests=request_iterable, + cycle_requests=None, + backend=backend_instance, + strategy=scheduling_strategy, + constraints={"max_time": time_constraint} + ) + + await group.create_processes() + await group.start(time.time()) + + async for response, request, info, state in group.request_updates(): + if response is not None: + # Process completed request + handle_response(response) + + await group.shutdown() + """ + + def __init__( + self, + requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, + cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, + backend: BackendInterface[RequestT, ResponseT], + strategy: SchedulingStrategy, + constraints: dict[str, Constraint], + ): + """ + Initialize a worker process group for distributed request processing. + + :param requests: Finite iterable of requests to process sequentially + :param cycle_requests: Iterable of requests to cycle through indefinitely + :param backend: Backend interface for processing requests + :param strategy: Scheduling strategy for request timing and distribution + :param constraints: Named constraints for controlling execution behavior + :raises ValueError: If neither requests nor cycle_requests are provided, + or if cycle_requests is an Iterator rather than Iterable + """ + if not requests and not cycle_requests: + raise ValueError( + "At least one of 'requests' or 'cycle_requests' must be provided. " + f"Got requests: {requests}, cycle_requests: {cycle_requests}" + ) + + if isinstance(cycle_requests, Iterator): + raise ValueError( + f"cycle_requests must be an Iterable or None, not an Iterator. " + f"Got {type(cycle_requests)}" + ) + + self.requests = requests + self.cycle_requests = cycle_requests + self.backend = backend + self.strategy = strategy + self.constraints = constraints + + # Multiprocessing contexts and primitives, created in create_processes + self.mp_context: BaseContext = None + self.mp_manager: BaseManager = None + self.processes: list[BaseProcess] = None + self.startup_barrier: Barrier = None + self.requests_generated_event: Event = None + self.constraint_reached_event: Event = None + self.shutdown_event: Event = None + self.error_event: Event = None + + # Scheduler and messaging state, created in start + self.state: WorkerGroupState[ResponseT, RequestT] = None + self.messaging: InterProcessMessaging[ + tuple[ + RequestT | MultiTurnRequestT[RequestT], + ScheduledRequestInfo, + ], + tuple[ + ResponseT | None, + RequestT | MultiTurnRequestT[RequestT], + ScheduledRequestInfo, + SchedulerState, + ], + ] = None + + async def create_processes(self): + """ + Create and initialize worker processes for distributed request processing. + + Sets up multiprocessing infrastructure and worker processes based on + strategy constraints, backend capabilities, and system configuration. + Determines optimal process count and concurrency limits, then spawns + worker processes with distributed request handling capabilities. + + :raises RuntimeError: If process initialization or startup fails + """ + # Processes limits and params + max_conc: int = min( + self.strategy.requests_limit or math.inf, + self.backend.requests_limit or math.inf, + ) + if max_conc == math.inf: + # if concurrency not specified, use settings + max_conc = settings.max_concurrency + if max_conc <= 0: + raise RuntimeError("max_concurrency resolved to 0; increase limits/config") + + # Calculate number of processes, ensure we don't exceed the max concurrency, + # or limits from the backend, strategy, or user settings + num_processes = int( + min( + max_conc, + self.strategy.processes_limit or math.inf, + self.backend.processes_limit or math.inf, + settings.max_worker_processes, + ) + ) + if num_processes <= 0: + raise RuntimeError("num_processes resolved to 0; increase limits/config") + + per_proc_max_conc = max_conc // num_processes + max_pending_size = max( + 1, math.floor(max_conc * settings.mp_max_pending_buffer_percent) + ) + per_proc_max_buffer_size = max( + 1, math.floor(per_proc_max_conc * settings.mp_max_worker_buffer_percent) + ) + + # Initialize multiprocessing components + self.mp_context: BaseContext = get_context(settings.mp_context_type) + self.mp_manager = self.mp_context.Manager() + self.startup_barrier = self.mp_context.Barrier(num_processes + 1) + self.requests_generated_event = self.mp_context.Event() + self.constraint_reached_event = self.mp_context.Event() + self.shutdown_event = self.mp_context.Event() + self.error_event = self.mp_context.Event() + + if settings.mp_messaging_object == "queue": + self.messaging = InterProcessMessagingQueue( + mp_context=self.mp_context, + serialization=settings.mp_serialization, + encoding=settings.mp_encoding, + max_pending_size=max_pending_size, + max_buffer_send_size=settings.mp_requests_send_buffer_size, + poll_interval=settings.mp_poll_interval, + ) + elif settings.mp_messaging_object == "manager_queue": + self.messaging = InterProcessMessagingManagerQueue( + manager=self.mp_manager, + mp_context=self.mp_context, + serialization=settings.mp_serialization, + encoding=settings.mp_encoding, + max_pending_size=max_pending_size, + max_buffer_send_size=settings.mp_requests_send_buffer_size, + poll_interval=settings.mp_poll_interval, + ) + elif settings.mp_messaging_object == "pipe": + self.messaging = InterProcessMessagingPipe( + num_workers=num_processes, + mp_context=self.mp_context, + serialization=settings.mp_serialization, + encoding=settings.mp_encoding, + max_pending_size=max_pending_size, + max_buffer_send_size=settings.mp_requests_send_buffer_size, + poll_interval=settings.mp_poll_interval, + ) + + # Initialize worker processes + self.processes = [] + for rank in range(num_processes): + # Distribute any remainder across the first N ranks + async_limit = per_proc_max_conc + ( + 1 if rank < (max_conc % num_processes) else 0 + ) + + worker = WorkerProcess[RequestT, ResponseT]( + messaging=self.messaging.create_worker_copy( + worker_index=rank, + max_buffer_send_size=None, + max_buffer_receive_size=per_proc_max_buffer_size, + ), + backend=self.backend, + request_timings=self.strategy.create_request_timings( + local_rank=rank, + local_world_size=num_processes, + local_max_concurrency=async_limit, + ), + async_limit=async_limit, + startup_barrier=self.startup_barrier, + requests_generated_event=self.requests_generated_event, + constraint_reached_event=self.constraint_reached_event, + shutdown_event=self.shutdown_event, + error_event=self.error_event, + ) + proc = self.mp_context.Process(target=worker.run, daemon=False) + proc.start() + self.processes.append(proc) + + wait_key = await wait_for_sync_objects( + { + "startup_barrier": self.startup_barrier, + "shutdown_event": self.shutdown_event, + "error_event": self.error_event, + }, + poll_interval=settings.mp_poll_interval, + ) + + if wait_key == "error_event": + raise RuntimeError( + "Worker process group startup failed: error_event is set" + ) + + async def start(self, start_time: float): + """ + Begin request processing at the specified start time. + + Initializes scheduler state and background tasks, then waits until the + specified start time before beginning operations. Sets up inter-process + communication and coordinates synchronized startup across all workers. + + :param start_time: Unix timestamp when processing should begin + :raises RuntimeError: If workers encounter errors during startup or + if create_processes() was not called first + """ + if not self.processes: + raise RuntimeError("create_processes() must be called before start()") + + stop_send_requests_event = threading.Event() + send_requests_stopped_event = threading.Event() + self.state = WorkerGroupState[RequestT, ResponseT]( + start_time=start_time, + processes=self.processes, + constraints=self.constraints, + stop_send_requests_event=stop_send_requests_event, + send_requests_stopped_event=send_requests_stopped_event, + requests_generated_event=self.requests_generated_event, + constraint_reached_event=self.constraint_reached_event, + shutdown_event=self.shutdown_event, + error_event=self.error_event, + ) + await self.messaging.start( + send_items=self.state.requests_generator( + self.requests, self.cycle_requests + ), + receive_callback=self.state.received_callback, + send_stopped_event=send_requests_stopped_event, + send_stop_criteria=[stop_send_requests_event], + receive_stop_criteria=[self.shutdown_event], + pydantic_models=list(SchedulerMessagingPydanticRegistry.registry.values()), + ) + + if (wait_time := start_time - time.time()) > 0: + await asyncio.sleep(wait_time) + if self.error_event.is_set(): + raise RuntimeError( + "error_event is set in WorkerProcessGroup, " + "indicating an error occurred in one of the worker processes." + ) + + async def request_updates( + self, + ) -> AsyncIterator[ + tuple[ + ResponseT | None, + RequestT, + ScheduledRequestInfo, + SchedulerState, + ] + ]: + """ + Yield request processing updates as they become available. + + Returns an async iterator of request updates including response, request, + request scheduling info, and scheduler state. Updates occur on request queued, + processing start, and completion. Response is None until processing completes. + + :return: Async iterator yielding (response, request, request_info, state) + tuples where response is None until processing is complete + :raises RuntimeError: If workers encounter unrecoverable errors + """ + while True: + if self.error_event.is_set(): + raise RuntimeError( + "error_event is set in WorkerProcessGroup, " + "indicating an error occurred in one of the worker processes." + ) + + try: + ( + response, + request, + request_info, + scheduler_state, + ) = await self.messaging.get(timeout=settings.mp_poll_interval) + + yield response, request, request_info, scheduler_state + except asyncio.TimeoutError: + if self.shutdown_event.is_set(): + # Everything yielded, exit + break + + async def shutdown(self) -> list[Exception]: # noqa: C901 + """ + Gracefully shut down the worker process group and clean up resources. + + Performs safe shutdown of worker processes, background tasks, and + multiprocessing resources. Coordinates orderly termination across + all workers and collects any exceptions encountered during shutdown. + + :return: List of exceptions encountered during shutdown; empty if no errors + """ + exceptions: list[Exception] = [] + if self.shutdown_event is not None: + self.shutdown_event.set() + + # Clear out start values + if self.messaging is not None: + try: + await asyncio.wait_for(self.messaging.stop(), timeout=5.0) + except Exception as err: # noqa: BLE001 + exceptions.append(err) + self.messaging = None + self.state = None + + # Clear out create processes values + if self.processes is not None: + for proc in self.processes: + try: + await asyncio.to_thread(proc.join, timeout=5.0) + if proc.exitcode is not None and proc.exitcode > 0: + exceptions.append( + RuntimeError( + f"Worker {proc.pid} exited with code {proc.exitcode}" + ) + ) + except Exception as err: # noqa: BLE001 + exceptions.append(err) + self.processes = None + self.startup_barrier = None + self.requests_generated_event = None + self.constraint_reached_event = None + self.shutdown_event = None + self.error_event = None + if self.mp_manager is not None: + try: + self.mp_manager.shutdown() + except Exception as err: # noqa: BLE001 + exceptions.append(err) + self.mp_manager = None + self.mp_context = None + + return exceptions + + +class _StateUpdate(NamedTuple): + state: SchedulerState + stop_queueing: bool + stop_processing: bool + + +class WorkerGroupState(Generic[RequestT, ResponseT]): + """ + Manages scheduler state and synchronization for worker process groups. + + Handles request generation, state updates, constraint evaluation, and + coordination between worker processes. Provides thread-safe state management + with request lifecycle tracking and constraint-based termination logic. + """ + + def __init__( + self, + start_time: float, + processes: list[BaseProcess], + constraints: dict[str, Constraint], + stop_send_requests_event: threading.Event, + send_requests_stopped_event: threading.Event, + requests_generated_event: Event, + constraint_reached_event: Event, + shutdown_event: Event, + error_event: Event, + ): + """ + Initialize worker group state management. + + :param start_time: Unix timestamp when processing should begin + :param processes: List of worker process instances + :param constraints: Named constraints for controlling execution behavior + :param send_requests_stopped_event: Threading event for request coordination + :param requests_generated_event: Multiprocessing event for generation completion + :param constraint_reached_event: Multiprocessing event for constraint stopping + :param shutdown_event: Multiprocessing event for coordinated shutdown + :param error_event: Multiprocessing event for error condition signaling + """ + self.start_time = start_time + self.processes = processes + self.constraints = constraints + self.stop_send_requests_event = stop_send_requests_event + self.send_requests_stopped_event = send_requests_stopped_event + self.requests_generated_event = requests_generated_event + self.constraint_reached_event = constraint_reached_event + self.shutdown_event = shutdown_event + self.error_event = error_event + + self._update_lock: threading.Lock = threading.Lock() + self._state: SchedulerState = SchedulerState( + node_id=0, + num_processes=len(processes), + start_time=start_time, + ) + self._queued_requests = set() + self._pending_requests = set() + self._processing_requests = set() + + def requests_generator( + self, + requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, + cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, + ) -> Generator[tuple[RequestT | MultiTurnRequestT[RequestT],], None, None]: + """ + Generate request-info pairs for worker processing with constraint evaluation. + + Processes finite requests sequentially then cycles through repeating requests + indefinitely. Creates scheduling metadata for each request and evaluates + constraints to determine when to stop request generation. + + :param requests: Finite iterable of requests to process sequentially + :param cycle_requests: Iterable of requests to cycle through indefinitely + :return: Generator yielding (request, request_info) tuples + """ + + def _iter(): + if requests: + yield from requests + + if cycle_requests: + while True: + yield from cycle_requests + + count = 0 + request_info: ScheduledRequestInfo = None + for request in _iter(): + count += 1 + + if hasattr(request, "request_id"): + request_id = request.request_id + elif hasattr(request, "id"): + request_id = request.id + else: + request_id = str(uuid.uuid4()) + request_info: ScheduledRequestInfo = ScheduledRequestInfo( + request_id=request_id, + status="queued", + scheduler_process_id=0, + scheduler_start_time=self.start_time, + ) + state_update = self._locked_update(request_info) + yield (request, request_info) + + if state_update.stop_queueing: + self.stop_send_requests_event.set() + return + + # Reached the end, inject a RequestsExhaustedConstraint to record + self._locked_update( + info=None, + requests_exhausted=RequestsExhaustedConstraint(num_requests=count), + ) + self.stop_send_requests_event.set() + + def received_callback( + self, + update: tuple[ + ResponseT | None, + RequestT | MultiTurnRequestT, + ScheduledRequestInfo, + ], + ) -> tuple[ + ResponseT | None, + RequestT | MultiTurnRequestT, + ScheduledRequestInfo, + SchedulerState, + ]: + """ + Process received request updates and inject current scheduler state. + + Updates internal state tracking based on request status changes and + evaluates constraints to determine if processing should be terminated. + Triggers shutdown when stop conditions are met. + + :param update: Tuple containing response, request, and request info + :return: Updated tuple with injected scheduler state + """ + response, request, request_info = update + state_update = self._locked_update(info=request_info) + + # Check if we need to tell workers to stop pulling new requests + # based on no more requests sent and all requests removed from queue + if ( + state_update.state.queued_requests == 0 + and self.send_requests_stopped_event.is_set() + and not self.requests_generated_event.is_set() + ): + self.requests_generated_event.set() + + # Check if we need to tell workers to stop processing requests (constraints) + if state_update.stop_processing and not self.constraint_reached_event.is_set(): + self.constraint_reached_event.set() + + # Check if all requests have been processed and can shutdown + if ( + state_update.state.processed_requests == state_update.state.created_requests + and self.send_requests_stopped_event.is_set() + and self.requests_generated_event.is_set() + and self.constraint_reached_event.is_set() + and not self.shutdown_event.is_set() + ): + self.shutdown_event.set() + + return ( + response, + request, + request_info, + state_update.state, # inject state for updates to be yielded back + ) + + def _locked_update( + self, + info: ScheduledRequestInfo | None = None, + **add_constraints: dict[str, Constraint], + ) -> _StateUpdate: + with self._update_lock: + if add_constraints: + self.constraints.update(add_constraints) + + if info is not None: + self._state.end_time = time.time() # Always update in case last update + self._update_state_request_counts(info) + self._update_with_constraints(info) + + state_copy: SchedulerState = self._state.model_copy() + + return _StateUpdate( + state_copy, + state_copy.end_queuing_time is not None, + state_copy.end_processing_time is not None, + ) + + def _update_state_request_counts(self, info: ScheduledRequestInfo): + if info.status == "queued": + self._queued_requests.add(info.request_id) + self._state.queued_requests = len(self._queued_requests) + self._state.created_requests += 1 + elif info.status == "pending": + self._queued_requests.remove(info.request_id) + self._state.queued_requests = len(self._queued_requests) + self._pending_requests.add(info.request_id) + self._state.pending_requests = len(self._pending_requests) + elif info.status == "in_progress": + self._pending_requests.remove(info.request_id) + self._state.pending_requests = len(self._pending_requests) + self._processing_requests.add(info.request_id) + self._state.processing_requests = len(self._processing_requests) + elif info.status == "completed": + self._processing_requests.remove(info.request_id) + self._state.processing_requests = len(self._processing_requests) + self._state.processed_requests += 1 + self._state.successful_requests += 1 + elif info.status in ("errored", "cancelled"): + if info.request_id in self._queued_requests: + self._queued_requests.remove(info.request_id) + self._state.queued_requests = len(self._queued_requests) + elif info.request_id in self._pending_requests: + self._pending_requests.remove(info.request_id) + self._state.pending_requests = len(self._pending_requests) + elif info.request_id in self._processing_requests: + self._processing_requests.remove(info.request_id) + self._state.processing_requests = len(self._processing_requests) + + self._state.processed_requests += 1 + self._state.errored_requests += 1 if info.status == "errored" else 0 + self._state.cancelled_requests += 1 if info.status == "cancelled" else 0 + else: + raise ValueError(f"Unknown request_info status {info.status} for {info}") + + def _update_with_constraints(self, info: ScheduledRequestInfo): + actions: dict[str, SchedulerUpdateAction] = { + name: const(self._state, info) for name, const in self.constraints.items() + } + self._state.scheduler_constraints = actions + stop_queuing_actions = {} + stop_processing_actions = {} + + for key, action in actions.items(): + # Action updates + if ( + self._state.end_queuing_time is None + and action.request_queuing == "stop" + ): + stop_queuing_actions[key] = action + if ( + self._state.end_processing_time is None + and action.request_processing in ("stop_local", "stop_all") + ): + stop_processing_actions[key] = action + + for progress_key in ( + "remaining_fraction", + "remaining_requests", + "remaining_duration", + ): + if (new_val := action.progress.get(progress_key)) is not None and ( + getattr(self._state, progress_key) is None + or new_val < getattr(self._state, progress_key) + ): + setattr(self._state, progress_key, new_val) + + if stop_queuing_actions: + self._state.end_queuing_constraints = stop_queuing_actions + self._state.end_queuing_time = time.time() + + if stop_processing_actions: + self._state.end_processing_constraints = stop_processing_actions + self._state.end_processing_time = time.time() diff --git a/tests/unit/scheduler/__init__.py b/tests/unit/scheduler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/scheduler/test_constraints.py b/tests/unit/scheduler/test_constraints.py new file mode 100644 index 00000000..931af413 --- /dev/null +++ b/tests/unit/scheduler/test_constraints.py @@ -0,0 +1,1418 @@ +import inspect +import random +import time +from abc import ABC +from typing import Protocol + +import pytest +from pydantic import ValidationError + +from guidellm.scheduler import ( + Constraint, + ConstraintInitializer, + ConstraintsInitializerFactory, + MaxDurationConstraint, + MaxErrorRateConstraint, + MaxErrorsConstraint, + MaxGlobalErrorRateConstraint, + MaxNumberConstraint, + PydanticConstraintInitializer, + ScheduledRequestInfo, + SchedulerState, + SchedulerUpdateAction, + SerializableConstraintInitializer, + UnserializableConstraintInitializer, +) +from guidellm.utils import InfoMixin, StandardBaseModel + + +class TestConstraint: + """Test the Constraint protocol.""" + + @pytest.mark.smoke + def test_is_protocol(self): + """Test that Constraint is a protocol and runtime checkable.""" + assert issubclass(Constraint, Protocol) + assert hasattr(Constraint, "_is_protocol") + assert Constraint._is_protocol is True + assert hasattr(Constraint, "_is_runtime_protocol") + assert Constraint._is_runtime_protocol is True + + @pytest.mark.smoke + def test_protocol_method_signature(self): + """Test that the Constraint protocol has the correct method signature.""" + call_method = Constraint.__call__ + sig = inspect.signature(call_method) + + expected_params = ["self", "state", "request"] + assert list(sig.parameters.keys()) == expected_params + + params = sig.parameters + assert "state" in params + assert "request" in params + + @pytest.mark.smoke + def test_runtime_is_constraint(self): + """Test that Constraint can be checked at runtime using isinstance.""" + + class ValidConstraint: + def __call__( + self, + state: SchedulerState, + request: ScheduledRequestInfo, + ) -> SchedulerUpdateAction: + return SchedulerUpdateAction() + + valid_instance = ValidConstraint() + assert isinstance(valid_instance, Constraint) + + class InvalidConstraint: + pass + + invalid_instance = InvalidConstraint() + assert not isinstance(invalid_instance, Constraint) + + @pytest.mark.smoke + def test_runtime_is_not_intializer(self): + """ + Test that a class not implementing the ConstraintInitializer + protocol is not recognized as such. + """ + + class ValidConstraint: + def __call__( + self, + state: SchedulerState, + request: ScheduledRequestInfo, + ) -> SchedulerUpdateAction: + return SchedulerUpdateAction() + + not_initializer_instance = ValidConstraint() + assert not isinstance(not_initializer_instance, ConstraintInitializer) + + +class TestConstraintInitializer: + """Test the ConstraintInitializer protocol.""" + + @pytest.mark.smoke + def test_is_protocol(self): + """Test that ConstraintInitializer is a protocol and runtime checkable.""" + assert issubclass(ConstraintInitializer, Protocol) + assert hasattr(ConstraintInitializer, "_is_protocol") + assert ConstraintInitializer._is_protocol is True + assert hasattr(ConstraintInitializer, "_is_runtime_protocol") + assert ConstraintInitializer._is_runtime_protocol is True + + @pytest.mark.smoke + def test_protocol_method_signature(self): + """Test that ConstraintInitializer protocol has correct method signature.""" + create_constraint_method = ConstraintInitializer.create_constraint + sig = inspect.signature(create_constraint_method) + + expected_params = ["self", "kwargs"] + assert list(sig.parameters.keys()) == expected_params + kwargs_param = sig.parameters["kwargs"] + assert kwargs_param.kind == kwargs_param.VAR_KEYWORD + + @pytest.mark.smoke + def test_runtime_is_initializer(self): + """Test that ConstraintInitializer can be checked at runtime.""" + + class ValidInitializer: + def create_constraint(self, **kwargs) -> Constraint: + class SimpleConstraint: + def __call__( + self, + state: SchedulerState, + request: ScheduledRequestInfo, + ) -> SchedulerUpdateAction: + return SchedulerUpdateAction() + + return SimpleConstraint() + + valid_instance = ValidInitializer() + assert isinstance(valid_instance, ConstraintInitializer) + + @pytest.mark.smoke + def test_runtime_is_not_constraint(self): + """ + Test that a class not implementing the Constraint protocol + is not recognized as such. + """ + + class ValidInitializer: + def create_constraint(self, **kwargs) -> Constraint: + class SimpleConstraint: + def __call__( + self, + state: SchedulerState, + request: ScheduledRequestInfo, + ) -> SchedulerUpdateAction: + return SchedulerUpdateAction() + + return SimpleConstraint() + + not_constraint_instance = ValidInitializer() + assert not isinstance(not_constraint_instance, Constraint) + + +class TestSerializableConstraintInitializer: + """Test the SerializableConstraintInitializer protocol.""" + + @pytest.mark.smoke + def test_is_protocol(self): + """Test SerializableConstraintInitializer is a protocol and checkable.""" + assert issubclass(SerializableConstraintInitializer, Protocol) + assert hasattr(SerializableConstraintInitializer, "_is_protocol") + assert SerializableConstraintInitializer._is_protocol is True + assert hasattr(SerializableConstraintInitializer, "_is_runtime_protocol") + assert SerializableConstraintInitializer._is_runtime_protocol is True + + @pytest.mark.smoke + def test_protocol_method_signatures(self): + """Test SerializableConstraintInitializer protocol has correct signatures.""" + methods = [ + "validated_kwargs", + "model_validate", + "model_dump", + "create_constraint", + ] + + for method_name in methods: + assert hasattr(SerializableConstraintInitializer, method_name) + + @pytest.mark.smoke + def test_runtime_is_serializable_initializer(self): + """Test that SerializableConstraintInitializer can be checked at runtime.""" + + class ValidSerializableInitializer: + @classmethod + def validated_kwargs(cls, *args, **kwargs): + return kwargs + + @classmethod + def model_validate(cls, **kwargs): + return cls() + + def model_dump(self): + return {} + + def create_constraint(self, **kwargs): + class SimpleConstraint: + def __call__(self, state, request): + return SchedulerUpdateAction() + + return SimpleConstraint() + + valid_instance = ValidSerializableInitializer() + assert isinstance(valid_instance, SerializableConstraintInitializer) + + +class TestPydanticConstraintInitializer: + """Test the PydanticConstraintInitializer implementation.""" + + @pytest.mark.smoke + def test_class_signatures(self): + """Test PydanticConstraintInitializer inheritance and abstract methods.""" + assert issubclass(PydanticConstraintInitializer, StandardBaseModel) + assert issubclass(PydanticConstraintInitializer, ABC) + assert issubclass(PydanticConstraintInitializer, InfoMixin) + + @pytest.mark.smoke + def test_abstract_methods(self): + """Test that PydanticConstraintInitializer has required abstract methods.""" + abstract_methods = PydanticConstraintInitializer.__abstractmethods__ + expected_methods = {"validated_kwargs", "create_constraint"} + assert abstract_methods == expected_methods + + @pytest.mark.sanity + def test_cannot_instantiate_directly(self): + """Test that PydanticConstraintInitializer cannot be instantiated directly.""" + with pytest.raises(TypeError): + PydanticConstraintInitializer(type_="test") + + +class TestUnserializableConstraintInitializer: + """Test the UnserializableConstraintInitializer implementation.""" + + @pytest.fixture( + params=[ + {"orig_info": {}}, + {"orig_info": {"class": "SomeClass", "module": "some.module"}}, + ] + ) + def valid_instances(self, request): + """Fixture providing test data for UnserializableConstraintInitializer.""" + constructor_args = request.param + instance = UnserializableConstraintInitializer(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test UnserializableConstraintInitializer inheritance.""" + assert issubclass( + UnserializableConstraintInitializer, PydanticConstraintInitializer + ) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test UnserializableConstraintInitializer initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, UnserializableConstraintInitializer) + assert instance.type_ == "unserializable" + assert instance.orig_info == constructor_args["orig_info"] + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test validated_kwargs class method.""" + result = UnserializableConstraintInitializer.validated_kwargs( + orig_info={"test": "data"} + ) + assert result == {"orig_info": {"test": "data"}} + + result = UnserializableConstraintInitializer.validated_kwargs() + assert result == {"orig_info": {}} + + @pytest.mark.sanity + def test_create_constraint_raises(self, valid_instances): + """Test that create_constraint raises RuntimeError.""" + instance, _ = valid_instances + with pytest.raises( + RuntimeError, match="Cannot create constraint from unserializable" + ): + instance.create_constraint() + + @pytest.mark.sanity + def test_call_raises(self, valid_instances): + """Test that calling constraint raises RuntimeError.""" + instance, _ = valid_instances + state = SchedulerState(node_id="test_node", num_processes=1, start_time=0.0) + request = ScheduledRequestInfo( + request_id="test_request", + status="pending", + scheduler_node_id="test_node", + scheduler_process_id=1, + scheduler_start_time=0.0, + ) + + with pytest.raises( + RuntimeError, match="Cannot invoke unserializable constraint" + ): + instance(state, request) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test UnserializableConstraintInitializer serialization/deserialization.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert data["type_"] == "unserializable" + assert data["orig_info"] == constructor_args["orig_info"] + + reconstructed = UnserializableConstraintInitializer.model_validate(data) + assert reconstructed.type_ == instance.type_ + assert reconstructed.orig_info == instance.orig_info + + +class TestMaxNumberConstraint: + """Test the MaxNumberConstraint implementation.""" + + @pytest.fixture(params=[{"max_num": 100}, {"max_num": 50.5}, {"max_num": 1}]) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxNumberConstraint(**constructor_args) + + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxNumberConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self, valid_instances): + """Test MaxNumberConstraint satisfies the ConstraintInitializer protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that MaxNumberConstraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxNumberConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxNumberConstraint() + with pytest.raises(ValidationError): + MaxNumberConstraint(max_num=-1) + with pytest.raises(ValidationError): + MaxNumberConstraint(max_num=0) + with pytest.raises(ValidationError): + MaxNumberConstraint(max_num="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions and progress""" + instance, constructor_args = valid_instances + start_time = time.time() + + for num_requests in range(0, int(constructor_args["max_num"]) * 2 + 1, 1): + state = SchedulerState( + start_time=start_time, + created_requests=num_requests, + processed_requests=num_requests, + errored_requests=0, + ) + request_info = ScheduledRequestInfo( + request_id="test", status="completed", created_at=start_time + ) + + action = instance(state, request_info) + assert isinstance(action, SchedulerUpdateAction) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxNumberConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxNumberConstraint.model_validate(data) + assert reconstructed.max_num == instance.max_num + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + @pytest.mark.smoke + def test_create_constraint_functionality(self, valid_instances): + """Test the constraint initializer functionality.""" + instance, constructor_args = valid_instances + + constraint = instance.create_constraint() + assert isinstance(constraint, MaxNumberConstraint) + assert constraint.max_num == constructor_args["max_num"] + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test MaxNumberConstraint.validated_kwargs class method.""" + result = MaxNumberConstraint.validated_kwargs(max_num=100) + assert result == {"max_num": 100, "current_index": -1} + + result = MaxNumberConstraint.validated_kwargs(50.5) + assert result == {"max_num": 50.5, "current_index": -1} + + @pytest.mark.smoke + def test_create_constraint(self, valid_instances): + """Test MaxNumberConstraint.create_constraint method.""" + instance, constructor_args = valid_instances + original_index = instance.current_index + constraint = instance.create_constraint() + + assert isinstance(constraint, MaxNumberConstraint) + assert constraint is not instance # Should return a copy + assert constraint.max_num == instance.max_num + assert instance.current_index == original_index + 1 # Original is incremented + assert constraint.current_index == original_index + 1 # Copy has incremented + + @pytest.mark.smoke + def test_factory_registration(self): + """Test MaxNumberConstraint is properly registered with expected aliases.""" + expected_aliases = ["max_number", "max_num", "max_requests", "max_req"] + + for alias in expected_aliases: + assert ConstraintsInitializerFactory.is_registered(alias) + registered_class = ConstraintsInitializerFactory.get_registered_object( + alias + ) + assert registered_class == MaxNumberConstraint + + @pytest.mark.smoke + @pytest.mark.parametrize( + "alias", ["max_number", "max_num", "max_requests", "max_req"] + ) + def test_factory_creation_with_aliases(self, alias): + """Test factory creation using different aliases.""" + # Test with dict configuration + constraint = ConstraintsInitializerFactory.create_constraint(alias, max_num=100) + assert isinstance(constraint, MaxNumberConstraint) + assert constraint.max_num == 100 + + # Test with simple value + constraint = ConstraintsInitializerFactory.create_constraint(alias, 50) + assert isinstance(constraint, MaxNumberConstraint) + assert constraint.max_num == 50 + + @pytest.mark.smoke + def test_factory_resolve_methods(self): + """Test factory resolve methods with various input formats.""" + # Test with dict config + resolved = ConstraintsInitializerFactory.resolve( + {"max_number": {"max_num": 200}} + ) + assert isinstance(resolved["max_number"], MaxNumberConstraint) + assert resolved["max_number"].max_num == 200 + + # Test with simple value + resolved = ConstraintsInitializerFactory.resolve({"max_num": 150}) + assert isinstance(resolved["max_num"], MaxNumberConstraint) + assert resolved["max_num"].max_num == 150 + + # Test with instance + instance = MaxNumberConstraint(max_num=75) + resolved = ConstraintsInitializerFactory.resolve({"max_requests": instance}) + assert resolved["max_requests"] is instance + + +class TestMaxDurationConstraint: + """Test the MaxDurationConstraint implementation.""" + + @pytest.fixture( + params=[{"max_duration": 2.0}, {"max_duration": 1}, {"max_duration": 0.5}] + ) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxDurationConstraint(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxDurationConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxDurationConstraint also satisfies + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that MaxDurationConstraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxDurationConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxDurationConstraint() + with pytest.raises(ValidationError): + MaxDurationConstraint(max_duration=-1) + with pytest.raises(ValidationError): + MaxDurationConstraint(max_duration=0) + with pytest.raises(ValidationError): + MaxDurationConstraint(max_duration="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions and progress through a time loop""" + instance, constructor_args = valid_instances + start_time = time.time() + + max_duration = constructor_args["max_duration"] + sleep_interval = max_duration * 0.05 + target_duration = max_duration * 1.5 + + elapsed = 0.0 + step = 0 + + while elapsed <= target_duration: + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=step + 1, + processed_requests=step, + ) + request = ScheduledRequestInfo( + request_id=f"test-{step}", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + + duration_exceeded = elapsed >= max_duration + + if not duration_exceeded: + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + else: + assert action.request_queuing == "stop" + assert action.request_processing == "stop_local" + assert isinstance(action.metadata, dict) + assert action.metadata["max_duration"] == max_duration + assert action.metadata["elapsed_time"] == pytest.approx(elapsed, abs=0.01) + assert action.metadata["duration_exceeded"] == duration_exceeded + assert action.metadata["start_time"] == start_time + assert isinstance(action.progress, dict) + expected_remaining_fraction = max(0.0, 1.0 - elapsed / max_duration) + expected_remaining_duration = max(0.0, max_duration - elapsed) + assert action.progress["remaining_fraction"] == pytest.approx( + expected_remaining_fraction, abs=0.1 + ) + assert action.progress["remaining_duration"] == pytest.approx( + expected_remaining_duration, abs=0.1 + ) + time.sleep(sleep_interval) + elapsed = time.time() - start_time + step += 1 + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxDurationConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxDurationConstraint.model_validate(data) + assert reconstructed.max_duration == instance.max_duration + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + @pytest.mark.smoke + def test_create_constraint_functionality(self, valid_instances): + """Test the constraint initializer functionality.""" + instance, constructor_args = valid_instances + + constraint = instance.create_constraint() + assert isinstance(constraint, MaxDurationConstraint) + assert constraint.max_duration == constructor_args["max_duration"] + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test MaxDurationConstraint.validated_kwargs class method.""" + result = MaxDurationConstraint.validated_kwargs(max_duration=60.0) + assert result == {"max_duration": 60.0, "current_index": -1} + + result = MaxDurationConstraint.validated_kwargs(30) + assert result == {"max_duration": 30, "current_index": -1} + + @pytest.mark.smoke + def test_create_constraint(self, valid_instances): + """Test MaxDurationConstraint.create_constraint method.""" + instance, constructor_args = valid_instances + original_index = instance.current_index + constraint = instance.create_constraint() + + assert isinstance(constraint, MaxDurationConstraint) + assert constraint is not instance # Should return a copy + assert constraint.max_duration == instance.max_duration + assert instance.current_index == original_index + 1 # Original is incremented + assert constraint.current_index == original_index + 1 # Copy has incremented + + @pytest.mark.smoke + def test_factory_registration(self): + """Test MaxDurationConstraint is properly registered with expected aliases.""" + expected_aliases = [ + "max_duration", + "max_dur", + "max_sec", + "max_seconds", + "max_min", + "max_minutes", + ] + + for alias in expected_aliases: + assert ConstraintsInitializerFactory.is_registered(alias) + registered_class = ConstraintsInitializerFactory.get_registered_object( + alias + ) + assert registered_class == MaxDurationConstraint + + @pytest.mark.smoke + @pytest.mark.parametrize( + "alias", + ["max_duration", "max_dur", "max_sec", "max_seconds", "max_min", "max_minutes"], + ) + def test_factory_creation_with_aliases(self, alias): + """Test factory creation using different aliases.""" + # Test with dict configuration + constraint = ConstraintsInitializerFactory.create_constraint( + alias, max_duration=60.0 + ) + assert isinstance(constraint, MaxDurationConstraint) + assert constraint.max_duration == 60.0 + + # Test with simple value + constraint = ConstraintsInitializerFactory.create_constraint(alias, 30.0) + assert isinstance(constraint, MaxDurationConstraint) + assert constraint.max_duration == 30.0 + + @pytest.mark.smoke + def test_factory_resolve_methods(self): + """Test factory resolve methods with various input formats.""" + # Test with dict config + resolved = ConstraintsInitializerFactory.resolve( + {"max_duration": {"max_duration": 120.0}} + ) + assert isinstance(resolved["max_duration"], MaxDurationConstraint) + assert resolved["max_duration"].max_duration == 120.0 + + # Test with simple value + resolved = ConstraintsInitializerFactory.resolve({"max_sec": 90.0}) + assert isinstance(resolved["max_sec"], MaxDurationConstraint) + assert resolved["max_sec"].max_duration == 90.0 + + # Test with instance + instance = MaxDurationConstraint(max_duration=45.0) + resolved = ConstraintsInitializerFactory.resolve({"max_minutes": instance}) + assert resolved["max_minutes"] is instance + + +class TestMaxErrorsConstraint: + """Test the MaxErrorsConstraint implementation.""" + + @pytest.fixture(params=[{"max_errors": 10}, {"max_errors": 5.5}, {"max_errors": 1}]) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxErrorsConstraint(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxErrorsConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxErrorsConstraint also satisfies + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that MaxErrorsConstraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxErrorsConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxErrorsConstraint() + with pytest.raises(ValidationError): + MaxErrorsConstraint(max_errors=-1) + with pytest.raises(ValidationError): + MaxErrorsConstraint(max_errors=0) + with pytest.raises(ValidationError): + MaxErrorsConstraint(max_errors="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions""" + instance, constructor_args = valid_instances + start_time = time.time() + + for num_errors in range(int(constructor_args["max_errors"] * 2)): + created_requests = (num_errors + 1) * 2 + processed_requests = num_errors + 1 + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=created_requests, + processed_requests=processed_requests, + errored_requests=num_errors, + ) + request = ScheduledRequestInfo( + request_id=f"test-{num_errors}", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + errors_exceeded = num_errors >= constructor_args["max_errors"] + if not errors_exceeded: + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + else: + assert action.request_queuing == "stop" + assert action.request_processing == "stop_all" + + assert isinstance(action.metadata, dict) + assert action.metadata == { + "max_errors": constructor_args["max_errors"], + "errors_exceeded": errors_exceeded, + "current_errors": num_errors, + } + assert action.progress == {} + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxErrorsConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxErrorsConstraint.model_validate(data) + assert reconstructed.max_errors == instance.max_errors + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test MaxErrorsConstraint.validated_kwargs class method.""" + result = MaxErrorsConstraint.validated_kwargs(max_errors=10) + assert result == {"max_errors": 10, "current_index": -1} + + result = MaxErrorsConstraint.validated_kwargs(5.5) + assert result == {"max_errors": 5.5, "current_index": -1} + + @pytest.mark.smoke + def test_create_constraint(self, valid_instances): + """Test MaxErrorsConstraint.create_constraint method.""" + instance, constructor_args = valid_instances + original_index = instance.current_index + constraint = instance.create_constraint() + + assert isinstance(constraint, MaxErrorsConstraint) + assert constraint is not instance + assert constraint.max_errors == instance.max_errors + assert instance.current_index == original_index + 1 + assert constraint.current_index == original_index + 1 + + @pytest.mark.smoke + def test_factory_registration(self): + """Test MaxErrorsConstraint is properly registered with expected aliases.""" + expected_aliases = ["max_errors", "max_err", "max_error", "max_errs"] + + for alias in expected_aliases: + assert ConstraintsInitializerFactory.is_registered(alias) + registered_class = ConstraintsInitializerFactory.get_registered_object( + alias + ) + assert registered_class == MaxErrorsConstraint + + @pytest.mark.smoke + @pytest.mark.parametrize( + "alias", ["max_errors", "max_err", "max_error", "max_errs"] + ) + def test_factory_creation_with_aliases(self, alias): + """Test factory creation using different aliases.""" + # Test with dict configuration + constraint = ConstraintsInitializerFactory.create_constraint( + alias, max_errors=10 + ) + assert isinstance(constraint, MaxErrorsConstraint) + assert constraint.max_errors == 10 + + # Test with simple value + constraint = ConstraintsInitializerFactory.create_constraint(alias, 5) + assert isinstance(constraint, MaxErrorsConstraint) + assert constraint.max_errors == 5 + + @pytest.mark.smoke + def test_factory_resolve_methods(self): + """Test factory resolve methods with various input formats.""" + # Test with dict config + resolved = ConstraintsInitializerFactory.resolve( + {"max_errors": {"max_errors": 15}} + ) + assert isinstance(resolved["max_errors"], MaxErrorsConstraint) + assert resolved["max_errors"].max_errors == 15 + + # Test with simple value + resolved = ConstraintsInitializerFactory.resolve({"max_err": 8}) + assert isinstance(resolved["max_err"], MaxErrorsConstraint) + assert resolved["max_err"].max_errors == 8 + + # Test with instance + instance = MaxErrorsConstraint(max_errors=3) + resolved = ConstraintsInitializerFactory.resolve({"max_error": instance}) + assert resolved["max_error"] is instance + + +class TestMaxErrorRateConstraint: + """Test the MaxErrorRateConstraint implementation.""" + + @pytest.fixture( + params=[ + {"max_error_rate": 0.1, "window_size": 40}, + {"max_error_rate": 0.5, "window_size": 50}, + {"max_error_rate": 0.05, "window_size": 55}, + ] + ) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxErrorRateConstraint(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxErrorRateConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxErrorRateConstraint also satisfies + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that MaxErrorRateConstraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxErrorRateConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxErrorRateConstraint() + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate=0) + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate=-1) + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate=1.5) + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate=0.5, window_size=0) + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions with sliding window behavior""" + instance, constructor_args = valid_instances + start_time = time.time() + + max_error_rate = constructor_args["max_error_rate"] + window_size = constructor_args["window_size"] + safety_factor = 1.5 + total_errors = 0 + error_window = [] + + for request_num in range(window_size * 2): + error_probability = max_error_rate * safety_factor + + if random.random() < error_probability: + total_errors += 1 + status = "errored" + error_window.append(1) + else: + status = "completed" + error_window.append(0) + error_window = ( + error_window[-window_size:] + if len(error_window) > window_size + else error_window + ) + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=request_num + 1, + processed_requests=request_num + 1, + ) + request = ScheduledRequestInfo( + request_id=f"test-{request_num}", + status=status, + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + error_count = sum(instance.error_window) + processed_requests = state.processed_requests + exceeded_min_processed = processed_requests >= window_size + current_error_rate = ( + error_count / float(min(processed_requests, window_size)) + if processed_requests > 0 + else 0.0 + ) + exceeded_error_rate = current_error_rate >= max_error_rate + should_stop = exceeded_min_processed and exceeded_error_rate + expected_queuing = "stop" if should_stop else "continue" + expected_processing = "stop_all" if should_stop else "continue" + + assert action.request_queuing == expected_queuing + assert action.request_processing == expected_processing + assert isinstance(action.metadata, dict) + assert action.metadata["max_error_rate"] == max_error_rate + assert action.metadata["window_size"] == window_size + assert action.metadata["error_count"] == error_count + assert action.metadata["current_error_rate"] == current_error_rate + assert action.metadata["exceeded_error_rate"] == exceeded_error_rate + assert action.progress == {} + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxErrorRateConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxErrorRateConstraint.model_validate(data) + assert reconstructed.max_error_rate == instance.max_error_rate + assert reconstructed.window_size == instance.window_size + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test MaxErrorRateConstraint.validated_kwargs class method.""" + result = MaxErrorRateConstraint.validated_kwargs( + max_error_rate=0.1, window_size=50 + ) + assert result == { + "max_error_rate": 0.1, + "window_size": 50, + "error_window": [], + "current_index": -1, + } + + result = MaxErrorRateConstraint.validated_kwargs(0.05) + assert result == { + "max_error_rate": 0.05, + "window_size": 30, + "error_window": [], + "current_index": -1, + } + + @pytest.mark.smoke + def test_create_constraint(self, valid_instances): + """Test MaxErrorRateConstraint.create_constraint method.""" + instance, constructor_args = valid_instances + original_index = instance.current_index + constraint = instance.create_constraint() + + assert isinstance(constraint, MaxErrorRateConstraint) + assert constraint is not instance # Should return a copy + assert constraint.max_error_rate == instance.max_error_rate + assert constraint.window_size == instance.window_size + assert instance.current_index == original_index + 1 # Original is incremented + assert constraint.current_index == original_index + 1 # Copy has incremented + + @pytest.mark.smoke + def test_factory_registration(self): + """Test MaxErrorRateConstraint is properly registered with expected aliases.""" + expected_aliases = ["max_error_rate", "max_err_rate", "max_errors_rate"] + + for alias in expected_aliases: + assert ConstraintsInitializerFactory.is_registered(alias) + registered_class = ConstraintsInitializerFactory.get_registered_object( + alias + ) + assert registered_class == MaxErrorRateConstraint + + @pytest.mark.smoke + @pytest.mark.parametrize( + "alias", ["max_error_rate", "max_err_rate", "max_errors_rate"] + ) + def test_factory_creation_with_aliases(self, alias): + """Test factory creation using different aliases.""" + # Test with dict configuration + constraint = ConstraintsInitializerFactory.create_constraint( + alias, max_error_rate=0.1, window_size=50 + ) + assert isinstance(constraint, MaxErrorRateConstraint) + assert constraint.max_error_rate == 0.1 + assert constraint.window_size == 50 + + # Test with simple value + constraint = ConstraintsInitializerFactory.create_constraint(alias, 0.05) + assert isinstance(constraint, MaxErrorRateConstraint) + assert constraint.max_error_rate == 0.05 + + @pytest.mark.smoke + def test_factory_resolve_methods(self): + """Test factory resolve methods with various input formats.""" + # Test with dict config + resolved = ConstraintsInitializerFactory.resolve( + {"max_error_rate": {"max_error_rate": 0.15, "window_size": 100}} + ) + assert isinstance(resolved["max_error_rate"], MaxErrorRateConstraint) + assert resolved["max_error_rate"].max_error_rate == 0.15 + assert resolved["max_error_rate"].window_size == 100 + + # Test with simple value + resolved = ConstraintsInitializerFactory.resolve({"max_err_rate": 0.08}) + assert isinstance(resolved["max_err_rate"], MaxErrorRateConstraint) + assert resolved["max_err_rate"].max_error_rate == 0.08 + + # Test with instance + instance = MaxErrorRateConstraint(max_error_rate=0.2, window_size=25) + resolved = ConstraintsInitializerFactory.resolve({"max_errors_rate": instance}) + assert resolved["max_errors_rate"] is instance + + +class TestMaxGlobalErrorRateConstraint: + """Test the MaxGlobalErrorRateConstraint implementation.""" + + @pytest.fixture( + params=[ + {"max_error_rate": 0.1, "min_processed": 50}, + {"max_error_rate": 0.2, "min_processed": 100}, + {"max_error_rate": 0.05, "min_processed": 31}, + ] + ) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxGlobalErrorRateConstraint(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxGlobalErrorRateConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxGlobalErrorRateConstraint also satisfies + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """ + Test that MaxGlobalErrorRateConstraint can be initialized + with valid parameters. + """ + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxGlobalErrorRateConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint() + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate=0) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate=-1) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate=1.5) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate=0.5, min_processed=0) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions based on global error rate""" + instance, constructor_args = valid_instances + start_time = time.time() + + max_error_rate = constructor_args["max_error_rate"] + min_processed = constructor_args["min_processed"] + safety_factor = 1.5 + total_requests = min_processed * 2 + total_errors = 0 + + for request_num in range(total_requests): + error_probability = max_error_rate * safety_factor + + if random.random() < error_probability: + total_errors += 1 + status = "errored" + else: + status = "completed" + + processed_requests = request_num + 1 + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=processed_requests + 10, + processed_requests=processed_requests, + errored_requests=total_errors, + ) + request = ScheduledRequestInfo( + request_id=f"test-{request_num}", + status=status, + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + + exceeded_min_processed = processed_requests >= min_processed + error_rate = ( + total_errors / float(processed_requests) + if processed_requests > 0 + else 0.0 + ) + exceeded_error_rate = error_rate >= max_error_rate + should_stop = exceeded_min_processed and exceeded_error_rate + + expected_queuing = "stop" if should_stop else "continue" + expected_processing = "stop_all" if should_stop else "continue" + + assert action.request_queuing == expected_queuing + assert action.request_processing == expected_processing + + assert isinstance(action.metadata, dict) + assert action.metadata == { + "max_error_rate": max_error_rate, + "min_processed": min_processed, + "processed_requests": processed_requests, + "errored_requests": total_errors, + "error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + } + + # Error constraints don't provide progress information + assert action.progress == {} + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxGlobalErrorRateConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxGlobalErrorRateConstraint.model_validate(data) + assert reconstructed.max_error_rate == instance.max_error_rate + assert reconstructed.min_processed == instance.min_processed + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test MaxGlobalErrorRateConstraint.validated_kwargs class method.""" + result = MaxGlobalErrorRateConstraint.validated_kwargs( + max_error_rate=0.1, min_processed=50 + ) + assert result == { + "max_error_rate": 0.1, + "min_processed": 50, + "current_index": -1, + } + + result = MaxGlobalErrorRateConstraint.validated_kwargs(0.05) + assert result == { + "max_error_rate": 0.05, + "min_processed": 30, + "current_index": -1, + } + + @pytest.mark.smoke + def test_create_constraint(self, valid_instances): + """Test MaxGlobalErrorRateConstraint.create_constraint method.""" + instance, constructor_args = valid_instances + original_index = instance.current_index + constraint = instance.create_constraint() + + assert isinstance(constraint, MaxGlobalErrorRateConstraint) + assert constraint is not instance # Should return a copy + assert constraint.max_error_rate == instance.max_error_rate + assert constraint.min_processed == instance.min_processed + assert instance.current_index == original_index + 1 # Original is incremented + assert constraint.current_index == original_index + 1 # Copy has incremented + + @pytest.mark.smoke + def test_factory_registration(self): + """Test MaxGlobalErrorRateConstraint is properly registered with aliases.""" + expected_aliases = [ + "max_global_error_rate", + "max_global_err_rate", + "max_global_errors_rate", + ] + + for alias in expected_aliases: + assert ConstraintsInitializerFactory.is_registered(alias) + registered_class = ConstraintsInitializerFactory.get_registered_object( + alias + ) + assert registered_class == MaxGlobalErrorRateConstraint + + @pytest.mark.smoke + @pytest.mark.parametrize( + "alias", + ["max_global_error_rate", "max_global_err_rate", "max_global_errors_rate"], + ) + def test_factory_creation_with_aliases(self, alias): + """Test factory creation using different aliases.""" + # Test with dict configuration + constraint = ConstraintsInitializerFactory.create_constraint( + alias, max_error_rate=0.1, min_processed=50 + ) + assert isinstance(constraint, MaxGlobalErrorRateConstraint) + assert constraint.max_error_rate == 0.1 + assert constraint.min_processed == 50 + + # Test with simple value + constraint = ConstraintsInitializerFactory.create_constraint(alias, 0.05) + assert isinstance(constraint, MaxGlobalErrorRateConstraint) + assert constraint.max_error_rate == 0.05 + + @pytest.mark.smoke + def test_factory_resolve_methods(self): + """Test factory resolve methods with various input formats.""" + # Test with dict config + resolved = ConstraintsInitializerFactory.resolve( + {"max_global_error_rate": {"max_error_rate": 0.12, "min_processed": 100}} + ) + assert isinstance( + resolved["max_global_error_rate"], MaxGlobalErrorRateConstraint + ) + assert resolved["max_global_error_rate"].max_error_rate == 0.12 + assert resolved["max_global_error_rate"].min_processed == 100 + + # Test with simple value + resolved = ConstraintsInitializerFactory.resolve({"max_global_err_rate": 0.08}) + assert isinstance(resolved["max_global_err_rate"], MaxGlobalErrorRateConstraint) + assert resolved["max_global_err_rate"].max_error_rate == 0.08 + + # Test with instance + instance = MaxGlobalErrorRateConstraint(max_error_rate=0.15, min_processed=75) + resolved = ConstraintsInitializerFactory.resolve( + {"max_global_errors_rate": instance} + ) + assert resolved["max_global_errors_rate"] is instance + + +class TestConstraintsInitializerFactory: + """Test the ConstraintsInitializerFactory implementation.""" + + @pytest.mark.sanity + def test_unregistered_key_fails(self): + """Test that unregistered keys raise ValueError.""" + unregistered_key = "nonexistent_constraint" + assert not ConstraintsInitializerFactory.is_registered(unregistered_key) + + with pytest.raises( + ValueError, match=f"Unknown constraint initializer key: {unregistered_key}" + ): + ConstraintsInitializerFactory.create(unregistered_key) + + with pytest.raises( + ValueError, match=f"Unknown constraint initializer key: {unregistered_key}" + ): + ConstraintsInitializerFactory.create_constraint(unregistered_key) + + @pytest.mark.smoke + def test_resolve_mixed_types(self): + """Test resolve method with mixed constraint types.""" + max_num_constraint = MaxNumberConstraint(max_num=25) + max_duration_initializer = MaxDurationConstraint(max_duration=120.0) + + mixed_spec = { + "max_number": max_num_constraint, + "max_duration": max_duration_initializer, + "max_errors": {"max_errors": 15}, + "max_error_rate": 0.08, + } + + resolved = ConstraintsInitializerFactory.resolve(mixed_spec) + + assert len(resolved) == 4 + assert all(isinstance(c, Constraint) for c in resolved.values()) + assert resolved["max_number"] is max_num_constraint + assert isinstance(resolved["max_duration"], MaxDurationConstraint) + assert isinstance(resolved["max_errors"], MaxErrorsConstraint) + assert isinstance(resolved["max_error_rate"], MaxErrorRateConstraint) + assert resolved["max_error_rate"].max_error_rate == 0.08 + + @pytest.mark.sanity + def test_resolve_with_invalid_key(self): + """Test that resolve raises ValueError for unregistered keys.""" + invalid_spec = { + "max_number": {"max_num": 100}, + "invalid_constraint": {"some_param": 42}, + } + + with pytest.raises( + ValueError, match="Unknown constraint initializer key: invalid_constraint" + ): + ConstraintsInitializerFactory.resolve(invalid_spec) + + @pytest.mark.smoke + def test_functional_constraint_creation(self): + """Test that created constraints are functionally correct.""" + constraint = ConstraintsInitializerFactory.create_constraint( + "max_number", max_num=10 + ) + start_time = time.time() + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=5, + processed_requests=5, + ) + request = ScheduledRequestInfo( + request_id="test-request", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + + state_exceeded = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=15, + processed_requests=15, + ) + action_exceeded = constraint(state_exceeded, request) + assert action_exceeded.request_queuing == "stop" + assert action_exceeded.request_processing == "stop_local" diff --git a/tests/unit/scheduler/test_environment.py b/tests/unit/scheduler/test_environment.py new file mode 100644 index 00000000..c73abe42 --- /dev/null +++ b/tests/unit/scheduler/test_environment.py @@ -0,0 +1,329 @@ +import inspect +import time +from abc import ABC +from typing import Generic +from unittest.mock import patch + +import pytest + +from guidellm.scheduler import ( + Environment, + MaxNumberConstraint, + NonDistributedEnvironment, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerState, + SynchronousStrategy, +) +from guidellm.utils import InfoMixin + + +class TestEnvironment: + @pytest.mark.smoke + def test_class_signatures(self): + """Test Environment inheritance and type relationships.""" + # Inheritance and abstract class properties + assert issubclass(Environment, ABC) + assert issubclass(Environment, Generic) + assert issubclass(Environment, InfoMixin) + assert inspect.isabstract(Environment) + assert hasattr(Environment, "info") + + # Abstract methods validation + expected_abstract_methods = { + "sync_run_params", + "sync_run_start", + "update_run_iteration", + "sync_run_error", + "sync_run_end", + } + assert Environment.__abstractmethods__ == expected_abstract_methods + + # Method signatures and async properties + method_signatures = { + "sync_run_params": ["self", "requests", "strategy", "constraints"], + "sync_run_start": ["self"], + "update_run_iteration": [ + "self", + "response", + "request", + "request_info", + "state", + ], + "sync_run_error": ["self", "err"], + "sync_run_end": ["self"], + } + + for method_name, expected_params in method_signatures.items(): + method = getattr(Environment, method_name) + sig = inspect.signature(method) + + # Check parameter names and count + param_names = list(sig.parameters.keys()) + assert param_names == expected_params + + # Check async nature + assert inspect.iscoroutinefunction(method) or inspect.isasyncgenfunction( + method + ) + + # Generic type parameters + orig_bases = getattr(Environment, "__orig_bases__", ()) + generic_base = next( + ( + base + for base in orig_bases + if hasattr(base, "__origin__") and base.__origin__ is Generic + ), + None, + ) + assert generic_base is not None + type_args = getattr(generic_base, "__args__", ()) + assert RequestT in type_args + assert ResponseT in type_args + + @pytest.mark.sanity + def test_invalid_implementation(self): + """Test that invalid implementations raise TypeError.""" + + class InvalidImplementation(Environment): + pass + + with pytest.raises(TypeError): + InvalidImplementation() + + @pytest.mark.sanity + def test_partial_invalid_implementation(self): + """Test that partial implementations raise TypeError.""" + + class PartialImplementation(Environment): + async def sync_run_params(self, requests, strategy, constraints): + return requests, strategy, constraints + + async def sync_run_start(self): + return 0.0 + + # Missing other required methods + + with pytest.raises(TypeError): + PartialImplementation() + + @pytest.mark.smoke + def test_implementation_construction(self): + """Test that concrete implementations can be constructed.""" + + class TestEnvironment(Environment): + async def sync_run_params(self, requests, strategy, constraints): + return requests, strategy, constraints + + async def sync_run_start(self): + return 0.0 + + async def update_run_iteration(self, response, request, request_info): + pass + + async def sync_run_error(self, err): + pass + + async def sync_run_end(self): + yield + + env = TestEnvironment() + assert isinstance(env, Environment) + + +class TestNonDistributedEnvironment: + @pytest.fixture + def valid_instances(self): + """Fixture providing test data for NonDistributedEnvironment.""" + instance = NonDistributedEnvironment() + return instance, {} + + @pytest.mark.smoke + def test_class_signatures(self, valid_instances): + """Test NonDistributedEnvironment inheritance and type relationships.""" + instance, constructor_args = valid_instances + assert issubclass(NonDistributedEnvironment, Environment) + assert issubclass(NonDistributedEnvironment, InfoMixin) + assert not inspect.isabstract(NonDistributedEnvironment) + + # Should inherit from Environment + assert isinstance(instance, Environment) + assert issubclass(NonDistributedEnvironment, Environment) + + # Should implement all required methods + required_methods = [ + "sync_run_params", + "sync_run_start", + "update_run_iteration", + "sync_run_error", + "sync_run_end", + ] + + for method_name in required_methods: + assert hasattr(instance, method_name) + assert callable(getattr(instance, method_name)) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test NonDistributedEnvironment initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, NonDistributedEnvironment) + assert isinstance(instance, Environment) + assert instance.run_errors == [] + + @pytest.mark.sanity + def test_invalid_initialization(self): + """Test that initialization doesn't accept invalid arguments.""" + with pytest.raises(TypeError): + NonDistributedEnvironment("invalid_arg") + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("requests", "strategy", "constraints"), + [ + ( + ["request1", "request2"], + SynchronousStrategy(), + {"max_requests": MaxNumberConstraint(max_num=10)}, + ), + ( + [], + SynchronousStrategy(), + {}, + ), + ( + ["single_request"], + SynchronousStrategy(), + {"max_requests": MaxNumberConstraint(max_num=1)}, + ), + ( + range(5), + SynchronousStrategy(), + {"max_requests": MaxNumberConstraint(max_num=5)}, + ), + ], + ids=[ + "multiple_requests", + "empty_requests", + "single_request", + "range_requests", + ], + ) + async def test_sync_run_params( + self, valid_instances, requests, strategy, constraints + ): + """Test sync_run_params returns parameters unchanged.""" + instance, constructor_args = valid_instances + + ( + returned_requests, + returned_strategy, + returned_constraints, + ) = await instance.sync_run_params(requests, strategy, constraints) + + assert returned_requests is requests + assert returned_strategy is strategy + assert returned_constraints is constraints + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("mock_time", "delay", "expected"), + [ + (1000.0, 0.0, 1000.0), + (500.0, 1.5, 501.5), + (100.0, 10.0, 110.0), + (0.0, 2.5, 2.5), + ], + ids=["no_delay", "small_delay", "large_delay", "zero_time"], + ) + async def test_sync_run_start(self, valid_instances, mock_time, delay, expected): + """Test sync_run_start uses configuration value correctly.""" + instance, constructor_args = valid_instances + + with ( + patch("time.time", return_value=mock_time), + patch("guidellm.scheduler.environment.settings") as mock_settings, + ): + mock_settings.scheduler_start_delay_non_distributed = delay + start_time = await instance.sync_run_start() + assert start_time == expected + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("response", "req"), + [ + ("mock_response", "mock_request"), + (None, "mock_request"), + ("mock_response", None), + (None, None), + ], + ids=["both_present", "no_response", "no_request", "both_none"], + ) + async def test_update_run_iteration(self, valid_instances, response, req): + """Test update_run_iteration no-op behavior.""" + instance, constructor_args = valid_instances + + mock_request_info = ScheduledRequestInfo( + request_id="test-123", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=time.time(), + ) + mock_state = SchedulerState( + node_id=0, + num_processes=1, + start_time=time.time(), + ) + + # Should not raise any errors and is a no-op + await instance.update_run_iteration( + response, req, mock_request_info, mock_state + ) + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_sync_run_error(self, valid_instances): + """Test sync_run_error stores errors correctly.""" + instance, constructor_args = valid_instances + + error1 = RuntimeError("First error") + error2 = ValueError("Second error") + + await instance.sync_run_error(error1) + assert error1 in instance.run_errors + assert len(instance.run_errors) == 1 + + await instance.sync_run_error(error2) + assert len(instance.run_errors) == 2 + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_sync_run_end(self, valid_instances): + """Test sync_run_end behavior with no errors and multiple errors.""" + instance, constructor_args = valid_instances + + # No errors - empty iterator + results = [] + async for result in instance.sync_run_end(): + results.append(result) + assert results == [] + + # Single error - raises original error + error = RuntimeError("Test error") + await instance.sync_run_error(error) + with pytest.raises(RuntimeError): + async for _ in instance.sync_run_end(): + pass + + # Multiple errors - raises RuntimeError with combined message + await instance.sync_run_error(ValueError("Second error")) + with pytest.raises(RuntimeError) as exc_info: + async for _ in instance.sync_run_end(): + pass + assert "Errors occurred during execution" in str(exc_info.value) diff --git a/tests/unit/scheduler/test_objects.py b/tests/unit/scheduler/test_objects.py new file mode 100644 index 00000000..df794ff8 --- /dev/null +++ b/tests/unit/scheduler/test_objects.py @@ -0,0 +1,1286 @@ +from __future__ import annotations + +import inspect +import typing +from collections.abc import AsyncIterator +from typing import Any, Optional, TypeVar, Union + +import pytest +from pydantic import ValidationError +from typing_extensions import TypeAliasType + +from guidellm.scheduler import ( + BackendInterface, + BackendT, + MeasuredRequestTimings, + MultiTurnRequestT, + RequestSchedulerTimings, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerState, + SchedulerUpdateAction, + SchedulerUpdateActionProgress, +) +from guidellm.utils import StandardBaseModel + + +def test_request_t(): + """Validate that RequestT is a TypeVar usable for generics and isn't bound.""" + assert isinstance(RequestT, TypeVar) + assert RequestT.__name__ == "RequestT" + assert RequestT.__bound__ is None + assert RequestT.__constraints__ == () + + +def test_response_t(): + """Validate that ResponseT is a TypeVar usable for generics and isn't bound.""" + assert isinstance(ResponseT, TypeVar) + assert ResponseT.__name__ == "ResponseT" + assert ResponseT.__bound__ is None + assert ResponseT.__constraints__ == () + + +def test_backend_t(): + """Validate that BackendT is a TypeVar bound to BackendInterface.""" + assert isinstance(BackendT, TypeVar) + assert BackendT.__name__ == "BackendT" + assert BackendT.__bound__.__name__ == "BackendInterface" + assert BackendT.__constraints__ == () + + +def test_multi_turn_request_t(): + """Validate MultiTurnRequestT is a TypeAliasType for multi-turn requests.""" + assert isinstance(MultiTurnRequestT, TypeAliasType) + assert MultiTurnRequestT.__name__ == "MultiTurnRequestT" + + value = MultiTurnRequestT.__value__ + assert hasattr(value, "__origin__") + assert value.__origin__ is Union + + type_params = getattr(MultiTurnRequestT, "__type_params__", ()) + assert len(type_params) == 1 + assert type_params[0].__name__ == "RequestT" + + +class TestBackendInterface: + """Test the BackendInterface abstract base class.""" + + @pytest.mark.smoke + def test_abstract_methods_defined(self): + """Test that all expected abstract methods are defined.""" + expected_methods = { + "process_startup", + "validate", + "process_shutdown", + "resolve", + } + expected_properties = { + "processes_limit", + "requests_limit", + "info", + } + + for method_name in expected_methods: + assert hasattr(BackendInterface, method_name) + method = getattr(BackendInterface, method_name) + assert inspect.isfunction(method) or inspect.ismethod(method) + + for prop_name in expected_properties: + assert hasattr(BackendInterface, prop_name) + prop = getattr(BackendInterface, prop_name) + assert hasattr(prop, "__get__") + + @pytest.mark.smoke + def test_generic_type_parameters(self): + """Test that BackendInterface has the correct generic type parameters.""" + orig_bases = BackendInterface.__orig_bases__ + protocol_base = None + generic_base = None + + for base in orig_bases: + if hasattr(base, "__origin__"): + if base.__origin__ is typing.Generic: + generic_base = base + elif base.__name__ == "Protocol": + protocol_base = base + + assert protocol_base is not None, "Should inherit from Protocol" + assert generic_base is not None, "Should inherit from Generic" + + if hasattr(generic_base, "__args__"): + type_params = generic_base.__args__ + assert len(type_params) == 3, "Should have 3 type parameters" + param_names = [param.__name__ for param in type_params] + expected_names = ["RequestT", "ResponseT"] + assert param_names == expected_names + + @pytest.mark.smoke + def test_implementation_construction(self): + """Test that a complete concrete implementation can be instantiated.""" + + class ConcreteBackend(BackendInterface[str, MeasuredRequestTimings, str]): + @property + def processes_limit(self) -> int | None: + return 4 + + @property + def requests_limit(self) -> int | None: + return 100 + + @property + def info(self) -> dict[str, Any]: + return {"model": "test", "version": "1.0"} + + async def process_startup(self) -> None: + pass + + async def validate(self) -> None: + pass + + async def process_shutdown(self) -> None: + pass + + async def resolve( + self, + request: str, + request_info: ScheduledRequestInfo, + history: list[tuple[str, str]] | None = None, + ) -> AsyncIterator[tuple[str, ScheduledRequestInfo]]: + yield f"Response to: {request}", request_info + + backend = ConcreteBackend() + assert isinstance(backend, BackendInterface) + assert isinstance(backend, ConcreteBackend) + assert backend.processes_limit == 4 + assert backend.requests_limit == 100 + info = backend.info + assert info == {"model": "test", "version": "1.0"} + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_implementation_async_methods(self): # noqa: C901 + """Test that async methods work correctly in concrete implementation.""" + + class AsyncBackend(BackendInterface[dict, MeasuredRequestTimings, dict]): + def __init__(self): + self.startup_called = False + self.validate_called = False + self.shutdown_called = False + + @property + def processes_limit(self) -> int | None: + return None # Unlimited + + @property + def requests_limit(self) -> int | None: + return None # Unlimited + + @property + def info(self) -> dict[str, Any]: + return {"backend": "async_test"} + + async def process_startup(self) -> None: + self.startup_called = True + + async def validate(self) -> None: + self.validate_called = True + + async def process_shutdown(self) -> None: + self.shutdown_called = True + + async def resolve( + self, + request: dict, + request_info: ScheduledRequestInfo, + history: list[tuple[dict, dict]] | None = None, + ) -> AsyncIterator[tuple[dict, ScheduledRequestInfo]]: + response = {"result": request.get("input", ""), "status": "success"} + yield response, request_info + + backend = AsyncBackend() + await backend.process_startup() + assert backend.startup_called + + await backend.validate() + assert backend.validate_called + + await backend.process_shutdown() + assert backend.shutdown_called + + request = {"input": "test_request"} + request_info = ScheduledRequestInfo( + request_id="test-123", + status="queued", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=1000.0, + ) + results = [] + async for response, updated_info in backend.resolve(request, request_info): + results.append((response, updated_info)) + + assert len(results) == 1 + response, updated_info = results[0] + assert response == {"result": "test_request", "status": "success"} + assert updated_info == request_info + + @pytest.mark.smoke + def test_method_signatures(self): + """Test that abstract methods have the expected signatures.""" + info_prop = BackendInterface.info + assert isinstance(info_prop, property) + + processes_limit_prop = BackendInterface.processes_limit + assert isinstance(processes_limit_prop, property) + + requests_limit_prop = BackendInterface.requests_limit + assert isinstance(requests_limit_prop, property) + + startup_sig = inspect.signature(BackendInterface.process_startup) + assert len(startup_sig.parameters) == 1 # Only self + assert list(startup_sig.parameters.keys()) == ["self"] + + validate_sig = inspect.signature(BackendInterface.validate) + assert len(validate_sig.parameters) == 1 # Only self + assert list(validate_sig.parameters.keys()) == ["self"] + + shutdown_sig = inspect.signature(BackendInterface.process_shutdown) + assert len(shutdown_sig.parameters) == 1 # Only self + assert list(shutdown_sig.parameters.keys()) == ["self"] + + resolve_sig = inspect.signature(BackendInterface.resolve) + expected_params = ["self", "request", "request_info", "history"] + assert list(resolve_sig.parameters.keys()) == expected_params + + history_param = resolve_sig.parameters["history"] + assert history_param.default is None + + +class TestRequestSchedulerTimings: + """Test the RequestSchedulerTimings model class.""" + + CHECK_KEYS = [ + "targeted_start", + "queued", + "dequeued", + "scheduled_at", + "resolve_start", + "resolve_end", + "finalized", + ] + + @pytest.fixture( + params=[ + {}, + { + "targeted_start": None, + "queued": None, + "dequeued": None, + "scheduled_at": None, + "resolve_start": None, + "resolve_end": None, + "finalized": None, + }, + { + "targeted_start": 1000.0, + "queued": 200.0, + "dequeued": 800.0, + "scheduled_at": 900.0, + "resolve_start": 1000.5, + "resolve_end": 1100.0, + "finalized": 1100.5, + }, + { + "queued": 200.0, + "scheduled_at": 250.0, + "resolve_start": 1000.5, + "resolve_end": 1100.0, + }, + { + "targeted_start": 0.0, + "queued": 0.0, + "dequeued": 0.0, + "scheduled_at": 0.0, + "resolve_start": 0.0, + "resolve_end": 0.0, + "finalized": 0.0, + }, + ], + ids=[ + "default_empty", + "all_none_explicit", + "complete_sequence", + "partial_data", + "zero_timestamps", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of RequestSchedulerTimings.""" + constructor_args = request.param + instance = RequestSchedulerTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test RequestSchedulerTimings inheritance and type relationships.""" + assert issubclass(RequestSchedulerTimings, StandardBaseModel) + assert hasattr(RequestSchedulerTimings, "model_dump") + assert hasattr(RequestSchedulerTimings, "model_validate") + + # Check all expected fields are defined + fields = RequestSchedulerTimings.model_fields + for key in self.CHECK_KEYS: + assert key in fields + field_info = fields[key] + assert field_info.annotation in (Union[float, None], Optional[float]) + assert field_info.default is None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, RequestSchedulerTimings) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args + for field, expected_value in constructor_args.items(): + assert getattr(instance, field) == expected_value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("targeted_start", "invalid_string"), + ("queued", "invalid_string"), + ("dequeued", [1, 2, 3]), + ("scheduled_at", {"key": "value"}), + ("resolve_start", {"key": "value"}), + ("resolve_end", [1, 2, 3]), + ("finalized", object()), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + RequestSchedulerTimings(**kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = RequestSchedulerTimings.model_validate(data) + assert isinstance(reconstructed, RequestSchedulerTimings) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + assert getattr(reconstructed, field) == getattr(instance, field) + + # Validate that the reconstructed instance matches original constructor args + for field, expected_value in constructor_args.items(): + assert getattr(reconstructed, field) == expected_value + + +class TestRequestTimings: + """Test the MeasuredRequestTimings model class.""" + + CHECK_KEYS = [ + "request_start", + "request_end", + ] + + @pytest.fixture( + params=[ + {}, + { + "request_start": None, + "request_end": None, + }, + { + "request_start": 1000.0, + "request_end": 1100.0, + }, + { + "request_start": 1000.0, + }, + { + "request_start": 0.0, + "request_end": 0.0, + }, + ], + ids=[ + "default_empty", + "all_none_explicit", + "complete_sequence", + "partial_data", + "zero_timestamps", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of MeasuredRequestTimings.""" + constructor_args = request.param + instance = MeasuredRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test MeasuredRequestTimings inheritance and type relationships.""" + assert issubclass(MeasuredRequestTimings, StandardBaseModel) + assert hasattr(MeasuredRequestTimings, "model_dump") + assert hasattr(MeasuredRequestTimings, "model_validate") + + # Check all expected fields are defined + fields = MeasuredRequestTimings.model_fields + for key in self.CHECK_KEYS: + assert key in fields + field_info = fields[key] + assert field_info.annotation in (Union[float, None], Optional[float]) + assert field_info.default is None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, MeasuredRequestTimings) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args + for field, expected_value in constructor_args.items(): + assert getattr(instance, field) == expected_value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("request_start", "invalid_string"), + ("request_end", [1, 2, 3]), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + MeasuredRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = MeasuredRequestTimings.model_validate(data) + assert isinstance(reconstructed, MeasuredRequestTimings) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + assert getattr(reconstructed, field) == getattr(instance, field) + + # Validate that the reconstructed instance matches original constructor args + for field, expected_value in constructor_args.items(): + assert getattr(reconstructed, field) == expected_value + + +class TestScheduledRequestInfo: + CHECK_KEYS = [ + "request_id", + "status", + "error", + "scheduler_node_id", + "scheduler_process_id", + "scheduler_start_time", + "scheduler_timings", + "request_timings", + ] + + @pytest.fixture( + params=[ + # Minimal required configuration + { + "request_id": "test-req-123", + "status": "queued", + "scheduler_node_id": 1, + "scheduler_process_id": 0, + "scheduler_start_time": 1000.0, + }, + # Complete configuration with all fields + { + "request_id": "test-req-456", + "status": "completed", + "error": None, + "scheduler_node_id": 2, + "scheduler_process_id": 1, + "scheduler_start_time": 2000.0, + "scheduler_timings": { + "targeted_start": 1900.0, + "queued": 1950.0, + "dequeued": 2000.0, + "resolve_start": 2050.0, + "resolve_end": 2100.0, + "finalized": 2150.0, + }, + "request_timings": { + "request_start": 2060.0, + "request_end": 2110.0, + }, + }, + # Error state configuration + { + "request_id": "test-req-error", + "status": "errored", + "error": "Connection timeout", + "scheduler_node_id": 0, + "scheduler_process_id": 0, + "scheduler_start_time": 3000.0, + }, + # Different status values + { + "request_id": "test-req-pending", + "status": "pending", + "scheduler_node_id": 1, + "scheduler_process_id": 2, + "scheduler_start_time": 4000.0, + }, + { + "request_id": "test-req-in-progress", + "status": "in_progress", + "scheduler_node_id": 2, + "scheduler_process_id": 1, + "scheduler_start_time": 5000.0, + }, + ], + ids=[ + "minimal_required", + "complete_configuration", + "error_state", + "pending_status", + "in_progress_status", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of ScheduledRequestInfo. + + Returns: + tuple: (instance, constructor_args) where instance is the constructed + ScheduledRequestInfo and constructor_args are the kwargs used. + """ + constructor_args = request.param.copy() + + # Handle nested objects + if "scheduler_timings" in constructor_args: + constructor_args["scheduler_timings"] = RequestSchedulerTimings( + **constructor_args["scheduler_timings"] + ) + if "request_timings" in constructor_args: + constructor_args["request_timings"] = MeasuredRequestTimings( + **constructor_args["request_timings"] + ) + + instance = ScheduledRequestInfo(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ScheduledRequestInfo inheritance and type relationships.""" + assert issubclass(ScheduledRequestInfo, StandardBaseModel) + assert issubclass(ScheduledRequestInfo, typing.Generic) + assert hasattr(ScheduledRequestInfo, "model_dump") + assert hasattr(ScheduledRequestInfo, "model_validate") + + # Check computed properties + assert hasattr(ScheduledRequestInfo, "started_at") + assert hasattr(ScheduledRequestInfo, "completed_at") + assert isinstance(ScheduledRequestInfo.started_at, property) + assert isinstance(ScheduledRequestInfo.completed_at, property) + + # Check that it's properly generic + orig_bases = getattr(ScheduledRequestInfo, "__orig_bases__", ()) + generic_base = next( + ( + base + for base in orig_bases + if hasattr(base, "__origin__") and base.__origin__ is typing.Generic + ), + None, + ) + assert generic_base is not None + + # Check required fields + fields = ScheduledRequestInfo.model_fields + for key in self.CHECK_KEYS: + assert key in fields + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ScheduledRequestInfo) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args + for field, expected_value in constructor_args.items(): + if field in ["scheduler_timings", "request_timings"]: + actual_value = getattr(instance, field) + if expected_value is None: + assert actual_value is None or ( + field == "scheduler_timings" + and isinstance(actual_value, RequestSchedulerTimings) + ) + else: + assert isinstance(actual_value, type(expected_value)) + else: + assert getattr(instance, field) == expected_value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("request_id", None), # Required field + ("request_id", 123), # Wrong type + ("status", "invalid_status"), # Invalid literal + ("scheduler_node_id", "not_an_int"), + ("scheduler_process_id", -1.5), + ("scheduler_start_time", "not_a_float"), + ("error", 123), # Should be string or None + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + # Start with valid base config + base_kwargs = { + "request_id": "test-req", + "status": "queued", + "scheduler_node_id": 1, + "scheduler_process_id": 0, + "scheduler_start_time": 1000.0, + } + base_kwargs[field] = value + with pytest.raises(ValidationError): + ScheduledRequestInfo(**base_kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = ScheduledRequestInfo.model_validate(data) + assert isinstance(reconstructed, ScheduledRequestInfo) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + original_value = getattr(instance, field) + reconstructed_value = getattr(reconstructed, field) + + if field in ["scheduler_timings", "request_timings"]: + if original_value is not None and reconstructed_value is not None: + assert ( + original_value.model_dump() == reconstructed_value.model_dump() + ) + else: + assert original_value is None or isinstance( + original_value, + (RequestSchedulerTimings, MeasuredRequestTimings), + ) + assert reconstructed_value is None or isinstance( + reconstructed_value, + (RequestSchedulerTimings, MeasuredRequestTimings), + ) + else: + assert original_value == reconstructed_value + + @pytest.mark.smoke + def test_started_at_property(self): + """Test the started_at property logic.""" + # Test with request_timings.request_start (should take precedence) + instance = ScheduledRequestInfo( + request_id="test-req", + status="completed", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + scheduler_timings=RequestSchedulerTimings(resolve_start=2000.0), + request_timings=MeasuredRequestTimings(request_start=2100.0), + ) + assert instance.started_at == 2100.0 + + # Test with only scheduler_timings.resolve_start + instance = ScheduledRequestInfo( + request_id="test-req", + status="completed", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + scheduler_timings=RequestSchedulerTimings(resolve_start=2000.0), + ) + assert instance.started_at == 2000.0 + + # Test with no timing info + instance = ScheduledRequestInfo( + request_id="test-req", + status="queued", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + ) + assert instance.started_at is None + + @pytest.mark.smoke + def test_completed_at_property(self): + """Test the completed_at property logic.""" + # Test with request_timings.request_end (should take precedence) + instance = ScheduledRequestInfo( + request_id="test-req", + status="completed", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + scheduler_timings=RequestSchedulerTimings(resolve_end=2000.0), + request_timings=MeasuredRequestTimings(request_end=2100.0), + ) + assert instance.completed_at == 2100.0 + + # Test with only scheduler_timings.resolve_end + instance = ScheduledRequestInfo( + request_id="test-req", + status="completed", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + scheduler_timings=RequestSchedulerTimings(resolve_end=2000.0), + ) + assert instance.completed_at == 2000.0 + + # Test with no timing info + instance = ScheduledRequestInfo( + request_id="test-req", + status="queued", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + ) + assert instance.completed_at is None + + +class TestSchedulerState: + CHECK_KEYS = [ + "node_id", + "num_processes", + "start_time", + "end_time", + "end_queuing_time", + "end_queuing_constraints", + "end_processing_time", + "end_processing_constraints", + "scheduler_constraints", + "remaining_fraction", + "remaining_requests", + "remaining_duration", + "created_requests", + "queued_requests", + "pending_requests", + "processing_requests", + "processed_requests", + "successful_requests", + "errored_requests", + "cancelled_requests", + ] + + @pytest.fixture( + params=[ + # Minimal required configuration + { + "node_id": 0, + "num_processes": 1, + "start_time": 1000.0, + }, + # Complete configuration with all fields + { + "node_id": 1, + "num_processes": 4, + "start_time": 2000.0, + "end_time": 3000.0, + "end_queuing_time": 2500.0, + "end_queuing_constraints": { + "time_limit": SchedulerUpdateAction( + request_queuing="stop", metadata={"max_duration": 1500} + ) + }, + "end_processing_time": 2800.0, + "end_processing_constraints": { + "request_limit": SchedulerUpdateAction( + request_processing="stop_all", metadata={"max_requests": 1000} + ) + }, + "scheduler_constraints": { + "rate_limit": SchedulerUpdateAction(metadata={"max_rps": 100}) + }, + "remaining_fraction": 0.25, + "remaining_requests": 50, + "remaining_duration": 300.0, + "created_requests": 200, + "queued_requests": 180, + "pending_requests": 20, + "processing_requests": 10, + "processed_requests": 150, + "successful_requests": 140, + "errored_requests": 8, + "cancelled_requests": 2, + }, + # Partial configuration with some stats + { + "node_id": 2, + "num_processes": 2, + "start_time": 4000.0, + "created_requests": 50, + "processed_requests": 30, + "successful_requests": 28, + "errored_requests": 2, + }, + # Edge case: zero values + { + "node_id": 0, + "num_processes": 1, + "start_time": 0.0, + "created_requests": 0, + "processed_requests": 0, + "successful_requests": 0, + }, + ], + ids=[ + "minimal_required", + "complete_configuration", + "partial_stats", + "zero_values", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of SchedulerState. + + Returns: + tuple: (instance, constructor_args) where instance is the constructed + SchedulerState and constructor_args are the kwargs used. + """ + constructor_args = request.param + instance = SchedulerState(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test SchedulerState inheritance and type relationships.""" + assert issubclass(SchedulerState, StandardBaseModel) + assert hasattr(SchedulerState, "model_dump") + assert hasattr(SchedulerState, "model_validate") + + # Check all expected fields are defined + fields = SchedulerState.model_fields + for key in self.CHECK_KEYS: + assert key in fields + + # Check field defaults for key counters + counter_fields = [ + "created_requests", + "queued_requests", + "pending_requests", + "processing_requests", + "processed_requests", + "successful_requests", + "errored_requests", + "cancelled_requests", + ] + for field in counter_fields: + field_info = fields[field] + assert field_info.default == 0 + + # Check that start_time has a default factory + start_time_field = fields["start_time"] + assert start_time_field.default_factory is not None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, SchedulerState) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args + for field, expected_value in constructor_args.items(): + assert getattr(instance, field) == expected_value + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("field", "value"), + [ + ("node_id", "not_an_int"), + ("start_time", "not_a_float"), + ("end_time", [1, 2, 3]), + ("remaining_fraction", "not_a_float"), + ("created_requests", "not_an_int"), + ("end_queuing_constraints", "not_a_dict"), + ("scheduler_constraints", ["not", "a", "dict"]), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + # Start with valid base config + base_kwargs = { + "node_id": 0, + "num_processes": 1, + "start_time": 1000.0, + } + base_kwargs[field] = value + with pytest.raises(ValidationError): + SchedulerState(**base_kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = SchedulerState.model_validate(data) + assert isinstance(reconstructed, SchedulerState) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + assert getattr(reconstructed, field) == getattr(instance, field) + + # Validate that the reconstructed instance matches original constructor args + for field, expected_value in constructor_args.items(): + assert getattr(reconstructed, field) == expected_value + + +class TestSchedulerUpdateAction: + CHECK_KEYS = [ + "request_queuing", + "request_processing", + "metadata", + "progress", + ] + + @pytest.fixture( + params=[ + # Default configuration + {}, + # All explicit default values + { + "request_queuing": "continue", + "request_processing": "continue", + "metadata": {}, + "progress": {}, + }, + # Stop queuing configuration + { + "request_queuing": "stop", + "request_processing": "continue", + "metadata": {"reason": "rate_limit_exceeded"}, + }, + # Stop local processing configuration + { + "request_queuing": "continue", + "request_processing": "stop_local", + "metadata": {"node_id": 1, "reason": "resource_exhausted"}, + }, + # Stop all processing configuration + { + "request_queuing": "stop", + "request_processing": "stop_all", + "metadata": { + "emergency_stop": True, + "reason": "critical_error", + "error_details": {"code": 500, "message": "Internal server error"}, + }, + }, + # Complex metadata configuration + { + "request_queuing": "continue", + "request_processing": "continue", + "metadata": { + "stats": {"processed": 100, "pending": 50}, + "constraints": {"max_rps": 10, "max_concurrent": 20}, + "config": {"batch_size": 32, "timeout": 30.0}, + }, + }, + # Progress with remaining_fraction only + { + "request_queuing": "continue", + "request_processing": "continue", + "progress": {"remaining_fraction": 0.75}, + }, + # Progress with remaining_requests only + { + "request_queuing": "continue", + "request_processing": "continue", + "progress": {"remaining_requests": 250.0}, + }, + # Progress with remaining_duration only + { + "request_queuing": "continue", + "request_processing": "continue", + "progress": {"remaining_duration": 120.5}, + }, + # Complete progress configuration + { + "request_queuing": "stop", + "request_processing": "stop_all", + "metadata": {"shutdown_reason": "completion"}, + "progress": { + "remaining_fraction": 0.0, + "remaining_requests": 0.0, + "remaining_duration": 0.0, + }, + }, + # Partial progress configuration + { + "request_queuing": "continue", + "request_processing": "continue", + "metadata": {"checkpoint": "mid_benchmark"}, + "progress": { + "remaining_fraction": 0.45, + "remaining_duration": 180.0, + }, + }, + ], + ids=[ + "default_empty", + "explicit_defaults", + "stop_queuing", + "stop_local_processing", + "stop_all_processing", + "complex_metadata", + "progress_fraction_only", + "progress_requests_only", + "progress_duration_only", + "complete_progress", + "partial_progress", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of SchedulerUpdateAction. + + Returns: + tuple: (instance, constructor_args) where instance is the constructed + SchedulerUpdateAction and constructor_args are the kwargs used. + """ + constructor_args = request.param + instance = SchedulerUpdateAction(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test SchedulerUpdateAction inheritance and type relationships.""" + assert issubclass(SchedulerUpdateAction, StandardBaseModel) + assert hasattr(SchedulerUpdateAction, "model_dump") + assert hasattr(SchedulerUpdateAction, "model_validate") + + # Check all expected fields are defined + fields = SchedulerUpdateAction.model_fields + for key in self.CHECK_KEYS: + assert key in fields + + # Check field defaults + assert fields["request_queuing"].default == "continue" + assert fields["request_processing"].default == "continue" + metadata_field = fields["metadata"] + assert metadata_field.default_factory is not None + progress_field = fields["progress"] + assert progress_field.default_factory is not None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, SchedulerUpdateAction) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args or defaults + for field in self.CHECK_KEYS: + if field in constructor_args: + assert getattr(instance, field) == constructor_args[field] + elif field in ["request_queuing", "request_processing"]: + assert getattr(instance, field) == "continue" + elif field in ["metadata", "progress"]: + assert getattr(instance, field) == {} + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("field", "value"), + [ + ("request_queuing", "invalid_action"), + ("request_queuing", 123), + ("request_processing", "invalid_action"), + ("request_processing", ["stop"]), + ("metadata", "not_a_dict"), + ("metadata", [{"key": "value"}]), + ("progress", "not_a_dict"), + ("progress", [{"remaining_fraction": 0.5}]), + ("progress", {"remaining_fraction": "not_a_float"}), + ("progress", {"remaining_requests": "not_a_float"}), + ("progress", {"remaining_duration": "not_a_float"}), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + SchedulerUpdateAction(**kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = SchedulerUpdateAction.model_validate(data) + assert isinstance(reconstructed, SchedulerUpdateAction) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + assert getattr(reconstructed, field) == getattr(instance, field) + + # Validate that the reconstructed instance matches expected values + for field in self.CHECK_KEYS: + if field in constructor_args: + assert getattr(reconstructed, field) == constructor_args[field] + elif field in ["request_queuing", "request_processing"]: + assert getattr(reconstructed, field) == "continue" + elif field in ["metadata", "progress"]: + assert getattr(reconstructed, field) == {} + + @pytest.mark.smoke + def test_progress_field_behavior(self): + """Test the progress field specific behavior and validation.""" + # Test empty progress (default) + instance = SchedulerUpdateAction() + assert instance.progress == {} + assert isinstance(instance.progress, dict) + + # Test progress with all valid fields + progress_data = { + "remaining_fraction": 0.75, + "remaining_requests": 100.0, + "remaining_duration": 30.5, + } + instance = SchedulerUpdateAction(progress=progress_data) + assert instance.progress == progress_data + + # Test progress with partial fields (TypedDict allows partial) + partial_progress = {"remaining_fraction": 0.25} + instance = SchedulerUpdateAction(progress=partial_progress) + assert instance.progress == partial_progress + + # Test progress with zero values + zero_progress = { + "remaining_fraction": 0.0, + "remaining_requests": 0.0, + "remaining_duration": 0.0, + } + instance = SchedulerUpdateAction(progress=zero_progress) + assert instance.progress == zero_progress + + # Test that progress field persists through marshalling + data = instance.model_dump() + assert "progress" in data + assert data["progress"] == zero_progress + + reconstructed = SchedulerUpdateAction.model_validate(data) + assert reconstructed.progress == zero_progress + + @pytest.mark.smoke + @pytest.mark.parametrize( + "progress_value", + [ + {"remaining_fraction": 0.0}, + {"remaining_fraction": 1.0}, + {"remaining_requests": 0.0}, + {"remaining_requests": 1000.0}, + {"remaining_duration": 0.0}, + {"remaining_duration": 3600.0}, + {"remaining_fraction": 0.5, "remaining_requests": 50.0}, + {"remaining_requests": 25.0, "remaining_duration": 120.0}, + {"remaining_fraction": 0.33, "remaining_duration": 45.0}, + ], + ) + def test_progress_valid_combinations(self, progress_value): + """Test various valid combinations of progress field values.""" + instance = SchedulerUpdateAction(progress=progress_value) + assert instance.progress == progress_value + + # Verify marshalling works correctly + data = instance.model_dump() + reconstructed = SchedulerUpdateAction.model_validate(data) + assert reconstructed.progress == progress_value + + @pytest.mark.smoke + def test_scheduler_update_action_progress_typeddict(self): + """Test the SchedulerUpdateActionProgress TypedDict behavior.""" + # Test that SchedulerUpdateActionProgress is a proper TypedDict + # Verify it's a TypedDict (has the special attributes) + assert hasattr(SchedulerUpdateActionProgress, "__annotations__") + assert hasattr(SchedulerUpdateActionProgress, "__total__") + assert hasattr(SchedulerUpdateActionProgress, "__required_keys__") + assert hasattr(SchedulerUpdateActionProgress, "__optional_keys__") + + # Check that all keys are optional (total=False) + expected_keys = { + "remaining_fraction", + "remaining_requests", + "remaining_duration", + } + actual_keys = set(SchedulerUpdateActionProgress.__annotations__.keys()) + assert actual_keys == expected_keys + assert SchedulerUpdateActionProgress.__total__ is False + assert SchedulerUpdateActionProgress.__required_keys__ == frozenset() + assert SchedulerUpdateActionProgress.__optional_keys__ == expected_keys + + # Test that type annotations are correct + annotations = SchedulerUpdateActionProgress.__annotations__ + assert "remaining_fraction" in annotations + assert "remaining_requests" in annotations + assert "remaining_duration" in annotations + + # Test creation of valid TypedDict instances + valid_progress_1: SchedulerUpdateActionProgress = {} + valid_progress_2: SchedulerUpdateActionProgress = {"remaining_fraction": 0.5} + valid_progress_3: SchedulerUpdateActionProgress = { + "remaining_fraction": 0.25, + "remaining_requests": 100.0, + "remaining_duration": 60.0, + } + + # All should be valid dict instances + assert isinstance(valid_progress_1, dict) + assert isinstance(valid_progress_2, dict) + assert isinstance(valid_progress_3, dict) diff --git a/tests/unit/scheduler/test_scheduler.py b/tests/unit/scheduler/test_scheduler.py new file mode 100644 index 00000000..33efc27f --- /dev/null +++ b/tests/unit/scheduler/test_scheduler.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +import asyncio +import inspect +import random +import uuid +from functools import wraps +from typing import Any, Generic + +import pytest +from pydantic import BaseModel, Field + +from guidellm.scheduler import ( + BackendInterface, + MaxNumberConstraint, + NonDistributedEnvironment, + ScheduledRequestInfo, + Scheduler, + SchedulerState, + SynchronousStrategy, +) +from guidellm.utils.singleton import ThreadSafeSingletonMixin + + +def async_timeout(delay: float): + """Decorator to add timeout to async test functions.""" + + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +class MockRequest(BaseModel): + payload: str + id_: str = Field(default_factory=lambda: str(uuid.uuid4())) + + +class MockBackend(BackendInterface): + """Mock backend for integration testing with predictable responses.""" + + def __init__( + self, + processes_limit_value: int | None = None, + requests_limit_value: int | None = None, + error_rate: float = 0.2, + response_delay: float = 0.0, + ): + self._processes_limit = processes_limit_value + self._requests_limit = requests_limit_value + self._error_rate = error_rate + self._response_delay = response_delay + + @property + def processes_limit(self) -> int | None: + return self._processes_limit + + @property + def requests_limit(self) -> int | None: + return self._requests_limit + + def info(self) -> dict[str, Any]: + return {"type": "mock_integration", "delay": self._response_delay} + + async def process_startup(self): + pass + + async def validate(self): + pass + + async def process_shutdown(self): + pass + + async def resolve(self, request: MockRequest, request_info, request_history): + """Return predictable response based on input request.""" + await asyncio.sleep(self._response_delay) + + if ( + self._error_rate + and self._error_rate > 0 + and random.random() < self._error_rate + ): + raise RuntimeError(f"mock_error_for_{request.payload}") + + yield f"response_for_{request.payload}" + + +class TestScheduler: + """Test suite for Scheduler class.""" + + @pytest.fixture + def valid_instances(self): + """Fixture providing test data for Scheduler.""" + # Clear singleton state between tests + if hasattr(Scheduler, "singleton_instance"): + Scheduler.singleton_instance = None + + instance = Scheduler() + constructor_args = {} + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test Scheduler inheritance and type relationships.""" + # Clear singleton before testing + if hasattr(Scheduler, "singleton_instance"): + Scheduler.singleton_instance = None + + assert issubclass(Scheduler, ThreadSafeSingletonMixin) + assert issubclass(Scheduler, Generic) + assert hasattr(Scheduler, "run") + assert callable(Scheduler.run) + + # Check method signature + run_sig = inspect.signature(Scheduler.run) + expected_params = [ + "self", + "requests", + "backend", + "strategy", + "env", + "constraints", + ] + param_names = list(run_sig.parameters.keys()) + assert param_names == expected_params + + # Check that run is async generator (returns AsyncIterator) + assert hasattr(Scheduler.run, "__code__") + code = Scheduler.run.__code__ + # Check for async generator flags or return annotation + assert ( + inspect.iscoroutinefunction(Scheduler.run) + or "AsyncIterator" in str(run_sig.return_annotation) + or code.co_flags & 0x100 # CO_GENERATOR flag + ) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test Scheduler initialization as singleton.""" + instance1, _ = valid_instances + instance2 = Scheduler() + + assert isinstance(instance1, Scheduler) + assert instance1 is instance2 + assert id(instance1) == id(instance2) + assert hasattr(instance1, "thread_lock") + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(10.0) + @pytest.mark.parametrize( + ("num_requests", "constraint_args"), + [ + (5, {"max_number": MaxNumberConstraint(max_num=10)}), + (20, {"max_number": MaxNumberConstraint(max_num=25)}), + (1, {"max_number": MaxNumberConstraint(max_num=5)}), + ], + ) + async def test_run_basic_functionality( + self, valid_instances, num_requests, constraint_args + ): + """Test Scheduler.run basic functionality with various parameters.""" + instance, _ = valid_instances + requests = [MockRequest(payload=f"req_{i}") for i in range(num_requests)] + backend = MockBackend(error_rate=0.0, response_delay=0.001) + strategy = SynchronousStrategy() + env = NonDistributedEnvironment() + + results = [] + async for response, _request, info, _state in instance.run( + requests=requests, + backend=backend, + strategy=strategy, + env=env, + **constraint_args, + ): + results.append((response, _request, info, _state)) + + assert len(results) > 0 + assert all(isinstance(r[1], MockRequest) for r in results) + assert all(isinstance(r[2], ScheduledRequestInfo) for r in results) + assert all(isinstance(r[3], SchedulerState) for r in results) + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_run_with_errors(self, valid_instances): + """Test Scheduler.run error handling.""" + instance, _ = valid_instances + requests = [MockRequest(payload=f"req_{i}") for i in range(5)] + backend = MockBackend(error_rate=1.0) # Force all requests to error + strategy = SynchronousStrategy() + env = NonDistributedEnvironment() + + error_count = 0 + async for response, _request, info, _state in instance.run( + requests=requests, + backend=backend, + strategy=strategy, + env=env, + max_number=MaxNumberConstraint(max_num=10), + ): + if info.status == "errored": + error_count += 1 + assert response is None + assert info.error is not None + + assert error_count > 0 + + @pytest.mark.sanity + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_run_invalid_parameters(self, valid_instances): + """Test Scheduler.run with invalid parameters.""" + instance, _ = valid_instances + + with pytest.raises((TypeError, ValueError, AttributeError)): + async for _ in instance.run( + requests=None, # Invalid requests + backend=None, # Invalid backend + strategy=SynchronousStrategy(), + env=NonDistributedEnvironment(), + ): + pass + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_run_constraint_variations(self, valid_instances): + """Test Scheduler.run with different constraint types.""" + instance, _ = valid_instances + requests = [MockRequest(payload=f"req_{i}") for i in range(3)] + backend = MockBackend(error_rate=0.0, response_delay=0.001) + strategy = SynchronousStrategy() + env = NonDistributedEnvironment() + + # Test with multiple constraints + results = [] + async for response, request, info, state in instance.run( + requests=requests, + backend=backend, + strategy=strategy, + env=env, + max_number=MaxNumberConstraint(max_num=5), + max_duration=5.0, # Should be converted to constraint + ): + results.append((response, request, info, state)) + + assert len(results) > 0 diff --git a/tests/unit/scheduler/test_strategies.py b/tests/unit/scheduler/test_strategies.py new file mode 100644 index 00000000..67a2d77d --- /dev/null +++ b/tests/unit/scheduler/test_strategies.py @@ -0,0 +1,1154 @@ +from __future__ import annotations + +import inspect +import math +import statistics +import time +from abc import ABC +from typing import Literal, TypeVar + +import pytest +from pydantic import ValidationError + +from guidellm.scheduler import ( + AsyncConstantStrategy, + AsyncPoissonStrategy, + ConcurrentStrategy, + ConstantRateRequestTimings, + LastCompletionRequestTimings, + NoDelayRequestTimings, + PoissonRateRequestTimings, + ScheduledRequestInfo, + ScheduledRequestTimings, + SchedulingStrategy, + StrategyT, + SynchronousStrategy, + ThroughputStrategy, +) +from guidellm.scheduler.strategies import ( + _exponential_decay_fraction, + _exponential_decay_tau, +) + + +def test_strategy_type(): + """Test that StrategyType is defined correctly as a Literal type.""" + # StrategyType is a type alias/literal type, we can't test its runtime value + # but we can test that it exists and is importable + from guidellm.scheduler.strategies import StrategyType + + assert StrategyType is not None + + +def test_strategy_t(): + """Test that StrategyT is filled out correctly as a TypeVar.""" + assert isinstance(StrategyT, type(TypeVar("test"))) + assert StrategyT.__name__ == "StrategyT" + assert StrategyT.__bound__ == SchedulingStrategy + assert StrategyT.__constraints__ == () + + +class TestExponentialDecay: + """Test suite for _exponential_decay_tau function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("max_progress", "convergence", "expected_range"), + [ + (1.0, 0.99, (0.21, 0.22)), + (5.0, 0.99, (1.08, 1.09)), + (10.0, 0.95, (3.33, 3.35)), + ], + ) + def test_tau_invocation(self, max_progress, convergence, expected_range): + """Test exponential decay tau calculation with valid inputs.""" + tau = _exponential_decay_tau(max_progress, convergence) + assert expected_range[0] <= tau <= expected_range[1] + expected_tau = max_progress / (-math.log(1 - convergence)) + assert tau == pytest.approx(expected_tau, rel=1e-10) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("progress", "tau", "expected_min", "expected_max"), + [ + (0.0, 1.0, 0.0, 0.0), # No progress = 0 + (1.0, 1.0, 0.6, 0.7), # 1 tau ≈ 63.2% + (2.0, 1.0, 0.85, 0.87), # 2 tau ≈ 86.5% + (3.0, 1.0, 0.95, 0.96), # 3 tau ≈ 95.0% + ], + ) + def test_exp_decay_invocation(self, progress, tau, expected_min, expected_max): + """Test exponential decay fraction calculation with valid inputs.""" + fraction = _exponential_decay_fraction(progress, tau) + assert expected_min <= fraction <= expected_max + expected_fraction = 1 - math.exp(-progress / tau) + assert fraction == pytest.approx(expected_fraction, rel=1e-10) + + @pytest.mark.smoke + def test_exp_boundary_conditions(self): + """Test boundary conditions for exponential decay fraction.""" + assert _exponential_decay_fraction(0.0, 1.0) == 0.0 + assert _exponential_decay_fraction(0.0, 10.0) == 0.0 + large_progress = 100.0 + fraction = _exponential_decay_fraction(large_progress, 1.0) + assert fraction > 0.99999 + + +class TestScheduledRequestTimings: + @pytest.mark.smoke + def test_signatures(self): + """Test that ScheduledRequestTimings is an abstract base class.""" + assert issubclass(ScheduledRequestTimings, ABC) + assert inspect.isabstract(ScheduledRequestTimings) + + abstract_methods = ScheduledRequestTimings.__abstractmethods__ + expected_methods = {"next_offset", "request_completed"} + assert abstract_methods == expected_methods + + # Validate method signatures + next_offset_method = ScheduledRequestTimings.next_offset + assert callable(next_offset_method) + request_completed_method = ScheduledRequestTimings.request_completed + assert callable(request_completed_method) + + # Check signature parameters using inspect + next_offset_sig = inspect.signature(next_offset_method) + assert len(next_offset_sig.parameters) == 1 + assert str(next_offset_sig.return_annotation) == "float" + request_completed_sig = inspect.signature(request_completed_method) + assert len(request_completed_sig.parameters) == 2 + params = list(request_completed_sig.parameters.values()) + param_annotation = params[1].annotation + assert param_annotation in {ScheduledRequestInfo, "ScheduledRequestInfo"} + + @pytest.mark.sanity + def test_invalid_implementation(self): + """Test that invalid implementations raise TypeError.""" + + class InvalidImplementation(ScheduledRequestTimings): + pass # Missing required abstract methods + + with pytest.raises(TypeError): + InvalidImplementation() + + @pytest.mark.smoke + def test_child_implementation(self): + """Test that concrete implementations can be constructed.""" + + class TestRequestTimings(ScheduledRequestTimings): + offset: float = 0.0 + + def next_offset(self) -> float: + self.offset += 1.0 + return self.offset + + def request_completed(self, request_info: ScheduledRequestInfo): + pass + + timing = TestRequestTimings() + assert isinstance(timing, ScheduledRequestTimings) + + assert timing.next_offset() == 1.0 + assert timing.next_offset() == 2.0 + + mock_request = ScheduledRequestInfo( + request_id="test", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=time.time(), + ) + timing.request_completed(mock_request) + + +class TestLastCompletionRequestTimings: + @pytest.fixture( + params=[ + {}, + {"offset": 10.0}, + {"startup_requests": 5, "startup_requests_delay": 0.5}, + { + "offset": 0.0, + "startup_requests": 0, + "startup_requests_delay": 0.0, + }, + { + "offset": 2.5, + "startup_requests": 3, + "startup_requests_delay": 1.0, + }, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of LastCompletionRequestTimings.""" + constructor_args = request.param + instance = LastCompletionRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization( + self, valid_instances: tuple[LastCompletionRequestTimings, dict] + ): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, LastCompletionRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("startup_requests", -1), + ("startup_requests_delay", -0.5), + ("offset", "invalid"), + ("startup_requests", 1.5), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + LastCompletionRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_lifecycle( + self, valid_instances: tuple[LastCompletionRequestTimings, dict] + ): + """Test the complete lifecycle of next_offset and request_completed calls.""" + instance, constructor_args = valid_instances + initial_offset = instance.offset + startup_requests = constructor_args.get("startup_requests", 0) + startup_delay = constructor_args.get("startup_requests_delay", 0.0) + request_times = [] + + for index in range(max(5, startup_requests + 2)): + offset = instance.next_offset() + assert isinstance(offset, (int, float)) + + if index < startup_requests: + expected_offset = initial_offset + (index + 1) * startup_delay + assert offset == pytest.approx(expected_offset, abs=1e-5) + + completion_time = time.time() + offset + request_times.append(completion_time) + + mock_request: ScheduledRequestInfo = ScheduledRequestInfo( + request_id=f"test-{index}", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=time.time(), + ) + mock_request.scheduler_timings.resolve_end = completion_time + instance.request_completed(mock_request) + + @pytest.mark.smoke + def test_marshalling( + self, valid_instances: tuple[LastCompletionRequestTimings, dict] + ): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = LastCompletionRequestTimings.model_validate(data) + assert isinstance(reconstructed, LastCompletionRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestNoDelayRequestTimings: + @pytest.fixture( + params=[ + {}, + {"offset": 0.2}, + {"startup_duration": 0.3, "startup_target_requests": 5}, + { + "offset": 0.15, + "startup_duration": 0.2, + "startup_target_requests": 20, + "startup_convergence": 0.9, + }, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of NoDelayRequestTimings.""" + constructor_args = request.param + instance = NoDelayRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[NoDelayRequestTimings, dict]): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, NoDelayRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("offset", -1.0), + ("startup_duration", -1.0), + ("startup_target_requests", 0), + ("startup_target_requests", -1), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + NoDelayRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_lifecycle(self, valid_instances: tuple[NoDelayRequestTimings, dict]): + """Test the complete lifecycle of timing methods.""" + instance, constructor_args = valid_instances + startup_duration = constructor_args.get("startup_duration", 0.0) + base_offset = constructor_args.get("offset", 0.0) + start_time = time.time() + min_time = base_offset + startup_duration + 0.2 + end_time = start_time + min_time + last_offset = -1 * math.inf + + while (current_time := time.time()) < end_time: + offset = instance.next_offset() + + if startup_duration > 0 and (current_time - start_time) <= startup_duration: + assert offset < base_offset + startup_duration + assert offset > last_offset + elif startup_duration > 0: + assert offset == base_offset + startup_duration + else: + assert offset == base_offset + + last_offset = offset + time.sleep(0.025) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[NoDelayRequestTimings, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = NoDelayRequestTimings.model_validate(data) + assert isinstance(reconstructed, NoDelayRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestConstantRateRequestTimings: + @pytest.fixture( + params=[ + {"rate": 1.0}, + {"rate": 5.0, "offset": 2.0}, + {"rate": 10.5, "offset": 1.0}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of ConstantRateRequestTimings.""" + constructor_args = request.param + instance = ConstantRateRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization( + self, valid_instances: tuple[ConstantRateRequestTimings, dict] + ): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ConstantRateRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("rate", 0), + ("rate", -1.0), + ("offset", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {"rate": 1.0} + kwargs[field] = value + with pytest.raises(ValidationError): + ConstantRateRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_constant_rate_behavior( + self, valid_instances: tuple[ConstantRateRequestTimings, dict] + ): + """Test that requests are scheduled at constant intervals.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + expected_interval = 1.0 / rate + base_offset = constructor_args.get("offset", 0.0) + num_requests = int(5 * rate) # simulate 5 seconds + + for ind in range(num_requests): + offset = instance.next_offset() + assert offset >= base_offset + assert offset == pytest.approx( + base_offset + ind * expected_interval, rel=1e-2 + ) + + @pytest.mark.smoke + def test_marshalling( + self, valid_instances: tuple[ConstantRateRequestTimings, dict] + ): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = ConstantRateRequestTimings.model_validate(data) + assert isinstance(reconstructed, ConstantRateRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestPoissonRateRequestTimings: + @pytest.fixture( + params=[ + {"rate": 1.0}, + { + "rate": 5.0, + "random_seed": 123, + "offset": 1.0, + }, + { + "rate": 0.5, + }, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of PoissonRateRequestTimings.""" + constructor_args = request.param + instance = PoissonRateRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization( + self, valid_instances: tuple[PoissonRateRequestTimings, dict] + ): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, PoissonRateRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("rate", 0), + ("rate", -1.0), + ("offset", "invalid"), + ("random_seed", "invalid"), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {"rate": 1.0} + kwargs[field] = value + with pytest.raises(ValidationError): + PoissonRateRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_lifecycle(self, valid_instances: tuple[PoissonRateRequestTimings, dict]): + """Test that Poisson timing produces variable intervals.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + base_offset = constructor_args.get("offset", 0.0) + num_requests = 200 + last_offset = 0.0 + intervals = [] + + for index in range(num_requests): + offset = instance.next_offset() + + if index == 0: + assert offset == base_offset + else: + assert offset > last_offset + + intervals.append(offset - last_offset) + last_offset = offset + + expected_mean_interval = 1.0 / rate + actual_mean_interval = statistics.mean(intervals) + tolerance = 0.2 * expected_mean_interval + assert abs(actual_mean_interval - expected_mean_interval) < tolerance + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[PoissonRateRequestTimings, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = PoissonRateRequestTimings.model_validate(data) + assert isinstance(reconstructed, PoissonRateRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestSchedulingStrategy: + @pytest.mark.smoke + def test_class_signatures(self): + """Test SchedulingStrategy inheritance and type relationships.""" + # Inheritance and abstract class properties + assert issubclass(SchedulingStrategy, object) + assert hasattr(SchedulingStrategy, "info") + + # Validate expected methods exist + expected_methods = { + "processes_limit", + "requests_limit", + "create_request_timings", + } + strategy_methods = set(dir(SchedulingStrategy)) + for method in expected_methods: + assert method in strategy_methods + + # validate expected properties + processes_limit_prop = SchedulingStrategy.processes_limit + assert isinstance(processes_limit_prop, property) + requests_limit_prop = SchedulingStrategy.requests_limit + assert isinstance(requests_limit_prop, property) + create_request_timings_method = SchedulingStrategy.create_request_timings + assert callable(create_request_timings_method) + + # Validate method signature + sig = inspect.signature(create_request_timings_method) + params = list(sig.parameters.keys()) + expected_params = [ + "self", + "local_rank", + "local_world_size", + "local_max_concurrency", + ] + assert params == expected_params + + @pytest.mark.sanity + def test_invalid_implementation(self): + """Test that invalid implementations raise NotImplementedError.""" + + class InvalidStrategy(SchedulingStrategy): + type_: Literal["strategy"] = "strategy" # type: ignore[assignment,annotation-unchecked] + + strategy = InvalidStrategy() + with pytest.raises(NotImplementedError): + strategy.create_request_timings(0, 1, 1) + + @pytest.mark.smoke + def test_concrete_implementation(self): + """Test that concrete implementations can be constructed.""" + + class TestStrategy(SchedulingStrategy): + type_: Literal["strategy"] = "strategy" # type: ignore[assignment,annotation-unchecked] + + def create_request_timings( + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, + ): + return LastCompletionRequestTimings() + + strategy = TestStrategy() + assert isinstance(strategy, SchedulingStrategy) + timing = strategy.create_request_timings(0, 1, 1) + assert isinstance(timing, ScheduledRequestTimings) + + +class TestSynchronousStrategy: + @pytest.mark.smoke + def test_initialization(self): + """Test initialization of SynchronousStrategy.""" + strategy = SynchronousStrategy() + assert strategy.type_ == "synchronous" + + @pytest.mark.smoke + def test_limits(self): + """Test that SynchronousStrategy enforces proper limits.""" + strategy = SynchronousStrategy() + assert strategy.processes_limit == 1 + assert strategy.requests_limit == 1 + + @pytest.mark.smoke + def test_create_timings_valid(self): + """Test creating timings with valid parameters.""" + strategy = SynchronousStrategy() + timing = strategy.create_request_timings(0, 1, 1) + assert isinstance(timing, LastCompletionRequestTimings) + + @pytest.mark.sanity + def test_create_timings_invalid(self): + """Test that invalid parameters raise ValueError.""" + strategy = SynchronousStrategy() + + with pytest.raises(ValueError): + strategy.create_request_timings(1, 1, 1) # rank != 0 + + with pytest.raises(ValueError): + strategy.create_request_timings(0, 2, 1) # world_size > 1 + + @pytest.mark.smoke + def test_string_representation(self): + """Test __str__ method for SynchronousStrategy.""" + strategy = SynchronousStrategy() + result = str(strategy) + assert result == "synchronous" + + @pytest.mark.smoke + def test_marshalling(self): + """Test marshalling to/from pydantic dict formats.""" + strategy = SynchronousStrategy() + data = strategy.model_dump() + assert isinstance(data, dict) + assert data["type_"] == "synchronous" + + reconstructed = SynchronousStrategy.model_validate(data) + assert isinstance(reconstructed, SynchronousStrategy) + assert reconstructed.type_ == "synchronous" + + # Test polymorphic reconstruction via base registry class + base_reconstructed = SchedulingStrategy.model_validate(data) + assert isinstance(base_reconstructed, SynchronousStrategy) + assert base_reconstructed.type_ == "synchronous" + + # Test model_validate_json pathway + json_str = strategy.model_dump_json() + json_reconstructed = SynchronousStrategy.model_validate_json(json_str) + assert isinstance(json_reconstructed, SynchronousStrategy) + assert json_reconstructed.type_ == "synchronous" + + # Test polymorphic model_validate_json via base class + base_json_reconstructed = SchedulingStrategy.model_validate_json(json_str) + assert isinstance(base_json_reconstructed, SynchronousStrategy) + assert base_json_reconstructed.type_ == "synchronous" + + +class TestConcurrentStrategy: + @pytest.fixture( + params=[ + {"streams": 1}, + {"streams": 4}, + {"streams": 8, "startup_duration": 2.0}, + {"streams": 2, "startup_duration": 0.0}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of ConcurrentStrategy.""" + constructor_args = request.param + instance = ConcurrentStrategy(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[ConcurrentStrategy, dict]): + """Test initialization of ConcurrentStrategy.""" + instance, constructor_args = valid_instances + assert instance.type_ == "concurrent" + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("streams", 0), + ("streams", -1), + ("startup_duration", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization.""" + kwargs = {"streams": 2} + kwargs[field] = value + with pytest.raises(ValidationError): + ConcurrentStrategy(**kwargs) + + @pytest.mark.smoke + def test_limits(self, valid_instances: tuple[ConcurrentStrategy, dict]): + """Test that ConcurrentStrategy returns correct limits.""" + instance, constructor_args = valid_instances + streams = constructor_args["streams"] + assert instance.processes_limit == streams + assert instance.requests_limit == streams + + @pytest.mark.smoke + def test_create_timings(self, valid_instances: tuple[ConcurrentStrategy, dict]): + """Test creating timings.""" + instance, constructor_args = valid_instances + streams = constructor_args["streams"] + startup_duration = constructor_args.get("startup_duration", 0.0) + + # Test with different rank and world_size combinations + for local_rank in range(min(streams, 2)): + for local_world_size in range(1, min(streams + 1, 3)): + if local_rank < local_world_size: + timing = instance.create_request_timings( + local_rank, local_world_size, streams + ) + assert isinstance(timing, LastCompletionRequestTimings) + + # Verify startup behavior + if startup_duration > 0: + # Check that timing has proper startup configuration + expected_delay_per_stream = startup_duration / streams + streams_per_worker = streams // local_world_size + expected_offset = ( + local_rank * streams_per_worker * expected_delay_per_stream + ) + assert timing.offset == pytest.approx(expected_offset, abs=1e-5) + + @pytest.mark.sanity + def test_create_timings_invalid( + self, valid_instances: tuple[ConcurrentStrategy, dict] + ): + """Test invalid inputs for create request timings.""" + instance, constructor_args = valid_instances + streams = constructor_args["streams"] + + # Test various invalid configurations + invalid_configs = [ + (streams, 1, 1), # rank >= streams + (0, streams + 1, 1), # world_size > streams + ] + + for local_rank, local_world_size, local_max_concurrency in invalid_configs: + if local_rank >= streams or local_world_size > streams: + with pytest.raises(ValueError): + instance.create_request_timings( + local_rank, local_world_size, local_max_concurrency + ) + + @pytest.mark.smoke + def test_string_representation( + self, valid_instances: tuple[ConcurrentStrategy, dict] + ): + """Test __str__ method for ConcurrentStrategy.""" + instance, constructor_args = valid_instances + streams = constructor_args["streams"] + result = str(instance) + assert result == f"concurrent@{streams}" + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[ConcurrentStrategy, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + assert data["type_"] == "concurrent" + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = ConcurrentStrategy.model_validate(data) + assert isinstance(reconstructed, ConcurrentStrategy) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + # Test polymorphic reconstruction via base registry class + base_reconstructed = SchedulingStrategy.model_validate(data) + assert isinstance(base_reconstructed, ConcurrentStrategy) + assert base_reconstructed.type_ == "concurrent" + + for key, value in constructor_args.items(): + assert getattr(base_reconstructed, key) == value + + # Test model_validate_json pathway + json_str = instance.model_dump_json() + json_reconstructed = ConcurrentStrategy.model_validate_json(json_str) + assert isinstance(json_reconstructed, ConcurrentStrategy) + + for key, value in constructor_args.items(): + assert getattr(json_reconstructed, key) == value + + # Test polymorphic model_validate_json via base class + base_json_reconstructed = SchedulingStrategy.model_validate_json(json_str) + assert isinstance(base_json_reconstructed, ConcurrentStrategy) + assert base_json_reconstructed.type_ == "concurrent" + + for key, value in constructor_args.items(): + assert getattr(base_json_reconstructed, key) == value + + +class TestThroughputStrategy: + @pytest.fixture( + params=[ + {}, + {"max_concurrency": 10}, + {"startup_duration": 5.0}, + {"max_concurrency": 5, "startup_duration": 2.0}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of ThroughputStrategy.""" + constructor_args = request.param + instance = ThroughputStrategy(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[ThroughputStrategy, dict]): + """Test initialization of ThroughputStrategy.""" + instance, constructor_args = valid_instances + assert instance.type_ == "throughput" + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("max_concurrency", 0), + ("max_concurrency", -1), + ("startup_duration", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + ThroughputStrategy(**kwargs) + + @pytest.mark.smoke + def test_limits(self, valid_instances: tuple[ThroughputStrategy, dict]): + """Test that ThroughputStrategy returns correct limits.""" + instance, constructor_args = valid_instances + max_concurrency = constructor_args.get("max_concurrency") + assert instance.processes_limit == max_concurrency + assert instance.requests_limit == max_concurrency + + @pytest.mark.smoke + def test_create_timings(self, valid_instances: tuple[ThroughputStrategy, dict]): + """Test creating timings.""" + instance, constructor_args = valid_instances + startup_duration = constructor_args.get("startup_duration", 0.0) + + # Test with different configurations + for local_rank in range(3): + for local_world_size in range(1, 4): + for local_max_concurrency in range(1, 6): + timing = instance.create_request_timings( + local_rank, local_world_size, local_max_concurrency + ) + assert isinstance(timing, NoDelayRequestTimings) + + # Verify startup configuration + if startup_duration > 0: + assert timing.startup_duration == startup_duration + assert timing.startup_target_requests == local_max_concurrency + expected_offset = ( + 0.05 * startup_duration * (local_rank / local_world_size) + ) + assert timing.offset == pytest.approx(expected_offset, abs=1e-5) + else: + assert timing.startup_duration == 0.0 + assert timing.offset == 0.0 + + @pytest.mark.smoke + def test_string_representation( + self, valid_instances: tuple[ThroughputStrategy, dict] + ): + """Test __str__ method for ThroughputStrategy.""" + instance, _ = valid_instances + result = str(instance) + assert result == "throughput" + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[ThroughputStrategy, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + assert data["type_"] == "throughput" + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = ThroughputStrategy.model_validate(data) + assert isinstance(reconstructed, ThroughputStrategy) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + # Test polymorphic reconstruction via base registry class + base_reconstructed = SchedulingStrategy.model_validate(data) + assert isinstance(base_reconstructed, ThroughputStrategy) + assert base_reconstructed.type_ == "throughput" + + for key, value in constructor_args.items(): + assert getattr(base_reconstructed, key) == value + + # Test model_validate_json pathway + json_str = instance.model_dump_json() + json_reconstructed = ThroughputStrategy.model_validate_json(json_str) + assert isinstance(json_reconstructed, ThroughputStrategy) + + for key, value in constructor_args.items(): + assert getattr(json_reconstructed, key) == value + + # Test polymorphic model_validate_json via base class + base_json_reconstructed = SchedulingStrategy.model_validate_json(json_str) + assert isinstance(base_json_reconstructed, ThroughputStrategy) + assert base_json_reconstructed.type_ == "throughput" + + for key, value in constructor_args.items(): + assert getattr(base_json_reconstructed, key) == value + + +class TestAsyncConstantStrategy: + @pytest.fixture( + params=[ + {"rate": 1.0}, + {"rate": 5.0}, + {"rate": 10.3, "max_concurrency": 8}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of AsyncConstantStrategy.""" + constructor_args = request.param + instance = AsyncConstantStrategy(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[AsyncConstantStrategy, dict]): + """Test initialization of AsyncConstantStrategy.""" + instance, constructor_args = valid_instances + assert instance.type_ == "constant" + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("rate", 0), + ("rate", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization.""" + kwargs = {"rate": 1.0} + kwargs[field] = value + with pytest.raises(ValidationError): + AsyncConstantStrategy(**kwargs) + + @pytest.mark.smoke + def test_create_timings(self, valid_instances: tuple[AsyncConstantStrategy, dict]): + """Test creating timings.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + + # Test with different worker configurations + for local_world_size in range(1, 5): + timing = instance.create_request_timings(0, local_world_size, 1) + assert isinstance(timing, ConstantRateRequestTimings) + + # Rate should be distributed across workers + expected_worker_rate = rate / local_world_size + assert timing.rate == pytest.approx(expected_worker_rate, abs=1e-5) + + @pytest.mark.smoke + def test_string_representation( + self, valid_instances: tuple[AsyncConstantStrategy, dict] + ): + """Test __str__ method for AsyncConstantStrategy.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + result = str(instance) + assert result == f"constant@{rate:.2f}" + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[AsyncConstantStrategy, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + assert data["type_"] == "constant" + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = AsyncConstantStrategy.model_validate(data) + assert isinstance(reconstructed, AsyncConstantStrategy) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + # Test polymorphic reconstruction via base registry class + base_reconstructed = SchedulingStrategy.model_validate(data) + assert isinstance(base_reconstructed, AsyncConstantStrategy) + assert base_reconstructed.type_ == "constant" + + for key, value in constructor_args.items(): + assert getattr(base_reconstructed, key) == value + + # Test model_validate_json pathway + json_str = instance.model_dump_json() + json_reconstructed = AsyncConstantStrategy.model_validate_json(json_str) + assert isinstance(json_reconstructed, AsyncConstantStrategy) + + for key, value in constructor_args.items(): + assert getattr(json_reconstructed, key) == value + + # Test polymorphic model_validate_json via base class + base_json_reconstructed = SchedulingStrategy.model_validate_json(json_str) + assert isinstance(base_json_reconstructed, AsyncConstantStrategy) + assert base_json_reconstructed.type_ == "constant" + + for key, value in constructor_args.items(): + assert getattr(base_json_reconstructed, key) == value + + +class TestAsyncPoissonStrategy: + @pytest.fixture( + params=[ + {"rate": 1.0}, + {"rate": 5.0, "random_seed": 123}, + {"rate": 10.3, "random_seed": 456, "max_concurrency": 8}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of AsyncPoissonStrategy.""" + constructor_args = request.param + instance = AsyncPoissonStrategy(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[AsyncPoissonStrategy, dict]): + """Test initialization of AsyncPoissonStrategy.""" + instance, constructor_args = valid_instances + assert instance.type_ == "poisson" + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("rate", 0), + ("rate", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization.""" + kwargs = {"rate": 1.0, "random_seed": 42} + kwargs[field] = value + with pytest.raises(ValidationError): + AsyncPoissonStrategy(**kwargs) + + @pytest.mark.smoke + def test_create_timings(self, valid_instances: tuple[AsyncPoissonStrategy, dict]): + """Test creating timings.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + base_seed = constructor_args.get("random_seed", 42) + + # Test with different worker configurations + for local_rank in range(3): + for local_world_size in range(1, 4): + timing = instance.create_request_timings( + local_rank, local_world_size, 1 + ) + assert isinstance(timing, PoissonRateRequestTimings) + + # Rate should be distributed across workers + expected_worker_rate = rate / local_world_size + assert timing.rate == pytest.approx(expected_worker_rate, abs=1e-5) + + # Each worker should have a unique seed + expected_seed = base_seed + local_rank + assert timing.random_seed == expected_seed + + @pytest.mark.smoke + def test_string_representation( + self, valid_instances: tuple[AsyncPoissonStrategy, dict] + ): + """Test __str__ method for AsyncPoissonStrategy.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + result = str(instance) + assert result == f"poisson@{rate:.2f}" + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[AsyncPoissonStrategy, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + assert data["type_"] == "poisson" + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = AsyncPoissonStrategy.model_validate(data) + assert isinstance(reconstructed, AsyncPoissonStrategy) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + # Test polymorphic reconstruction via base registry class + base_reconstructed = SchedulingStrategy.model_validate(data) + assert isinstance(base_reconstructed, AsyncPoissonStrategy) + assert base_reconstructed.type_ == "poisson" + + for key, value in constructor_args.items(): + assert getattr(base_reconstructed, key) == value + + # Test model_validate_json pathway + json_str = instance.model_dump_json() + json_reconstructed = AsyncPoissonStrategy.model_validate_json(json_str) + assert isinstance(json_reconstructed, AsyncPoissonStrategy) + + for key, value in constructor_args.items(): + assert getattr(json_reconstructed, key) == value + + # Test polymorphic model_validate_json via base class + base_json_reconstructed = SchedulingStrategy.model_validate_json(json_str) + assert isinstance(base_json_reconstructed, AsyncPoissonStrategy) + assert base_json_reconstructed.type_ == "poisson" + + for key, value in constructor_args.items(): + assert getattr(base_json_reconstructed, key) == value diff --git a/tests/unit/scheduler/test_worker.py b/tests/unit/scheduler/test_worker.py new file mode 100644 index 00000000..b62d66d5 --- /dev/null +++ b/tests/unit/scheduler/test_worker.py @@ -0,0 +1,672 @@ +from __future__ import annotations + +import asyncio +import inspect +import random +import time +from dataclasses import dataclass +from functools import wraps +from multiprocessing import Barrier, Event, Process +from multiprocessing.synchronize import Barrier as ProcessingBarrier +from multiprocessing.synchronize import Event as ProcessingEvent +from typing import Any, Generic, Literal + +import pytest +import pytest_asyncio + +from guidellm.scheduler import ( + BackendInterface, + ConstantRateRequestTimings, + LastCompletionRequestTimings, + MeasuredRequestTimings, + NoDelayRequestTimings, + PoissonRateRequestTimings, + ScheduledRequestInfo, + ScheduledRequestTimings, + SchedulerMessagingPydanticRegistry, + WorkerProcess, +) +from guidellm.utils import InterProcessMessagingQueue + +STANDARD_NUM_REQUESTS: int = 200 + + +def async_timeout(delay): + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +@dataclass +class TimingsBounds: + exact: float | None = None + lower: float | None = None + upper: float | None = None + prev_request: Literal["greater", "greater_equal", "less", "less_equal"] | None = ( + None + ) + tolerance: float = 10e-4 + actual_tolerance: float = 10e-4 + + +class MockRequestTimings(MeasuredRequestTimings): + """Mock timing implementation for testing.""" + + +class MockBackend(BackendInterface): + """Mock backend for testing worker functionality.""" + + def __init__( + self, + lifecycle_delay: float = 0.1, + resolve_delay: float = 0.0, + should_fail: bool = False, + request_error_rate: float = 0.0, + ): + self.lifecycle_delay = lifecycle_delay + self.resolve_delay = resolve_delay + self.should_fail = should_fail + self.request_error_rate = request_error_rate + self.process_startup_called = False + self.validate_called = False + self.process_shutdown_called = False + self.resolve_called = False + + @property + def processes_limit(self) -> int | None: + return None + + @property + def requests_limit(self) -> int | None: + return None + + @property + def info(self) -> dict[str, Any]: + return { + "type": "mock", + "lifecycle_delay": self.lifecycle_delay, + "resolve_delay": self.resolve_delay, + } + + async def process_startup(self): + await asyncio.sleep(self.lifecycle_delay) + self.process_startup_called = True + + async def validate(self): + await asyncio.sleep(self.lifecycle_delay) + self.validate_called = True + if self.should_fail: + raise RuntimeError("Mock validation failed") + + async def process_shutdown(self): + await asyncio.sleep(self.lifecycle_delay) + self.process_shutdown_called = True + + async def resolve(self, request, request_info, request_history): + self.resolve_called = True + await asyncio.sleep( + self.resolve_delay if not str(request).startswith("cancel") else 1000.0 + ) + if self.should_fail: + raise RuntimeError("Mock resolve failed") + if self.request_error_rate > 0.0 and random.random() < self.request_error_rate: + raise RuntimeError("Mock resolve failed") + yield f"response_for_{request}", request_info + + +class TestWorkerProcess: + """Test suite for WorkerProcess class.""" + + @pytest_asyncio.fixture( + params=[ + { + "messaging": { + "serialization": "dict", + "encoding": None, + "max_buffer_receive_size": 2, + }, + "worker": { + "async_limit": 1, + }, + }, + { + "messaging": { + "serialization": "dict", + "encoding": None, + "max_buffer_receive_size": 100, + }, + "worker": { + "async_limit": 1000, + }, + }, + ], + ) + async def valid_instances(self, request): + """Fixture providing test data for WorkerProcess.""" + constructor_args = request.param + main_messaging = InterProcessMessagingQueue( + **constructor_args["messaging"], poll_interval=0.01 + ) + + try: + instance = WorkerProcess( + messaging=main_messaging.create_worker_copy(0), + backend=MockBackend(), + request_timings=LastCompletionRequestTimings(), + **constructor_args["worker"], + startup_barrier=Barrier(2), + requests_generated_event=Event(), + constraint_reached_event=Event(), + shutdown_event=Event(), + error_event=Event(), + ) + await main_messaging.start( + pydantic_models=list( + SchedulerMessagingPydanticRegistry.registry.values() + ) + ) + yield instance, main_messaging, constructor_args + finally: + await main_messaging.stop() + + @pytest.mark.smoke + def test_class_signatures( + self, + valid_instances: tuple[WorkerProcess, InterProcessMessagingQueue, dict], + ): + """Test inheritance and type relationships.""" + worker_process, main_messaging, constructor_args = valid_instances + + # Class + assert isinstance(worker_process, Generic) + assert issubclass(WorkerProcess, Generic) + + # Generics + orig_bases = getattr(WorkerProcess, "__orig_bases__", ()) + assert len(orig_bases) > 0 + generic_base = next( + ( + base + for base in orig_bases + if hasattr(base, "__origin__") and base.__origin__ is Generic + ), + None, + ) + assert generic_base is not None + type_args = getattr(generic_base, "__args__", ()) + assert len(type_args) == 2 # RequestT, ResponseT + + # Function signatures + run_sig = inspect.signature(WorkerProcess.run) + assert len(run_sig.parameters) == 1 + assert "self" in run_sig.parameters + + run_async_sig = inspect.signature(WorkerProcess.run_async) + assert len(run_async_sig.parameters) == 1 + assert "self" in run_async_sig.parameters + + stop_processing_sig = inspect.signature(WorkerProcess._stop_monitor) + assert len(stop_processing_sig.parameters) == 1 + assert "self" in stop_processing_sig.parameters + + requests_processing_sig = inspect.signature(WorkerProcess._process_requests) + assert len(requests_processing_sig.parameters) == 1 + assert "self" in requests_processing_sig.parameters + + @pytest.mark.smoke + def test_initialization( + self, + valid_instances: tuple[WorkerProcess, InterProcessMessagingQueue, dict], + ): + """Test basic initialization of WorkerProcess.""" + instance, main_messaging, constructor_args = valid_instances + + # messaging + assert instance.messaging is not None + assert isinstance(instance.messaging, InterProcessMessagingQueue) + assert instance.messaging is not main_messaging + assert instance.messaging.worker_index is not None + assert instance.messaging.worker_index == 0 + assert ( + instance.messaging.serialization + == constructor_args["messaging"]["serialization"] + ) + assert instance.messaging.encoding == constructor_args["messaging"]["encoding"] + assert ( + instance.messaging.max_buffer_receive_size + == constructor_args["messaging"]["max_buffer_receive_size"] + ) + + # worker + assert instance.async_limit == constructor_args["worker"]["async_limit"] + assert instance.startup_barrier is not None + assert isinstance(instance.startup_barrier, ProcessingBarrier) + assert instance.shutdown_event is not None + assert isinstance(instance.shutdown_event, ProcessingEvent) + assert instance.error_event is not None + assert isinstance(instance.error_event, ProcessingEvent) + assert instance.requests_generated_event is not None + assert isinstance(instance.requests_generated_event, ProcessingEvent) + assert instance.constraint_reached_event is not None + assert isinstance(instance.constraint_reached_event, ProcessingEvent) + assert instance.backend is not None + assert isinstance(instance.backend, MockBackend) + assert instance.request_timings is not None + assert isinstance(instance.request_timings, LastCompletionRequestTimings) + assert not instance.startup_completed + + @pytest.mark.sanity + def test_invalid_initialization(self): + """Test that invalid initialization raises appropriate errors.""" + + # Test with missing required parameters + with pytest.raises(TypeError): + WorkerProcess() + + # Create a complete set of valid parameters + backend = MockBackend() + request_timings = LastCompletionRequestTimings() + barrier = Barrier(2) + shutdown_event = Event() + error_event = Event() + requests_generated_event = Event() + constraint_reached_event = Event() + messaging = InterProcessMessagingQueue() + + # Test missing each required parameter one by one + required_params = [ + "messaging", + "backend", + "request_timings", + "async_limit", + "startup_barrier", + "requests_generated_event", + "constraint_reached_event", + "shutdown_event", + "error_event", + ] + + for param_to_remove in required_params: + kwargs = { + "messaging": messaging, + "backend": backend, + "request_timings": request_timings, + "async_limit": 5, + "startup_barrier": barrier, + "requests_generated_event": requests_generated_event, + "constraint_reached_event": constraint_reached_event, + "shutdown_event": shutdown_event, + "error_event": error_event, + } + + del kwargs[param_to_remove] + + with pytest.raises(TypeError): + WorkerProcess(**kwargs) + + @pytest.mark.smoke + @pytest.mark.asyncio + # @async_timeout(15) + @pytest.mark.parametrize( + ("num_requests", "num_canceled", "error_rate"), + [ + (20, 0, 0), + (STANDARD_NUM_REQUESTS, 20, 0.5), + ], + ) + async def test_run_async_lifecycle( # noqa: C901, PLR0912 + self, + valid_instances: tuple[WorkerProcess, InterProcessMessagingQueue, dict], + num_requests: int, + num_canceled: int, + error_rate: float, + ): + """Test the asynchronous request processing of WorkerProcess.""" + instance, main_messaging, constructor_args = valid_instances + instance.backend.request_error_rate = error_rate + instance_task = asyncio.create_task(instance.run_async()) + + try: + await asyncio.to_thread(instance.startup_barrier.wait) + start_time = time.time() + + # Send regular requests + requests_tracker = {} + for index in range(num_requests): + request = f"request_{index}" + request_info = ScheduledRequestInfo( + request_id=request, + scheduler_start_time=start_time, + scheduler_process_id=0, + ) + request_info.scheduler_timings.queued = time.time() + requests_tracker[request] = { + "sent": True, + "received_pending": 0, + "received_in_progress": 0, + "received_resolved": 0, + } + await main_messaging.put( + (request, request_info), + timeout=2.0, + ) + + # Process regular requests + error_count = 0 + for _ in range(num_requests * 3): + # Each request must have a pending, in_progress, and resolution + response, request, request_info = await main_messaging.get(timeout=2.0) + assert request is not None + assert request_info is not None + assert request_info.request_id is not None + assert request_info.status is not None + assert request_info.scheduler_node_id > -1 + assert request_info.scheduler_process_id > -1 + assert request_info.scheduler_start_time == start_time + assert request_info.scheduler_timings is not None + assert request_info.scheduler_timings.targeted_start is not None + assert request_info.scheduler_timings.targeted_start >= start_time + + if request_info.status == "pending": + requests_tracker[request]["received_pending"] += 1 + assert request_info.scheduler_timings.dequeued is not None + assert ( + request_info.scheduler_timings.dequeued + >= request_info.scheduler_timings.targeted_start + ) + elif request_info.status == "in_progress": + requests_tracker[request]["received_in_progress"] += 1 + assert request_info.scheduler_timings.scheduled_at is not None + assert ( + request_info.scheduler_timings.scheduled_at + >= request_info.scheduler_timings.dequeued + ) + assert request_info.scheduler_timings.resolve_start is not None + assert ( + request_info.scheduler_timings.resolve_start + >= request_info.scheduler_timings.scheduled_at + ) + elif request_info.status == "completed": + assert response == f"response_for_{request}" + requests_tracker[request]["received_resolved"] += 1 + assert request_info.scheduler_timings.resolve_end is not None + assert ( + request_info.scheduler_timings.resolve_end + > request_info.scheduler_timings.resolve_start + ) + elif request_info.status == "errored": + assert response is None + requests_tracker[request]["received_resolved"] += 1 + error_count += 1 + assert request_info.scheduler_timings.resolve_end is not None + assert ( + request_info.scheduler_timings.resolve_end + > request_info.scheduler_timings.resolve_start + ) + else: + raise ValueError(f"Unexpected status: {request_info.status}") + + # Ensure correct error rate + assert float(error_count) / num_requests == pytest.approx( + error_rate, rel=0.2 + ) + + # Ensure no extra statuses + with pytest.raises(asyncio.TimeoutError): + await main_messaging.get(timeout=0.5) + + # Send cancel requests + for index in range(num_canceled): + cancel_request = f"cancel_request_{index}" + cancel_info = ScheduledRequestInfo( + request_id=request, + scheduler_start_time=start_time, + scheduler_process_id=0, + ) + cancel_info.scheduler_timings.queued = time.time() + requests_tracker[cancel_request] = { + "sent": True, + "received_pending": 0, + "received_in_progress": 0, + "received_resolved": 0, + } + await main_messaging.put( + (cancel_request, cancel_info), + timeout=2.0, + ) + + # Receive expected updates for cancel up to async number + for _ in range(2 * min(num_canceled, instance.async_limit)): + # Each request (up to async limit) will have pending, in_progress + response, request, request_info = await main_messaging.get(timeout=2.0) + if request_info.status == "pending": + requests_tracker[request]["received_pending"] += 1 + elif request_info.status == "in_progress": + requests_tracker[request]["received_in_progress"] += 1 + error_count += 1 + else: + raise ValueError(f"Unexpected status: {request_info.status}") + + # Signal constraints reached to start canceling + instance.constraint_reached_event.set() + await asyncio.sleep(0) + + # Receive the remaining canceled updates + for _ in range(num_canceled): + # All cancel requests should resolve with canceled (no other statuses) + response, request, request_info = await main_messaging.get(timeout=2.0) + assert request is not None + assert request_info is not None + assert request_info.request_id is not None + assert request_info.status is not None + assert request_info.scheduler_node_id > -1 + assert request_info.scheduler_process_id > -1 + assert request_info.scheduler_start_time == start_time + assert request_info.scheduler_timings is not None + + if request_info.status == "cancelled": + requests_tracker[request]["received_resolved"] += 1 + assert request_info.scheduler_timings.resolve_end is not None + assert request_info.scheduler_timings.resolve_end > start_time + else: + raise ValueError(f"Unexpected status: {request_info.status}") + + # Ensure no extra statuses + with pytest.raises(asyncio.TimeoutError): + await main_messaging.get(timeout=0.5) + + # Signal requests stop now that all requests have been processed + instance.requests_generated_event.set() + await asyncio.sleep(0) + + # Validate all the requests are correct + for request_key in [f"request_{index}" for index in range(num_requests)]: + assert request_key in requests_tracker + assert requests_tracker[request_key]["sent"] + assert requests_tracker[request_key]["received_pending"] == 1 + assert requests_tracker[request_key]["received_resolved"] == 1 + if request_key.startswith("request"): + assert requests_tracker[request_key]["received_in_progress"] == 1 + finally: + # Shut down + instance.shutdown_event.set() + await asyncio.wait_for(instance_task, timeout=2.0) + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(15) + @pytest.mark.parametrize( + ("request_timings", "timing_bounds"), + [ + ( + LastCompletionRequestTimings(offset=0.1), + [ + TimingsBounds(lower=0.1, prev_request="greater_equal") + for _ in range(STANDARD_NUM_REQUESTS) + ], + ), + ( + NoDelayRequestTimings(offset=0.05), + [ + TimingsBounds(lower=0.05, upper=0.05, actual_tolerance=1.0) + for _ in range(STANDARD_NUM_REQUESTS) + ], + ), + ( + ConstantRateRequestTimings(rate=100, offset=0.2), + [ + TimingsBounds( + exact=0.2 + ind * 0.01, + lower=0.2, + prev_request="greater", + actual_tolerance=10e-2, + ) + for ind in range(STANDARD_NUM_REQUESTS) + ], + ), + ( + PoissonRateRequestTimings(rate=200, offset=0.01), + [ + TimingsBounds(lower=0.01, prev_request="greater") + for ind in range(STANDARD_NUM_REQUESTS) + ], + ), + ], + ids=[ + "LastCompletion", + "NoDelay", + "ConstantRate", + "PoissonRate", + ], + ) + async def test_run_with_timings( # noqa: C901, PLR0912 + self, + valid_instances: tuple[WorkerProcess, InterProcessMessagingQueue, dict], + request_timings: ScheduledRequestTimings, + timing_bounds: list[TimingsBounds], + ): + instance, main_messaging, constructor_args = valid_instances + instance.request_timings = request_timings + num_requests = STANDARD_NUM_REQUESTS + assert len(timing_bounds) == num_requests + + # Start process + process = Process(target=instance.run) + process.start() + + try: + await asyncio.to_thread(instance.startup_barrier.wait) + start_time = time.time() + 0.1 + + # Send regular requests + requests_tracker = {} + for ind in range(num_requests): + request = f"request_{ind}" + requests_tracker[request] = { + "sent": True, + "target_start_time": -1, + "actual_start_time": -1, + "received_pending": 0, + "received_in_progress": 0, + "received_resolved": 0, + } + await main_messaging.put( + ( + request, + ScheduledRequestInfo(scheduler_start_time=start_time), + ), + timeout=2.0, + ) + + # Process regular requests + for _ in range(num_requests * 3): + # Each request must have pending, in_progress, and resolved statuses + response, request, request_info = await main_messaging.get(timeout=2.0) + if request_info.status == "pending": + requests_tracker[request]["received_pending"] += 1 + elif request_info.status == "in_progress": + requests_tracker[request]["received_in_progress"] += 1 + requests_tracker[request]["target_start_time"] = ( + request_info.scheduler_timings.targeted_start + ) + requests_tracker[request]["actual_start_time"] = ( + request_info.scheduler_timings.resolve_start + ) + elif request_info.status == "completed": + assert response == f"response_for_{request}" + requests_tracker[request]["received_resolved"] += 1 + else: + raise ValueError(f"Unexpected status: {request_info.status}") + + # Ensure no extra statuses + with pytest.raises(asyncio.TimeoutError): + await main_messaging.get(timeout=0.1) + + # Trigger stopping for constraints and requests + instance.requests_generated_event.set() + instance.constraint_reached_event.set() + await asyncio.sleep(0) + + # Validate request values are correct + for ind in range(num_requests): + request = f"request_{ind}" + assert requests_tracker[request]["received_pending"] == 1 + assert requests_tracker[request]["received_in_progress"] == 1 + assert requests_tracker[request]["received_resolved"] == 1 + + bounds = timing_bounds[ind] + target_offset = ( + requests_tracker[request]["target_start_time"] - start_time + ) + actual_offset = ( + requests_tracker[request]["actual_start_time"] - start_time + ) + prev_offset = ( + requests_tracker[f"request_{ind - 1}"]["target_start_time"] + - start_time + if ind > 0 + else None + ) + + if bounds.exact is not None: + assert target_offset == pytest.approx( + bounds.exact, rel=bounds.tolerance + ) + assert target_offset == pytest.approx( + actual_offset, rel=bounds.actual_tolerance or bounds.tolerance + ) + if bounds.lower is not None: + assert target_offset >= bounds.lower - bounds.tolerance + assert actual_offset >= bounds.lower - ( + bounds.actual_tolerance or bounds.tolerance + ) + if bounds.upper is not None: + assert target_offset <= bounds.upper + bounds.tolerance + assert actual_offset <= bounds.upper + ( + bounds.actual_tolerance or bounds.tolerance + ) + if bounds.prev_request is not None and prev_offset is not None: + if bounds.prev_request == "greater": + assert target_offset > prev_offset - bounds.tolerance + elif bounds.prev_request == "greater_equal": + assert target_offset >= prev_offset - bounds.tolerance + elif bounds.prev_request == "less": + assert target_offset < prev_offset + bounds.tolerance + elif bounds.prev_request == "less_equal": + assert target_offset <= prev_offset + bounds.tolerance + finally: + # Trigger shutdown + instance.shutdown_event.set() + await asyncio.to_thread(process.join, timeout=2.0) + + if process.is_alive(): + process.terminate() + await asyncio.to_thread(process.join, timeout=2.0) + assert process.exitcode <= 0, ( + f"Process exited with error code: {process.exitcode}" + ) diff --git a/tests/unit/scheduler/test_worker_group.py b/tests/unit/scheduler/test_worker_group.py new file mode 100644 index 00000000..b72fb95b --- /dev/null +++ b/tests/unit/scheduler/test_worker_group.py @@ -0,0 +1,473 @@ +from __future__ import annotations + +import asyncio +import inspect +import time +from functools import wraps +from multiprocessing.context import BaseContext +from multiprocessing.managers import BaseManager +from multiprocessing.process import BaseProcess +from multiprocessing.synchronize import Barrier, Event +from typing import Any, Generic, Literal + +import pytest +from pydantic import Field + +from guidellm.scheduler import ( + AsyncConstantStrategy, + BackendInterface, + ConcurrentStrategy, + MaxDurationConstraint, + MaxNumberConstraint, + MeasuredRequestTimings, + ScheduledRequestInfo, + SchedulerMessagingPydanticRegistry, + SchedulerState, + SynchronousStrategy, + ThroughputStrategy, + WorkerProcessGroup, +) +from guidellm.scheduler.worker_group import WorkerGroupState +from guidellm.utils import InterProcessMessaging + + +def async_timeout(delay): + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +class MockRequestTimings(MeasuredRequestTimings): + """Mock timing implementation for testing.""" + + timings_type: Literal["mock"] = Field(default="mock") + + +class MockBackend(BackendInterface): + """Mock backend for testing worker group functionality.""" + + def __init__( + self, + processes_limit_value: int | None = None, + requests_limit_value: int | None = None, + ): + self._processes_limit = processes_limit_value + self._requests_limit = requests_limit_value + + @property + def processes_limit(self) -> int | None: + return self._processes_limit + + @property + def requests_limit(self) -> int | None: + return self._requests_limit + + def info(self) -> dict[str, Any]: + return {"type": "mock"} + + async def process_startup(self): + pass + + async def validate(self): + pass + + async def process_shutdown(self): + pass + + async def resolve(self, request, request_info, request_history): + request_info.request_timings = MockRequestTimings( + request_start=time.time(), request_end=time.time() + ) + yield f"response_for_{request}", request_info + + +class TestWorkerProcessGroup: + """Test suite for WorkerProcessGroup class.""" + + def setup_method(self): + self._original_messaging_registry = ( + SchedulerMessagingPydanticRegistry.registry.copy() + if SchedulerMessagingPydanticRegistry.registry + else {} + ) + self._original_timings_registry = ( + MeasuredRequestTimings.registry.copy() + if MeasuredRequestTimings.registry + else {} + ) + MeasuredRequestTimings.register_decorator(MockRequestTimings, "mock") + SchedulerMessagingPydanticRegistry.register_decorator( + MockRequestTimings, "mock" + ) + + def teardown_method(self): + SchedulerMessagingPydanticRegistry.registry = self._original_messaging_registry + MeasuredRequestTimings.registry = self._original_timings_registry + MeasuredRequestTimings.model_rebuild(force=True) + ScheduledRequestInfo.model_rebuild(force=True) + + @pytest.fixture( + params=[ + { + "requests": None, + "cycle_requests": ["request1", "request2", "request3"], + "strategy": SynchronousStrategy(), + "constraints": {"max_num": MaxNumberConstraint(max_num=10)}, + }, + { + "requests": None, + "cycle_requests": ["req_a", "req_b"], + "strategy": ConcurrentStrategy(streams=2), + "constraints": {"max_num": MaxNumberConstraint(max_num=5)}, + }, + { + "requests": ["req_x", "req_y", "req_z"], + "cycle_requests": None, + "strategy": ThroughputStrategy(max_concurrency=5), + "constraints": {}, + }, + { + "requests": None, + "cycle_requests": ["req_8", "req_9", "req_10"], + "strategy": AsyncConstantStrategy(rate=20), + "constraints": {"max_duration": MaxDurationConstraint(max_duration=1)}, + }, + ], + ids=["sync_max", "concurrent_max", "throughput_no_cycle", "constant_duration"], + ) + def valid_instances(self, request): + """Fixture providing test data for WorkerProcessGroup.""" + constructor_args = request.param.copy() + instance = WorkerProcessGroup(**request.param, backend=MockBackend()) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self, valid_instances): + """Test inheritance and type relationships.""" + instance, _ = valid_instances + + # Class + assert isinstance(instance, Generic) + assert issubclass(WorkerProcessGroup, Generic) + + # Generics + orig_bases = getattr(WorkerProcessGroup, "__orig_bases__", ()) + assert len(orig_bases) > 0 + generic_base = next( + ( + base + for base in orig_bases + if hasattr(base, "__origin__") and base.__origin__ is Generic + ), + None, + ) + assert generic_base is not None + type_args = getattr(generic_base, "__args__", ()) + assert len(type_args) == 2 + + # Function signatures + create_processes_sig = inspect.signature(WorkerProcessGroup.create_processes) + assert len(create_processes_sig.parameters) == 1 + assert "self" in create_processes_sig.parameters + + start_sig = inspect.signature(WorkerProcessGroup.start) + assert len(start_sig.parameters) == 2 + assert "self" in start_sig.parameters + assert "start_time" in start_sig.parameters + + request_updates_sig = inspect.signature(WorkerProcessGroup.request_updates) + assert len(request_updates_sig.parameters) == 1 + assert "self" in request_updates_sig.parameters + + shutdown_sig = inspect.signature(WorkerProcessGroup.shutdown) + assert len(shutdown_sig.parameters) == 1 + assert "self" in shutdown_sig.parameters + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test basic initialization of WorkerProcessGroup.""" + instance, constructor_args = valid_instances + + # Core attributes + assert isinstance(instance.backend, MockBackend) + assert instance.requests is constructor_args["requests"] + assert instance.cycle_requests is constructor_args["cycle_requests"] + assert isinstance(instance.strategy, type(constructor_args["strategy"])) + assert isinstance(instance.constraints, dict) + assert instance.constraints == constructor_args["constraints"] + + # Multiprocessing attributes (should be None initially) + assert instance.mp_context is None + assert instance.mp_manager is None + assert instance.processes is None + + # Synchronization primitives (should be None initially) + assert instance.startup_barrier is None + assert instance.shutdown_event is None + assert instance.error_event is None + assert instance.requests_generated_event is None + assert instance.constraint_reached_event is None + + # Scheduler state and messaging (should be None initially) + assert instance.state is None + assert instance.messaging is None + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("requests", "cycle_requests", "expected_error"), + [ + (None, None, ValueError), + ([], iter([]), ValueError), # cycle_requests as Iterator + (None, iter(["req1"]), ValueError), # cycle_requests as Iterator + ], + ids=["no_requests", "cycle_as_iterator_empty", "cycle_as_iterator_data"], + ) + def test_invalid_initialization_values( + self, requests, cycle_requests, expected_error + ): + """Test WorkerProcessGroup with invalid initialization values.""" + with pytest.raises(expected_error): + WorkerProcessGroup( + requests=requests, + cycle_requests=cycle_requests, + backend=MockBackend(), + strategy=SynchronousStrategy(), + constraints={}, + ) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test WorkerProcessGroup initialization without required fields.""" + with pytest.raises(TypeError): + WorkerProcessGroup() + + @pytest.mark.smoke + @async_timeout(10) + @pytest.mark.asyncio + async def test_lifecycle(self, valid_instances: tuple[WorkerProcessGroup, dict]): # noqa: C901, PLR0912 + """Test the lifecycle methods of WorkerProcessGroup.""" + instance, constructor_args = valid_instances + assert instance.requests or instance.cycle_requests + assert instance.backend + assert instance.strategy + assert instance.constraints is not None + + # Validate create_processes works and sets correct state + await instance.create_processes() + assert instance.mp_context is not None + assert isinstance(instance.mp_context, BaseContext) + assert instance.mp_manager is not None + assert isinstance(instance.mp_manager, BaseManager) + assert instance.processes is not None + assert isinstance(instance.processes, list) + assert len(instance.processes) > 0 + assert all(isinstance(proc, BaseProcess) for proc in instance.processes) + assert all(proc.is_alive() for proc in instance.processes) + assert instance.startup_barrier is not None + assert isinstance(instance.startup_barrier, Barrier) + assert instance.requests_generated_event is not None + assert isinstance(instance.requests_generated_event, Event) + assert instance.constraint_reached_event is not None + assert isinstance(instance.constraint_reached_event, Event) + assert instance.shutdown_event is not None + assert isinstance(instance.shutdown_event, Event) + assert instance.error_event is not None + assert isinstance(instance.error_event, Event) + assert instance.messaging is not None + assert isinstance(instance.messaging, InterProcessMessaging) + assert instance.messaging.worker_index is None + + # Validate start works and sets correct state + start_time = time.time() + 0.1 + await instance.start(start_time=start_time) + assert instance.state is not None + assert isinstance(instance.state, WorkerGroupState) + assert not instance.requests_generated_event.is_set() + assert not instance.constraint_reached_event.is_set() + assert not instance.shutdown_event.is_set() + assert not instance.error_event.is_set() + + # Test iter updates + requests_tracker = {} + + async for ( + response, + request, + request_info, + scheduler_state, + ) in instance.request_updates(): + # Validate returned request + assert request is not None + + # Validate returned request info and response + assert request_info is not None + assert isinstance(request_info, ScheduledRequestInfo) + assert request_info.request_id is not None + assert request_info.status is not None + if request_info.request_id not in requests_tracker: + requests_tracker[request_info.request_id] = { + "received_pending": 0, + "received_in_progress": 0, + "received_resolved": 0, + "received_cancelled": 0, + } + assert request_info.scheduler_node_id > -1 + assert request_info.scheduler_process_id > -1 + assert request_info.scheduler_start_time == start_time + assert request_info.scheduler_timings is not None + if request_info.status == "pending": + requests_tracker[request_info.request_id]["received_pending"] += 1 + assert request_info.scheduler_timings.dequeued is not None + assert request_info.scheduler_timings.targeted_start is not None + assert request_info.scheduler_timings.targeted_start >= start_time + elif request_info.status == "in_progress": + requests_tracker[request_info.request_id]["received_in_progress"] += 1 + assert request_info.scheduler_timings.scheduled_at is not None + assert ( + request_info.scheduler_timings.scheduled_at + >= request_info.scheduler_timings.dequeued + ) + assert request_info.scheduler_timings.resolve_start is not None + assert ( + request_info.scheduler_timings.resolve_start + >= request_info.scheduler_timings.scheduled_at + ) + elif request_info.status == "completed": + requests_tracker[request_info.request_id]["received_resolved"] += 1 + assert response is not None + assert request_info.scheduler_timings.resolve_end is not None + assert ( + request_info.scheduler_timings.resolve_end + > request_info.scheduler_timings.resolve_start + ) + assert request_info.request_timings is not None + assert isinstance(request_info.request_timings, MockRequestTimings) + assert request_info.request_timings.request_start is not None + assert ( + request_info.request_timings.request_start + >= request_info.scheduler_timings.targeted_start + ) + assert request_info.request_timings.request_end is not None + assert ( + request_info.request_timings.request_end + >= request_info.request_timings.request_start + ) + elif request_info.status in ("errored", "cancelled"): + assert response is None + requests_tracker[request_info.request_id]["received_resolved"] += 1 + assert request_info.scheduler_timings.resolve_end is not None + assert ( + request_info.scheduler_timings.resolve_end + > request_info.scheduler_start_time + ) + if request_info.status == "cancelled": + requests_tracker[request_info.request_id]["received_cancelled"] += 1 + + # Validate state structure + assert scheduler_state is not None + assert isinstance(scheduler_state, SchedulerState) + assert scheduler_state.node_id > -1 + assert scheduler_state.start_time == start_time + assert scheduler_state.end_time is not None + if constructor_args.get("constraints"): + assert scheduler_state.remaining_fraction is not None + assert scheduler_state.remaining_fraction >= 0.0 + assert scheduler_state.remaining_fraction <= 1.0 + if constructor_args.get("constraints", {}).get("max_num") is not None: + assert scheduler_state.remaining_requests is not None + assert scheduler_state.remaining_requests >= 0 + assert ( + scheduler_state.remaining_requests + <= constructor_args["constraints"]["max_num"].max_num + ) + if constructor_args.get("constraints", {}).get("max_duration") is not None: + assert scheduler_state.remaining_duration is not None + assert scheduler_state.remaining_duration >= 0.0 + assert ( + scheduler_state.remaining_duration + <= constructor_args["constraints"]["max_duration"].max_duration + ) + assert scheduler_state.created_requests >= 0 + assert scheduler_state.queued_requests >= 0 + assert scheduler_state.pending_requests >= 0 + assert scheduler_state.processing_requests >= 0 + assert scheduler_state.processed_requests >= 0 + assert scheduler_state.successful_requests >= 0 + assert scheduler_state.errored_requests >= 0 + assert scheduler_state.cancelled_requests >= 0 + + # Validate correctness of all updates + for _, counts in requests_tracker.items(): + assert counts["received_cancelled"] in (0, 1) + if counts["received_cancelled"] == 0: + assert counts["received_pending"] == 1 + assert counts["received_in_progress"] >= 1 + assert counts["received_resolved"] == 1 + assert scheduler_state is not None # last yielded state + assert scheduler_state.end_time > scheduler_state.start_time + assert scheduler_state.end_queuing_time is not None + assert scheduler_state.end_queuing_constraints is not None + assert scheduler_state.end_processing_time is not None + assert scheduler_state.end_processing_time >= scheduler_state.start_time + assert scheduler_state.end_processing_constraints is not None + assert scheduler_state.scheduler_constraints is not None + assert scheduler_state.created_requests == len(requests_tracker) + assert scheduler_state.queued_requests == 0 + assert scheduler_state.pending_requests == 0 + assert scheduler_state.processing_requests == 0 + assert scheduler_state.processed_requests == len(requests_tracker) + assert scheduler_state.successful_requests >= 0 + assert scheduler_state.errored_requests >= 0 + assert scheduler_state.cancelled_requests >= 0 + assert ( + scheduler_state.successful_requests + + scheduler_state.errored_requests + + scheduler_state.cancelled_requests + == len(requests_tracker) + ) + if constructor_args.get("constraints"): + assert list(scheduler_state.scheduler_constraints.keys()) == list( + constructor_args["constraints"].keys() + ) + assert scheduler_state.remaining_fraction == 0.0 + if "max_num" in constructor_args["constraints"]: + assert "max_num" in scheduler_state.end_queuing_constraints + assert "max_num" in scheduler_state.end_processing_constraints + max_num = constructor_args["constraints"]["max_num"].max_num + assert scheduler_state.created_requests == max_num + assert scheduler_state.successful_requests == max_num + assert scheduler_state.errored_requests == 0 + assert scheduler_state.cancelled_requests == 0 + if "max_duration" in constructor_args["constraints"]: + assert "max_duration" in scheduler_state.end_queuing_constraints + assert "max_duration" in scheduler_state.end_processing_constraints + assert scheduler_state.remaining_duration == 0.0 + else: + assert "requests_exhausted" in scheduler_state.scheduler_constraints + assert "requests_exhausted" in scheduler_state.end_queuing_constraints + assert "requests_exhausted" in scheduler_state.end_processing_constraints + assert scheduler_state.remaining_fraction is None + assert scheduler_state.remaining_requests is None + assert scheduler_state.remaining_duration is None + + # Test shutdown + exceptions = await instance.shutdown() + + # Check valid shutdown behavior + assert isinstance(exceptions, list) + assert len(exceptions) == 0 + assert instance.messaging is None + assert instance.state is None + assert instance.processes is None + assert instance.startup_barrier is None + assert instance.requests_generated_event is None + assert instance.constraint_reached_event is None + assert instance.shutdown_event is None + assert instance.error_event is None + assert instance.mp_manager is None + assert instance.mp_context is None