diff --git a/example_usage.py b/example_usage.py new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index 0b1014cb..6c46da4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,10 +44,13 @@ keywords = [ ] dependencies = [ "click>=8.0.0,<8.2.0", + "culsans~=0.9.0", "datasets", + "eval_type_backport", "ftfy>=6.0.0", "httpx[http2]<1.0.0", "loguru", + "msgpack", "numpy", "pillow", "protobuf", @@ -139,6 +142,7 @@ ignore_missing_imports=true [tool.ruff] +target-version = "py39" line-length = 88 indent-width = 4 exclude = ["build", "dist", "env", ".venv"] @@ -149,15 +153,16 @@ indent-style = "space" [tool.ruff.lint] ignore = [ - "PLR0913", - "TC001", - "COM812", - "ISC001", - "TC002", + "COM812", # ignore trailing comma errors due to older Python versions + "PD011", # ignore .values usage since ruff assumes it's a Pandas DataFrame + "PLR0913", # ignore too many arguments in function definitions "PLW1514", # allow Path.open without encoding "RET505", # allow `else` blocks "RET506", # allow `else` blocks - "PD011", # ignore .values usage since ruff assumes it's a Pandas DataFrame + "S311", # allow standard pseudo-random generators + "TC001", # ignore imports used only for type checking + "TC002", # ignore imports used only for type checking + "TC003", # ignore imports used only for type checking ] select = [ # Rules reference: https://docs.astral.sh/ruff/rules/ diff --git a/src/guidellm/backend/__init__.py b/src/guidellm/backend/__init__.py index 315a28f0..0f1a412e 100644 --- a/src/guidellm/backend/__init__.py +++ b/src/guidellm/backend/__init__.py @@ -1,23 +1,24 @@ +""" +Backend infrastructure for GuideLLM language model interactions. + +Provides abstract base classes, implemented backends, request/response objects, +and timing utilities for standardized communication with LLM providers. +""" + from .backend import ( Backend, BackendType, ) -from .openai import CHAT_COMPLETIONS_PATH, TEXT_COMPLETIONS_PATH, OpenAIHTTPBackend -from .response import ( - RequestArgs, - ResponseSummary, - StreamingResponseType, - StreamingTextResponse, +from .objects import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, ) __all__ = [ - "CHAT_COMPLETIONS_PATH", - "TEXT_COMPLETIONS_PATH", "Backend", "BackendType", - "OpenAIHTTPBackend", - "RequestArgs", - "ResponseSummary", - "StreamingResponseType", - "StreamingTextResponse", + "GenerationRequest", + "GenerationRequestTimings", + "GenerationResponse", ] diff --git a/src/guidellm/backend/backend.py b/src/guidellm/backend/backend.py index bf2788a7..feace1a2 100644 --- a/src/guidellm/backend/backend.py +++ b/src/guidellm/backend/backend.py @@ -1,13 +1,25 @@ -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator -from pathlib import Path -from typing import Any, Literal, Optional, Union +""" +Backend interface and registry for generative AI model interactions. -from loguru import logger -from PIL import Image +Provides the abstract base class for implementing backends that communicate with +generative AI models. Backends handle the lifecycle of generation requests. -from guidellm.backend.response import ResponseSummary, StreamingTextResponse -from guidellm.config import settings +Classes: + Backend: Abstract base class for generative AI backends with registry support. + +Type Aliases: + BackendType: Literal type defining supported backend implementations. +""" + +from typing import Literal, Optional + +from guidellm.backend.objects import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, +) +from guidellm.scheduler import BackendInterface +from guidellm.utils.registry import RegistryMixin __all__ = [ "Backend", @@ -18,242 +30,75 @@ BackendType = Literal["openai_http"] -class Backend(ABC): +class Backend( + RegistryMixin["type[Backend]"], + BackendInterface[GenerationRequest, GenerationRequestTimings, GenerationResponse], +): """ - Abstract base class for generative AI backends. - - This class provides a common interface for creating and interacting with different - generative AI backends. Subclasses should implement the abstract methods to - define specific backend behavior. - - :cvar _registry: A registration dictionary that maps BackendType to backend classes. - :param type_: The type of the backend. + Base class for generative AI backends with registry and lifecycle. + + Provides a standard interface for backends that communicate with generative AI + models. Combines the registry pattern for automatic discovery with a defined + lifecycle for process-based distributed execution. + + Backend lifecycle phases: + 1. Creation and configuration + 2. Process startup - Initialize resources in worker process + 3. Validation - Verify backend readiness + 4. Request resolution - Process generation requests + 5. Process shutdown - Clean up resources + + Backend state (excluding process_startup resources) must be pickleable for + distributed execution across process boundaries. + + Example: + :: + @Backend.register("my_backend") + class MyBackend(Backend): + def __init__(self, api_key: str): + super().__init__("my_backend") + self.api_key = api_key + + async def process_startup(self): + self.client = MyAPIClient(self.api_key) + + backend = Backend.create("my_backend", api_key="secret") """ - _registry: dict[BackendType, "type[Backend]"] = {} - - @classmethod - def register(cls, backend_type: BackendType): - """ - A decorator to register a backend class in the backend registry. - - :param backend_type: The type of backend to register. - :type backend_type: BackendType - :return: The decorated backend class. - :rtype: Type[Backend] - """ - if backend_type in cls._registry: - raise ValueError(f"Backend type already registered: {backend_type}") - - if not issubclass(cls, Backend): - raise TypeError("Only subclasses of Backend can be registered") - - def inner_wrapper(wrapped_class: type["Backend"]): - cls._registry[backend_type] = wrapped_class - logger.info("Registered backend type: {}", backend_type) - return wrapped_class - - return inner_wrapper - @classmethod def create(cls, type_: BackendType, **kwargs) -> "Backend": """ - Factory method to create a backend instance based on the backend type. + Create a backend instance based on the backend type. :param type_: The type of backend to create. - :type type_: BackendType :param kwargs: Additional arguments for backend initialization. :return: An instance of a subclass of Backend. - :rtype: Backend :raises ValueError: If the backend type is not registered. """ - logger.info("Creating backend of type {}", type_) - - if type_ not in cls._registry: - err = ValueError(f"Unsupported backend type: {type_}") - logger.error("{}", err) - raise err + backend = cls.get_registered_object(type_) - return Backend._registry[type_](**kwargs) + return backend(**kwargs) def __init__(self, type_: BackendType): - self._type = type_ - - @property - def type_(self) -> BackendType: - """ - :return: The type of the backend. """ - return self._type + Initialize a backend instance. - @property - @abstractmethod - def target(self) -> str: - """ - :return: The target location for the backend. + :param type_: The backend type identifier. """ - ... + self.type_ = type_ @property - @abstractmethod - def model(self) -> Optional[str]: + def processes_limit(self) -> Optional[int]: """ - :return: The model used for the backend requests. + :return: Maximum number of worker processes supported. None if unlimited. """ - ... + return None @property - @abstractmethod - def info(self) -> dict[str, Any]: - """ - :return: The information about the backend. - """ - ... - - @abstractmethod - async def reset(self) -> None: - """ - Reset the connection object. This is useful for backends that - reuse connections or have state that needs to be cleared. + def requests_limit(self) -> Optional[int]: """ - ... - - async def validate(self): - """ - Handle final setup and validate the backend is ready for use. - If not successful, raises the appropriate exception. - """ - logger.info("{} validating backend {}", self.__class__.__name__, self.type_) - await self.check_setup() - models = await self.available_models() - if not models: - raise ValueError("No models available for the backend") - - # Use the preferred route defined in the global settings when performing the - # validation request. This avoids calling an unavailable endpoint (ie - # /v1/completions) when the deployment only supports the chat completions - # endpoint. - if settings.preferred_route == "chat_completions": - async for _ in self.chat_completions( # type: ignore[attr-defined] - content="Test connection", output_token_count=1 - ): - pass - else: - async for _ in self.text_completions( # type: ignore[attr-defined] - prompt="Test connection", output_token_count=1 - ): - pass - - await self.reset() - - @abstractmethod - async def check_setup(self): - """ - Check the setup for the backend. - If unsuccessful, raises the appropriate exception. - - :raises ValueError: If the setup check fails. - """ - ... - - @abstractmethod - async def prepare_multiprocessing(self): - """ - Prepare the backend for use in a multiprocessing environment. - This is useful for backends that have instance state that can not - be shared across processes and should be cleared out and re-initialized - for each new process. - """ - ... - - @abstractmethod - async def available_models(self) -> list[str]: - """ - Get the list of available models for the backend. - - :return: The list of available models. - :rtype: List[str] - """ - ... - - @abstractmethod - async def text_completions( - self, - prompt: Union[str, list[str]], - request_id: Optional[str] = None, - prompt_token_count: Optional[int] = None, - output_token_count: Optional[int] = None, - **kwargs, - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - """ - Generate text only completions for the given prompt. - Does not support multiple modalities, complicated chat interfaces, - or chat templates. Specifically, it requests with only the prompt. - - :param prompt: The prompt (or list of prompts) to generate a completion for. - If a list is supplied, these are concatenated and run through the model - for a single prompt. - :param request_id: The unique identifier for the request, if any. - Added to logging statements and the response for tracking purposes. - :param prompt_token_count: The number of tokens measured in the prompt, if any. - Returned in the response stats for later analysis, if applicable. - :param output_token_count: If supplied, the number of tokens to enforce - generation of for the output for this request. - :param kwargs: Additional keyword arguments to pass with the request. - :return: An async generator that yields a StreamingTextResponse for start, - a StreamingTextResponse for each received iteration, - and a ResponseSummary for the final response. - """ - ... - - @abstractmethod - async def chat_completions( - self, - content: Union[ - str, - list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], - Any, - ], - request_id: Optional[str] = None, - prompt_token_count: Optional[int] = None, - output_token_count: Optional[int] = None, - raw_content: bool = False, - **kwargs, - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - """ - Generate chat completions for the given content. - Supports multiple modalities, complicated chat interfaces, and chat templates. - Specifically, it requests with the content, which can be any combination of - text, images, and audio provided the target model supports it, - and returns the output text. Additionally, any chat templates - for the model are applied within the backend. - - :param content: The content (or list of content) to generate a completion for. - This supports any combination of text, images, and audio (model dependent). - Supported text only request examples: - content="Sample prompt", content=["Sample prompt", "Second prompt"], - content=[{"type": "text", "value": "Sample prompt"}. - Supported text and image request examples: - content=["Describe the image", PIL.Image.open("image.jpg")], - content=["Describe the image", Path("image.jpg")], - content=["Describe the image", {"type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}]. - Supported text and audio request examples: - content=["Transcribe the audio", Path("audio.wav")], - content=["Transcribe the audio", {"type": "input_audio", - "input_audio": {"data": f"{base64_bytes}", "format": "wav}]. - Additionally, if raw_content=True then the content is passed directly to the - backend without any processing. - :param request_id: The unique identifier for the request, if any. - Added to logging statements and the response for tracking purposes. - :param prompt_token_count: The number of tokens measured in the prompt, if any. - Returned in the response stats for later analysis, if applicable. - :param output_token_count: If supplied, the number of tokens to enforce - generation of for the output for this request. - :param kwargs: Additional keyword arguments to pass with the request. - :return: An async generator that yields a StreamingTextResponse for start, - a StreamingTextResponse for each received iteration, - and a ResponseSummary for the final response. + :return: Maximum number of concurrent requests supported globally. + None if unlimited. """ - ... + return None diff --git a/src/guidellm/backend/interface.py b/src/guidellm/backend/interface.py new file mode 100644 index 00000000..4f38c76a --- /dev/null +++ b/src/guidellm/backend/interface.py @@ -0,0 +1,97 @@ +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from typing import ( + Any, + Generic, + Optional, + TypeVar, +) + +from guidellm.scheduler import ( + MeasuredRequestTimingsT, + RequestT, + ResponseT, + ScheduledRequestInfo, +) + + +class BackendInterface(ABC, Generic[RequestT, MeasuredRequestTimingsT, ResponseT]): + """ + Abstract interface for request processing backends. Note: before process_startup + is invoked, the implementation must ensure all properties are pickleable. + """ + + @property + @abstractmethod + def processes_limit(self) -> Optional[int]: + """Maximum worker processes supported, or None if unlimited.""" + ... + + @property + @abstractmethod + def requests_limit(self) -> Optional[int]: + """Maximum concurrent requests supported, or None if unlimited.""" + ... + + @abstractmethod + def info(self) -> dict[str, Any]: + """ + :return: Backend metadata including model any initializaiton and + configuration information. + """ + ... + + @abstractmethod + async def process_startup(self) -> None: + """ + Perform backend initialization and startup procedures. + + :raises: Implementation-specific exceptions for startup failures. + """ + ... + + @abstractmethod + async def validate(self) -> None: + """ + Validate backend configuration and operational status. + + :raises: Implementation-specific exceptions for validation failures. + """ + ... + + @abstractmethod + async def process_shutdown(self) -> None: + """ + Perform backend cleanup and shutdown procedures. + + :raises: Implementation-specific exceptions for shutdown failures. + """ + ... + + @abstractmethod + def resolve( + self, + request: RequestT, + request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + history: Optional[list[tuple[RequestT, ResponseT]]] = None, + ) -> AsyncIterator[tuple[ResponseT, ScheduledRequestInfo[MeasuredRequestTimingsT]]]: + """ + Process a request and yield incremental response updates. + + :param request: The request object to process. + :param request_info: Scheduling metadata and timing information. + :param history: Optional conversation history for multi-turn requests. + :yield: Tuples of (response, updated_request_info) for each response chunk. + :raises: Implementation-specific exceptions for processing failures. + """ + ... + + @abstractmethod + async def default_model(self) -> Optional[str]: + """ + :return: The default model name or identifier for generation requests. + """ + ... + + +BackendT = TypeVar("BackendT", bound="BackendInterface") diff --git a/src/guidellm/backend/objects.py b/src/guidellm/backend/objects.py new file mode 100644 index 00000000..10a76f1e --- /dev/null +++ b/src/guidellm/backend/objects.py @@ -0,0 +1,148 @@ +""" +Backend object models for request and response handling. + +Provides standardized models for generation requests, responses, and timing +information to ensure consistent data handling across different backend +implementations. +""" + +import uuid +from typing import Any, Literal, Optional + +from pydantic import Field + +from guidellm.objects.pydantic import StandardBaseModel +from guidellm.scheduler import MeasuredRequestTimings + +__all__ = [ + "GenerationRequest", + "GenerationRequestTimings", + "GenerationResponse", +] + + +class GenerationRequest(StandardBaseModel): + """Request model for backend generation operations.""" + + request_id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique identifier for the request.", + ) + request_type: Literal["text_completions", "chat_completions"] = Field( + default="text_completions", + description=( + "Type of request. 'text_completions' uses backend.text_completions(), " + "'chat_completions' uses backend.chat_completions()." + ), + ) + content: Any = Field( + description=( + "Request content. For text_completions: string or list of strings. " + "For chat_completions: string, list of messages, or raw content " + "(set raw_content=True in params)." + ) + ) + params: dict[str, Any] = Field( + default_factory=dict, + description=( + "Additional parameters passed to backend methods. " + "Common: max_tokens, temperature, stream." + ), + ) + stats: dict[Literal["prompt_tokens"], int] = Field( + default_factory=dict, + description="Request statistics including prompt token count.", + ) + constraints: dict[Literal["output_tokens"], int] = Field( + default_factory=dict, + description="Request constraints such as maximum output tokens.", + ) + + +class GenerationResponse(StandardBaseModel): + """Response model for backend generation operations.""" + + request_id: str = Field( + description="Unique identifier matching the original GenerationRequest." + ) + request_args: dict[str, Any] = Field( + description="Arguments passed to the backend for this request." + ) + value: Optional[str] = Field( + default=None, + description="Complete generated text content. None for streaming responses.", + ) + delta: Optional[str] = Field( + default=None, description="Incremental text content for streaming responses." + ) + iterations: int = Field( + default=0, description="Number of generation iterations completed." + ) + request_prompt_tokens: Optional[int] = Field( + default=None, description="Token count from the original request prompt." + ) + request_output_tokens: Optional[int] = Field( + default=None, + description="Expected output token count from the original request.", + ) + response_prompt_tokens: Optional[int] = Field( + default=None, description="Actual prompt token count reported by the backend." + ) + response_output_tokens: Optional[int] = Field( + default=None, description="Actual output token count reported by the backend." + ) + + @property + def prompt_tokens(self) -> Optional[int]: + """ + :return: The number of prompt tokens used in the request + (response_prompt_tokens if available, otherwise request_prompt_tokens). + """ + return self.response_prompt_tokens or self.request_prompt_tokens + + @property + def output_tokens(self) -> Optional[int]: + """ + :return: The number of output tokens generated in the response + (response_output_tokens if available, otherwise request_output_tokens). + """ + return self.response_output_tokens or self.request_output_tokens + + @property + def total_tokens(self) -> Optional[int]: + """ + :return: The total number of tokens used in the request and response. + Sum of prompt_tokens and output_tokens. + """ + if self.prompt_tokens is None or self.output_tokens is None: + return None + return self.prompt_tokens + self.output_tokens + + def preferred_prompt_tokens( + self, preferred_source: Literal["request", "response"] + ) -> Optional[int]: + if preferred_source == "request": + return self.request_prompt_tokens or self.response_prompt_tokens + else: + return self.response_prompt_tokens or self.request_prompt_tokens + + def preferred_output_tokens( + self, preferred_source: Literal["request", "response"] + ) -> Optional[int]: + if preferred_source == "request": + return self.request_output_tokens or self.response_output_tokens + else: + return self.response_output_tokens or self.request_output_tokens + + +class GenerationRequestTimings(MeasuredRequestTimings): + """Timing model for tracking generation request lifecycle events.""" + + first_iteration: Optional[float] = Field( + default=None, + description="Unix timestamp when the first generation iteration began.", + ) + last_iteration: Optional[float] = Field( + default=None, + description="Unix timestamp when the last generation iteration completed.", + ) diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index e62e9003..cc251153 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -1,705 +1,642 @@ +""" +OpenAI HTTP backend implementation for GuideLLM. + +Provides HTTP-based backend for OpenAI-compatible servers including OpenAI API, +vLLM servers, and other compatible inference engines. Supports text and chat +completions with streaming, authentication, and multimodal capabilities. + +Classes: + UsageStats: Token usage statistics for generation requests. + OpenAIHTTPBackend: HTTP backend for OpenAI-compatible API servers. +""" + import base64 +import contextlib import copy import json import time -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, ClassVar, Optional, Union import httpx -from loguru import logger from PIL import Image +from pydantic import dataclasses from guidellm.backend.backend import Backend -from guidellm.backend.response import ( - RequestArgs, - ResponseSummary, - StreamingTextResponse, +from guidellm.backend.objects import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, ) -from guidellm.config import settings +from guidellm.scheduler import ScheduledRequestInfo -__all__ = [ - "CHAT_COMPLETIONS", - "CHAT_COMPLETIONS_PATH", - "MODELS", - "TEXT_COMPLETIONS", - "TEXT_COMPLETIONS_PATH", - "OpenAIHTTPBackend", -] +__all__ = ["OpenAIHTTPBackend", "UsageStats"] -TEXT_COMPLETIONS_PATH = "/v1/completions" -CHAT_COMPLETIONS_PATH = "/v1/chat/completions" +@dataclasses.dataclass +class UsageStats: + """Token usage statistics for generation requests.""" -EndpointType = Literal["chat_completions", "models", "text_completions"] -CHAT_COMPLETIONS: EndpointType = "chat_completions" -MODELS: EndpointType = "models" -TEXT_COMPLETIONS: EndpointType = "text_completions" + prompt_tokens: Optional[int] = None + output_tokens: Optional[int] = None @Backend.register("openai_http") class OpenAIHTTPBackend(Backend): """ - A HTTP-based backend implementation for requests to an OpenAI compatible server. - For example, a vLLM server instance or requests to OpenAI's API. - - :param target: The target URL string for the OpenAI server. ex: http://0.0.0.0:8000 - :param model: The model to use for all requests on the target server. - If none is provided, the first available model will be used. - :param api_key: The API key to use for requests to the OpenAI server. - If provided, adds an Authorization header with the value - "Authorization: Bearer {api_key}". - If not provided, no Authorization header is added. - :param organization: The organization to use for requests to the OpenAI server. - For example, if set to "org_123", adds an OpenAI-Organization header with the - value "OpenAI-Organization: org_123". - If not provided, no OpenAI-Organization header is added. - :param project: The project to use for requests to the OpenAI server. - For example, if set to "project_123", adds an OpenAI-Project header with the - value "OpenAI-Project: project_123". - If not provided, no OpenAI-Project header is added. - :param timeout: The timeout to use for requests to the OpenAI server. - If not provided, the default timeout provided from settings is used. - :param http2: If True, uses HTTP/2 for requests to the OpenAI server. - Defaults to True. - :param follow_redirects: If True, the HTTP client will follow redirect responses. - If not provided, the default value from settings is used. - :param max_output_tokens: The maximum number of tokens to request for completions. - If not provided, the default maximum tokens provided from settings is used. - :param extra_query: Query parameters to include in requests to the OpenAI server. - If "chat_completions", "models", or "text_completions" are included as keys, - the values of these keys will be used as the parameters for the respective - endpoint. - If not provided, no extra query parameters are added. - :param extra_body: Body parameters to include in requests to the OpenAI server. - If "chat_completions", "models", or "text_completions" are included as keys, - the values of these keys will be included in the body for the respective - endpoint. - If not provided, no extra body parameters are added. - :param remove_from_body: Parameters that should be removed from the body of each - request. - If not provided, no parameters are removed from the body. + HTTP backend for OpenAI-compatible servers. + + Supports OpenAI API, vLLM servers, and other compatible endpoints with + text/chat completions, streaming, authentication, and multimodal inputs. + Handles request formatting, response parsing, error handling, and token + usage tracking with flexible parameter customization. + + Example: + :: + backend = OpenAIHTTPBackend( + target="http://localhost:8000", + model="gpt-3.5-turbo", + api_key="your-api-key" + ) + + await backend.process_startup() + async for response, request_info in backend.resolve(request, info): + process_response(response) + await backend.process_shutdown() """ + HEALTH_PATH: ClassVar[str] = "/health" + MODELS_PATH: ClassVar[str] = "/v1/models" + TEXT_COMPLETIONS_PATH: ClassVar[str] = "/v1/completions" + CHAT_COMPLETIONS_PATH: ClassVar[str] = "/v1/chat/completions" + + MODELS_KEY: ClassVar[str] = "models" + TEXT_COMPLETIONS_KEY: ClassVar[str] = "text_completions" + CHAT_COMPLETIONS_KEY: ClassVar[str] = "chat_completions" + def __init__( self, - target: Optional[str] = None, + target: str, model: Optional[str] = None, api_key: Optional[str] = None, organization: Optional[str] = None, project: Optional[str] = None, - timeout: Optional[float] = None, - http2: Optional[bool] = True, - follow_redirects: Optional[bool] = None, + timeout: float = 60.0, + http2: bool = True, + follow_redirects: bool = True, max_output_tokens: Optional[int] = None, + stream_response: bool = True, extra_query: Optional[dict] = None, extra_body: Optional[dict] = None, remove_from_body: Optional[list[str]] = None, headers: Optional[dict] = None, - verify: Optional[bool] = None, + verify: bool = False, ): - super().__init__(type_="openai_http") - self._target = target or settings.openai.base_url - - if not self._target: - raise ValueError("Target URL must be provided for OpenAI HTTP backend.") - - if self._target.endswith("/v1") or self._target.endswith("/v1/"): - # backwards compatability, strip v1 off - self._target = self._target[:-3] - - if self._target.endswith("/"): - self._target = self._target[:-1] - - self._model = model - - # Start with default headers based on other params - default_headers: dict[str, str] = {} - api_key = api_key or settings.openai.api_key - bearer_token = settings.openai.bearer_token - if api_key: - default_headers["Authorization"] = f"Bearer {api_key}" - elif bearer_token: - default_headers["Authorization"] = bearer_token - - self.organization = organization or settings.openai.organization - if self.organization: - default_headers["OpenAI-Organization"] = self.organization - - self.project = project or settings.openai.project - if self.project: - default_headers["OpenAI-Project"] = self.project - - # User-provided headers from kwargs or settings override defaults - merged_headers = default_headers.copy() - merged_headers.update(settings.openai.headers or {}) - if headers: - merged_headers.update(headers) - - # Remove headers with None values for backward compatibility and convenience - self.headers = {k: v for k, v in merged_headers.items() if v is not None} - - self.timeout = timeout if timeout is not None else settings.request_timeout - self.http2 = http2 if http2 is not None else settings.request_http2 - self.follow_redirects = ( - follow_redirects - if follow_redirects is not None - else settings.request_follow_redirects - ) - self.verify = verify if verify is not None else settings.openai.verify - self.max_output_tokens = ( - max_output_tokens - if max_output_tokens is not None - else settings.openai.max_output_tokens - ) - self.extra_query = extra_query - self.extra_body = extra_body - self.remove_from_body = remove_from_body - self._async_client: Optional[httpx.AsyncClient] = None - - @property - def target(self) -> str: """ - :return: The target URL string for the OpenAI server. + Initialize OpenAI HTTP backend. + + :param target: Target URL for the OpenAI server (e.g., "http://localhost:8000"). + :param model: Model to use for requests. If None, uses first available model. + :param api_key: API key for authentication. Adds Authorization header + if provided. + :param organization: Organization ID. Adds OpenAI-Organization header + if provided. + :param project: Project ID. Adds OpenAI-Project header if provided. + :param timeout: Request timeout in seconds. Defaults to 60 seconds. + :param http2: Whether to use HTTP/2. Defaults to True. + :param follow_redirects: Whether to follow redirects. Default True. + :param max_output_tokens: Maximum tokens for completions. If None, none is set. + :param stream_response: Whether to stream responses by default. Can be + overridden per request. Defaults to True. + :param extra_query: Additional query parameters. Both general and + endpoint-specific with type keys supported. + :param extra_body: Additional body parameters. Both general and + endpoint-specific with type keys supported. + :param remove_from_body: Parameter names to remove from request bodies. + :param headers: Additional HTTP headers. + :param verify: Whether to verify SSL certificates. Default False. """ - return self._target + super().__init__(type_="openai_http") - @property - def model(self) -> Optional[str]: - """ - :return: The model to use for all requests on the target server. - If validate hasn't been called yet and no model was passed in, - this will be None until validate is called to set the default. - """ - return self._model + # Request Values + self.target = target.rstrip("/").removesuffix("/v1") + self.model = model + self.headers = self._build_headers(api_key, organization, project, headers) + + # Store configuration + self.timeout = timeout + self.http2 = http2 + self.follow_redirects = follow_redirects + self.verify = verify + self.max_output_tokens = max_output_tokens + self.stream_response = stream_response + self.extra_query = extra_query or {} + self.extra_body = extra_body or {} + self.remove_from_body = remove_from_body or [] + + # Runtime state + self._in_process = False + self._async_client: Optional[httpx.AsyncClient] = None - @property def info(self) -> dict[str, Any]: """ - :return: The information about the backend. + :return: Dictionary containing backend configuration details. """ return { - "max_output_tokens": self.max_output_tokens, + "target": self.target, + "model": self.model, + "headers": self.headers, "timeout": self.timeout, "http2": self.http2, "follow_redirects": self.follow_redirects, - "headers": self.headers, - "text_completions_path": TEXT_COMPLETIONS_PATH, - "chat_completions_path": CHAT_COMPLETIONS_PATH, + "verify": self.verify, + "max_output_tokens": self.max_output_tokens, + "stream_response": self.stream_response, + "extra_query": self.extra_query, + "extra_body": self.extra_body, + "remove_from_body": self.remove_from_body, + "health_path": self.HEALTH_PATH, + "models_path": self.MODELS_PATH, + "text_completions_path": self.TEXT_COMPLETIONS_PATH, + "chat_completions_path": self.CHAT_COMPLETIONS_PATH, } - async def reset(self) -> None: + async def process_startup(self): """ - Reset the connection object. This is useful for backends that - reuse connections or have state that needs to be cleared. - For this backend, it closes the async client if it exists. + Initialize HTTP client and backend resources. + + :raises RuntimeError: If backend is already initialized. + :raises httpx.Exception: If HTTP client cannot be created. """ - if self._async_client is not None: - await self._async_client.aclose() + if self._in_process: + raise RuntimeError("Backend already started up for process.") + + self._async_client = httpx.AsyncClient( + http2=self.http2, + timeout=self.timeout, + follow_redirects=self.follow_redirects, + verify=self.verify, + ) + self._in_process = True - async def check_setup(self): + async def process_shutdown(self): """ - Check if the backend is setup correctly and can be used for requests. - Specifically, if a model is not provided, it grabs the first available model. - If no models are available, raises a ValueError. - If a model is provided and not available, raises a ValueError. + Clean up HTTP client and backend resources. - :raises ValueError: If no models or the provided model is not available. + :raises RuntimeError: If backend was not properly initialized. + :raises httpx.Exception: If HTTP client cannot be closed. """ - models = await self.available_models() - if not models: - raise ValueError(f"No models available for target: {self.target}") - - if not self.model: - self._model = models[0] - elif self.model not in models: - raise ValueError( - f"Model {self.model} not found in available models:" - f"{models} for target: {self.target}" - ) + if not self._in_process: + raise RuntimeError("Backend not started up for process.") - async def prepare_multiprocessing(self): + await self._async_client.aclose() # type: ignore [union-attr] + self._async_client = None + self._in_process = False + + async def validate(self): """ - Prepare the backend for use in a multiprocessing environment. - Clears out the sync and async clients to ensure they are re-initialized - for each process. + Validate backend configuration and connectivity. + + Validate backend configuration and connectivity through test requests, + and auto-selects first available model if none is configured. + + :raises RuntimeError: If backend cannot connect or validate configuration. """ - if self._async_client is not None: - await self._async_client.aclose() - self._async_client = None + self._check_in_process() + + if self.model: + with contextlib.suppress(httpx.TimeoutException, httpx.HTTPStatusError): + # Model is set, use /health endpoint as first check + target = f"{self.target}{self.HEALTH_PATH}" + headers = self._get_headers() + response = await self._async_client.get(target, headers=headers) # type: ignore [union-attr] + response.raise_for_status() + + return + + with contextlib.suppress(httpx.TimeoutException, httpx.HTTPStatusError): + # Check if models endpoint is available next + models = await self.available_models() + if models and not self.model: + self.model = models[0] + elif not self.model: + raise RuntimeError( + "No model available and could not set a default model " + "from the server's available models." + ) + + return + + with contextlib.suppress(httpx.TimeoutException, httpx.HTTPStatusError): + # Last check, fall back on dummy request to text completions + async for _, __ in self.text_completions( + prompt="Validate backend", + request_id="validate", + output_token_count=1, + stream_response=False, + ): + pass + + return + + raise RuntimeError( + "Backend validation failed. Could not connect to the server or " + "validate the backend configuration." + ) async def available_models(self) -> list[str]: """ - Get the available models for the target server using the OpenAI models endpoint: - /v1/models + Get available models from the target server. + + :return: List of model identifiers. + :raises HTTPError: If models endpoint returns an error. + :raises RuntimeError: If backend is not initialized. """ - target = f"{self.target}/v1/models" - headers = self._headers() - params = self._params(MODELS) - response = await self._get_async_client().get( - target, headers=headers, params=params - ) + self._check_in_process() + + target = f"{self.target}{self.MODELS_PATH}" + headers = self._get_headers() + params = self._get_params(self.MODELS_KEY) + response = await self._async_client.get(target, headers=headers, params=params) # type: ignore [union-attr] response.raise_for_status() - models = [] + return [item["id"] for item in response.json()["data"]] + + async def default_model(self) -> Optional[str]: + """ + Get the default model for this backend. + + :return: Model name or None if no model is available. + """ + if self.model or not self._in_process: + return self.model + + models = await self.available_models() + return models[0] if models else None + + async def resolve( + self, + request: GenerationRequest, + request_info: ScheduledRequestInfo[GenerationRequestTimings], + history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, + ) -> AsyncIterator[ + tuple[GenerationResponse, ScheduledRequestInfo[GenerationRequestTimings]] + ]: + """ + Process a generation request and yield progressive responses. + + Handles request formatting, timing tracking, API communication, and + response parsing with streaming support. + + :param request: Generation request with content and parameters. + :param request_info: Request tracking info updated with timing metadata. + :param history: Conversation history. Currently not supported. + :raises NotImplementedError: If history is provided. + :yields: Tuples of (response, updated_request_info) as generation progresses. + """ + self._check_in_process() + if history is not None: + raise NotImplementedError( + "Multi-turn requests with conversation history are not yet supported" + ) + + response = GenerationResponse( + request_id=request.request_id, + request_args={ + "request_type": request.request_type, + "output_token_count": request.constraints.get("output_tokens"), + **request.params, + }, + value="", + request_prompt_tokens=request.stats.get("prompt_tokens"), + request_output_tokens=request.constraints.get("output_tokens"), + ) + request_info.request_timings = GenerationRequestTimings() + request_info.request_timings.request_start = time.time() + + completion_method = ( + self.text_completions + if request.request_type == "text_completions" + else self.chat_completions + ) + completion_kwargs = ( + { + "prompt": request.content, + "request_id": request.request_id, + "output_token_count": request.constraints.get("output_tokens"), + "stream_response": request.params.get("stream", self.stream_response), + **request.params, + } + if request.request_type == "text_completions" + else { + "content": request.content, + "request_id": request.request_id, + "output_token_count": request.constraints.get("output_tokens"), + "stream_response": request.params.get("stream", self.stream_response), + **request.params, + } + ) + + async for delta, usage_stats in completion_method(**completion_kwargs): + if request_info.request_timings.request_start is None: + request_info.request_timings.request_start = time.time() + + if delta is not None: + if request_info.request_timings.first_iteration is None: + request_info.request_timings.first_iteration = time.time() + response.value += delta # type: ignore [operator] + response.delta = delta + request_info.request_timings.last_iteration = time.time() + response.iterations += 1 + + if usage_stats is not None: + request_info.request_timings.request_end = time.time() + response.request_output_tokens = usage_stats.output_tokens + response.request_prompt_tokens = usage_stats.prompt_tokens - for item in response.json()["data"]: - models.append(item["id"]) + yield response, request_info - return models + if request_info.request_timings.request_end is None: + request_info.request_timings.request_end = time.time() + response.delta = None + yield response, request_info - async def text_completions( # type: ignore[override] + async def text_completions( self, prompt: Union[str, list[str]], - request_id: Optional[str] = None, - prompt_token_count: Optional[int] = None, + request_id: Optional[str], # noqa: ARG002 output_token_count: Optional[int] = None, + stream_response: bool = True, **kwargs, - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: + ) -> AsyncIterator[tuple[Optional[str], Optional[UsageStats]]]: """ - Generate text completions for the given prompt using the OpenAI - completions endpoint: /v1/completions. - - :param prompt: The prompt (or list of prompts) to generate a completion for. - If a list is supplied, these are concatenated and run through the model - for a single prompt. - :param request_id: The unique identifier for the request, if any. - Added to logging statements and the response for tracking purposes. - :param prompt_token_count: The number of tokens measured in the prompt, if any. - Returned in the response stats for later analysis, if applicable. - :param output_token_count: If supplied, the number of tokens to enforce - generation of for the output for this request. - :param kwargs: Additional keyword arguments to pass with the request. - :return: An async generator that yields a StreamingTextResponse for start, - a StreamingTextResponse for each received iteration, - and a ResponseSummary for the final response. + Generate text completions using the /v1/completions endpoint. + + :param prompt: Text prompt(s) for completion. Single string or list. + :param request_id: Request identifier for tracking. + :param output_token_count: Maximum tokens to generate. Overrides default + if specified. + :param stream_response: Whether to stream response progressively. + :param kwargs: Additional request parameters (temperature, top_p, etc.). + :yields: Tuples of (generated_text, usage_stats). First yield is (None, None). + :raises RuntimeError: If backend is not initialized. + :raises HTTPError: If API request fails. """ - logger.debug("{} invocation with args: {}", self.__class__.__name__, locals()) - - if isinstance(prompt, list): - raise ValueError( - "List prompts (batching) is currently not supported for " - f"text_completions OpenAI pathways. Received: {prompt}" - ) - - headers = self._headers() - params = self._params(TEXT_COMPLETIONS) - payload = self._completions_payload( - endpoint_type=TEXT_COMPLETIONS, - orig_kwargs=kwargs, + self._check_in_process() + target = f"{self.target}{self.TEXT_COMPLETIONS_PATH}" + headers = self._get_headers() + params = self._get_params(self.TEXT_COMPLETIONS_KEY) + body = self._get_body( + endpoint_type=self.TEXT_COMPLETIONS_KEY, + request_kwargs=kwargs, max_output_tokens=output_token_count, prompt=prompt, ) + yield None, None # Initial yield for async iterator to signal start - try: - async for resp in self._iterative_completions_request( - type_="text_completions", - request_id=request_id, - request_prompt_tokens=prompt_token_count, - request_output_tokens=output_token_count, + if not stream_response: + response = await self._async_client.post( # type: ignore [union-attr] + target, headers=headers, params=params, - payload=payload, - ): - yield resp - except Exception as ex: - logger.error( - "{} request with headers: {} and params: {} and payload: {} failed: {}", - self.__class__.__name__, - headers, - params, - payload, - ex, + json=body, + ) + response.raise_for_status() + data = response.json() + yield ( + self._get_completions_text_content(data), + self._get_completions_usage_stats(data), ) - raise ex + return + + body.update({"stream": True, "stream_options": {"include_usage": True}}) + async with self._async_client.stream( # type: ignore [union-attr] + "POST", + target, + headers=headers, + params=params, + json=body, + ) as stream: + stream.raise_for_status() + async for line in stream.aiter_lines(): + if not line or not line.strip().startswith("data:"): + continue + if line.strip() == "data: [DONE]": + break + data = json.loads(line.strip()[len("data: ") :]) + yield ( + self._get_completions_text_content(data), + self._get_completions_usage_stats(data), + ) - async def chat_completions( # type: ignore[override] + async def chat_completions( self, content: Union[ str, list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], Any, ], - request_id: Optional[str] = None, - prompt_token_count: Optional[int] = None, + request_id: Optional[str] = None, # noqa: ARG002 output_token_count: Optional[int] = None, raw_content: bool = False, + stream_response: bool = True, **kwargs, - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: + ) -> AsyncIterator[tuple[Optional[str], Optional[UsageStats]]]: """ - Generate chat completions for the given content using the OpenAI - chat completions endpoint: /v1/chat/completions. - - :param content: The content (or list of content) to generate a completion for. - This supports any combination of text, images, and audio (model dependent). - Supported text only request examples: - content="Sample prompt", content=["Sample prompt", "Second prompt"], - content=[{"type": "text", "value": "Sample prompt"}. - Supported text and image request examples: - content=["Describe the image", PIL.Image.open("image.jpg")], - content=["Describe the image", Path("image.jpg")], - content=["Describe the image", {"type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}]. - Supported text and audio request examples: - content=["Transcribe the audio", Path("audio.wav")], - content=["Transcribe the audio", {"type": "input_audio", - "input_audio": {"data": f"{base64_bytes}", "format": "wav}]. - Additionally, if raw_content=True then the content is passed directly to the - backend without any processing. - :param request_id: The unique identifier for the request, if any. - Added to logging statements and the response for tracking purposes. - :param prompt_token_count: The number of tokens measured in the prompt, if any. - Returned in the response stats for later analysis, if applicable. - :param output_token_count: If supplied, the number of tokens to enforce - generation of for the output for this request. - :param kwargs: Additional keyword arguments to pass with the request. - :return: An async generator that yields a StreamingTextResponse for start, - a StreamingTextResponse for each received iteration, - and a ResponseSummary for the final response. + Generate chat completions using the /v1/chat/completions endpoint. + + Supports multimodal inputs including text and images with message formatting. + + :param content: Chat content - string, list of mixed content, or raw content + when raw_content=True. + :param request_id: Request identifier (currently unused). + :param output_token_count: Maximum tokens to generate. Overrides default + if specified. + :param raw_content: If True, passes content directly without formatting. + :param stream_response: Whether to stream response progressively. + :param kwargs: Additional request parameters (temperature, top_p, tools, etc.). + :yields: Tuples of (generated_text, usage_stats). First yield is (None, None). + :raises RuntimeError: If backend is not initialized. + :raises HTTPError: If API request fails. """ - logger.debug("{} invocation with args: {}", self.__class__.__name__, locals()) - headers = self._headers() - params = self._params(CHAT_COMPLETIONS) - messages = ( - content if raw_content else self._create_chat_messages(content=content) - ) - payload = self._completions_payload( - endpoint_type=CHAT_COMPLETIONS, - orig_kwargs=kwargs, + self._check_in_process() + target = f"{self.target}{self.CHAT_COMPLETIONS_PATH}" + headers = self._get_headers() + params = self._get_params(self.CHAT_COMPLETIONS_KEY) + body = self._get_body( + endpoint_type=self.CHAT_COMPLETIONS_KEY, + request_kwargs=kwargs, max_output_tokens=output_token_count, - messages=messages, + messages=self._get_chat_messages(content) if not raw_content else content, + **kwargs, ) + yield None, None # Initial yield for async iterator to signal start - try: - async for resp in self._iterative_completions_request( - type_="chat_completions", - request_id=request_id, - request_prompt_tokens=prompt_token_count, - request_output_tokens=output_token_count, - headers=headers, - params=params, - payload=payload, - ): - yield resp - except Exception as ex: - logger.error( - "{} request with headers: {} and params: {} and payload: {} failed: {}", - self.__class__.__name__, - headers, - params, - payload, - ex, + if not stream_response: + response = await self._async_client.post( # type: ignore [union-attr] + target, headers=headers, params=params, json=body ) - raise ex - - def _get_async_client(self) -> httpx.AsyncClient: - """ - Get the async HTTP client for making requests. - If the client has not been created yet, it will create one. - - :return: The async HTTP client. - """ - if self._async_client is None or self._async_client.is_closed: - client = httpx.AsyncClient( - http2=self.http2, - timeout=self.timeout, - follow_redirects=self.follow_redirects, - verify=self.verify, + response.raise_for_status() + data = response.json() + yield ( + self._get_completions_text_content(data), + self._get_completions_usage_stats(data), ) - self._async_client = client - else: - client = self._async_client + return - return client - - def _headers(self) -> dict[str, str]: - headers = { - "Content-Type": "application/json", - } - headers.update(self.headers) - return headers - - def _params(self, endpoint_type: EndpointType) -> dict[str, str]: - if self.extra_query is None: - return {} - - if ( - CHAT_COMPLETIONS in self.extra_query - or MODELS in self.extra_query - or TEXT_COMPLETIONS in self.extra_query - ): - return self.extra_query.get(endpoint_type, {}) - - return self.extra_query - - def _extra_body(self, endpoint_type: EndpointType) -> dict[str, Any]: - if self.extra_body is None: - return {} - - if ( - CHAT_COMPLETIONS in self.extra_body - or MODELS in self.extra_body - or TEXT_COMPLETIONS in self.extra_body - ): - return copy.deepcopy(self.extra_body.get(endpoint_type, {})) - - return copy.deepcopy(self.extra_body) + body.update({"stream": True, "stream_options": {"include_usage": True}}) + async with self._async_client.stream( # type: ignore [union-attr] + "POST", target, headers=headers, params=params, json=body + ) as stream: + stream.raise_for_status() + async for line in stream.aiter_lines(): + if not line or not line.strip().startswith("data:"): + continue + if line.strip() == "data: [DONE]": + break + data = json.loads(line.strip()[len("data: ") :]) + yield ( + self._get_completions_text_content(data), + self._get_completions_usage_stats(data), + ) - def _completions_payload( + def _build_headers( self, - endpoint_type: EndpointType, - orig_kwargs: Optional[dict], - max_output_tokens: Optional[int], - **kwargs, - ) -> dict: - payload = self._extra_body(endpoint_type) - payload.update(orig_kwargs or {}) - payload.update(kwargs) - payload["model"] = self.model - payload["stream"] = True - payload["stream_options"] = { - "include_usage": True, - } + api_key: Optional[str], + organization: Optional[str], + project: Optional[str], + user_headers: Optional[dict], + ) -> dict[str, str]: + headers = {} - if max_output_tokens or self.max_output_tokens: - logger.debug( - "{} adding payload args for setting output_token_count: {}", - self.__class__.__name__, - max_output_tokens or self.max_output_tokens, + if api_key: + headers["Authorization"] = ( + f"Bearer {api_key}" if not api_key.startswith("Bearer") else api_key + ) + if organization: + headers["OpenAI-Organization"] = organization + if project: + headers["OpenAI-Project"] = project + if user_headers: + headers.update(user_headers) + + return {key: val for key, val in headers.items() if val is not None} + + def _check_in_process(self): + if not self._in_process or self._async_client is None: + raise RuntimeError( + "Backend not started up for process, cannot process requests." ) - payload["max_tokens"] = max_output_tokens or self.max_output_tokens - payload["max_completion_tokens"] = payload["max_tokens"] - - if max_output_tokens: - # only set stop and ignore_eos if max_output_tokens set at request level - # otherwise the instance value is just the max to enforce we stay below - payload["stop"] = None - payload["ignore_eos"] = True - if self.remove_from_body: - for key in self.remove_from_body: - payload.pop(key, None) + def _get_headers(self) -> dict[str, str]: + return { + "Content-Type": "application/json", + **self.headers, + } - return payload + def _get_params(self, endpoint_type: str) -> dict[str, str]: + if endpoint_type in self.extra_query: + return copy.deepcopy(self.extra_query[endpoint_type]) + return copy.deepcopy(self.extra_query) - @staticmethod - def _create_chat_messages( + def _get_chat_messages( + self, content: Union[ str, list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], Any, ], - ) -> list[dict]: + ) -> list[dict[str, Any]]: if isinstance(content, str): - return [ - { - "role": "user", - "content": content, - } - ] - - if isinstance(content, list): - resolved_content = [] - - for item in content: - if isinstance(item, dict): - resolved_content.append(item) - elif isinstance(item, str): - resolved_content.append({"type": "text", "text": item}) - elif isinstance(item, Image.Image) or ( - isinstance(item, Path) and item.suffix.lower() in [".jpg", ".jpeg"] - ): - image = item if isinstance(item, Image.Image) else Image.open(item) - encoded = base64.b64encode(image.tobytes()).decode("utf-8") - resolved_content.append( - { - "type": "image", - "image": { - "url": f"data:image/jpeg;base64,{encoded}", - }, - } - ) - elif isinstance(item, Path) and item.suffix.lower() in [".wav"]: - encoded = base64.b64encode(item.read_bytes()).decode("utf-8") - resolved_content.append( - { - "type": "input_audio", - "input_audio": { - "data": f"{encoded}", - "format": "wav", - }, - } - ) - else: - raise ValueError( - f"Unsupported content item type: {item} in list: {content}" - ) - - return [ - { - "role": "user", - "content": resolved_content, - } - ] - - raise ValueError(f"Unsupported content type: {content}") - - async def _iterative_completions_request( - self, - type_: Literal["text_completions", "chat_completions"], - request_id: Optional[str], - request_prompt_tokens: Optional[int], - request_output_tokens: Optional[int], - headers: dict[str, str], - params: dict[str, str], - payload: dict[str, Any], - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - if type_ == "text_completions": - target = f"{self.target}{TEXT_COMPLETIONS_PATH}" - elif type_ == "chat_completions": - target = f"{self.target}{CHAT_COMPLETIONS_PATH}" + return [{"role": "user", "content": content}] + + if not isinstance(content, list): + raise ValueError(f"Unsupported content type: {type(content)}") + + resolved_content = [] + for item in content: + if isinstance(item, dict): + resolved_content.append(item) + elif isinstance(item, str): + resolved_content.append({"type": "text", "text": item}) + elif isinstance(item, (Image.Image, Path)): + resolved_content.append(self._get_chat_message_media_item(item)) + else: + raise ValueError(f"Unsupported content item type: {type(item)}") + + return [{"role": "user", "content": resolved_content}] + + def _get_chat_message_media_item( + self, item: Union[Path, Image.Image] + ) -> dict[str, Any]: + if isinstance(item, Image.Image): + encoded = base64.b64encode(item.tobytes()).decode("utf-8") + return { + "type": "image", + "image": {"url": f"data:image/jpeg;base64,{encoded}"}, + } + + # Handle file paths + suffix = item.suffix.lower() + if suffix in [".jpg", ".jpeg"]: + image = Image.open(item) + encoded = base64.b64encode(image.tobytes()).decode("utf-8") + return { + "type": "image", + "image": {"url": f"data:image/jpeg;base64,{encoded}"}, + } + elif suffix == ".wav": + encoded = base64.b64encode(item.read_bytes()).decode("utf-8") + return { + "type": "input_audio", + "input_audio": {"data": encoded, "format": "wav"}, + } else: - raise ValueError(f"Unsupported type: {type_}") - - logger.info( - "{} making request: {} to target: {} using http2: {} following " - "redirects: {} for timeout: {} with headers: {} and params: {} and ", - "payload: {}", - self.__class__.__name__, - request_id, - target, - self.http2, - self.follow_redirects, - self.timeout, - headers, - params, - payload, - ) - - response_value = "" - response_prompt_count: Optional[int] = None - response_output_count: Optional[int] = None - iter_count = 0 - start_time = time.time() - iter_time = start_time - first_iter_time: Optional[float] = None - last_iter_time: Optional[float] = None - - yield StreamingTextResponse( - type_="start", - value="", - start_time=start_time, - first_iter_time=None, - iter_count=iter_count, - delta="", - time=start_time, - request_id=request_id, - ) - - # reset start time after yielding start response to ensure accurate timing - start_time = time.time() - - async with self._get_async_client().stream( - "POST", target, headers=headers, params=params, json=payload - ) as stream: - stream.raise_for_status() - - async for line in stream.aiter_lines(): - iter_time = time.time() - logger.debug( - "{} request: {} recieved iter response line: {}", - self.__class__.__name__, - request_id, - line, - ) - - if not line or not line.strip().startswith("data:"): - continue + raise ValueError(f"Unsupported file type: {suffix}") - if line.strip() == "data: [DONE]": - break - - data = json.loads(line.strip()[len("data: ") :]) - if delta := self._extract_completions_delta_content(type_, data): - if first_iter_time is None: - first_iter_time = iter_time - last_iter_time = iter_time - - iter_count += 1 - response_value += delta - - yield StreamingTextResponse( - type_="iter", - value=response_value, - iter_count=iter_count, - start_time=start_time, - first_iter_time=first_iter_time, - delta=delta, - time=iter_time, - request_id=request_id, - ) - - if usage := self._extract_completions_usage(data): - response_prompt_count = usage["prompt"] - response_output_count = usage["output"] - - logger.info( - "{} request: {} with headers: {} and params: {} and payload: {} completed" - "with: {}", - self.__class__.__name__, - request_id, - headers, - params, - payload, - response_value, - ) + def _get_body( + self, + endpoint_type: str, + request_kwargs: Optional[dict[str, Any]], + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> dict[str, Any]: + # Start with endpoint-specific extra body parameters + extra_body = self.extra_body.get(endpoint_type, self.extra_body) + + body = copy.deepcopy(extra_body) + body.update(request_kwargs or {}) + body.update(kwargs) + body["model"] = self.model + + # Handle token limits + max_tokens = max_output_tokens or self.max_output_tokens + if max_tokens is not None: + body.update( + { + "max_tokens": max_tokens, + "max_completion_tokens": max_tokens, + } + ) + # Set stop conditions only for request-level limits + if max_output_tokens: + body.update({"stop": None, "ignore_eos": True}) - yield ResponseSummary( - value=response_value, - request_args=RequestArgs( - target=target, - headers=headers, - params=params, - payload=payload, - timeout=self.timeout, - http2=self.http2, - follow_redirects=self.follow_redirects, - ), - start_time=start_time, - end_time=iter_time, - first_iter_time=first_iter_time, - last_iter_time=last_iter_time, - iterations=iter_count, - request_prompt_tokens=request_prompt_tokens, - request_output_tokens=request_output_tokens, - response_prompt_tokens=response_prompt_count, - response_output_tokens=response_output_count, - request_id=request_id, - ) + return {key: val for key, val in body.items() if val is not None} - @staticmethod - def _extract_completions_delta_content( - type_: Literal["text_completions", "chat_completions"], data: dict - ) -> Optional[str]: - if "choices" not in data or not data["choices"]: + def _get_completions_text_content(self, data: dict) -> Optional[str]: + if not data.get("choices"): return None - if type_ == "text_completions": - return data["choices"][0]["text"] + choice = data["choices"][0] + return choice.get("text") or choice.get("delta", {}).get("content") - if type_ == "chat_completions": - return data["choices"][0]["delta"]["content"] - - raise ValueError(f"Unsupported type: {type_}") - - @staticmethod - def _extract_completions_usage( - data: dict, - ) -> Optional[dict[Literal["prompt", "output"], int]]: - if "usage" not in data or not data["usage"]: + def _get_completions_usage_stats(self, data: dict) -> Optional[UsageStats]: + if not data.get("usage"): return None - return { - "prompt": data["usage"]["prompt_tokens"], - "output": data["usage"]["completion_tokens"], - } + return UsageStats( + prompt_tokens=data["usage"].get("prompt_tokens"), + output_tokens=data["usage"].get("completion_tokens"), + ) diff --git a/src/guidellm/backend/response.py b/src/guidellm/backend/response.py deleted file mode 100644 index ee2101d7..00000000 --- a/src/guidellm/backend/response.py +++ /dev/null @@ -1,136 +0,0 @@ -from typing import Any, Literal, Optional - -from pydantic import computed_field - -from guidellm.config import settings -from guidellm.objects.pydantic import StandardBaseModel - -__all__ = [ - "RequestArgs", - "ResponseSummary", - "StreamingResponseType", - "StreamingTextResponse", -] - - -StreamingResponseType = Literal["start", "iter"] - - -class StreamingTextResponse(StandardBaseModel): - """ - A model representing the response content for a streaming text request. - - :param type_: The type of the response; either 'start' or 'iter'. - :param value: The value of the response up to this iteration. - :param start_time: The time.time() the request started. - :param iter_count: The iteration count for the response. For 'start' this is 0 - and for the first 'iter' it is 1. - :param delta: The text delta added to the response for this stream iteration. - :param time: If 'start', the time.time() the request started. - If 'iter', the time.time() the iteration was received. - :param request_id: The unique identifier for the request, if any. - """ - - type_: StreamingResponseType - value: str - start_time: float - first_iter_time: Optional[float] - iter_count: int - delta: str - time: float - request_id: Optional[str] = None - - -class RequestArgs(StandardBaseModel): - """ - A model representing the arguments for a request to a backend. - Biases towards an HTTP request, but can be used for other types of backends. - - :param target: The target URL or function for the request. - :param headers: The headers, if any, included in the request such as authorization. - :param params: The query parameters, if any, included in the request. - :param payload: The payload / arguments for the request including the prompt / - content and other configurations. - :param timeout: The timeout for the request in seconds, if any. - :param http2: Whether HTTP/2 was used for the request, if applicable. - :param follow_redirects: Whether the request should follow redirect responses. - """ - - target: str - headers: dict[str, str] - params: dict[str, str] - payload: dict[str, Any] - timeout: Optional[float] = None - http2: Optional[bool] = None - follow_redirects: Optional[bool] = None - - -class ResponseSummary(StandardBaseModel): - """ - A model representing a summary of a backend request. - Always returned as the final iteration of a streaming request. - - :param value: The final value returned from the request. - :param request_args: The arguments used to make the request. - :param iterations: The number of iterations in the request. - :param start_time: The time the request started. - :param end_time: The time the request ended. - :param first_iter_time: The time the first iteration was received. - :param last_iter_time: The time the last iteration was received. - :param request_prompt_tokens: The number of tokens measured in the prompt - for the request, if any. - :param request_output_tokens: The number of tokens enforced for the output - for the request, if any. - :param response_prompt_tokens: The number of tokens measured in the prompt - for the response, if any. - :param response_output_tokens: The number of tokens measured in the output - for the response, if any. - :param request_id: The unique identifier for the request, if any. - :param error: The error message, if any, returned from making the request. - """ - - value: str - request_args: RequestArgs - iterations: int = 0 - start_time: float - end_time: float - first_iter_time: Optional[float] - last_iter_time: Optional[float] - request_prompt_tokens: Optional[int] = None - request_output_tokens: Optional[int] = None - response_prompt_tokens: Optional[int] = None - response_output_tokens: Optional[int] = None - request_id: Optional[str] = None - error: Optional[str] = None - - @computed_field # type: ignore[misc] - @property - def prompt_tokens(self) -> Optional[int]: - """ - The number of tokens measured in the prompt based on preferences - for trusting the input or response. - - :return: The number of tokens in the prompt, if any. - """ - if settings.preferred_prompt_tokens_source == "request": - return self.request_prompt_tokens or self.response_prompt_tokens - - return self.response_prompt_tokens or self.request_prompt_tokens - - @computed_field # type: ignore[misc] - @property - def output_tokens(self) -> Optional[int]: - """ - The number of tokens measured in the output based on preferences - for trusting the input or response. - - :return: The number of tokens in the output, if any. - """ - if self.error is not None: - # error occurred, can't trust request tokens were all generated - return self.response_prompt_tokens - - if settings.preferred_output_tokens_source == "request": - return self.request_output_tokens or self.response_output_tokens - - return self.response_output_tokens or self.request_output_tokens diff --git a/src/guidellm/benchmark/__init__.py b/src/guidellm/benchmark/__init__.py index a4676c7e..69bdf860 100644 --- a/src/guidellm/benchmark/__init__.py +++ b/src/guidellm/benchmark/__init__.py @@ -3,12 +3,12 @@ Benchmark, BenchmarkArgs, BenchmarkMetrics, - BenchmarkRunStats, + BenchmarkSchedulerStats, BenchmarkT, GenerativeBenchmark, GenerativeMetrics, + GenerativeRequestStats, GenerativeTextErrorStats, - GenerativeTextResponseStats, StatusBreakdown, ) from .benchmarker import Benchmarker, BenchmarkerResult, GenerativeBenchmarker @@ -38,7 +38,7 @@ "BenchmarkAggregator", "BenchmarkArgs", "BenchmarkMetrics", - "BenchmarkRunStats", + "BenchmarkSchedulerStats", "BenchmarkT", "Benchmarker", "BenchmarkerProgressDisplay", @@ -51,10 +51,10 @@ "GenerativeBenchmarksConsole", "GenerativeBenchmarksReport", "GenerativeMetrics", + "GenerativeRequestStats", "GenerativeTextBenchmarkerProgressDisplay", "GenerativeTextBenchmarkerTaskProgressState", "GenerativeTextErrorStats", - "GenerativeTextResponseStats", "Profile", "ProfileType", "StatusBreakdown", diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py index af7f1a13..c62f6177 100644 --- a/src/guidellm/benchmark/aggregator.py +++ b/src/guidellm/benchmark/aggregator.py @@ -1,760 +1,833 @@ -import time -from abc import ABC, abstractmethod -from pathlib import Path +""" +Benchmark result aggregation and compilation interfaces. + +Provides protocols and implementations for collecting, processing, and compiling +benchmark data from scheduler executions into final metrics and statistics. + +Classes: + Aggregator: Protocol for processing benchmark data updates. + CompilableAggregator: Protocol for aggregators that can compile final results. + SchedulerStatsAggregator: Aggregates scheduler timing and performance metrics. + GenerativeRequestsStatsProgressAggregator: Tracks generation metrics during run. + GenerativeRequestsAggregator: Compiles complete generative benchmark results. + +Functions: + add_aggregate_metric: Helper for accumulating timing and count metrics. + +Type Variables: + RequestT: Generic request object type. + ResponseT: Generic response object type. + RequestTimingsT: Generic request timing object type. +""" + +import math from typing import ( Any, - Generic, Literal, Optional, - TypeVar, + Protocol, Union, + runtime_checkable, ) from pydantic import Field -from guidellm.backend import ResponseSummary +from guidellm.backend import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, +) from guidellm.benchmark.benchmark import ( - BenchmarkArgs, - BenchmarkRunStats, - BenchmarkT, - GenerativeBenchmark, - GenerativeTextErrorStats, - GenerativeTextResponseStats, + BenchmarkSchedulerStats, + GenerativeMetrics, + GenerativeRequestStats, ) from guidellm.config import settings from guidellm.objects import ( - RunningStats, StandardBaseModel, StatusBreakdown, - TimeRunningStats, -) -from guidellm.request import ( - GenerationRequest, - GenerativeRequestLoaderDescription, - RequestLoaderDescription, + StatusDistributionSummary, ) from guidellm.scheduler import ( - GenerativeRequestsWorkerDescription, + MeasuredRequestTimingsT, RequestT, ResponseT, - SchedulerRequestResult, - WorkerDescription, + ScheduledRequestInfo, + SchedulerState, ) -from guidellm.utils import check_load_processor __all__ = [ - "AggregatorT", - "BenchmarkAggregator", - "GenerativeBenchmarkAggregator", + "Aggregator", + "CompilableAggregator", + "GenerativeRequestsAggregator", + "GenerativeRequestsStatsProgressAggregator", + "SchedulerStatsAggregator", + "add_aggregate_metric", ] -class SchedulerRunningStats(StandardBaseModel): +@runtime_checkable +class Aggregator(Protocol[ResponseT, RequestT, MeasuredRequestTimingsT]): """ - The metrics for the scheduler stored as running statistics for easy calculations - of rates, averages, totals, etc. + Protocol for processing benchmark data updates during execution. + + Defines the interface for aggregators that collect and process request/response + data from scheduler executions. Implementations update aggregation state with + each completed request for eventual compilation into final metrics. """ - created_requests: RunningStats = Field( - description=( - "The running statistics for the number of requests created for this " - "benchmark run. This includes all requests created, regardless of " - "their status." - ), - default_factory=RunningStats, - ) - queued_requests: RunningStats = Field( - description=( - "The running statistics for the number of requests pending in queue " - "for this benchmark run. This includes requests that are waiting to " - "be scheduled." - ), - default_factory=RunningStats, - ) - scheduled_requests: RunningStats = Field( - description=( - "The running statistics for the number of requests scheduled (actively " - "running but waiting for the desired start time) for this benchmark run." - ), - default_factory=RunningStats, - ) - processing_requests: RunningStats = Field( - description=( - "The running statistics for the number of requests actively being " - "processed by the worker for this benchmark run." - ), - default_factory=RunningStats, - ) - completed_requests: RunningStats = Field( - description=( - "The running statistics for the number of requests completed for this " - "benchmark run. This includes requests within the warmup and cooldown " - "period, if any, along with the final results." - ), - default_factory=RunningStats, - ) + def __call__( + self, + agg_state: dict[str, Any], + response: Optional[ResponseT], + request: RequestT, + request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + scheduler_state: SchedulerState, + ) -> Optional[dict[str, Any]]: + """ + Process a completed request and update aggregation state. + + :param agg_state: Current aggregation state to update in-place. + :param response: Response generated for the request, if successful. + :param request: The processed request object. + :param request_info: Scheduling metadata and timing information. + :param scheduler_state: Current scheduler execution state. + :return: Optional intermediate updates for progress reporting. + """ + ... -class RequestsRunningStats(StandardBaseModel): +@runtime_checkable +class CompilableAggregator(Aggregator[ResponseT, RequestT, MeasuredRequestTimingsT]): """ - The metrics for requests that have succeeded, been canceled, or errored stored - as running statistics for easy calculations of rates, averages, totals, etc. + Protocol for aggregators that compile final results from aggregated state. + + Extends the Aggregator protocol with the ability to transform accumulated + state into final benchmark results and metrics after execution completes. """ - totals: StatusBreakdown[RunningStats, RunningStats, RunningStats, RunningStats] = ( - Field( - description=( - "The running statistics for the total number of requests that " - "completed within the benchmark run." - ), - default_factory=lambda: StatusBreakdown( - successful=RunningStats(), - errored=RunningStats(), - incomplete=RunningStats(), - total=RunningStats(), - ), - ) - ) - queued_time: TimeRunningStats = Field( - description=( - "The running statistics for the time spent in queue for all requests that " - "completed within the benchmark run. This is the time from when the " - "request was created to when it was dequeued by the worker." - ), - default_factory=TimeRunningStats, - ) - scheduled_time_delay: TimeRunningStats = Field( - description=( - "The running statistics for the time spent from when a request was " - "dequeued by the worker to when it was actually scheduled by the worker" - "for all requests that completed within the benchmark run. " - "This should be as close to 0 as possible, any additional time is " - "overheads from the system or the worker." - ), - default_factory=TimeRunningStats, - ) - scheduled_time_sleep: TimeRunningStats = Field( - description=( - "The running statistics for the time for each request spent sleeping til " - "the desired start time was reached for all requests that completed within " - "the benchmark run. This is the time from when the request was scheduled " - "to when the desired start time was reached. " - ), - default_factory=TimeRunningStats, - ) - worker_start_delay: TimeRunningStats = Field( - description=( - "The running statistics for the time delay between when the request was " - "scheduled and when the worker actually started processing subtracting any " - "sleep time for all requests that completed within the benchmark run. " - "This should be as close to 0 as possible, any additional time is " - "overheads from the system or the worker." - ), - default_factory=TimeRunningStats, - ) - worker_time: TimeRunningStats = Field( - description=( - "The running statistics for the time spent processing all requests that " - "completed within the benchmark run. This is the time from when the " - "request was started to when it was completed." - ), - default_factory=TimeRunningStats, - ) - worker_start_time_targeted_delay: TimeRunningStats = Field( - description=( - "The running statistics for the delay between the targeted start time and " - "the actual start time for requests that completed within the benchmark " - "run. This represents delays from the best case desired start time. " - "For async strategies, this represents delays from the ideal system. " - "For sync strategies, since those are doubled in queue, this should be " - "as close to the time for a request to be processed as possible." - ), - default_factory=TimeRunningStats, - ) - request_start_time_delay: TimeRunningStats = Field( - description=( - "The running statistics for the delay between the actual request being " - "made and the time the worker started on the request for all requests " - "that completed within the benchmark run. This time should be as close to " - "0 as possible, any additional time is overhead from the system or " - "the worker." - ), - default_factory=TimeRunningStats, - ) - request_start_time_targeted_delay: TimeRunningStats = Field( - description=( - "The running statistics for the delay between the targeted start time and " - "the actual start time for all requests that completed within the " - "benchmark run. This represents delays from the best case desired start " - "time. For async strategies, this represents delays from the ideal system. " - "For sync strategies, since those are duplicated in queue, this should be " - "as close to the time for a request to be processed." - ), - default_factory=TimeRunningStats, - ) - request_time_delay: TimeRunningStats = Field( - description=( - "The running statistics for the delay in time between the total request " - "time and the worker time. This should be as close to 0 as possible, any " - "additional time is overhead from the system or the worker. " - ), - default_factory=TimeRunningStats, - ) - request_time: TimeRunningStats = Field( - description=( - "The running statistics for the time spent processing all requests that " - "completed within the benchmark run. This is the time from when the " - "request was created to when it was completed." - ), - default_factory=TimeRunningStats, - ) + def compile( + self, agg_state: dict[str, Any], scheduler_state: SchedulerState + ) -> dict[str, Any]: + """ + Compile aggregated state into final benchmark results. + + :param agg_state: The accumulated aggregation state. + :param scheduler_state: Final scheduler execution state. + :return: Compiled benchmark results and metrics. + """ -class BenchmarkAggregator( - ABC, StandardBaseModel, Generic[BenchmarkT, RequestT, ResponseT] +def add_aggregate_metric( + base_key: str, + agg_state: dict[str, Any], + end_val: Optional[Union[int, float]], + start_val: Optional[Union[int, float]] = 0.0, + count: int = 1, ): """ - A pydantic base class representing the base class for aggregating benchmark results. - The purpose is to receive and process results from a Benchmarker as it iterates - through a Scheduler for an individual benchmark run. - As results are added, lightweight statistics are updated and stored for immediate - progress and informational updates to the caller. - Once the benchmark run is complete, the `compile` method is called to finalize - the benchmark and return a Benchmark object with all the results and statistics - fully calculated. + Add timing or count metrics to aggregation state. + + Accumulates delta values and counts for computing averages and totals. + Creates entries for "{base_key}_total" and "{base_key}_count" in agg_state. + + :param base_key: Base key name for the metric. + :param agg_state: Aggregation state dictionary to update. + :param end_val: End value for calculating delta, or None to skip. + :param start_val: Start value for calculating delta, defaults to 0.0. + :param count: Number of occurrences to count, defaults to 1. """ + if start_val is None or end_val is None: + return + + delta_val = end_val - start_val + agg_state[f"{base_key}_total"] = agg_state.get(f"{base_key}_total", 0) + delta_val + agg_state[f"{base_key}_count"] = agg_state.get(f"{base_key}_count", 0) + count - type_: Literal["benchmark_aggregator"] = "benchmark_aggregator" - run_id: str = Field( - description=( - "The unique identifier for the encompasing benchmark run that this " - "benchmark was a part of." - ) - ) - args: BenchmarkArgs = Field( - description=( - "The arguments used to create the benchmark run that this benchmark was " - "a part of." - ) - ) - worker_description: Union[ - GenerativeRequestsWorkerDescription, WorkerDescription - ] = Field( - description=( - "The description and specifics for the worker used to resolve requests " - "for this benchmark." - ), - discriminator="type_", - ) - request_loader_description: Union[ - GenerativeRequestLoaderDescription, RequestLoaderDescription - ] = Field( - description=( - "The description and specifics for the request loader used to create " - "requests for this benchmark." - ), - discriminator="type_", - ) - extras: dict[str, Any] = Field( - description=( - "Any additional information or metadata that was passed for this benchmark." - ) - ) - in_warmup: bool = Field( - description=( - "A flag to indicate if the benchmark is currently in the warmup phase." - ), - default=False, - exclude=True, - ) - in_cooldown: bool = Field( - description=( - "A flag to indicate if the benchmark is currently in the cooldown phase." - ), - default=False, - exclude=True, - ) - scheduler_stats: SchedulerRunningStats = Field( - description=( - "The running statistics for the scheduler for this benchmark run. " - "This includes all requests created, regardless of their status." - ), - default_factory=SchedulerRunningStats, - ) - requests_stats: RequestsRunningStats = Field( - description=( - "The running statistics for the requests for this benchmark run. " - "This includes all requests created, regardless of their status." - ), - default_factory=RequestsRunningStats, - ) - results: StatusBreakdown[ - list[SchedulerRequestResult[RequestT, ResponseT]], - list[SchedulerRequestResult[RequestT, ResponseT]], - list[SchedulerRequestResult[RequestT, ResponseT]], - None, - ] = Field( - description=( - "The completed requests for this benchmark run broken down by status" - "and excluding warmup and cooldown requests." - ), - default_factory=lambda: StatusBreakdown( # type: ignore[arg-type] - successful=[], - errored=[], - incomplete=[], - total=None, - ), - ) - def add_result( +class SchedulerStatsAggregator( + CompilableAggregator[ResponseT, RequestT, MeasuredRequestTimingsT] +): + """ + Aggregates scheduler timing and performance metrics. + + Collects timing data for various scheduler phases including queuing, + resolution, and processing delays to generate performance statistics. + """ + + def __call__( self, - result: SchedulerRequestResult[RequestT, ResponseT], - ) -> bool: + agg_state: dict[str, Any], + response: Optional[ResponseT], + request: RequestT, + request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + scheduler_state: SchedulerState, + ) -> Optional[dict[str, Any]]: """ - Add a result to the aggregator. This will update the internal statistics - and add the result to the list of results if it is not within the warmup or - cooldown period. - - :param result: The result to add to the aggregator. - :return: True if the result was added, False if it was added because it - did not fit within the warmup or cooldown period, was not requested, - or is not finished + Aggregate scheduler timing metrics for a completed request. + + :param agg_state: Current aggregation state to update. + :param response: Response generated for the request, if successful. + :param request: The processed request object. + :param request_info: Scheduling metadata and timing information. + :param scheduler_state: Current scheduler execution state. + :return: Updated aggregation state for intermediate reporting. """ - # Add scheduler statistics - self.scheduler_stats.created_requests += max( - 0, result.run_info.created_requests - ) - self.scheduler_stats.queued_requests += max(0, result.run_info.queued_requests) - self.scheduler_stats.scheduled_requests += max( - 0, result.run_info.scheduled_requests - ) - self.scheduler_stats.processing_requests += max( - 0, result.run_info.processing_requests - ) - self.scheduler_stats.completed_requests += max( - 0, result.run_info.completed_requests - ) - - if result.type_ != "request_complete" or ( - result.request_info.canceled and not result.request_info.requested - ): - # If the result is not completed yet, don't add to the results - # If the result was canceled and not started, ignore it - return False - - # Add request statistics - self.requests_stats.totals.total += 1 - if result.request_info.canceled: - self.requests_stats.totals.incomplete += 1 - elif result.request_info.errored: - self.requests_stats.totals.errored += 1 - elif result.request_info.completed: - self.requests_stats.totals.successful += 1 - else: - raise ValueError( - "Unexpected state: request_info must be either " - "completed, canceled, or errored. " - f"Got {result.request_info}" - ) - - self.requests_stats.queued_time.update( - result.request_info.dequeued_time - result.request_info.queued_time - ) - self.requests_stats.scheduled_time_delay.update( - result.request_info.scheduled_time - result.request_info.dequeued_time + if response is None: + return None + + add_aggregate_metric( + "queued_time", + agg_state, + request_info.scheduler_timings.dequeued, + request_info.scheduler_timings.queued, ) - sleep_time = max( - 0.0, - result.request_info.targeted_start_time - - result.request_info.scheduled_time, + add_aggregate_metric( + "worker_resolve_start_delay", + agg_state, + request_info.scheduler_timings.resolve_start, + request_info.scheduler_timings.scheduled, ) - self.requests_stats.scheduled_time_sleep.update(sleep_time) - time_to_worker_start = ( - result.request_info.worker_start - result.request_info.scheduled_time + add_aggregate_metric( + "worker_resolve_time", + agg_state, + request_info.scheduler_timings.resolve_end, + request_info.scheduler_timings.resolve_start, ) - self.requests_stats.worker_start_delay.update(time_to_worker_start - sleep_time) - self.requests_stats.worker_time.update( - result.request_info.worker_end - result.request_info.worker_start + add_aggregate_metric( + "worker_resolve_end_delay", + agg_state, + request_info.scheduler_timings.resolve_end, + request_info.request_timings.request_end, ) - self.requests_stats.worker_start_time_targeted_delay.update( - result.request_info.worker_start - result.request_info.targeted_start_time + add_aggregate_metric( + "finalized_delay", + agg_state, + request_info.scheduler_timings.finalized, + request_info.scheduler_timings.resolve_end, ) - self.requests_stats.request_start_time_delay.update( - result.request_info.worker_start - result.request_info.targeted_start_time + add_aggregate_metric( + "worker_targeted_start_delay", + agg_state, + request_info.scheduler_timings.resolve_start, + request_info.scheduler_timings.targeted_start, ) - self.requests_stats.request_start_time_targeted_delay.update( - result.request_info.worker_start - result.request_info.targeted_start_time + add_aggregate_metric( + "request_start_delay", + agg_state, + request_info.scheduler_timings.resolve_start, + request_info.request_timings.request_start, ) - self.requests_stats.request_time_delay.update( - (result.request_info.worker_end - result.request_info.worker_start) - - (result.request_info.worker_end - result.request_info.worker_start) + add_aggregate_metric( + "request_time", + agg_state, + request_info.request_timings.request_end, + request_info.request_timings.request_start, ) - self.requests_stats.request_time.update( - result.request_info.worker_end - result.request_info.worker_start + add_aggregate_metric( + "request_targeted_start_delay", + agg_state, + request_info.request_timings.request_start, + request_info.scheduler_timings.targeted_start, ) - # Add result to the list of results provided we are not in warmup or cooldown - total_completed = self.requests_stats.totals.total.total - global_start_time = self.requests_stats.totals.total.start_time + return agg_state - in_warmup_number = ( - self.args.warmup_number and total_completed <= self.args.warmup_number - ) - in_warmup_duration = ( - self.args.warmup_duration - and result.request_info.worker_start - <= (global_start_time + self.args.warmup_duration) - ) + def compile( + self, agg_state: dict[str, Any], scheduler_state: SchedulerState + ) -> dict[Literal["scheduler_stats"], BenchmarkSchedulerStats]: + """ + Compile scheduler timing metrics into benchmark statistics. - if in_warmup_number or in_warmup_duration: - self.in_warmup = True - return True + :param agg_state: Accumulated timing data and counts. + :param scheduler_state: Final scheduler execution state. + :return: Dictionary containing compiled scheduler statistics. + """ + return { + "scheduler_stats": BenchmarkSchedulerStats( + start_time=scheduler_state.start_time, + end_time=scheduler_state.end_time, + requests_made=StatusBreakdown( + successful=scheduler_state.successful_requests, + incomplete=scheduler_state.cancelled_requests, + errored=scheduler_state.errored_requests, + ), + queued_time_avg=( + agg_state.get("queued_time_total", 0.0) + / agg_state.get("queued_time_count", 1) + ), + worker_resolve_start_delay_avg=( + agg_state.get("worker_resolve_start_delay_total", 0.0) + / agg_state.get("worker_resolve_start_delay_count", 1) + ), + worker_resolve_time_avg=( + agg_state.get("worker_resolve_time_total", 0.0) + / agg_state.get("worker_resolve_time_count", 1) + ), + worker_resolve_end_delay_avg=( + agg_state.get("worker_resolve_end_delay_total", 0.0) + / agg_state.get("worker_resolve_end_delay_count", 1) + ), + finalized_delay_avg=( + agg_state.get("finalized_delay_total", 0.0) + / agg_state.get("finalized_delay_count", 1) + ), + worker_targeted_start_delay_avg=( + agg_state.get("worker_targeted_start_delay_total", 0.0) + / agg_state.get("worker_targeted_start_delay_count", 1) + ), + request_start_delay_avg=( + agg_state.get("request_start_delay_total", 0.0) + / agg_state.get("request_start_delay_count", 1) + ), + request_time_avg=( + agg_state.get("request_time_total", 0.0) + / agg_state.get("request_time_count", 1) + ), + request_targeted_delay_avg=( + agg_state.get("request_targeted_delay_total", 0.0) + / agg_state.get("request_targeted_delay_count", 1) + ), + ), + } - self.in_warmup = False - in_cooldown_number = ( - self.args.cooldown_number - and self.args.max_number - and total_completed > self.args.max_number - self.args.cooldown_number - ) - in_cooldown_duration = ( - self.args.cooldown_duration - and self.args.max_duration - and result.request_info.worker_start - > global_start_time + self.args.max_duration - self.args.cooldown_duration - ) - if in_cooldown_number or in_cooldown_duration: - self.in_cooldown = True - return True +class GenerativeRequestsStatsProgressAggregator( + Aggregator[GenerationResponse, GenerationRequest, GenerationRequestTimings] +): + """ + Tracks generative model metrics during benchmark execution. - self.in_cooldown = False + Aggregates token-level metrics including time to first token, inter-token + latency, and token counts for real-time progress monitoring. + """ - if result.request_info.canceled: - self.results.incomplete.append(result) - elif result.request_info.errored: - self.results.errored.append(result) - elif result.request_info.completed: - self.results.successful.append(result) - else: - raise ValueError( - "Unexpected state: request_info must be either " - "completed, canceled, or errored. " - f"Got {result.request_info}" + def __call__( + self, + agg_state: dict[str, Any], + response: Optional[GenerationResponse], + request: GenerationRequest, + request_info: ScheduledRequestInfo[GenerationRequestTimings], + scheduler_state: SchedulerState, + ) -> Optional[dict[str, Any]]: + """ + Aggregate generative model metrics for a completed request. + + :param agg_state: Current aggregation state to update. + :param response: Generation response with token and timing data. + :param request: The processed generation request. + :param request_info: Scheduling metadata and timing information. + :param scheduler_state: Current scheduler execution state. + :return: Updated aggregation state for progress reporting. + """ + if response is None: + return None + + if ( + request_info.status == "completed" + and request_info.request_timings.request_end is not None + ): + agg_state["requests_per_second"] = scheduler_state.successful_requests / ( + request_info.request_timings.request_end - scheduler_state.start_time + ) + add_aggregate_metric( + "request_latency", + agg_state, + request_info.request_timings.request_end, + request_info.request_timings.request_start, ) - return True + if ( + request_info.status == "completed" + and request_info.request_timings.first_iteration is not None + and request_info.request_timings.last_iteration is not None + and response.output_tokens + ): + add_aggregate_metric( + "time_per_output_token", + agg_state, + request_info.request_timings.last_iteration, + request_info.request_timings.request_start, + count=response.output_tokens, + ) - @abstractmethod - def compile(self) -> BenchmarkT: - """ - Compile the benchmark results and statistics into a Benchmark object. - This is required to be implemented by subclasses to finalize the benchmark - and return the compiled object. - """ - ... + if ( + request_info.request_timings.first_iteration is not None + and request_info.request_timings.request_start is not None + ): + add_aggregate_metric( + "time_to_first_token", + agg_state, + request_info.request_timings.first_iteration, + request_info.request_timings.request_start, + ) + + if ( + request_info.request_timings.first_iteration is not None + and request_info.request_timings.last_iteration is not None + and response.output_tokens is not None + and response.output_tokens > 1 + ): + add_aggregate_metric( + "inter_token_latency", + agg_state, + request_info.request_timings.last_iteration, + request_info.request_timings.first_iteration, + count=response.output_tokens - 1, + ) + if response.prompt_tokens is not None: + add_aggregate_metric( + "prompt_tokens", + agg_state, + response.prompt_tokens, + ) + if request_info.request_timings.request_end is not None: + agg_state["prompt_tokens_per_second"] = agg_state[ + "prompt_tokens_total" + ] / ( + request_info.request_timings.request_end + - scheduler_state.start_time + ) + + if response.output_tokens is not None: + add_aggregate_metric( + "output_tokens", + agg_state, + response.output_tokens, + ) + if request_info.request_timings.request_end is not None: + agg_state["output_tokens_per_second"] = agg_state[ + "output_tokens_total" + ] / ( + request_info.request_timings.request_end + - scheduler_state.start_time + ) + + if response.total_tokens is not None: + add_aggregate_metric( + "total_tokens", + agg_state, + response.total_tokens, + ) + if request_info.request_timings.request_end is not None: + agg_state["total_tokens_per_second"] = agg_state[ + "total_tokens_total" + ] / ( + request_info.request_timings.request_end + - scheduler_state.start_time + ) -AggregatorT = TypeVar("AggregatorT", bound=BenchmarkAggregator) + return agg_state -class GenerativeRequestsRunningStats(RequestsRunningStats): +class GenerativeRequestsAggregator( + StandardBaseModel, + CompilableAggregator[ + GenerationResponse, GenerationRequest, GenerationRequestTimings + ], +): """ - The metrics for generative requests that have succeeded, been canceled, or errored - stored as running statistics for easy calculations of rates, averages, totals, etc. + Compiles complete generative benchmark results with warmup/cooldown filtering. + + Aggregates request data during execution and compiles comprehensive metrics + including timing distributions, token statistics, and throughput measurements. + Supports filtering warmup and cooldown periods from final results. """ - time_to_first_token: TimeRunningStats = Field( - description=( - "The running statistics for the time from the start of the request to the " - "first token being generated for all requests that completed within the " - "benchmark run." - ), - default_factory=TimeRunningStats, - ) - inter_token_latency: TimeRunningStats = Field( - description=( - "The running statistics for the time between each token being generated " - "for all requests that completed within the benchmark run." - ), - default_factory=TimeRunningStats, + warmup_requests: Optional[int] = Field( + default=None, + description="Number of warmup requests to ignore at benchmark start", ) - prompt_tokens: RunningStats = Field( - description=( - "The running statistics for the token count for the prompt for all " - "requests that completed, if available in the response." - ), - default_factory=RunningStats, + warmup_duration: Optional[float] = Field( + default=None, + description="Warmup duration in seconds to ignore at benchmark start", ) - output_tokens: RunningStats = Field( - description=( - "The running statistics for the token count for the output for all " - "requests that completed, if available in the response." - ), - default_factory=RunningStats, + cooldown_requests: Optional[int] = Field( + default=None, + description="Number of cooldown requests to ignore at benchmark end", ) - total_tokens: RunningStats = Field( - description=( - "The running statistics for the total token count for all requests that " - "completed, if available in the response." - ), - default_factory=RunningStats, + cooldown_duration: Optional[float] = Field( + default=None, + description="Cooldown duration in seconds to ignore at benchmark end", ) + _in_cooldown: bool = False + _in_warmup: bool = False + def __call__( + self, + agg_state: dict[str, Any], + response: Optional[GenerationResponse], + request: GenerationRequest, + request_info: ScheduledRequestInfo[GenerationRequestTimings], + scheduler_state: SchedulerState, + ) -> Optional[dict[str, Any]]: + """ + Collect completed requests for final compilation. -class GenerativeBenchmarkAggregator( - BenchmarkAggregator[GenerativeBenchmark, GenerationRequest, ResponseSummary] -): - type_: Literal["generative_benchmark_aggregator"] = ( - "generative_benchmark_aggregator" # type: ignore[assignment] - ) - processor: Optional[Union[str, Path, Any]] = Field( - description=( - "The tokenizer to use for calculating token counts when none are " - "avaiable that match the preferred source." - ) - ) - processor_args: Optional[dict[str, Any]] = Field( - description=( - "Additional arguments to pass to the tokenizer if it requires " - "any specific configuration for loading or processing." - ), - ) - worker_description: GenerativeRequestsWorkerDescription = Field( - description=( - "The description and specifics for the worker used to resolve requests " - "for this benchmark." - ), - discriminator="type_", - ) - request_loader_description: GenerativeRequestLoaderDescription = Field( - description=( - "The description and specifics for the request loader used to create " - "requests for this benchmark." - ), - discriminator="type_", - ) - requests_stats: GenerativeRequestsRunningStats = Field( - description=( - "The running statistics for the requests for this benchmark run. " - "This includes all requests created, regardless of their status." - ), - default_factory=GenerativeRequestsRunningStats, - ) + Filters requests based on warmup/cooldown settings and categorizes by + completion status for comprehensive benchmark analysis. - def add_result( - self, result: SchedulerRequestResult[GenerationRequest, ResponseSummary] - ) -> bool: + :param agg_state: Current aggregation state to update. + :param response: Generation response data. + :param request: The processed generation request. + :param request_info: Scheduling metadata and timing information. + :param scheduler_state: Current scheduler execution state. + :return: None, as this aggregator only collects for final compilation. """ - Add a result to the aggregator. This will update the internal statistics - and add the result to the list of results if it is not within the warmup or - cooldown period. + status = { + "requests_in_warmup": False, + "requests_in_cooldown": False, + } + + if ( + response is None + or request_info.status not in {"completed", "canceled", "errored"} + or (request_info.status == "canceled" and request_info.started_at is None) + ): + # Ignore requests that don't have a response yet. + # Ignore requests that were canceled before they started. + return status + + if ( + self.warmup_requests is not None + and self.warmup_requests >= scheduler_state.processed_requests + ) or ( + self.warmup_duration is not None + and request_info.request_timings.request_end is not None + and ( + scheduler_state.start_time + self.warmup_duration + >= request_info.request_timings.request_end + ) + ): + status["requests_in_warmup"] = True + + return status + + if ( + self.cooldown_requests is not None + and scheduler_state.remaining_requests is not None + and self.cooldown_requests >= scheduler_state.remaining_requests + ) or ( + self.cooldown_duration is not None + and scheduler_state.remaining_duration is not None + and self.cooldown_duration >= scheduler_state.remaining_duration + ): + return status["requests_in_cooldown"] - :param result: The result to add to the aggregator. - """ - if not super().add_result(result): - return False + if "completed" not in agg_state: + agg_state["completed"] = [] + agg_state["errored"] = [] + agg_state["incomplete"] = [] - if result.request is None: - raise ValueError("Request is None, cannot add result.") + if request_info.status == "completed": + agg_state["completed"].append((response, request, request_info)) + elif request_info.status == "canceled": + agg_state["incomplete"].append((response, request, request_info)) + else: + agg_state["errored"].append((response, request, request_info)) - if result.response is None: - raise ValueError("Response is None, cannot add result.") + return status - self.requests_stats.request_start_time_delay.update( - result.response.start_time - result.request_info.worker_start - ) - self.requests_stats.request_start_time_targeted_delay.update( - result.response.start_time - result.request_info.targeted_start_time - ) - self.requests_stats.request_time_delay.update( - (result.response.start_time - result.request_info.worker_start) - + result.request_info.worker_end - - result.response.end_time - ) - self.requests_stats.request_time.update( - result.response.end_time - result.response.start_time - ) - if result.response.first_iter_time: - self.requests_stats.time_to_first_token.update( - result.response.first_iter_time - result.response.start_time - ) - if result.response.last_iter_time and result.response.first_iter_time: - self.requests_stats.inter_token_latency.update( - result.response.last_iter_time - result.response.first_iter_time, - count=(result.response.output_tokens or 1) - 1, - ) - self.requests_stats.prompt_tokens += result.response.request_prompt_tokens or 0 - self.requests_stats.output_tokens += result.response.request_output_tokens or 0 - total_tokens = (result.response.request_prompt_tokens or 0) + ( - result.response.request_output_tokens or 0 - ) - self.requests_stats.total_tokens += total_tokens + def compile( + self, agg_state: dict[str, Any], scheduler_state: SchedulerState + ) -> dict[str, Any]: + """ + Compile aggregated requests into comprehensive benchmark results. - return True + Transforms collected request data into detailed metrics including timing + distributions, token statistics, throughput measurements, and status breakdowns. - def compile(self) -> GenerativeBenchmark: + :param agg_state: Accumulated request data categorized by completion status. + :param scheduler_state: Final scheduler execution state. + :return: Complete benchmark results with metrics and request statistics. """ - Compile the benchmark results and statistics into a GenerativeBenchmark object. - This is required to be implemented by subclasses to finalize the benchmark - and return the compiled object. - """ - successful, incomplete, errored = self._compile_results() - - return GenerativeBenchmark.from_stats( - run_id=self.run_id, - successful=successful, - incomplete=incomplete, - errored=errored, - args=self.args, - run_stats=BenchmarkRunStats( - start_time=self.requests_stats.totals.total.start_time, - end_time=time.time(), - requests_made=StatusBreakdown( - successful=int(self.requests_stats.totals.successful.total), - errored=int(self.requests_stats.totals.errored.total), - incomplete=int(self.requests_stats.totals.incomplete.total), - total=int(self.requests_stats.totals.total.total), - ), - queued_time_avg=self.requests_stats.queued_time.mean, - scheduled_time_delay_avg=self.requests_stats.scheduled_time_delay.mean, - scheduled_time_sleep_avg=self.requests_stats.scheduled_time_sleep.mean, - worker_start_delay_avg=self.requests_stats.worker_start_delay.mean, - worker_time_avg=self.requests_stats.worker_time.mean, - worker_start_time_targeted_delay_avg=self.requests_stats.worker_start_time_targeted_delay.mean, - request_start_time_delay_avg=self.requests_stats.request_start_time_delay.mean, - request_start_time_targeted_delay_avg=self.requests_stats.request_start_time_targeted_delay.mean, - request_time_delay_avg=self.requests_stats.request_time_delay.mean, - request_time_avg=self.requests_stats.request_time.mean, - ), - worker=self.worker_description, - requests_loader=self.request_loader_description, - extras=self.extras, + successful: list[GenerativeRequestStats] = [ + self._create_generate_stats(response, request, request_info) + for (response, request, request_info) in agg_state.get("completed", []) + ] + incomplete: list[GenerativeRequestStats] = [ + self._create_generate_stats(response, request, request_info) + for (response, request, request_info) in agg_state.get("incomplete", []) + ] + errored: list[GenerativeRequestStats] = [ + self._create_generate_stats(response, request, request_info) + for (response, request, request_info) in agg_state.get("errored", []) + ] + total: list[GenerativeRequestStats] = successful + incomplete + errored + total_types = list[Literal["successful", "incomplete", "error"]] = [ + *["successful"] * len(successful), + *["incomplete"] * len(incomplete), + *["error"] * len(errored), + ] + start_time = min( + [math.inf] + + [ + req.scheduler_info.request_timings.request_start + for req in total + if req.scheduler_info.request_timings.request_start is not None + ] + ) + end_time = max( + [-1 * math.inf] + + [ + req.scheduler_info.request_timings.request_end + for req in total + if req.scheduler_info.request_timings.request_end is not None + ] ) - def _compile_results( - self, - ) -> tuple[ - list[GenerativeTextResponseStats], - list[GenerativeTextErrorStats], - list[GenerativeTextErrorStats], - ]: - successful: list[GenerativeTextResponseStats] = [ - GenerativeTextResponseStats( - request_id=result.request.request_id, - request_type=result.request.request_type, - scheduler_info=result.request_info, - prompt=str(result.request.content), - prompt_tokens=self._compile_tokens_count( - value=str(result.request.content), - requests_tokens=result.response.request_prompt_tokens, - response_tokens=result.response.response_prompt_tokens, - preferred_tokens_source=settings.preferred_prompt_tokens_source, - errored=False, + return { + "start_time": start_time, + "end_time": end_time, + "request_totals": StatusBreakdown( + successful=len(successful), + incomplete=len(incomplete), + errored=len(errored), + total=len(total), + ), + "requests": StatusBreakdown( + successful=successful, + incomplete=incomplete, + errored=errored, + ), + "metrics": GenerativeMetrics( + requests_per_second=( + StatusDistributionSummary.from_request_times( + request_types=total_types, + requests=[ + ( + req.scheduler_info.request_timings.request_start, + req.scheduler_info.request_timings.request_end, + ) + for req in total + if ( + req.scheduler_info.request_timings.request_start + is not None + and req.scheduler_info.request_timings.request_end + is not None + ) + ], + distribution_type="rate", + ) ), - output=result.response.value, - output_tokens=self._compile_tokens_count( - value=result.response.value, - requests_tokens=result.response.request_output_tokens, - response_tokens=result.response.response_output_tokens, - preferred_tokens_source=settings.preferred_output_tokens_source, - errored=False, + request_concurrency=( + StatusDistributionSummary.from_request_times( + request_types=total_types, + requests=[ + ( + req.scheduler_info.request_timings.request_start, + req.scheduler_info.request_timings.request_end, + ) + for req in total + if ( + req.scheduler_info.request_timings.request_start + is not None + and req.scheduler_info.request_timings.request_end + is not None + ) + ], + distribution_type="concurrency", + ) ), - start_time=result.response.start_time, - end_time=result.response.end_time, - first_token_time=result.response.first_iter_time or -1.0, - last_token_time=result.response.last_iter_time or -1.0, - ) - for result in self.results.successful - if result.request and result.response - ] - incomplete: list[GenerativeTextErrorStats] = [ - GenerativeTextErrorStats( - error=result.response.error or "", - request_id=result.request.request_id, - request_type=result.request.request_type, - scheduler_info=result.request_info, - prompt=str(result.request.content), - prompt_tokens=self._compile_tokens_count( - value=str(result.request.content), - requests_tokens=result.response.request_prompt_tokens, - response_tokens=result.response.response_prompt_tokens, - preferred_tokens_source=settings.preferred_prompt_tokens_source, - errored=True, + request_latency=( + StatusDistributionSummary.from_values( + value_types=total_types, + values=[ + req.request_latency + for req in total + if req.request_latency is not None + ], + ) ), - output=result.response.value, - output_tokens=self._compile_tokens_count( - value=result.response.value, - requests_tokens=result.response.request_output_tokens, - response_tokens=result.response.response_output_tokens, - preferred_tokens_source=settings.preferred_output_tokens_source, - errored=True, + prompt_token_count=( + StatusDistributionSummary.from_values( + value_types=[ + type_ + for type_, req in zip(total_types, total) + if req.prompt_tokens is not None + ], + values=[ + req.prompt_tokens + for req in total + if req.prompt_tokens is not None + ], + ) ), - start_time=result.response.start_time, - end_time=result.response.end_time, - first_token_time=result.response.first_iter_time, - last_token_time=result.response.last_iter_time, - ) - for result in self.results.incomplete - if result.request and result.response - ] - error: list[GenerativeTextErrorStats] = [ - GenerativeTextErrorStats( - error=result.response.error or "", - request_id=result.request.request_id, - request_type=result.request.request_type, - scheduler_info=result.request_info, - prompt=str(result.request.content), - prompt_tokens=self._compile_tokens_count( - value=str(result.request.content), - requests_tokens=result.response.request_prompt_tokens, - response_tokens=result.response.response_prompt_tokens, - preferred_tokens_source=settings.preferred_prompt_tokens_source, - errored=True, + output_token_count=( + StatusDistributionSummary.from_values( + value_types=[ + type_ + for type_, req in zip(total_types, total) + if req.output_tokens is not None + ], + values=[ + req.output_tokens + for req in total + if req.output_tokens is not None + ], + ) ), - output=result.response.value, - output_tokens=self._compile_tokens_count( - value=result.response.value, - requests_tokens=result.response.request_output_tokens, - response_tokens=result.response.response_output_tokens, - preferred_tokens_source=settings.preferred_output_tokens_source, - errored=True, + total_token_count=( + StatusDistributionSummary.from_values( + value_types=[ + type_ + for type_, req in zip(total_types, total) + if req.prompt_tokens is not None + or req.output_tokens is not None + ], + values=( + (req.prompt_tokens or 0) + (req.output_tokens or 0) + for req in total + if req.prompt_tokens is not None + or req.output_tokens is not None + ), + ) ), - start_time=result.response.start_time, - end_time=result.response.end_time, - first_token_time=result.response.first_iter_time, - last_token_time=result.response.last_iter_time, - ) - for result in self.results.errored - if result.request and result.response - ] - - return successful, incomplete, error + time_to_first_token_ms=( + StatusDistributionSummary.from_values( + value_types=[ + type_ + for type_, req in zip(total_types, total) + if req.time_to_first_token_ms is not None + ], + values=[ + req.time_to_first_token_ms + for req in total + if req.time_to_first_token_ms is not None + ], + ) + ), + time_per_output_token_ms=( + StatusDistributionSummary.from_values( + value_types=[ + type_ + for type_, req in zip(total_types, total) + if req.time_per_output_token_ms is not None + ], + values=[ + req.time_per_output_token_ms + for req in total + if req.time_per_output_token_ms is not None + ], + weights=[ + req.output_tokens + for req in total + if req.time_per_output_token_ms is not None + ], + ) + ), + inter_token_latency_ms=( + StatusDistributionSummary.from_values( + value_types=[ + type_ + for type_, req in zip(total_types, total) + if req.inter_token_latency_ms is not None + ], + values=[ + req.inter_token_latency_ms + for req in total + if req.inter_token_latency_ms is not None + ], + weights=[ + req.output_tokens - 1 + for req in total + if req.inter_token_latency_ms is not None + ], + ) + ), + output_tokens_per_second=( + StatusDistributionSummary.from_iterable_request_times( + request_types=[ + type_ + for type_, req in zip(total_types, total) + if req.output_tokens_per_second is not None + ], + requests=[ + ( + req.scheduler_info.request_timings.request_start, + req.scheduler_info.request_timings.request_end, + ) + for req in total + if req.output_tokens_per_second is not None + ], + first_iter_times=[ + req.scheduler_info.request_timings.first_iteration + for req in total + if req.output_tokens_per_second is not None + and req.scheduler_info.request_timings.first_iteration + is not None + ], + iter_counts=[ + req.output_tokens + for req in total + if req.output_tokens_per_second is not None + and req.output_tokens is not None + ], + ) + ), + tokens_per_second=( + StatusDistributionSummary.from_iterable_request_times( + request_types=[ + type_ + for type_, req in zip(total_types, total) + if req.tokens_per_second is not None + ], + requests=[ + ( + req.scheduler_info.request_timings.request_start, + req.scheduler_info.request_timings.request_end, + ) + for req in total + if req.tokens_per_second is not None + ], + first_iter_times=[ + req.scheduler_info.request_timings.first_iteration + for req in total + if req.tokens_per_second is not None + ], + iter_counts=[ + req.output_tokens + for req in total + if req.tokens_per_second is not None + ], + first_iter_counts=[ + req.prompt_tokens + for req in total + if req.tokens_per_second is not None + ], + ) + ), + ), + } + + @classmethod + def _create_generate_stats( + cls, + response: GenerationResponse, + request: GenerationRequest, + request_info: ScheduledRequestInfo[GenerationRequestTimings], + ) -> GenerativeRequestStats: + prompt_tokens = response.preferred_prompt_tokens( + settings.preferred_prompt_tokens_source + ) + output_tokens = response.preferred_output_tokens( + settings.preferred_output_tokens_source + ) - def _compile_tokens_count( - self, - value: str, - requests_tokens: Optional[int], - response_tokens: Optional[int], - preferred_tokens_source: Optional[Literal["request", "response", "local"]], - errored: bool, - ) -> int: - if not errored and preferred_tokens_source == "response" and response_tokens: - return response_tokens or 0 - - if not errored and preferred_tokens_source == "request" and requests_tokens: - return requests_tokens or 0 - - if preferred_tokens_source in {"response", "request"} and ( - self.processor is None or errored or response_tokens or requests_tokens - ): - # we had a preferred tokens source that isn't local and we either - # have the data to return something or we don't have the ability - # to calculate locally - return response_tokens or requests_tokens or 0 - - self.processor = check_load_processor( - self.processor, - processor_args=self.processor_args, - error_msg="Processor/Tokenizer is required for calculating token counts.", + return GenerativeRequestStats( + request_id=request.request_id, + request_type=request.request_type, + prompt=str(request.content), + request_args=response.request_args, + output=response.value, + iterations=response.iterations, + prompt_tokens=prompt_tokens, + output_tokens=output_tokens, + total_tokens=( + prompt_tokens + output_tokens + if prompt_tokens is not None and output_tokens is not None + else None + ), + scheduler_info=request_info, ) - return len(self.processor.tokenize(value)) diff --git a/src/guidellm/benchmark/benchmark.py b/src/guidellm/benchmark/benchmark.py index 1e2a5f4b..a91a88e9 100644 --- a/src/guidellm/benchmark/benchmark.py +++ b/src/guidellm/benchmark/benchmark.py @@ -1,9 +1,35 @@ -import random +""" +Benchmark data models and metrics for performance measurement and analysis. + +Provides comprehensive data structures for capturing, storing, and analyzing +benchmark results from scheduler executions. Includes timing measurements, +token statistics, and performance metrics for generative AI workloads. + +Classes: + BenchmarkSchedulerStats: Scheduler timing and performance statistics. + BenchmarkMetrics: Core benchmark metrics and distributions. + BenchmarkRequestStats: Individual request processing statistics. + Benchmark: Base benchmark result container with generic metrics. + GenerativeRequestStats: Request statistics for generative AI workloads. + GenerativeMetrics: Comprehensive metrics for generative benchmarks. + GenerativeBenchmark: Complete generative benchmark results and analysis. + GenerativeBenchmarksReport: Container for multiple benchmark results. + +Type Variables: + BenchmarkMetricsT: Generic benchmark metrics type. + BenchmarkRequestStatsT: Generic request statistics type. + BenchmarkT: Generic benchmark container type. +""" + +import json import uuid -from typing import Any, Literal, Optional, TypeVar, Union +from pathlib import Path +from typing import Any, ClassVar, Generic, Literal, Optional, TypedDict, TypeVar, Union +import yaml from pydantic import Field, computed_field +from guidellm.backend import GenerationRequestTimings from guidellm.benchmark.profile import ( AsyncProfile, ConcurrentProfile, @@ -17,819 +43,450 @@ StatusBreakdown, StatusDistributionSummary, ) -from guidellm.request import ( - GenerativeRequestLoaderDescription, - RequestLoaderDescription, -) from guidellm.scheduler import ( AsyncConstantStrategy, AsyncPoissonStrategy, ConcurrentStrategy, - GenerativeRequestsWorkerDescription, - SchedulerRequestInfo, + ScheduledRequestInfo, + SchedulerState, SchedulingStrategy, SynchronousStrategy, ThroughputStrategy, - WorkerDescription, ) __all__ = [ "Benchmark", - "BenchmarkArgs", "BenchmarkMetrics", - "BenchmarkRunStats", + "BenchmarkSchedulerStats", "BenchmarkT", "GenerativeBenchmark", + "GenerativeBenchmarksReport", "GenerativeMetrics", - "GenerativeTextErrorStats", - "GenerativeTextResponseStats", - "StatusBreakdown", + "GenerativeRequestStats", ] -class BenchmarkArgs(StandardBaseModel): - """ - A serializable model representing the arguments used to specify a benchmark run - and how data was collected for it. - """ - - profile: Union[ - AsyncProfile, - SweepProfile, - ConcurrentProfile, - ThroughputProfile, - SynchronousProfile, - Profile, - ] = Field( - description=( - "The profile used for the entire benchmark run that the strategy for " - "this benchmark was pulled from." - ), - discriminator="type_", - ) - strategy_index: int = Field( - description=( - "The index of the strategy in the profile that was used for this benchmark." - ) - ) - strategy: Union[ - ConcurrentStrategy, - SchedulingStrategy, - ThroughputStrategy, - SynchronousStrategy, - AsyncPoissonStrategy, - AsyncConstantStrategy, - SchedulingStrategy, - ] = Field( - description="The scheduling strategy used to run this benchmark. ", - discriminator="type_", - ) - max_number: Optional[int] = Field( - description="The maximum number of requests to run for this benchmark, if any." - ) - max_duration: Optional[float] = Field( - description="The maximum duration in seconds to run this benchmark, if any." - ) - warmup_number: Optional[int] = Field( - description=( - "The number of requests to run for the warmup phase of this benchmark, " - "if any. These are requests that were not included in the final results." - ) - ) - warmup_duration: Optional[float] = Field( - description=( - "The duration in seconds to run for the warmup phase of this benchmark, " - "if any. These are requests that were not included in the final results." - ) - ) - cooldown_number: Optional[int] = Field( - description=( - "The number of requests to run for the cooldown phase of this benchmark, " - "if any. These are requests that were not included in the final results." - ) - ) - cooldown_duration: Optional[float] = Field( - description=( - "The duration in seconds to run for the cooldown phase of this benchmark, " - "if any. These are requests that were not included in the final results." - ) - ) - - -class BenchmarkRunStats(StandardBaseModel): - """ - A serializable model representing the run process statistics for the - entire benchmark run across all requests including warmup and cooldown. - """ +class BenchmarkSchedulerStats(StandardBaseModel): + """Scheduler timing and performance statistics.""" start_time: float = Field( - description="The start time of the benchmark run.", - ) - end_time: float = Field( - description="The end time of the benchmark run.", + description="Unix timestamp when the benchmark run started" ) + end_time: float = Field(description="Unix timestamp when the benchmark run ended") requests_made: StatusBreakdown[int, int, int, int] = Field( - description=( - "The number of requests made for the benchmark run broken down by " - "status including successful, incomplete, errored, and the sum of all three" - ) + description="Request counts by status: successful, incomplete, errored, total" ) queued_time_avg: float = Field( - description=( - "The average time spent in the queue for each request in the benchmark " - "run until it was dequeued by a worker." - ) - ) - scheduled_time_delay_avg: float = Field( - description=( - "The average time delay between when a request was dequeued and when it " - "was scheduled to be processed by a worker in the benchmark run. " - "This should be as close to 0 as possible, any additional time is " - "overheads from the system or the worker." - ) + description="Avg time requests spent in the queue (seconds)" ) - scheduled_time_sleep_avg: float = Field( - description=( - "The average time spent sleeping til the desired start time was reached " - "after being scheduled by the worker in the benchmark run." - ) + worker_resolve_start_delay_avg: float = Field( + description="Avg delay before worker begins resolving req after dequeue (sec)" ) - worker_start_delay_avg: float = Field( - description=( - "The average time delay between when a request was scheduled and when " - "the worker started processing it in the benchmark run. " - "This should be as close to 0 as possible, any additional time is " - "overheads from the system or the worker." - ) + worker_resolve_time_avg: float = Field( + description="Avg time for worker to resolve requests (seconds)" ) - worker_time_avg: float = Field( - description=( - "The average time taken by the worker to process each request in the " - "benchmark run. This includes the time to generate the response and " - "any additional processing time." - ) + worker_resolve_end_delay_avg: float = Field( + description="Avg delay after request end till worker resolves (seconds)" ) - worker_start_time_targeted_delay_avg: float = Field( - description=( - "The average time delay between when a request was targeted to start " - "and when the worker actually started processing it in the benchmark " - "run. For async strategies, this represents delays from the ideal " - "system. For sync strategies, since those are doubled in queue, " - "this should be as close to the time for a request to be processed " - "as possible. Any additional time is overhead from the system or " - "the worker." - ) + finalized_delay_avg: float = Field( + description="Avg delay after resolve til finalized with in scheduler (sec)" ) - request_start_time_delay_avg: float = Field( - description=( - "The average time delay between the actual request being made " - "and the time the worker started on the request for all requests " - "that completed within the benchmark run. This time should be as close " - "to 0 as possible, any additional time is overhead from the system or " - "the worker." - ) + worker_targeted_start_delay_avg: float = Field( + description="Avg delay from targeted start to actual worker start (seconds)" ) - request_start_time_targeted_delay_avg: float = Field( - description=( - "The average time delay between when the targeted start time and " - "the actual start time for each request in the benchmark run. " - "For async strategies, this represents delays from the ideal " - "system. For sync strategies, this should be as close to the " - "time for a request to be processed as possible. Any additional " - "time is overhead from the system or the worker." - ) + request_start_delay_avg: float = Field( + description="Avg delay after resolve til request start (seconds)" ) - request_time_delay_avg: float = Field( - description=( - "The average time delay between the total request time and the " - "worker time. This should be as close to 0 as possible, any additional " - "time is overhead from the system or the worker. " - ) - ) - request_time_avg: float = Field( - description=( - "The average time spent processing all requests in the benchmark run. " - "This is the time from when the actual request was started to when " - "it was completed." - ) + request_time_avg: float = Field(description="Avg request processing time (seconds)") + request_targeted_delay_avg: float = Field( + description="Avg delay from targeted start to actual request start" ) +class SchedulerDict(TypedDict, total=False): + """Scheduler configuration and execution state dictionary.""" + + strategy: Union[ + AsyncConstantStrategy, + AsyncPoissonStrategy, + ConcurrentStrategy, + SynchronousStrategy, + ThroughputStrategy, + SchedulingStrategy, + ] + constraints: dict[str, dict[str, Any]] + state: SchedulerState + + +class BenchmarkerDict(TypedDict, total=False): + """Benchmarker configuration and component settings dictionary.""" + + profile: Union[ + AsyncProfile, + ConcurrentProfile, + SynchronousProfile, + ThroughputProfile, + SweepProfile, + Profile, + ] + requests: dict[str, Any] + backend: dict[str, Any] + environment: dict[str, Any] + aggregators: dict[str, dict[str, Any]] + + class BenchmarkMetrics(StandardBaseModel): - """ - A serializable model representing the metrics for a benchmark run. - """ + """Core benchmark metrics and statistical distributions.""" requests_per_second: StatusDistributionSummary = Field( - description="The distribution of requests per second for the benchmark.", + description="Distribution of requests per second across benchmark execution" ) request_concurrency: StatusDistributionSummary = Field( - description="The distribution of requests concurrency for the benchmark.", + description="Distribution of concurrent request counts during execution" ) + request_latency: StatusDistributionSummary = Field( + description="Distribution of request latencies for completed requests" + ) + + +BenchmarkMetricsT = TypeVar("BenchmarkMetricsT", bound=BenchmarkMetrics) + + +class BenchmarkRequestStats(StandardBaseModel): + """Individual request processing statistics and scheduling metadata.""" + + scheduler_info: ScheduledRequestInfo[GenerationRequestTimings] = Field( + description="Scheduler metadata and timing information for the request" + ) + +BenchmarkRequestStatsT = TypeVar("BenchmarkRequestStatsT", bound=BenchmarkRequestStats) -class Benchmark(StandardBaseModel): - """ - The base serializable model representing a benchmark run and its results. - Specific benchmarker implementations should extend this model to include - additional information or metadata as needed. - Note, requests_per_second and request_concurrency are kept at this level - and are expected to be populated by the subclass implementation to ensure - the logic for Profiles can include more complicated logic for determining - what rates and concurrency values to use for subsequent strategies. - """ +class Benchmark(StandardBaseModel, Generic[BenchmarkMetricsT, BenchmarkRequestStatsT]): + """Base benchmark result container with execution metadata.""" type_: Literal["benchmark"] = "benchmark" id_: str = Field( default_factory=lambda: str(uuid.uuid4()), - description="The unique identifier for the benchmark.", + description="Unique identifier for this benchmark execution", ) run_id: str = Field( - description=( - "The unique identifier for the encompasing benchmark run that this " - "benchmark was a part of." - ) + description="Identifier for the benchmarker run containing this benchmark" ) - args: BenchmarkArgs = Field( - description=( - "The arguments used to specify how to run the benchmark and collect data." - ) + run_index: int = Field( + description="Sequential index of this benchmark within the benchmarker run" ) - run_stats: BenchmarkRunStats = Field( - description=( - "The process statistics for the entire benchmark run across all requests." - ) + scheduler: SchedulerDict = Field( + description="Scheduler configuration and execution state" ) - worker: Union[WorkerDescription] = Field( - description=( - "The description and specifics for the worker used to resolve requests " - "for this benchmark." - ), + benchmarker: BenchmarkerDict = Field( + description="Benchmarker configuration and component settings" ) - request_loader: Union[RequestLoaderDescription] = Field( - description=( - "The description and specifics for the request loader used to create " - "requests for this benchmark." - ), + env_args: dict[str, Any] = Field( + description="Environment arguments and runtime configuration" ) extras: dict[str, Any] = Field( - description=( - "Any additional information or metadata that was passed for this benchmark." - ) + description="Additional metadata and custom benchmark parameters" ) - metrics: BenchmarkMetrics = Field( - description=( - "The metrics for the benchmark run represented as a distribution of " - "various per-request statistics." - ), + run_stats: BenchmarkSchedulerStats = Field( + description="Scheduler timing and performance statistics" + ) + start_time: float = Field( + default=-1.0, description="Unix timestamp when the first request was initiated" + ) + end_time: float = Field( + default=-1.0, description="Unix timestamp when the last request completed" + ) + + @computed_field # type: ignore[misc] + @property + def duration(self) -> float: + """ + Benchmark execution duration in seconds. + + :return: Time elapsed from first request start to last request completion. + """ + return self.end_time - self.start_time + + metrics: BenchmarkMetricsT = Field( + description="Performance metrics and statistical distributions" + ) + request_totals: StatusBreakdown[int, int, int, int] = Field( + description="Request counts by status: successful, incomplete, errored, total" + ) + requests: StatusBreakdown[ + list[BenchmarkRequestStatsT], + list[BenchmarkRequestStatsT], + list[BenchmarkRequestStatsT], + None, + ] = Field( + description="Request details grouped by status: successful, incomplete, errored" ) BenchmarkT = TypeVar("BenchmarkT", bound=Benchmark) -class GenerativeTextResponseStats(StandardBaseModel): - """ - A serializable model representing the request values, response values, and - statistics for a generative text response. - """ +class GenerativeRequestStats(BenchmarkRequestStats): + """Request statistics for generative AI text generation workloads.""" - type_: Literal["generative_text_response"] = "generative_text_response" - request_id: Optional[str] = Field( - description="The unique identifier for the request.", - ) + type_: Literal["generative_request_stats"] = "generative_request_stats" + request_id: str = Field(description="Unique identifier for the request") request_type: Literal["text_completions", "chat_completions"] = Field( - description="The type of request made to the generative backend." + description="Type of generative request: text or chat completion" ) - scheduler_info: SchedulerRequestInfo = Field( - description=( - "The info about the request from the scheduler about how it was run." - ), + prompt: str = Field(description="Input text prompt for generation") + request_args: dict[str, Any] = Field( + description="Generation parameters and configuration options" ) - prompt: str = Field( - description="The text prompt used for the generative request.", + output: Optional[str] = Field( + description="Generated text output, if request completed successfully" ) - output: str = Field( - description="The generated text output from the generative request.", + iterations: int = Field( + description="Number of processing iterations for the request" ) - prompt_tokens: int = Field( - description="The number of tokens in the prompt text.", + prompt_tokens: Optional[int] = Field( + description="Number of tokens in the input prompt" ) - output_tokens: int = Field( - description="The number of tokens in the generated output text.", - ) - start_time: float = Field( - description="The time the request started.", - ) - end_time: float = Field( - description="The time the request ended.", - ) - first_token_time: float = Field( - description="The time the first token was received.", - ) - last_token_time: float = Field( - description="The time the last token was received.", + output_tokens: Optional[int] = Field( + description="Number of tokens in the generated output" ) @computed_field # type: ignore[misc] @property - def request_latency(self) -> float: - """ - :return: The duration of the request in seconds from the start to the end. + def total_tokens(self) -> Optional[int]: """ - return self.end_time - self.start_time + Total token count including prompt and output tokens. - @computed_field # type: ignore[misc] - @property - def time_to_first_token_ms(self) -> float: + :return: Sum of prompt and output tokens, or None if either is unavailable. """ - :return: The time in milliseconds from the start of the request to the first - token received. - """ - return 1000 * (self.first_token_time - self.start_time) + if self.prompt_tokens is None or self.output_tokens is None: + return None + + return self.prompt_tokens + self.output_tokens @computed_field # type: ignore[misc] @property - def time_per_output_token_ms(self) -> float: + def request_latency(self) -> Optional[float]: """ - :return: The average time in milliseconds per output token generated. - This includes the time to generate the first token and all other tokens. + End-to-end request processing latency in seconds. + + :return: Duration from request start to completion, or None if unavailable. """ - if self.output_tokens == 0: - return 0.0 + if ( + not self.scheduler_info.request_timings.request_end + or not self.scheduler_info.request_timings.request_start + ): + return None return ( - 1000 * (self.last_token_time - self.first_token_time) / self.output_tokens + self.scheduler_info.request_timings.request_end + - self.scheduler_info.request_timings.request_start ) @computed_field # type: ignore[misc] @property - def inter_token_latency_ms(self) -> float: + def time_to_first_token_ms(self) -> Optional[float]: """ - :return: The average time in milliseconds between generating tokens in the - output text. Note, does not include the time to generate the first token. + Time to first token generation in milliseconds. + + :return: Latency from request start to first token, or None if unavailable. """ - if self.output_tokens <= 1: - return 0.0 + if ( + not self.scheduler_info.request_timings.first_iteration + or not self.scheduler_info.request_timings.request_start + ): + return None - return ( - 1000 - * (self.last_token_time - self.first_token_time) - / (self.output_tokens - 1) + return 1000 * ( + self.scheduler_info.request_timings.first_iteration + - self.scheduler_info.request_timings.request_start ) @computed_field # type: ignore[misc] @property - def tokens_per_second(self) -> float: - """ - :return: The average number of tokens generated per second in the prompt and - output text. + def time_per_output_token_ms(self) -> Optional[float]: """ - if (latency := self.request_latency) == 0.0: - return 0.0 + Average time per output token in milliseconds. - return (self.prompt_tokens + self.output_tokens) / latency + Includes time for first token and all subsequent tokens. - @computed_field # type: ignore[misc] - @property - def output_tokens_per_second(self) -> float: - """ - :return: The average number of output tokens generated per second. + :return: Average milliseconds per output token, or None if unavailable. """ - if (latency := self.request_latency) == 0.0: - return 0.0 - - return self.output_tokens / latency - - -class GenerativeTextErrorStats(GenerativeTextResponseStats): - """ - A serializable model representing the request values, response values, and - statistics for a generative text response that errored. - Extends and overrides the GenerativeTextResponseStats model to include the - error message and optional properties given the error occurred. - """ + if ( + not self.scheduler_info.request_timings.request_start + or not self.scheduler_info.request_timings.last_iteration + or not self.output_tokens + ): + return None - type_: Literal["generative_text_error"] = "generative_text_error" # type: ignore[assignment] - error: str = Field( - description=( - "The error message for the error that occurred while making the request." + return ( + 1000 + * ( + self.scheduler_info.request_timings.last_iteration + - self.scheduler_info.request_timings.request_start + ) + / self.output_tokens ) - ) - output: Optional[str] = Field( # type: ignore[assignment] - default=None, - description=( - "The generated text output from the generative request, if any, " - "before the error occurred." - ), - ) - first_token_time: Optional[float] = Field( # type: ignore[assignment] - default=None, - description=( - "The time the first token was received, if any, before the error occurred." - ), - ) - last_token_time: Optional[float] = Field( # type: ignore[assignment] - default=None, - description=( - "The time the last token was received, if any, before the error occurred." - ), - ) @computed_field # type: ignore[misc] @property - def time_to_first_token_ms(self) -> Optional[float]: # type: ignore[override] + def inter_token_latency_ms(self) -> Optional[float]: """ - :return: The time in milliseconds from the start of the request to the first - token received. None if the first token was not received. - """ - if self.first_token_time is None: - return None + Average inter-token latency in milliseconds. - return super().time_to_first_token_ms + Measures time between token generations, excluding first token. - @computed_field # type: ignore[misc] - @property - def time_per_output_token_ms(self) -> Optional[float]: # type: ignore[override] - """ - :return: The average time in milliseconds per output token generated. - This includes the time to generate the first token and all other tokens. - None if the output_tokens is None or 0. + :return: Average milliseconds between tokens, or None if unavailable. """ if ( - self.output_tokens is None - or self.output_tokens == 0 - or self.first_token_time is None - or self.last_token_time is None + not self.scheduler_info.request_timings.first_iteration + or not self.scheduler_info.request_timings.last_iteration + or not self.output_tokens + or self.output_tokens <= 1 ): return None - return super().time_per_output_token_ms + return ( + 1000 + * ( + self.scheduler_info.request_timings.last_iteration + - self.scheduler_info.request_timings.first_iteration + ) + / (self.output_tokens - 1) + ) @computed_field # type: ignore[misc] @property - def inter_token_latency_ms(self) -> Optional[float]: # type: ignore[override] + def tokens_per_second(self) -> Optional[float]: """ - :return: The average time in milliseconds between generating tokens in the - output text. Note, does not include the time to generate the first token. - None if there were no output_tokens or the first token was not received. + Overall token throughput including prompt and output tokens. + + :return: Total tokens per second, or None if unavailable. """ - if ( - self.output_tokens is None - or self.first_token_time is None - or self.last_token_time is None - ): + if not (latency := self.request_latency) or not (tokens := self.total_tokens): return None - return super().inter_token_latency_ms + return tokens / latency @computed_field # type: ignore[misc] @property - def output_tokens_per_second(self) -> Optional[float]: # type: ignore[override] + def output_tokens_per_second(self) -> Optional[float]: """ - :return: The average number of tokens generated per second in the output text. - Note, does not include the time to generate the first token. None if there - were no output_tokens or the first token was not received. + Output token generation throughput. + + :return: Output tokens per second, or None if unavailable. """ - if self.inter_token_latency_ms is None: + if not (latency := self.request_latency) or not self.output_tokens: return None - return super().output_tokens_per_second + return self.output_tokens / latency class GenerativeMetrics(BenchmarkMetrics): - """ - A serializable model representing the metrics for a generative benchmark run. - """ + """Comprehensive metrics for generative AI benchmarks.""" - request_latency: StatusDistributionSummary = Field( - description="The distribution of latencies for the completed requests.", - ) prompt_token_count: StatusDistributionSummary = Field( - description=( - "The distribution of token counts in the prompts for completed, " - "errored, and all requests." - ) + description="Distribution of prompt token counts by request status" ) output_token_count: StatusDistributionSummary = Field( - description=( - "The distribution of token counts in the outputs for completed, " - "errored, and all requests." - ) + description="Distribution of output token counts by request status" + ) + total_token_count: StatusDistributionSummary = Field( + description="Distribution of total token counts by request status" ) time_to_first_token_ms: StatusDistributionSummary = Field( - description=( - "The distribution of latencies to receiving the first token in " - "milliseconds for completed, errored, and all requests." - ), + description="Distribution of first token latencies in milliseconds" ) time_per_output_token_ms: StatusDistributionSummary = Field( - description=( - "The distribution of latencies per output token in milliseconds for " - "completed, errored, and all requests. " - "This includes the time to generate the first token and all other tokens." - ), + description="Distribution of average time per output token in milliseconds" ) inter_token_latency_ms: StatusDistributionSummary = Field( - description=( - "The distribution of latencies between tokens in milliseconds for " - "completed, errored, and all requests." - ), + description="Distribution of inter-token latencies in milliseconds" ) output_tokens_per_second: StatusDistributionSummary = Field( - description=( - "The distribution of output tokens per second for completed, " - "errored, and all requests." - ), + description="Distribution of output token generation rates" ) tokens_per_second: StatusDistributionSummary = Field( - description=( - "The distribution of tokens per second, including prompt and output tokens " - "for completed, errored, and all requests." - ), + description="Distribution of total token throughput including prompt and output" ) -class GenerativeBenchmark(Benchmark): - """ - A serializable model representing a benchmark run and its results for generative - requests and responses. Includes the completed and errored requests, the start - and end times for the benchmark, and the statistics for the requests and responses. - """ +class GenerativeBenchmark(Benchmark[GenerativeMetrics, GenerativeRequestStats]): + """Complete generative AI benchmark results with specialized metrics.""" type_: Literal["generative_benchmark"] = "generative_benchmark" # type: ignore[assignment] - start_time: float = Field( - description="The start time of the first request for the benchmark.", - ) - end_time: float = Field( - description="The end time of the last request for the benchmark.", - ) - @computed_field # type: ignore[misc] - @property - def duration(self) -> float: - """ - :return: The duration of the benchmark in seconds from the start of the - first request to the end of the last request. - """ - return self.end_time - self.start_time - worker: GenerativeRequestsWorkerDescription = Field( - description=( - "The description and specifics for the worker used to resolve requests " - "for this benchmark." - ), - ) - request_loader: GenerativeRequestLoaderDescription = Field( - description=( - "The description and specifics for the request loader used to create " - "requests for this benchmark." - ), - ) - metrics: GenerativeMetrics = Field( - description=( - "The metrics for the benchmark run represented as a distribution of " - "various per-request statistics." - ), - ) - # Output is ordered so keep the requests at the end for better readability in files - request_totals: StatusBreakdown[int, int, int, int] = Field( - description=( - "The number of requests made for the benchmark broken down by status " - "including successful, incomplete, errored, and the sum of all three" - ) - ) - request_samples: Optional[StatusBreakdown[int, int, int, None]] = Field( - description=( - "The number of requests that were randomly sampled for " - "the benchmark. None if no sampling was applied." - ), - default=None, - ) - requests: StatusBreakdown[ - list[GenerativeTextResponseStats], - list[GenerativeTextErrorStats], - list[GenerativeTextErrorStats], - None, - ] = Field( - description=( - "The breakdown of requests for the benchmark run including successful, " - "incomplete, and errored requests." - ), - ) +class GenerativeBenchmarksReport(StandardBaseModel): + """Container for multiple benchmark results with load/save functionality.""" + + DEFAULT_FILE: ClassVar[str] = "benchmarks.json" - def set_sample_size(self, sample_size: Optional[int]) -> "GenerativeBenchmark": + @staticmethod + def load_file( + path: Union[str, Path], type_: Literal["json", "yaml"] | None = None + ) -> "GenerativeBenchmarksReport": """ - Set the sample size for the benchmark. This will randomly sample the - requests for each status type to the given sample size or the maximum - number of requests for that status type, whichever is smaller. - This is applied to requests.successful, requests.errored, and - requests.incomplete. - If None, no sampling is applied and the state is kept. - - :param sample_size: The number of requests to sample for each status type. - :return: The benchmark with the sampled requests. - :raises ValueError: If the sample size is invalid. + Load a report from a file. + + :param path: The path to load the report from. + :param type_: File type override, auto-detected from extension if None. + :return: The loaded report. + :raises ValueError: If file type is unsupported. """ + path = Path(path) if not isinstance(path, Path) else path - if sample_size is not None: - if sample_size < 0 or not isinstance(sample_size, int): - raise ValueError( - f"Sample size must be non-negative integer, given {sample_size}" - ) + if path.is_dir(): + path = path / GenerativeBenchmarksReport.DEFAULT_FILE - sample_size = min(sample_size, len(self.requests.successful)) - error_sample_size = min(sample_size, len(self.requests.errored)) - incomplete_sample_size = min(sample_size, len(self.requests.incomplete)) + path.parent.mkdir(parents=True, exist_ok=True) + path_suffix = path.suffix.lower()[1:] - self.requests.successful = random.sample( - self.requests.successful, sample_size - ) - self.requests.errored = random.sample( - self.requests.errored, error_sample_size - ) - self.requests.incomplete = random.sample( - self.requests.incomplete, incomplete_sample_size - ) - self.request_samples = StatusBreakdown( - successful=len(self.requests.successful), - incomplete=len(self.requests.incomplete), - errored=len(self.requests.errored), - ) + with path.open("r") as file: + if (type_ or path_suffix) == "json": + model_dict = json.loads(file.read()) + elif (type_ or path_suffix) in ["yaml", "yml"]: + model_dict = yaml.safe_load(file) + else: + raise ValueError(f"Unsupported file type: {type_} for {path}.") - return self + return GenerativeBenchmarksReport.model_validate(model_dict) - @staticmethod - def from_stats( - run_id: str, - successful: list[GenerativeTextResponseStats], - incomplete: list[GenerativeTextErrorStats], - errored: list[GenerativeTextErrorStats], - args: BenchmarkArgs, - run_stats: BenchmarkRunStats, - worker: GenerativeRequestsWorkerDescription, - requests_loader: GenerativeRequestLoaderDescription, - extras: Optional[dict[str, Any]], - ) -> "GenerativeBenchmark": + benchmarks: list[GenerativeBenchmark] = Field( + description="The list of completed benchmarks contained within the report.", + default_factory=list, + ) + + def save_file( + self, path: Union[str, Path], type_: Literal["json", "yaml"] | None = None + ) -> Path: """ - Create a GenerativeBenchmark instance from the given statistics and metadata. - Given the completed and errored requests, the benchmark will fill in the - remaining statistics for the various metrics required for a benchmark. - This is the preferred method for creating a GenerativeBenchmark instance - to ensure all statistics are properly calculated and populated. - - :param run_id: The unique identifier for the benchmark run. - :param completed: The list of completed requests. - :param errored: The list of errored requests. - :param args: The arguments used to specify how to run the benchmark - and collect data. - :param run_stats: The process statistics for the entire benchmark run across - all requests. - :param worker: The description and specifics for the worker used to resolve - requests. - :param requests_loader: The description and specifics for the request loader - used to create requests. - :param extras: Any additional information or metadata that was passed for - this benchmark. - :return: A GenerativeBenchmark instance with the given statistics and metadata - populated and calculated + Save the report to a file. + + :param path: The path to save the report to. + :param type_: File type override, auto-detected from extension if None. + :return: The path to the saved report. + :raises ValueError: If file type is unsupported. """ - total = successful + incomplete + errored - total_types: list[Literal["successful", "incomplete", "error"]] = [ - *["successful"] * len(successful), # type: ignore[list-item] - *["incomplete"] * len(incomplete), # type: ignore[list-item] - *["error"] * len(errored), # type: ignore[list-item] - ] - start_time = min(req.start_time for req in total) - end_time = max(req.end_time for req in total) - - total_with_prompt, total_types_with_prompt = ( - zip(*filtered) - if ( - filtered := list( - filter(lambda val: bool(val[0].prompt), zip(total, total_types)) - ) - ) - else ([], []) - ) - total_with_output_first, total_types_with_output_first = ( - zip(*filtered) - if ( - filtered := list( - filter( - lambda val: bool(val[0].output_tokens > 0), - zip(total, total_types), - ) - ) - ) - else ([], []) - ) - total_with_output_multi, total_types_with_output_multi = ( - zip(*filtered) - if ( - filtered := list( - filter( - lambda val: bool(val[0].output_tokens > 1), - zip(total, total_types), - ) - ) - ) - else ([], []) - ) + path = Path(path) if not isinstance(path, Path) else path - return GenerativeBenchmark( - run_id=run_id, - args=args, - run_stats=run_stats, - extras=extras or {}, - start_time=start_time, - end_time=end_time, - worker=worker, - request_loader=requests_loader, - metrics=GenerativeMetrics( - requests_per_second=StatusDistributionSummary.from_request_times( - request_types=total_types, - requests=[(req.start_time, req.end_time) for req in total], - distribution_type="rate", - ), - request_concurrency=StatusDistributionSummary.from_request_times( - request_types=total_types, - requests=[(req.start_time, req.end_time) for req in total], - distribution_type="concurrency", - ), - request_latency=StatusDistributionSummary.from_values( - value_types=total_types, - values=[req.request_latency for req in total], - ), - prompt_token_count=StatusDistributionSummary.from_values( - value_types=list(total_types_with_prompt), - values=[req.prompt_tokens for req in total_with_prompt], - ), - output_token_count=StatusDistributionSummary.from_values( - value_types=list(total_types_with_output_first), - values=[req.output_tokens for req in total_with_output_first], - ), - time_to_first_token_ms=StatusDistributionSummary.from_values( - value_types=list(total_types_with_output_first), - values=[ - req.time_to_first_token_ms or 0 - for req in total_with_output_first - ], - ), - time_per_output_token_ms=StatusDistributionSummary.from_values( - value_types=list(total_types_with_output_first), - values=[ - req.time_per_output_token_ms or 0 - for req in total_with_output_first - ], - weights=[req.output_tokens for req in total_with_output_first], - ), - inter_token_latency_ms=StatusDistributionSummary.from_values( - value_types=list(total_types_with_output_multi), - values=[ - req.inter_token_latency_ms or 0 - for req in total_with_output_multi - ], - weights=[req.output_tokens - 1 for req in total_with_output_multi], - ), - output_tokens_per_second=StatusDistributionSummary.from_iterable_request_times( - request_types=list(total_types_with_output_first), - requests=[ - (req.start_time, req.end_time) - for req in total_with_output_first - ], - first_iter_times=[ - req.first_token_time or req.start_time - for req in total_with_output_first - ], - iter_counts=[req.output_tokens for req in total_with_output_first], - ), - tokens_per_second=StatusDistributionSummary.from_iterable_request_times( - request_types=list(total_types_with_output_first), - requests=[ - (req.start_time, req.end_time) - for req in total_with_output_first - ], - first_iter_times=[ - req.first_token_time or req.start_time - for req in total_with_output_first - ], - iter_counts=[req.output_tokens for req in total_with_output_first], - first_iter_counts=[ - req.prompt_tokens for req in total_with_output_first - ], - ), - ), - request_totals=StatusBreakdown( - successful=len(successful), - incomplete=len(incomplete), - errored=len(errored), - total=len(total), - ), - requests=StatusBreakdown( - successful=successful, - incomplete=incomplete, - errored=errored, - ), - ) + if path.is_dir(): + path = path / GenerativeBenchmarksReport.DEFAULT_FILE + + path.parent.mkdir(parents=True, exist_ok=True) + path_suffix = path.suffix.lower()[1:] + model_dict = self.model_dump() + + if (type_ or path_suffix) == "json": + save_str = json.dumps(model_dict) + elif (type_ or path_suffix) in ["yaml", "yml"]: + save_str = yaml.dump(model_dict) + else: + raise ValueError(f"Unsupported file type: {type_} for {path}.") + + with path.open("w") as file: + file.write(save_str) + + return path diff --git a/src/guidellm/benchmark/benchmarker.py b/src/guidellm/benchmark/benchmarker.py index 11b6d245..30822f16 100644 --- a/src/guidellm/benchmark/benchmarker.py +++ b/src/guidellm/benchmark/benchmarker.py @@ -1,334 +1,242 @@ -import time +""" +Benchmark execution orchestration and lifecycle management. + +Provides the core benchmarking engine that coordinates request scheduling, +data aggregation, and result compilation across different execution strategies +and environments. + +Classes: + Benchmarker: Abstract benchmark orchestrator for request processing workflows. + +Type Variables: + BenchmarkT: Generic benchmark result type. + RequestT: Generic request object type. + RequestTimingsT: Generic request timing object type. + ResponseT: Generic response object type. +""" + import uuid -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Iterable -from pathlib import Path +from abc import ABC +from collections.abc import AsyncIterator, Iterable from typing import ( Any, Generic, - Literal, Optional, Union, ) -from pydantic import Field -from transformers import PreTrainedTokenizerBase # type: ignore # noqa: PGH003 - -from guidellm.backend import Backend, ResponseSummary -from guidellm.benchmark.aggregator import ( - AggregatorT, - BenchmarkT, - GenerativeBenchmarkAggregator, -) -from guidellm.benchmark.benchmark import BenchmarkArgs, GenerativeBenchmark +from guidellm.benchmark.aggregator import Aggregator, CompilableAggregator +from guidellm.benchmark.benchmark import BenchmarkT from guidellm.benchmark.profile import Profile -from guidellm.objects import StandardBaseModel -from guidellm.request import ( - GenerationRequest, - GenerativeRequestLoaderDescription, - RequestLoaderDescription, -) from guidellm.scheduler import ( - GenerativeRequestsWorker, - RequestsWorker, + BackendT, + Constraint, + Environment, + MeasuredRequestTimingsT, RequestT, ResponseT, Scheduler, - SchedulerRequestResult, + SchedulerState, SchedulingStrategy, ) +from guidellm.utils import InfoMixin, ThreadSafeSingletonMixin -__all__ = ["Benchmarker", "BenchmarkerResult", "GenerativeBenchmarker"] +__all__ = ["Benchmarker"] -class BenchmarkerResult( - StandardBaseModel, Generic[AggregatorT, BenchmarkT, RequestT, ResponseT] +class Benchmarker( + Generic[BenchmarkT, RequestT, MeasuredRequestTimingsT, ResponseT], + ABC, + ThreadSafeSingletonMixin, ): - type_: Literal[ - "run_start", - "run_complete", - "scheduler_start", - "scheduler_update", - "scheduler_complete", - "benchmark_compiled", - ] - start_time: float - end_number: int - profile: Profile - current_index: int - current_strategy: Optional[SchedulingStrategy] = None - current_aggregator: Optional[AggregatorT] = None - current_benchmark: Optional[BenchmarkT] = None - current_result: Optional[SchedulerRequestResult[RequestT, ResponseT]] = None - - -class BenchmarkerStrategyLimits(StandardBaseModel): - requests_loader_size: Optional[int] = Field( - description="Size of the request loader.", - ) - max_number_per_strategy: Optional[int] = Field( - description="Maximum number of requests to process per strategy.", - ge=0, - ) - max_duration_per_strategy: Optional[float] = Field( - description="Maximum duration (in seconds) to process requests per strategy.", - ge=0, - ) - warmup_percent_per_strategy: Optional[float] = Field( - description="Percentage of requests to use for warmup.", - ge=0, - le=1, - ) - cooldown_percent_per_strategy: Optional[float] = Field( - description="Percentage of requests to use for cooldown.", - ge=0, - le=1, - ) - - @property - def max_number(self) -> Optional[int]: - if self.max_number_per_strategy is not None: - return self.max_number_per_strategy - - if self.requests_loader_size is not None: - return self.requests_loader_size + """ + Abstract benchmark orchestrator for request processing workflows. - return None + Coordinates the execution of benchmarking runs across different scheduling + strategies, aggregating metrics and compiling results. Manages the complete + benchmark lifecycle from request submission through result compilation. - @property - def max_duration(self) -> Optional[float]: - return self.max_duration_per_strategy - - @property - def warmup_number(self) -> Optional[int]: - if self.warmup_percent_per_strategy is None or self.max_number is None: - return None - - return int(self.warmup_percent_per_strategy * self.max_number) - - @property - def warmup_duration(self) -> Optional[float]: - if self.warmup_percent_per_strategy is None or self.max_duration is None: - return None - - return self.warmup_percent_per_strategy * self.max_duration - - @property - def cooldown_number(self) -> Optional[int]: - if self.cooldown_percent_per_strategy is None or self.max_number is None: - return None - - return int(self.cooldown_percent_per_strategy * self.max_number) - - @property - def cooldown_duration(self) -> Optional[float]: - if self.cooldown_percent_per_strategy is None or self.max_duration is None: - return None - - return self.cooldown_percent_per_strategy * self.max_duration - - -class Benchmarker(Generic[AggregatorT, BenchmarkT, RequestT, ResponseT], ABC): - def __init__( - self, - worker: RequestsWorker[RequestT, ResponseT], - request_loader: Iterable[RequestT], - requests_loader_description: RequestLoaderDescription, - benchmark_save_extras: Optional[dict[str, Any]] = None, - ): - self.worker = worker - self.scheduler: Scheduler[RequestT, ResponseT] = Scheduler( - worker=worker, request_loader=request_loader - ) - self.requests_loader_description = requests_loader_description - self.benchmark_save_extras = benchmark_save_extras + Implements thread-safe singleton pattern to ensure consistent state across + concurrent benchmark operations. + """ async def run( self, + requests: Iterable[ + Union[RequestT, Iterable[Union[RequestT, tuple[RequestT, float]]]] + ], + backend: BackendT[RequestT, MeasuredRequestTimingsT, ResponseT], profile: Profile, - max_number_per_strategy: Optional[int], - max_duration_per_strategy: Optional[float], - warmup_percent_per_strategy: Optional[float], - cooldown_percent_per_strategy: Optional[float], - ) -> AsyncGenerator[ - BenchmarkerResult[AggregatorT, BenchmarkT, RequestT, ResponseT], None + environment: Environment, + benchmark_aggregators: dict[ + str, + Union[ + Aggregator[ResponseT, RequestT, MeasuredRequestTimingsT], + CompilableAggregator[ResponseT, RequestT, MeasuredRequestTimingsT], + ], + ], + benchmark_class: type[BenchmarkT], + ) -> AsyncIterator[ + tuple[ + dict[str, Any], + Optional[BenchmarkT], + SchedulingStrategy, + Optional[SchedulerState], + ] ]: - try: - requests_loader_size = len(self.scheduler.request_loader) # type: ignore[arg-type] - except Exception: # noqa: BLE001 - requests_loader_size = None - - strategy_limits = BenchmarkerStrategyLimits( - requests_loader_size=requests_loader_size, - max_number_per_strategy=max_number_per_strategy, - max_duration_per_strategy=max_duration_per_strategy, - warmup_percent_per_strategy=warmup_percent_per_strategy, - cooldown_percent_per_strategy=cooldown_percent_per_strategy, - ) - start_time = time.time() - end_number = len(profile.strategy_types) - current_index = -1 - run_id = str(uuid.uuid4()) - - yield BenchmarkerResult( - type_="run_start", - start_time=start_time, - end_number=end_number, - profile=profile, - current_index=current_index, - current_strategy=None, - current_aggregator=None, - current_benchmark=None, - current_result=None, - ) - - while scheduling_strategy := profile.next_strategy(): - current_index += 1 - aggregator = self.create_benchmark_aggregator( - run_id=run_id, - profile=profile, - strategy_index=current_index, - strategy=scheduling_strategy, - limits=strategy_limits, - ) - - async for result in self.scheduler.run( - scheduling_strategy=scheduling_strategy, - max_number=max_number_per_strategy, - max_duration=max_duration_per_strategy, - ): - if result.type_ == "run_start": - yield BenchmarkerResult( - type_="scheduler_start", - start_time=start_time, - end_number=end_number, - profile=profile, - current_index=current_index, - current_strategy=scheduling_strategy, - current_aggregator=aggregator, - current_benchmark=None, - current_result=None, - ) - elif result.type_ == "run_complete": - yield BenchmarkerResult( - type_="scheduler_complete", - start_time=start_time, - end_number=end_number, - profile=profile, - current_index=current_index, - current_strategy=scheduling_strategy, - current_aggregator=aggregator, - current_benchmark=None, - current_result=None, - ) - elif isinstance(result, SchedulerRequestResult): - aggregator.add_result(result) - - yield BenchmarkerResult( - type_="scheduler_update", - start_time=start_time, - end_number=end_number, - profile=profile, - current_index=current_index, - current_strategy=scheduling_strategy, - current_aggregator=aggregator, - current_benchmark=None, - current_result=result, - ) - else: - raise ValueError(f"Unexpected result type: {type(result)}") - - benchmark: BenchmarkT = aggregator.compile() - profile.completed_strategy( - average_rate=benchmark.metrics.requests_per_second.successful.mean, - average_concurrency=benchmark.metrics.request_concurrency.successful.mean, - ) - - yield BenchmarkerResult( - type_="benchmark_compiled", - start_time=start_time, - end_number=end_number, - profile=profile, - current_index=current_index, - current_strategy=scheduling_strategy, - current_aggregator=None, - current_benchmark=benchmark, - current_result=None, - ) - - yield BenchmarkerResult( - type_="run_complete", - start_time=start_time, - end_number=end_number, - profile=profile, - current_index=current_index, - current_strategy=None, - current_aggregator=None, - current_benchmark=None, - current_result=None, - ) - - @abstractmethod - def create_benchmark_aggregator( - self, - run_id: str, - profile: Profile, - strategy_index: int, - strategy: SchedulingStrategy, - limits: BenchmarkerStrategyLimits, - ) -> AggregatorT: ... - - -class GenerativeBenchmarker( - Benchmarker[ - GenerativeBenchmarkAggregator, - GenerativeBenchmark, - GenerationRequest, - ResponseSummary, - ], -): - def __init__( - self, - backend: Backend, - request_loader: Iterable[GenerationRequest], - request_loader_description: GenerativeRequestLoaderDescription, - benchmark_save_extras: Optional[dict[str, Any]] = None, - processor: Optional[Union[str, Path, PreTrainedTokenizerBase]] = None, - processor_args: Optional[dict[str, Any]] = None, - ): - super().__init__( - worker=GenerativeRequestsWorker(backend), - request_loader=request_loader, - requests_loader_description=request_loader_description, - benchmark_save_extras=benchmark_save_extras, - ) - self.processor = processor - self.processor_args = processor_args - - def create_benchmark_aggregator( - self, + """ + Execute benchmark runs across multiple scheduling strategies. + + Orchestrates the complete benchmark workflow: iterates through scheduling + strategies from the profile, executes requests through the scheduler, + aggregates metrics, and compiles final benchmark results. + + :param requests: Request datasets for processing across strategies. + :param backend: Backend interface for request processing. + :param profile: Benchmark profile defining strategies and constraints. + :param environment: Execution environment for coordination. + :param benchmark_aggregators: Metric aggregation functions by name. + :param benchmark_class: Class for constructing final benchmark objects. + :yield: Tuples of (metrics_update, benchmark_result, strategy, state). + :raises Exception: If benchmark execution or compilation fails. + """ + with self.thread_lock: + run_id = str(uuid.uuid4()) + strategies_generator = profile.strategies_generator() + strategy, constraints = next(strategies_generator) + + while strategy is not None: + yield {}, None, strategy, None + aggregators_state = {key: {} for key in benchmark_aggregators} + + async for ( + response, + request, + request_info, + scheduler_state, + ) in Scheduler[ + BackendT, RequestT, MeasuredRequestTimingsT, ResponseT + ].run( + requests=requests, + backend=backend, + strategy=strategy, + env=environment, + **constraints, + ): + aggregators_update = {} + for key, aggregator in benchmark_aggregators.items(): + update = aggregator( + aggregators_state[key], + response, + request, + request_info, + scheduler_state, + ) + if update: + aggregators_update.update(update) + yield aggregators_update, None, strategy, scheduler_state + + benchmark_kwargs = self._compile_benchmark_kwargs( + run_id=run_id, + run_index=len(profile.completed_strategies), + profile=profile, + requests=requests, + backend=backend, + environment=environment, + aggregators=benchmark_aggregators, + aggregators_state=aggregators_state, + strategy=strategy, + constraints=constraints, + scheduler_state=scheduler_state, + ) + benchmark = benchmark_class(**benchmark_kwargs) + yield {}, benchmark, strategy, None + + strategy, constraints = strategies_generator.send(benchmark) + + @classmethod + def _compile_benchmark_kwargs( + cls, run_id: str, + run_index: int, profile: Profile, - strategy_index: int, + requests: Iterable[ + Union[RequestT, Iterable[Union[RequestT, tuple[RequestT, float]]]] + ], + backend: BackendT[RequestT, MeasuredRequestTimingsT, ResponseT], + environment: Environment, + aggregators: dict[ + str, + Union[ + Aggregator[ResponseT, RequestT, MeasuredRequestTimingsT], + CompilableAggregator[ResponseT, RequestT, MeasuredRequestTimingsT], + ], + ], + aggregators_state: dict[str, dict[str, Any]], strategy: SchedulingStrategy, - limits: BenchmarkerStrategyLimits, - ) -> GenerativeBenchmarkAggregator: - return GenerativeBenchmarkAggregator( - run_id=run_id, - args=BenchmarkArgs( - profile=profile, - strategy_index=strategy_index, - strategy=strategy, - max_number=limits.max_number, - max_duration=limits.max_duration, - warmup_number=limits.warmup_number, - warmup_duration=limits.warmup_duration, - cooldown_number=limits.cooldown_number, - cooldown_duration=limits.cooldown_duration, - ), - worker_description=self.worker.description, # type: ignore[arg-type] - request_loader_description=self.requests_loader_description, # type: ignore[arg-type] - extras=self.benchmark_save_extras or {}, - processor=self.processor, - processor_args=self.processor_args, - ) + constraints: dict[str, Union[Any, dict[str, Any], Constraint]], + scheduler_state: Optional[SchedulerState], + ) -> dict[str, Any]: + """ + Compile benchmark construction parameters from execution results. + + Aggregates metadata from scheduler execution and compiles it into + structured parameters for benchmark object construction. + + :param run_id: Unique identifier for the benchmark run. + :param run_index: Index of this strategy in the benchmark profile. + :param profile: Benchmark profile containing strategy configuration. + :param requests: Request datasets used for the benchmark. + :param backend: Backend interface used for request processing. + :param environment: Execution environment for coordination. + :param aggregators: Metric aggregation functions by name. + :param aggregators_state: Current state of metric aggregators. + :param strategy: Scheduling strategy that was executed. + :param constraints: Runtime constraints applied during execution. + :param scheduler_state: Final state of scheduler execution. + :return: Dictionary of parameters for benchmark object construction. + :raises ValueError: If aggregator output conflicts with existing keys. + """ + benchmark_kwargs = { + "run_id": run_id, + "run_index": run_index, + "scheduler": { + "strategy": strategy, + "constraints": { + key: InfoMixin.extract_from_obj(val) for key, val in constraints + }, + "state": scheduler_state, + }, + "benchmarker": { + "profile": profile, + "requests": InfoMixin.extract_from_obj(requests), + "backend": InfoMixin.extract_from_obj(backend), + "environment": InfoMixin.extract_from_obj(environment), + "aggregators": { + key: InfoMixin.extract_from_obj(aggregator) + for key, aggregator in aggregators.items() + }, + }, + "system": {}, + "extras": {}, + } + for key, aggregator in aggregators.items(): + if not isinstance(aggregator, CompilableAggregator): + continue + + compiled = aggregator.compile(aggregators_state[key]) + + if key not in benchmark_kwargs: + benchmark_kwargs[key] = compiled + continue + + existing_val = benchmark_kwargs[key] + if not (isinstance(existing_val, dict) and isinstance(compiled, dict)): + raise ValueError( + f"Key '{key}' already exists with value {existing_val} " + f"(type: {type(existing_val).__name__}) and cannot be " + f"overwritten with {compiled} (type: {type(compiled).__name__})" + ) + existing_val.update(compiled) + + return benchmark_kwargs diff --git a/src/guidellm/benchmark/entrypoints.py b/src/guidellm/benchmark/entrypoints.py index 2ef85c3e..74658818 100644 --- a/src/guidellm/benchmark/entrypoints.py +++ b/src/guidellm/benchmark/entrypoints.py @@ -3,21 +3,45 @@ from typing import Any, Literal, Optional, Union from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict +from pydantic import Field, validate_call from transformers import ( # type: ignore[import] PreTrainedTokenizerBase, ) -from guidellm.backend import Backend, BackendType -from guidellm.benchmark.benchmarker import GenerativeBenchmarker +from guidellm.backend import ( + Backend, + BackendType, + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, +) +from guidellm.benchmark.aggregator import ( + GenerativeRequestsAggregator, + GenerativeRequestsStatsProgressAggregator, + SchedulerStatsAggregator, +) +from guidellm.benchmark.benchmark import GenerativeBenchmark +from guidellm.benchmark.benchmarker import Benchmarker from guidellm.benchmark.output import ( GenerativeBenchmarksConsole, GenerativeBenchmarksReport, ) -from guidellm.benchmark.profile import ProfileType, create_profile -from guidellm.benchmark.progress import GenerativeTextBenchmarkerProgressDisplay +from guidellm.benchmark.profile import Profile, ProfileType +from guidellm.benchmark.progress import ( + BenchmarkerProgress, + BenchmarkerProgressGroup, + GenerativeConsoleBenchmarkerProgress, +) from guidellm.benchmark.scenario import GenerativeTextScenario, Scenario from guidellm.request import GenerativeRequestLoader from guidellm.scheduler import StrategyType +from guidellm.scheduler.environment import NonDistributedEnvironment + +__all__ = [ + "benchmark_generative_text", + "benchmark_with_scenario", + "reimport_benchmarks_report", +] async def benchmark_with_scenario(scenario: Scenario, **kwargs): @@ -31,6 +55,7 @@ async def benchmark_with_scenario(scenario: Scenario, **kwargs): raise ValueError(f"Unsupported Scenario type {type(scenario)}") +@validate_call async def benchmark_generative_text( target: str, backend_type: BackendType, @@ -53,28 +78,27 @@ async def benchmark_generative_text( rate: Optional[Union[float, list[float]]], max_seconds: Optional[float], max_requests: Optional[int], + max_errors: Optional[int], + max_error_rate: Optional[float], + max_global_error_rate: Optional[float], warmup_percent: Optional[float], cooldown_percent: Optional[float], output_path: Optional[Union[str, Path]], output_extras: Optional[dict[str, Any]], output_sampling: Optional[int], random_seed: int, - show_progress: bool = True, - show_progress_scheduler_stats: bool = False, + progress: Optional[list[BenchmarkerProgress]] = Field( + default_factory=lambda: [GenerativeConsoleBenchmarkerProgress()] + ), output_console: bool = True, ) -> tuple[GenerativeBenchmarksReport, Optional[Path]]: - console = GenerativeBenchmarksConsole(enabled=show_progress) - console.print_line("Creating backend...") + console = GenerativeBenchmarksConsole(enabled=progress is not None) backend = Backend.create( backend_type, target=target, model=model, **(backend_args or {}) ) - await backend.validate() - console.print_line( - f"Backend {backend_type} connected to {target} for model {backend.model}." - ) if processor is None: - processor = backend.model + processor = await backend.default_model() console.print_line("Creating request loader...") request_loader = GenerativeRequestLoader( @@ -83,11 +107,6 @@ async def benchmark_generative_text( processor=processor, processor_args=processor_args, shuffle=data_sampler == "random", - iter_type=( - "finite" # assume a finite dataset is our limit - if max_requests is None and max_seconds is None - else "infinite" # default to infinite so we don't run out of data - ), random_seed=random_seed, ) unique_requests = request_loader.num_unique_items(raise_err=False) @@ -96,41 +115,74 @@ async def benchmark_generative_text( if unique_requests > 0 else f"Created loader with unknown number unique requests from {data}.\n\n" ) - - profile = create_profile(rate_type=rate_type, rate=rate) - benchmarker = GenerativeBenchmarker( - backend=backend, - request_loader=request_loader, - request_loader_description=request_loader.description, - benchmark_save_extras=output_extras, - processor=processor, - processor_args=processor_args, - ) - progress = ( - GenerativeTextBenchmarkerProgressDisplay( - display_scheduler_stats=show_progress_scheduler_stats - ) - if show_progress - else None + profile = Profile.create( + rate_type=rate_type, + rate=rate, + random_seed=random_seed, + constraints={ + key: val + for key, val in { + "max_requests": max_requests, + "max_seconds": max_seconds, + "max_errors": max_errors, + "max_error_rate": max_error_rate, + "max_global_error_rate": max_global_error_rate, + }.items() + if val is not None + }, ) report = GenerativeBenchmarksReport() + aggregators = { + "scheduler_stats": SchedulerStatsAggregator(), + "requests_progress": GenerativeRequestsStatsProgressAggregator(), + "requests": GenerativeRequestsAggregator( + warmup_requests=( + int(max_requests * warmup_percent) + if warmup_percent and max_requests + else None + ), + warmup_duration=( + max_seconds * warmup_percent if warmup_percent and max_seconds else None + ), + cooldown_requests=( + int(max_requests * cooldown_percent) + if cooldown_percent and max_requests + else None + ), + cooldown_duration=( + max_seconds * cooldown_percent + if cooldown_percent and max_seconds + else None + ), + ), + } + progress_group = BenchmarkerProgressGroup( + instances=progress or [], enabled=progress is not None + ) - async for result in benchmarker.run( - profile=profile, - max_number_per_strategy=max_requests, - max_duration_per_strategy=max_seconds, - warmup_percent_per_strategy=warmup_percent, - cooldown_percent_per_strategy=cooldown_percent, + async for ( + _aggregator_update, + benchmark, + _strategy, + _scheduler_state, + ) in progress_group( + profile, + Benchmarker[ + GenerativeBenchmark, + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, + ].run( + requests=request_loader, + backend=backend, + profile=profile, + environment=NonDistributedEnvironment(), + benchmark_aggregators=aggregators, + benchmark_class=GenerativeBenchmark, + ), ): - if progress: - progress.update(result) - - if result.type_ == "benchmark_compiled": - if result.current_benchmark is None: - raise ValueError("Current benchmark is None") - report.benchmarks.append( - result.current_benchmark.set_sample_size(output_sampling) - ) + if benchmark: + report.benchmarks.append(benchmark) if output_console: console.benchmarks = report.benchmarks diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py index 8a113f72..01f576b2 100644 --- a/src/guidellm/benchmark/output.py +++ b/src/guidellm/benchmark/output.py @@ -1,19 +1,22 @@ import csv import json import math +from abc import ABC, abstractmethod from collections import OrderedDict from datetime import datetime from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, ClassVar, Optional, Union import humps # type: ignore[import-not-found] -import yaml -from pydantic import Field from rich.console import Console from rich.padding import Padding from rich.text import Text -from guidellm.benchmark.benchmark import GenerativeBenchmark, GenerativeMetrics +from guidellm.benchmark.benchmark import ( + GenerativeBenchmark, + GenerativeBenchmarksReport, + GenerativeMetrics, +) from guidellm.benchmark.profile import ( AsyncProfile, ConcurrentProfile, @@ -23,7 +26,6 @@ from guidellm.config import settings from guidellm.objects import ( DistributionSummary, - StandardBaseModel, StatusDistributionSummary, ) from guidellm.presentation import UIDataBuilder @@ -32,396 +34,176 @@ from guidellm.utils import Colors, split_text_list_by_length __all__ = [ - "GenerativeBenchmarksConsole", - "GenerativeBenchmarksReport", + "GenerativeBenchmarkerCSV", + "GenerativeBenchmarkerConsole", + "GenerativeBenchmarkerHTML", + "GenerativeBenchmarkerOutput", ] -class GenerativeBenchmarksReport(StandardBaseModel): - """ - A pydantic model representing a completed benchmark report. - Contains a list of benchmarks along with convenience methods for finalizing - and saving the report. - """ - - @staticmethod - def load_file(path: Union[str, Path]) -> "GenerativeBenchmarksReport": - """ - Load a report from a file. The file type is determined by the file extension. - If the file is a directory, it expects a file named benchmarks.json under the - directory. - - :param path: The path to load the report from. - :return: The loaded report. - """ - path, type_ = GenerativeBenchmarksReport._file_setup(path) - - if type_ == "json": - with path.open("r") as file: - model_dict = json.load(file) - - return GenerativeBenchmarksReport.model_validate(model_dict) - - if type_ == "yaml": - with path.open("r") as file: - model_dict = yaml.safe_load(file) - - return GenerativeBenchmarksReport.model_validate(model_dict) - - if type_ == "csv": - raise ValueError(f"CSV file type is not supported for loading: {path}.") - - if type_ == "html": - raise ValueError(f"HTML file type is not supported for loading: {path}.") - - raise ValueError(f"Unsupported file type: {type_} for {path}.") - - benchmarks: list[GenerativeBenchmark] = Field( - description="The list of completed benchmarks contained within the report.", - default_factory=list, - ) - - def set_sample_size( - self, sample_size: Optional[int] - ) -> "GenerativeBenchmarksReport": - """ - Set the sample size for each benchmark in the report. In doing this, it will - reduce the contained requests of each benchmark to the sample size. - If sample size is None, it will return the report as is. - - :param sample_size: The sample size to set for each benchmark. - If None, the report will be returned as is. - :return: The report with the sample size set for each benchmark. - """ - - if sample_size is not None: - for benchmark in self.benchmarks: - benchmark.set_sample_size(sample_size) - - return self - - def save_file(self, path: Union[str, Path]) -> Path: - """ - Save the report to a file. The file type is determined by the file extension. - If the file is a directory, it will save the report to a file named - benchmarks.json under the directory. - - :param path: The path to save the report to. - :return: The path to the saved report. - """ - path, type_ = GenerativeBenchmarksReport._file_setup(path) - - if type_ == "json": - return self.save_json(path) - - if type_ == "yaml": - return self.save_yaml(path) - - if type_ == "csv": - return self.save_csv(path) - - if type_ == "html": - return self.save_html(path) +class GenerativeBenchmarkerOutput(ABC): + @abstractmethod + async def finalize(self, report: GenerativeBenchmarksReport) -> Any: ... - raise ValueError(f"Unsupported file type: {type_} for {path}.") - def save_json(self, path: Union[str, Path]) -> Path: - """ - Save the report to a JSON file containing all of the report data which is - reloadable using the pydantic model. If the file is a directory, it will save - the report to a file named benchmarks.json under the directory. +class GenerativeBenchmarkerConsole(GenerativeBenchmarkerOutput): + """Console output formatter for benchmark results with rich formatting.""" - :param path: The path to save the report to. - :return: The path to the saved report. + def __init__(self): """ - path, type_ = GenerativeBenchmarksReport._file_setup(path, "json") - - if type_ != "json": - raise ValueError( - f"Unsupported file type for saving a JSON: {type_} for {path}." - ) - - model_dict = self.model_dump() - model_json = json.dumps(model_dict) - - with path.open("w") as file: - file.write(model_json) - - return path - - def save_yaml(self, path: Union[str, Path]) -> Path: + Initialize the console output formatter. """ - Save the report to a YAML file containing all of the report data which is - reloadable using the pydantic model. If the file is a directory, it will save - the report to a file named benchmarks.yaml under the directory. - - :param path: The path to save the report to. - :return: The path to the saved report. - """ - - path, type_ = GenerativeBenchmarksReport._file_setup(path, "yaml") - - if type_ != "yaml": - raise ValueError( - f"Unsupported file type for saving a YAML: {type_} for {path}." - ) - - model_dict = self.model_dump() - model_yaml = yaml.dump(model_dict) - - with path.open("w") as file: - file.write(model_yaml) - - return path - - def save_csv(self, path: Union[str, Path]) -> Path: - """ - Save the report to a CSV file containing the summarized statistics and values - for each report. Note, this data is not reloadable using the pydantic model. - If the file is a directory, it will save the report to a file named - benchmarks.csv under the directory. - - :param path: The path to save the report to. - :return: The path to the saved report. - """ - path, type_ = GenerativeBenchmarksReport._file_setup(path, "csv") - - if type_ != "csv": - raise ValueError( - f"Unsupported file type for saving a CSV: {type_} for {path}." - ) - - with path.open("w", newline="") as file: - writer = csv.writer(file) - headers: list[str] = [] - rows: list[list[Union[str, float, list[float]]]] = [] - - for benchmark in self.benchmarks: - benchmark_headers: list[str] = [] - benchmark_values: list[Union[str, float, list[float]]] = [] - - desc_headers, desc_values = self._benchmark_desc_headers_and_values( - benchmark - ) - benchmark_headers += desc_headers - benchmark_values += desc_values - - for status in StatusDistributionSummary.model_fields: - status_headers, status_values = ( - self._benchmark_status_headers_and_values(benchmark, status) - ) - benchmark_headers += status_headers - benchmark_values += status_values - - benchmark_extra_headers, benchmark_extra_values = ( - self._benchmark_extras_headers_and_values(benchmark) - ) - benchmark_headers += benchmark_extra_headers - benchmark_values += benchmark_extra_values - - if not headers: - headers = benchmark_headers - rows.append(benchmark_values) - - writer.writerow(headers) - for row in rows: - writer.writerow(row) - - return path + self.console = Console() - def save_html(self, path: Union[str, Path]) -> Path: + async def finalize(self, report: GenerativeBenchmarksReport): """ - Download html, inject report data and save to a file. + Print the complete benchmark report to the console. - :param path: The path to create the report at. - :return: The path to the report. + :param report: The completed benchmark report. + :return: None (console output doesn't save to a file). """ + self._print_benchmarks_metadata(report.benchmarks) + self._print_benchmarks_info(report.benchmarks) + self._print_benchmarks_stats(report.benchmarks) - data_builder = UIDataBuilder(self.benchmarks) - data = data_builder.to_dict() - camel_data = humps.camelize(data) - ui_api_data = {} - for k, v in camel_data.items(): - key = f"window.{humps.decamelize(k)} = {{}};" - value = f"window.{humps.decamelize(k)} = {json.dumps(v, indent=2)};\n" - ui_api_data[key] = value - return create_report(ui_api_data, path) - - @staticmethod - def _file_setup( - path: Union[str, Path], - default_file_type: Literal["json", "yaml", "csv", "html"] = "json", - ) -> tuple[Path, Literal["json", "yaml", "csv", "html"]]: - path = Path(path) if not isinstance(path, Path) else path - - if path.is_dir(): - path = path / f"benchmarks.{default_file_type}" - - path.parent.mkdir(parents=True, exist_ok=True) - path_suffix = path.suffix.lower() - - if path_suffix == ".json": - return path, "json" - - if path_suffix in [".yaml", ".yml"]: - return path, "yaml" - - if path_suffix in [".csv"]: - return path, "csv" - - if path_suffix in [".html"]: - return path, "html" + def _print_benchmarks_metadata(self, benchmarks: list[GenerativeBenchmark]): + start_time = benchmarks[0].run_stats.start_time + end_time = benchmarks[-1].run_stats.end_time + duration = end_time - start_time - raise ValueError( - f"Unsupported file extension: {path_suffix} for {path}; " - "expected json, yaml, csv, or html." - ) + self._print_section_header("Benchmarks Metadata") + self._print_labeled_line("Run id", str(benchmarks[0].run_id)) + self._print_labeled_line("Duration", f"{duration:.1f} seconds") + self._print_labeled_line("Profile", self._get_profile_str(benchmarks[0])) + self._print_labeled_line("Scheduler", self._get_scheduler_str(benchmarks[0])) + self._print_labeled_line("Environment", self._get_env_args_str(benchmarks[0])) + self._print_labeled_line("Extras", self._get_extras_str(benchmarks[0])) - @staticmethod - def _benchmark_desc_headers_and_values( - benchmark: GenerativeBenchmark, - ) -> tuple[list[str], list[Union[str, float]]]: + def _print_benchmarks_info(self, benchmarks: list[GenerativeBenchmark]): + sections = { + "Metadata": (0, 3), + "Requests Made": (4, 6), + "Prompt Tok/Req": (7, 9), + "Output Tok/Req": (10, 12), + "Prompt Tok Total": (13, 15), + "Output Tok Total": (16, 18), + } headers = [ - "Type", - "Run Id", - "Id", - "Name", + "Benchmark", "Start Time", "End Time", - "Duration", - ] - values: list[Union[str, float]] = [ - benchmark.type_, - benchmark.run_id, - benchmark.id_, - strategy_display_str(benchmark.args.strategy), - datetime.fromtimestamp(benchmark.start_time).strftime("%Y-%m-%d %H:%M:%S"), - datetime.fromtimestamp(benchmark.end_time).strftime("%Y-%m-%d %H:%M:%S"), - benchmark.duration, - ] - - if len(headers) != len(values): - raise ValueError("Headers and values length mismatch.") - - return headers, values - - @staticmethod - def _benchmark_extras_headers_and_values( - benchmark: GenerativeBenchmark, - ) -> tuple[list[str], list[str]]: - headers = ["Args", "Worker", "Request Loader", "Extras"] - values: list[str] = [ - json.dumps(benchmark.args.model_dump()), - json.dumps(benchmark.worker.model_dump()), - json.dumps(benchmark.request_loader.model_dump()), - json.dumps(benchmark.extras), - ] - - if len(headers) != len(values): - raise ValueError("Headers and values length mismatch.") - - return headers, values - - @staticmethod - def _benchmark_status_headers_and_values( - benchmark: GenerativeBenchmark, status: str - ) -> tuple[list[str], list[Union[float, list[float]]]]: - headers = [ - f"{status.capitalize()} Requests", - ] - values = [ - getattr(benchmark.request_totals, status), + "Duration (s)", + "Comp", + "Inc", + "Err", + "Comp", + "Inc", + "Err", + "Comp", + "Inc", + "Err", + "Comp", + "Inc", + "Err", + "Comp", + "Inc", + "Err", ] - for metric in GenerativeMetrics.model_fields: - metric_headers, metric_values = ( - GenerativeBenchmarksReport._benchmark_status_metrics_stats( - benchmark, status, metric - ) + rows = [] + for benchmark in benchmarks: + rows.append( + [ + strategy_display_str(benchmark.scheduler["strategy"]), + datetime.fromtimestamp(benchmark.start_time).strftime("%H:%M:%S"), + datetime.fromtimestamp(benchmark.end_time).strftime("%H:%M:%S"), + f"{(benchmark.end_time - benchmark.start_time):.1f}", + f"{benchmark.request_totals.successful:.0f}", + f"{benchmark.request_totals.incomplete:.0f}", + f"{benchmark.request_totals.errored:.0f}", + f"{benchmark.metrics.prompt_token_count.successful.mean:.1f}", + f"{benchmark.metrics.prompt_token_count.incomplete.mean:.1f}", + f"{benchmark.metrics.prompt_token_count.errored.mean:.1f}", + f"{benchmark.metrics.output_token_count.successful.mean:.1f}", + f"{benchmark.metrics.output_token_count.incomplete.mean:.1f}", + f"{benchmark.metrics.output_token_count.errored.mean:.1f}", + f"{benchmark.metrics.prompt_token_count.successful.total_sum:.0f}", + f"{benchmark.metrics.prompt_token_count.incomplete.total_sum:.0f}", + f"{benchmark.metrics.prompt_token_count.errored.total_sum:.0f}", + f"{benchmark.metrics.output_token_count.successful.total_sum:.0f}", + f"{benchmark.metrics.output_token_count.incomplete.total_sum:.0f}", + f"{benchmark.metrics.output_token_count.errored.total_sum:.0f}", + ] ) - headers += metric_headers - values += metric_values - - if len(headers) != len(values): - raise ValueError("Headers and values length mismatch.") - return headers, values + self._print_table(headers, rows, "Benchmarks Info", sections) - @staticmethod - def _benchmark_status_metrics_stats( - benchmark: GenerativeBenchmark, - status: str, - metric: str, - ) -> tuple[list[str], list[Union[float, list[float]]]]: - status_display = status.capitalize() - metric_display = metric.replace("_", " ").capitalize() - status_dist_summary: StatusDistributionSummary = getattr( - benchmark.metrics, metric - ) - dist_summary: DistributionSummary = getattr(status_dist_summary, status) + def _print_benchmarks_stats(self, benchmarks: list[GenerativeBenchmark]): + sections = { + "Metadata": (0, 0), + "Request Stats": (1, 2), + "Out Tok/sec": (3, 3), + "Tot Tok/sec": (4, 4), + "Req Latency (sec)": (5, 7), + "TTFT (ms)": (8, 10), + "ITL (ms)": (11, 13), + "TPOT (ms)": (14, 16), + } headers = [ - f"{status_display} {metric_display} mean", - f"{status_display} {metric_display} median", - f"{status_display} {metric_display} std dev", - ( - f"{status_display} {metric_display} " - "[min, 0.1, 1, 5, 10, 25, 75, 90, 95, 99, max]" - ), - ] - values: list[Union[float, list[float]]] = [ - dist_summary.mean, - dist_summary.median, - dist_summary.std_dev, - [ - dist_summary.min, - dist_summary.percentiles.p001, - dist_summary.percentiles.p01, - dist_summary.percentiles.p05, - dist_summary.percentiles.p10, - dist_summary.percentiles.p25, - dist_summary.percentiles.p75, - dist_summary.percentiles.p90, - dist_summary.percentiles.p95, - dist_summary.percentiles.p99, - dist_summary.max, - ], + "Benchmark", + "Per Second", + "Concurrency", + "mean", + "mean", + "mean", + "median", + "p99", + "mean", + "median", + "p99", + "mean", + "median", + "p99", + "mean", + "median", + "p99", ] - if len(headers) != len(values): - raise ValueError("Headers and values length mismatch.") - - return headers, values - - -class GenerativeBenchmarksConsole: - """ - A class for outputting progress and benchmark results to the console. - Utilizes the rich library for formatting, enabling colored and styled output. - """ - - def __init__(self, enabled: bool = True): - """ - :param enabled: Whether to enable console output. Defaults to True. - If False, all console output will be suppressed. - """ - self.enabled = enabled - self.benchmarks: Optional[list[GenerativeBenchmark]] = None - self.console = Console() + rows = [] + for benchmark in benchmarks: + rows.append( + [ + strategy_display_str(benchmark.scheduler["strategy"]), + f"{benchmark.metrics.requests_per_second.successful.mean:.2f}", + f"{benchmark.metrics.request_concurrency.successful.mean:.2f}", + f"{benchmark.metrics.output_tokens_per_second.successful.mean:.1f}", + f"{benchmark.metrics.tokens_per_second.successful.mean:.1f}", + f"{benchmark.metrics.request_latency.successful.mean:.2f}", + f"{benchmark.metrics.request_latency.successful.median:.2f}", + f"{benchmark.metrics.request_latency.successful.percentiles.p99:.2f}", + f"{benchmark.metrics.time_to_first_token_ms.successful.mean:.1f}", + f"{benchmark.metrics.time_to_first_token_ms.successful.median:.1f}", + f"{benchmark.metrics.time_to_first_token_ms.successful.percentiles.p99:.1f}", + f"{benchmark.metrics.inter_token_latency_ms.successful.mean:.1f}", + f"{benchmark.metrics.inter_token_latency_ms.successful.median:.1f}", + f"{benchmark.metrics.inter_token_latency_ms.successful.percentiles.p99:.1f}", + f"{benchmark.metrics.time_per_output_token_ms.successful.mean:.1f}", + f"{benchmark.metrics.time_per_output_token_ms.successful.median:.1f}", + f"{benchmark.metrics.time_per_output_token_ms.successful.percentiles.p99:.1f}", + ] + ) - @property - def benchmarks_profile_str(self) -> str: - """ - :return: A string representation of the profile used for the benchmarks. - """ - profile = self.benchmarks[0].args.profile if self.benchmarks else None + self._print_table(headers, rows, "Benchmarks Stats", sections) + def _get_profile_str(self, benchmark: GenerativeBenchmark) -> str: + profile = benchmark.benchmarker.get("profile") if profile is None: return "None" profile_args = OrderedDict( { "type": profile.type_, - "strategies": profile.strategy_types, + "strategies": getattr(profile, "strategy_types", []), } ) @@ -438,131 +220,63 @@ def benchmarks_profile_str(self) -> str: return ", ".join(f"{key}={value}" for key, value in profile_args.items()) - @property - def benchmarks_args_str(self) -> str: - """ - :return: A string representation of the arguments used for the benchmarks. - """ - args = self.benchmarks[0].args if self.benchmarks else None - - if args is None: - return "None" - + def _get_args_str(self, benchmark: GenerativeBenchmark) -> str: + args = benchmark.args args_dict = OrderedDict( { "max_number": args.max_number, "max_duration": args.max_duration, - "warmup_number": args.warmup_number, - "warmup_duration": args.warmup_duration, - "cooldown_number": args.cooldown_number, - "cooldown_duration": args.cooldown_duration, - } - ) - - return ", ".join(f"{key}={value}" for key, value in args_dict.items()) - - @property - def benchmarks_worker_desc_str(self) -> str: - """ - :return: A string representation of the worker used for the benchmarks. - """ - return str(self.benchmarks[0].worker) if self.benchmarks else "None" - - @property - def benchmarks_request_loader_desc_str(self) -> str: - """ - :return: A string representation of the request loader used for the benchmarks. - """ - return str(self.benchmarks[0].request_loader) if self.benchmarks else "None" - - @property - def benchmarks_extras_str(self) -> str: - """ - :return: A string representation of the extras used for the benchmarks. - """ - extras = self.benchmarks[0].extras if self.benchmarks else None + "warmup_number": args.warmup_number, + "warmup_duration": args.warmup_duration, + "cooldown_number": args.cooldown_number, + "cooldown_duration": args.cooldown_duration, + } + ) + return ", ".join(f"{key}={value}" for key, value in args_dict.items()) + def _get_extras_str(self, benchmark: GenerativeBenchmark) -> str: + extras = benchmark.extras if not extras: return "None" - return ", ".join(f"{key}={value}" for key, value in extras.items()) - def print_section_header(self, title: str, indent: int = 0, new_lines: int = 2): - """ - Print out a styled section header to the console. - The title is underlined, bolded, and colored with the INFO color. - - :param title: The title of the section. - :param indent: The number of spaces to indent the title. - Defaults to 0. - :param new_lines: The number of new lines to print before the title. - Defaults to 2. - """ - self.print_line( - value=f"{title}:", - style=f"bold underline {Colors.INFO}", + def _print_section_header(self, title: str, indent: int = 0, new_lines: int = 2): + self._print_line( + f"{title}:", + f"bold underline {Colors.INFO}", indent=indent, new_lines=new_lines, ) - def print_labeled_line( + def _print_labeled_line( self, label: str, value: str, indent: int = 4, new_lines: int = 0 ): - """ - Print out a styled, labeled line (label: value) to the console. - The label is bolded and colored with the INFO color, - and the value is italicized. - - :param label: The label of the line. - :param value: The value of the line. - :param indent: The number of spaces to indent the line. - Defaults to 4. - :param new_lines: The number of new lines to print before the line. - Defaults to 0. - """ - self.print_line( - value=[label + ":", value], - style=["bold " + Colors.INFO, "italic"], + self._print_line( + [label + ":", value], + ["bold " + Colors.INFO, "italic"], new_lines=new_lines, indent=indent, ) - def print_line( + def _print_line( self, value: Union[str, list[str]], style: Union[str, list[str]] = "", indent: int = 0, new_lines: int = 0, ): - """ - Print out a a value to the console as a line with optional indentation. - - :param value: The value to print. - :param style: The style to apply to the value. - Defaults to none. - :param indent: The number of spaces to indent the line. - Defaults to 0. - :param new_lines: The number of new lines to print before the value. - Defaults to 0. - """ - if not self.enabled: - return - text = Text() - for _ in range(new_lines): text.append("\n") if not isinstance(value, list): value = [value] - if not isinstance(style, list): style = [style for _ in range(len(value))] if len(value) != len(style): raise ValueError( - f"Value and style length mismatch. Value length: {len(value)}, " - f"Style length: {len(style)}." + f"Value and style length mismatch: {len(value)} vs {len(style)}" ) for val, sty in zip(value, style): @@ -570,128 +284,80 @@ def print_line( self.console.print(Padding.indent(text, indent)) - def print_table( + def _print_table( self, headers: list[str], rows: list[list[Any]], title: str, sections: Optional[dict[str, tuple[int, int]]] = None, - max_char_per_col: int = 2**10, + max_char_per_col: int = 1024, indent: int = 0, new_lines: int = 2, ): - """ - Print a table to the console with the given headers and rows. - - :param headers: The headers of the table. - :param rows: The rows of the table. - :param title: The title of the table. - :param sections: The sections of the table grouping columns together. - This is a mapping of the section display name to a tuple of the start and - end column indices. If None, no sections are added (default). - :param max_char_per_col: The maximum number of characters per column. - :param indent: The number of spaces to indent the table. - Defaults to 0. - :param new_lines: The number of new lines to print before the table. - Defaults to 0. - """ - if rows and any(len(row) != len(headers) for row in rows): raise ValueError( - f"Headers and rows length mismatch. Headers length: {len(headers)}, " - f"Row length: {len(rows[0]) if rows else 'N/A'}." + f"Headers and rows length mismatch: {len(headers)} vs {len(rows[0]) if rows else 'N/A'}" ) - max_characters_per_column = self.calculate_max_chars_per_column( + max_chars_per_column = self._calculate_max_chars_per_column( headers, rows, sections, max_char_per_col ) - self.print_section_header(title, indent=indent, new_lines=new_lines) - self.print_table_divider( - max_characters_per_column, include_separators=False, indent=indent - ) + self._print_section_header(title, indent=indent, new_lines=new_lines) + self._print_table_divider(max_chars_per_column, False, indent) if sections: - self.print_table_sections( - sections, max_characters_per_column, indent=indent - ) - self.print_table_row( - split_text_list_by_length(headers, max_characters_per_column), - style=f"bold {Colors.INFO}", - indent=indent, - ) - self.print_table_divider( - max_characters_per_column, include_separators=True, indent=indent + self._print_table_sections(sections, max_chars_per_column, indent) + self._print_table_row( + split_text_list_by_length(headers, max_chars_per_column), + f"bold {Colors.INFO}", + indent, ) + self._print_table_divider(max_chars_per_column, True, indent) for row in rows: - self.print_table_row( - split_text_list_by_length(row, max_characters_per_column), - style="italic", - indent=indent, + self._print_table_row( + split_text_list_by_length(row, max_chars_per_column), + "italic", + indent, ) - self.print_table_divider( - max_characters_per_column, include_separators=False, indent=indent - ) + self._print_table_divider(max_chars_per_column, False, indent) - def calculate_max_chars_per_column( + def _calculate_max_chars_per_column( self, headers: list[str], rows: list[list[Any]], sections: Optional[dict[str, tuple[int, int]]], max_char_per_col: int, ) -> list[int]: - """ - Calculate the maximum number of characters per column in the table. - This is done by checking the length of the headers, rows, and optional sections - to ensure all columns are accounted for and spaced correctly. - - :param headers: The headers of the table. - :param rows: The rows of the table. - :param sections: The sections of the table grouping columns together. - This is a mapping of the section display name to a tuple of the start and - end column indices. If None, no sections are added (default). - :param max_char_per_col: The maximum number of characters per column. - :return: A list of the maximum number of characters per column. - """ - max_characters_per_column = [] + """Calculate maximum characters per column for table formatting.""" + max_chars_per_column = [] for ind in range(len(headers)): - max_characters_per_column.append(min(len(headers[ind]), max_char_per_col)) - + max_chars_per_column.append(min(len(headers[ind]), max_char_per_col)) for row in rows: - max_characters_per_column[ind] = max( - max_characters_per_column[ind], len(str(row[ind])) + max_chars_per_column[ind] = max( + max_chars_per_column[ind], len(str(row[ind])) ) if not sections: - return max_characters_per_column + return max_chars_per_column - for section in sections: - start_col, end_col = sections[section] - min_section_len = len(section) + ( - end_col - start_col - ) # ensure we have enough space for separators + for section, (start_col, end_col) in sections.items(): + min_section_len = len(section) + (end_col - start_col) chars_in_columns = sum( - max_characters_per_column[start_col : end_col + 1] + max_chars_per_column[start_col : end_col + 1] ) + 2 * (end_col - start_col) if min_section_len > chars_in_columns: add_chars_per_col = math.ceil( (min_section_len - chars_in_columns) / (end_col - start_col + 1) ) for col in range(start_col, end_col + 1): - max_characters_per_column[col] += add_chars_per_col + max_chars_per_column[col] += add_chars_per_col - return max_characters_per_column + return max_chars_per_column - def print_table_divider( + def _print_table_divider( self, max_chars_per_column: list[int], include_separators: bool, indent: int = 0 ): - """ - Print a divider line for the table (top and bottom of table with '=' characters) - - :param max_chars_per_column: The maximum number of characters per column. - :param include_separators: Whether to include separators between columns. - :param indent: The number of spaces to indent the line. - Defaults to 0. - """ + """Print table divider line.""" if include_separators: columns = [ settings.table_headers_border_char * max_chars @@ -704,29 +370,15 @@ def print_table_divider( settings.table_border_char * (max_chars + 2) for max_chars in max_chars_per_column ] - columns[-1] = columns[-1][:-2] - self.print_line(value=columns, style=Colors.INFO, indent=indent) + self._print_line(columns, Colors.INFO, indent) - def print_table_sections( + def _print_table_sections( self, sections: dict[str, tuple[int, int]], max_chars_per_column: list[int], indent: int = 0, ): - """ - Print the sections of the table with corresponding separators to the columns - the sections are mapped to to ensure it is compliant with a CSV format. - For example, a section named "Metadata" with columns 0-3 will print this: - Metadata ,,,, - Where the spaces plus the separators at the end will span the columns 0-3. - All columns must be accounted for in the sections. - - :param sections: The sections of the table. - :param max_chars_per_column: The maximum number of characters per column. - :param indent: The number of spaces to indent the line. - Defaults to 0. - """ section_tuples = [(start, end, name) for name, (start, end) in sections.items()] section_tuples.sort(key=lambda x: x[0]) @@ -750,30 +402,23 @@ def print_table_sections( end_col - start_col + 1 ) num_separators = end_col - start_col - line_values.append(section) - line_styles.append("bold " + Colors.INFO) - line_values.append( - " " * (section_length - len(section) - num_separators - 2) + line_values.extend( + [ + section, + " " * (section_length - len(section) - num_separators - 2), + settings.table_column_separator_char * num_separators, + settings.table_column_separator_char + " ", + ] ) - line_styles.append("") - line_values.append(settings.table_column_separator_char * num_separators) - line_styles.append("") - line_values.append(settings.table_column_separator_char + " ") - line_styles.append(Colors.INFO) + line_styles.extend(["bold " + Colors.INFO, "", "", Colors.INFO]) + line_values = line_values[:-1] line_styles = line_styles[:-1] - self.print_line(value=line_values, style=line_styles, indent=indent) + self._print_line(line_values, line_styles, indent) - def print_table_row( + def _print_table_row( self, column_lines: list[list[str]], style: str, indent: int = 0 ): - """ - Print a single row of a table to the console. - - :param column_lines: The lines of text to print for each column. - :param indent: The number of spaces to indent the line. - Defaults to 0. - """ for row in range(len(column_lines[0])): print_line = [] print_styles = [] @@ -788,209 +433,213 @@ def print_table_row( print_styles.extend([style, Colors.INFO, ""]) print_line = print_line[:-2] print_styles = print_styles[:-2] - self.print_line(value=print_line, style=print_styles, indent=indent) + self._print_line(print_line, print_styles, indent) - def print_benchmarks_metadata(self): - """ - Print out the metadata of the benchmarks to the console including the run id, - duration, profile, args, worker, request loader, and extras. - """ - if not self.benchmarks: - raise ValueError( - "No benchmarks to print metadata for. Please set benchmarks first." - ) +class GenerativeBenchmarkerCSV(GenerativeBenchmarkerOutput): + """CSV output formatter for benchmark results.""" - start_time = self.benchmarks[0].run_stats.start_time - end_time = self.benchmarks[-1].run_stats.end_time - duration = end_time - start_time + DEFAULT_FILE: ClassVar[str] = "benchmarks.json" - self.print_section_header(title="Benchmarks Metadata") - self.print_labeled_line( - label="Run id", - value=str(self.benchmarks[0].run_id), - ) - self.print_labeled_line( - label="Duration", - value=f"{duration:.1f} seconds", - ) - self.print_labeled_line( - label="Profile", - value=self.benchmarks_profile_str, - ) - self.print_labeled_line( - label="Args", - value=self.benchmarks_args_str, - ) - self.print_labeled_line( - label="Worker", - value=self.benchmarks_worker_desc_str, - ) - self.print_labeled_line( - label="Request Loader", - value=self.benchmarks_request_loader_desc_str, - ) - self.print_labeled_line( - label="Extras", - value=self.benchmarks_extras_str, + def __init__(self, output_path: Optional[Union[str, Path]] = None): + """ + Initialize the CSV output formatter. + + :param output_path: Optional path where CSV file should be saved. + If not provided, will be saved to the default location. + """ + output_path = output_path or GenerativeBenchmarkerCSV.DEFAULT_FILE + output_path = ( + Path(output_path) if not isinstance(output_path, Path) else output_path ) - def print_benchmarks_info(self): + if output_path.is_dir(): + output_path = output_path / GenerativeBenchmarkerCSV.DEFAULT_FILE + + self.output_path = output_path + + async def finalize(self, report: GenerativeBenchmarksReport) -> Path: """ - Print out the benchmark information to the console including the start time, - end time, duration, request totals, and token totals for each benchmark. + Save the benchmark report as a CSV file. + + :param report: The completed benchmark report. + :return: Path to the saved CSV file. """ - if not self.benchmarks: - raise ValueError( - "No benchmarks to print info for. Please set benchmarks first." - ) + self.output_path.parent.mkdir(parents=True, exist_ok=True) - sections = { - "Metadata": (0, 3), - "Requests Made": (4, 6), - "Prompt Tok/Req": (7, 9), - "Output Tok/Req": (10, 12), - "Prompt Tok Total": (13, 15), - "Output Tok Total": (16, 18), - } + with self.output_path.open("w", newline="") as file: + writer = csv.writer(file) + headers: list[str] = [] + rows: list[list[Union[str, float, list[float]]]] = [] + + for benchmark in report.benchmarks: + benchmark_headers: list[str] = [] + benchmark_values: list[Union[str, float, list[float]]] = [] + + # Add description headers and values + desc_headers, desc_values = self._get_benchmark_desc_headers_and_values( + benchmark + ) + benchmark_headers.extend(desc_headers) + benchmark_values.extend(desc_values) + + # Add status-based metrics + for status in StatusDistributionSummary.model_fields: + status_headers, status_values = ( + self._get_benchmark_status_headers_and_values(benchmark, status) + ) + benchmark_headers.extend(status_headers) + benchmark_values.extend(status_values) + + # Add extra fields + extras_headers, extras_values = ( + self._get_benchmark_extras_headers_and_values(benchmark) + ) + benchmark_headers.extend(extras_headers) + benchmark_values.extend(extras_values) + + if not headers: + headers = benchmark_headers + rows.append(benchmark_values) + + writer.writerow(headers) + for row in rows: + writer.writerow(row) + + return self.output_path + + def _get_benchmark_desc_headers_and_values( + self, benchmark: GenerativeBenchmark + ) -> tuple[list[str], list[Union[str, float]]]: + """Get description headers and values for a benchmark.""" headers = [ - "Benchmark", + "Type", + "Run Id", + "Id", + "Name", "Start Time", "End Time", - "Duration (s)", - "Comp", - "Inc", - "Err", - "Comp", - "Inc", - "Err", - "Comp", - "Inc", - "Err", - "Comp", - "Inc", - "Err", - "Comp", - "Inc", - "Err", + "Duration", ] - rows = [] + values: list[Union[str, float]] = [ + benchmark.type_, + benchmark.run_id, + benchmark.id_, + strategy_display_str(benchmark.args.strategy), + datetime.fromtimestamp(benchmark.start_time).strftime("%Y-%m-%d %H:%M:%S"), + datetime.fromtimestamp(benchmark.end_time).strftime("%Y-%m-%d %H:%M:%S"), + benchmark.duration, + ] + return headers, values - for benchmark in self.benchmarks: - rows.append( - [ - strategy_display_str(benchmark.args.strategy), - f"{datetime.fromtimestamp(benchmark.start_time).strftime('%H:%M:%S')}", - f"{datetime.fromtimestamp(benchmark.end_time).strftime('%H:%M:%S')}", - f"{(benchmark.end_time - benchmark.start_time):.1f}", - f"{benchmark.request_totals.successful:.0f}", - f"{benchmark.request_totals.incomplete:.0f}", - f"{benchmark.request_totals.errored:.0f}", - f"{benchmark.metrics.prompt_token_count.successful.mean:.1f}", - f"{benchmark.metrics.prompt_token_count.incomplete.mean:.1f}", - f"{benchmark.metrics.prompt_token_count.errored.mean:.1f}", - f"{benchmark.metrics.output_token_count.successful.mean:.1f}", - f"{benchmark.metrics.output_token_count.incomplete.mean:.1f}", - f"{benchmark.metrics.output_token_count.errored.mean:.1f}", - f"{benchmark.metrics.prompt_token_count.successful.total_sum:.0f}", - f"{benchmark.metrics.prompt_token_count.incomplete.total_sum:.0f}", - f"{benchmark.metrics.prompt_token_count.errored.total_sum:.0f}", - f"{benchmark.metrics.output_token_count.successful.total_sum:.0f}", - f"{benchmark.metrics.output_token_count.incomplete.total_sum:.0f}", - f"{benchmark.metrics.output_token_count.errored.total_sum:.0f}", - ] - ) + def _get_benchmark_extras_headers_and_values( + self, benchmark: GenerativeBenchmark + ) -> tuple[list[str], list[str]]: + """Get extra fields headers and values for a benchmark.""" + headers = ["Args", "Worker", "Request Loader", "Extras"] + values: list[str] = [ + json.dumps(benchmark.args.model_dump()), + json.dumps(benchmark.worker.model_dump()), + json.dumps(benchmark.request_loader.model_dump()), + json.dumps(benchmark.extras), + ] + return headers, values - self.print_table( - headers=headers, rows=rows, title="Benchmarks Info", sections=sections - ) + def _get_benchmark_status_headers_and_values( + self, benchmark: GenerativeBenchmark, status: str + ) -> tuple[list[str], list[Union[float, list[float]]]]: + """Get status-based metrics headers and values for a benchmark.""" + headers = [f"{status.capitalize()} Requests"] + values = [getattr(benchmark.request_totals, status)] - def print_benchmarks_stats(self): - """ - Print out the benchmark statistics to the console including the requests per - second, request concurrency, output tokens per second, total tokens per second, - request latency, time to first token, inter token latency, and time per output - token for each benchmark. - """ - if not self.benchmarks: - raise ValueError( - "No benchmarks to print stats for. Please set benchmarks first." + for metric in GenerativeMetrics.model_fields: + metric_headers, metric_values = self._get_benchmark_status_metrics_stats( + benchmark, status, metric ) + headers.extend(metric_headers) + values.extend(metric_values) + + return headers, values + + def _get_benchmark_status_metrics_stats( + self, benchmark: GenerativeBenchmark, status: str, metric: str + ) -> tuple[list[str], list[Union[float, list[float]]]]: + """Get statistical metrics for a specific status and metric.""" + status_display = status.capitalize() + metric_display = metric.replace("_", " ").capitalize() + status_dist_summary: StatusDistributionSummary = getattr( + benchmark.metrics, metric + ) + dist_summary: DistributionSummary = getattr(status_dist_summary, status) - sections = { - "Metadata": (0, 0), - "Request Stats": (1, 2), - "Out Tok/sec": (3, 3), - "Tot Tok/sec": (4, 4), - "Req Latency (sec)": (5, 7), - "TTFT (ms)": (8, 10), - "ITL (ms)": (11, 13), - "TPOT (ms)": (14, 16), - } headers = [ - "Benchmark", - "Per Second", - "Concurrency", - "mean", - "mean", - "mean", - "median", - "p99", - "mean", - "median", - "p99", - "mean", - "median", - "p99", - "mean", - "median", - "p99", + f"{status_display} {metric_display} mean", + f"{status_display} {metric_display} median", + f"{status_display} {metric_display} std dev", + f"{status_display} {metric_display} [min, 0.1, 1, 5, 10, 25, 75, 90, 95, 99, max]", ] - rows = [] + values: list[Union[float, list[float]]] = [ + dist_summary.mean, + dist_summary.median, + dist_summary.std_dev, + [ + dist_summary.min, + dist_summary.percentiles.p001, + dist_summary.percentiles.p01, + dist_summary.percentiles.p05, + dist_summary.percentiles.p10, + dist_summary.percentiles.p25, + dist_summary.percentiles.p75, + dist_summary.percentiles.p90, + dist_summary.percentiles.p95, + dist_summary.percentiles.p99, + dist_summary.max, + ], + ] + return headers, values - for benchmark in self.benchmarks: - rows.append( - [ - strategy_display_str(benchmark.args.strategy), - f"{benchmark.metrics.requests_per_second.successful.mean:.2f}", - f"{benchmark.metrics.request_concurrency.successful.mean:.2f}", - f"{benchmark.metrics.output_tokens_per_second.successful.mean:.1f}", - f"{benchmark.metrics.tokens_per_second.successful.mean:.1f}", - f"{benchmark.metrics.request_latency.successful.mean:.2f}", - f"{benchmark.metrics.request_latency.successful.median:.2f}", - f"{benchmark.metrics.request_latency.successful.percentiles.p99:.2f}", - f"{benchmark.metrics.time_to_first_token_ms.successful.mean:.1f}", - f"{benchmark.metrics.time_to_first_token_ms.successful.median:.1f}", - f"{benchmark.metrics.time_to_first_token_ms.successful.percentiles.p99:.1f}", - f"{benchmark.metrics.inter_token_latency_ms.successful.mean:.1f}", - f"{benchmark.metrics.inter_token_latency_ms.successful.median:.1f}", - f"{benchmark.metrics.inter_token_latency_ms.successful.percentiles.p99:.1f}", - f"{benchmark.metrics.time_per_output_token_ms.successful.mean:.1f}", - f"{benchmark.metrics.time_per_output_token_ms.successful.median:.1f}", - f"{benchmark.metrics.time_per_output_token_ms.successful.percentiles.p99:.1f}", - ] - ) - self.print_table( - headers=headers, - rows=rows, - title="Benchmarks Stats", - sections=sections, +class GenerativeBenchmarkerHTML(GenerativeBenchmarkerOutput): + """HTML output formatter for benchmark results.""" + + DEFAULT_FILE: ClassVar[str] = "benchmarks.html" + + def __init__(self, output_path: Optional[Union[str, Path]] = None): + """ + Initialize the HTML output formatter. + + :param output_path: Optional path where HTML file should be saved. + If not provided, will be saved to the default location. + """ + output_path = output_path or GenerativeBenchmarkerCSV.DEFAULT_FILE + output_path = ( + Path(output_path) if not isinstance(output_path, Path) else output_path ) - def print_full_report(self): + if output_path.is_dir(): + output_path = output_path / GenerativeBenchmarkerCSV.DEFAULT_FILE + + self.output_path = output_path + + async def finalize(self, report: GenerativeBenchmarksReport) -> Path: """ - Print out the benchmark statistics to the console. - Temporarily enables the console if it's disabled. + Save the benchmark report as an HTML file. - Format: - - Metadata - - Info - - Stats + :param report: The completed benchmark report. + :return: Path to the saved HTML file. """ - orig_enabled = self.enabled - self.enabled = True - self.print_benchmarks_metadata() - self.print_benchmarks_info() - self.print_benchmarks_stats() - self.enabled = orig_enabled + data_builder = UIDataBuilder(report.benchmarks) + data = data_builder.to_dict() + camel_data = humps.camelize(data) + + ui_api_data = {} + for key, value in camel_data.items(): + placeholder_key = f"window.{humps.decamelize(key)} = {{}};" + replacement_value = ( + f"window.{humps.decamelize(key)} = {json.dumps(value, indent=2)};\n" + ) + ui_api_data[placeholder_key] = replacement_value + + create_report(ui_api_data, self.output_path) + + return str(self.output_path) diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py index 642cb7a8..1623767d 100644 --- a/src/guidellm/benchmark/profile.py +++ b/src/guidellm/benchmark/profile.py @@ -1,20 +1,45 @@ -from collections.abc import Sequence -from typing import Literal, Optional, Union +""" +Benchmarking profile configurations for coordinating multi-strategy execution. + +Provides configurable profile abstractions for orchestrating sequential and +parallel execution of different scheduling strategies during benchmarking, +with automatic strategy generation and constraint management. + +Classes: + Profile: Abstract base for multi-strategy benchmarking profiles. + SynchronousProfile: Single synchronous strategy execution profile. + ConcurrentProfile: Fixed-concurrency strategy execution profile. + ThroughputProfile: Maximum throughput strategy execution profile. + AsyncProfile: Rate-based asynchronous strategy execution profile. + SweepProfile: Adaptive multi-strategy sweep execution profile. + +Type Aliases: + ProfileType: Literal type for supported profile configurations. +""" + +from abc import ABC, abstractmethod +from collections.abc import Generator +from typing import Any, Generic, Literal, Optional, Union import numpy as np from pydantic import Field, computed_field -from guidellm.config import settings +from guidellm.benchmark.benchmark import BenchmarkT from guidellm.objects import StandardBaseModel from guidellm.scheduler import ( AsyncConstantStrategy, AsyncPoissonStrategy, ConcurrentStrategy, + Constraint, + ConstraintInitializer, + ConstraintsInitializerFactory, SchedulingStrategy, + StrategyT, StrategyType, SynchronousStrategy, ThroughputStrategy, ) +from guidellm.utils import RegistryMixin __all__ = [ "AsyncProfile", @@ -24,386 +49,595 @@ "SweepProfile", "SynchronousProfile", "ThroughputProfile", - "create_profile", ] ProfileType = Literal["synchronous", "concurrent", "throughput", "async", "sweep"] -class Profile(StandardBaseModel): +class Profile( + StandardBaseModel, + ABC, + Generic[StrategyT, BenchmarkT], + RegistryMixin, +): + """ + Abstract base for multi-strategy benchmarking execution profiles. + + Coordinates sequential execution of scheduling strategies with automatic + strategy generation, constraint management, and completion tracking for + comprehensive benchmarking workflows. + """ + + @classmethod + def create( + cls, + rate_type: str, + rate: Optional[Union[float, int, list[float, int]]], + random_seed: int, + **kwargs: Any, + ) -> "Profile": + """ + Create a profile instance based on the specified type. + + :param rate_type: The type of profile to create. + :param rate: Rate parameter for profile configuration. + :param random_seed: Random seed for stochastic strategies. + :param kwargs: Additional arguments for profile configuration. + :return: Configured profile instance for the specified type. + :raises ValueError: If the profile type is not registered. + """ + profile_class: type[Profile] = cls.get_registered_object(rate_type) + resolved_kwargs = profile_class.resolve_args( + rate_type=rate_type, rate=rate, random_seed=random_seed, **kwargs + ) + + return profile_class(**resolved_kwargs) + + @classmethod + @abstractmethod + def resolve_args( + cls, + rate_type: str, + rate: Optional[Union[float, int, list[float, int]]], + random_seed: int, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Resolve and validate arguments for profile construction. + + :param rate_type: The type of the profile. + :param rate: Rate parameter for configuration. + :param random_seed: Random seed for stochastic strategies. + :param kwargs: Additional arguments to resolve. + :return: Dictionary of resolved arguments for profile construction. + """ + ... + type_: Literal["profile"] = Field( - description="The type of benchmarking profile to use.", - ) - completed_strategies: int = Field( - default=0, - description="The number of scheduling strategies generated so far.", + description="The type of benchmarking profile to use", ) - measured_rates: list[float] = Field( + completed_strategies: list[SchedulingStrategy] = Field( default_factory=list, - description=("The average rates measured for the strategies that have run."), + description="The strategies that have completed execution", ) - measured_concurrencies: list[float] = Field( - default_factory=list, - description=( - "The average concurrency measured for the strategies that have run." - ), + constraints: Optional[ + dict[str, Union[Any, dict[str, Any], ConstraintInitializer]] + ] = Field( + default=None, + description="Runtime constraints to apply during strategy execution", ) - def completed_strategy(self, average_rate: float, average_concurrency: float): - self.measured_rates.append(average_rate) - self.measured_concurrencies.append(average_concurrency) - self.completed_strategies += 1 - @computed_field # type: ignore[misc] @property def strategy_types(self) -> list[StrategyType]: - return [] + """ + :return: List of all strategy types expected to be executed in this profile. + """ + return [strat.type_ for strat in self.completed_strategies] + + def strategies_generator( + self, + ) -> Generator[ + tuple[ + Optional[StrategyT], + Optional[dict[str, Union[Any, dict[str, Any], Constraint]]], + ], + BenchmarkT, + None, + ]: + """ + Generate strategies and constraints for sequential profile execution. + + :return: Generator yielding (strategy, constraints) tuples and + receiving benchmark results from each execution. + """ + prev_strategy: Optional[StrategyT] = None + prev_benchmark: Optional[BenchmarkT] = None + + while ( + strategy := self.next_strategy(prev_strategy, prev_benchmark) + ) is not None: + constraints = self.next_strategy_constraints( + strategy, prev_strategy, prev_benchmark + ) + prev_benchmark = yield ( + strategy, + constraints, + ) + prev_strategy = strategy + self.completed_strategies.append(prev_strategy) + + @abstractmethod + def next_strategy( + self, + prev_strategy: Optional[StrategyT], + prev_benchmark: Optional[BenchmarkT], + ) -> Optional[StrategyT]: + """ + Generate the next strategy to execute in the profile sequence. + + :param prev_strategy: The previously completed strategy. + :param prev_benchmark: Benchmark results from the previous strategy. + :return: Next strategy to execute, or None if profile is complete. + """ + ... + + def next_strategy_constraints( + self, + next_strategy: Optional[StrategyT], + prev_strategy: Optional[StrategyT], + prev_benchmark: Optional[BenchmarkT], + ) -> Optional[dict[str, Union[Any, dict[str, Any], Constraint]]]: + """ + Generate constraints for the next strategy execution. + + :param next_strategy: The next strategy to be executed. + :param prev_strategy: The previously completed strategy. + :param prev_benchmark: Benchmark results from the previous strategy. + :return: Constraints dictionary for the next strategy, or None. + """ + return ( + ConstraintsInitializerFactory.resolve(self.constraints) + if next_strategy and self.constraints + else None + ) - def next_strategy(self) -> Optional[SchedulingStrategy]: - return None +@Profile.register("synchronous") +class SynchronousProfile(Profile[StrategyT, BenchmarkT]): + """Single synchronous strategy execution profile.""" -class SynchronousProfile(Profile): type_: Literal["synchronous"] = "synchronous" # type: ignore[assignment] + @classmethod + def resolve_args( + cls, + rate_type: str, + rate: Optional[Union[float, int, list[float, int]]], + random_seed: int, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Resolve arguments for synchronous profile construction. + + :param rate_type: The type/strategy of the profile (ignored). + :param rate: Rate parameter (must be None, will be stripped). + :param random_seed: Random seed (ignored and stripped). + :param kwargs: Additional arguments to pass through. + :return: Dictionary of resolved arguments. + :raises ValueError: If rate is not None. + """ + if rate is not None: + raise ValueError("SynchronousProfile does not accept a rate parameter") + + return kwargs + @property def strategy_types(self) -> list[StrategyType]: + """Get the single synchronous strategy type.""" return [self.type_] - def next_strategy(self) -> Optional[SchedulingStrategy]: - if self.completed_strategies >= 1: + def next_strategy( + self, + prev_strategy: Optional[StrategyT], + prev_benchmark: Optional[BenchmarkT], + ) -> Optional[StrategyT]: + """ + Generate synchronous strategy or None if already completed. + + :param prev_strategy: The previously completed strategy (unused). + :param prev_benchmark: Benchmark results from the previous strategy (unused). + :return: SynchronousStrategy for the first execution, None afterward. + """ + if len(self.completed_strategies) >= 1: return None return SynchronousStrategy() - @staticmethod - def from_standard_args( - rate_type: Union[StrategyType, ProfileType], - rate: Optional[Union[float, Sequence[float]]], - **kwargs, - ) -> "SynchronousProfile": - if rate_type != "synchronous": - raise ValueError("Rate type must be 'synchronous' for synchronous profile.") - if rate is not None: - raise ValueError( - "Rate does not apply to synchronous profile, it must be set to None." - ) - - if kwargs: - raise ValueError( - "No additional arguments are allowed for synchronous profile." - ) - - return SynchronousProfile() +@Profile.register("concurrent") +class ConcurrentProfile(Profile[StrategyT, BenchmarkT]): + """Fixed-concurrency strategy execution profile with configurable stream counts.""" - -class ConcurrentProfile(Profile): type_: Literal["concurrent"] = "concurrent" # type: ignore[assignment] - streams: Union[int, Sequence[int]] = Field( - description="The number of concurrent streams to use.", + streams: Union[int, list[int]] = Field( + description="Number of concurrent streams for request scheduling", + gt=0, + ) + startup_duration: float = Field( + default=0.0, + description=( + "Duration in seconds for distributing startup requests " + "before completion-based timing" + ), + ge=0, ) + @classmethod + def resolve_args( + cls, + rate_type: str, + rate: Optional[Union[float, int, list[float, int]]], + random_seed: int, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Resolve arguments for concurrent profile construction. + + :param rate_type: The type/strategy of the profile (ignored). + :param rate: Rate parameter, remapped to streams. + :param random_seed: Random seed (ignored and stripped). + :param kwargs: Additional arguments to pass through. + :return: Dictionary of resolved arguments. + :raises ValueError: If rate is None. + """ + kwargs["streams"] = rate + return kwargs + @property def strategy_types(self) -> list[StrategyType]: - num_strategies = len(self.streams) if isinstance(self.streams, Sequence) else 1 - + """Get concurrent strategy types for each configured stream count.""" + num_strategies = len(self.streams) if isinstance(self.streams, list) else 1 return [self.type_] * num_strategies - def next_strategy(self) -> Optional[SchedulingStrategy]: - streams = self.streams if isinstance(self.streams, Sequence) else [self.streams] - - if self.completed_strategies >= len(streams): + def next_strategy( + self, + prev_strategy: Optional[StrategyT], + prev_benchmark: Optional[BenchmarkT], + ) -> Optional[StrategyT]: + """ + Generate concurrent strategy for the next stream count. + + :param prev_strategy: The previously completed strategy (unused). + :param prev_benchmark: Benchmark results from the previous strategy (unused). + :return: ConcurrentStrategy with next stream count, or None if complete. + """ + streams = self.streams if isinstance(self.streams, list) else [self.streams] + + if len(self.completed_strategies) >= len(streams): return None return ConcurrentStrategy( - streams=streams[self.completed_strategies], + streams=streams[len(self.completed_strategies)], + startup_duration=self.startup_duration, ) - @staticmethod - def from_standard_args( - rate_type: Union[StrategyType, ProfileType], - rate: Optional[Union[float, Sequence[float]]], - **kwargs, - ) -> "ConcurrentProfile": - if rate_type != "concurrent": - raise ValueError("Rate type must be 'concurrent' for concurrent profile.") - - if not rate: - raise ValueError("Rate (streams) must be provided for concurrent profile.") - if not isinstance(rate, Sequence): - rate = [rate] - - if not all(stream.is_integer() and stream > 0 for stream in rate): - raise ValueError( - f"All rate values (streams) must be positive integers, received {rate}" - ) - - if kwargs: - raise ValueError( - "No additional arguments are allowed for concurrent profile." - ) +@Profile.register("throughput") +class ThroughputProfile(Profile[StrategyT, BenchmarkT]): + """ + Maximum throughput strategy execution profile with optional concurrency limits. + """ - return ConcurrentProfile(streams=[int(rat) for rat in rate]) - - -class ThroughputProfile(Profile): type_: Literal["throughput"] = "throughput" # type: ignore[assignment] max_concurrency: Optional[int] = Field( default=None, - description="The maximum number of concurrent requests that can be scheduled.", + description="Maximum number of concurrent requests to schedule", + gt=0, + ) + startup_duration: float = Field( + default=0.0, + description=( + "Duration in seconds for distributing startup requests " + "before full throughput scheduling" + ), + ge=0, ) + @classmethod + def resolve_args( + cls, + rate_type: str, + rate: Optional[Union[float, int, list[float, int]]], + random_seed: int, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Resolve arguments for throughput profile construction. + + :param rate_type: The type/strategy of the profile (ignored). + :param rate: Rate parameter to remap to max_concurrency. + :param random_seed: Random seed (ignored and stripped). + :param kwargs: Additional arguments to pass through. + :return: Dictionary of resolved arguments. + """ + # Remap rate to max_concurrency, strip out random_seed + kwargs.pop("random_seed", None) + if rate is not None: + kwargs["max_concurrency"] = rate + return kwargs + @property def strategy_types(self) -> list[StrategyType]: + """Get the single throughput strategy type.""" return [self.type_] - def next_strategy(self) -> Optional[SchedulingStrategy]: - if self.completed_strategies >= 1: + def next_strategy( + self, + prev_strategy: Optional[StrategyT], + prev_benchmark: Optional[BenchmarkT], + ) -> Optional[StrategyT]: + """ + Generate throughput strategy or None if already completed. + + :param prev_strategy: The previously completed strategy (unused). + :param prev_benchmark: Benchmark results from the previous strategy (unused). + :return: ThroughputStrategy for the first execution, None afterward. + """ + if len(self.completed_strategies) >= 1: return None return ThroughputStrategy( max_concurrency=self.max_concurrency, + startup_duration=self.startup_duration, ) - @staticmethod - def from_standard_args( - rate_type: Union[StrategyType, ProfileType], - rate: Optional[Union[float, Sequence[float]]], - **kwargs, - ) -> "ThroughputProfile": - if rate_type != "throughput": - raise ValueError("Rate type must be 'throughput' for throughput profile.") - - if rate is not None: - raise ValueError( - "Rate does not apply to throughput profile, it must be set to None." - ) - - return ThroughputProfile(**kwargs) +@Profile.register(["async", "constant", "poisson"]) +class AsyncProfile(Profile[StrategyT, BenchmarkT]): + """ + Rate-based asynchronous strategy execution profile with configurable patterns. + """ -class AsyncProfile(ThroughputProfile): type_: Literal["async"] = "async" # type: ignore[assignment] strategy_type: Literal["constant", "poisson"] = Field( - description="The type of asynchronous strategy to use.", + description="Type of asynchronous strategy pattern to use", ) - rate: Union[float, Sequence[float]] = Field( - description="The rate of requests per second to use.", + rate: Union[float, list[float]] = Field( + description="Request scheduling rate in requests per second", + gt=0, ) - initial_burst: bool = Field( - default=True, + startup_duration: float = Field( + default=0.0, description=( - "True to send an initial burst of requests (math.floor(self.rate)) " - "to reach target rate. False to not send an initial burst." + "Duration in seconds for distributing startup requests " + "to converge quickly to desired rate" ), + ge=0, + ) + max_concurrency: Optional[int] = Field( + default=None, + description="Maximum number of concurrent requests to schedule", + gt=0, ) random_seed: int = Field( default=42, - description=( - "The random seed to use for the asynchronous strategy. " - "This is used to generate random numbers for the Poisson strategy." - ), + description="Random seed for Poisson distribution strategy", ) + @classmethod + def resolve_args( + cls, + rate_type: Union[ProfileType, StrategyT], + rate: Optional[Union[float, int, list[float, int]]], + random_seed: int, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Resolve arguments for async profile construction. + + :param rate_type: The type/strategy of the profile. + :param rate: Rate parameter for the profile. + :param random_seed: Random seed for stochastic strategies. + :param kwargs: Additional arguments to pass through. + :return: Dictionary of resolved arguments. + :raises ValueError: If rate is None. + """ + if rate is None: + raise ValueError("AsyncProfile requires a rate parameter") + + kwargs["strategy_type"] = ( + rate_type + if rate_type in ["constant", "poisson"] + else kwargs.get("strategy_type", "constant") + ) + kwargs["rate"] = rate + kwargs["random_seed"] = random_seed + return kwargs + @property def strategy_types(self) -> list[StrategyType]: - num_strategies = len(self.rate) if isinstance(self.rate, Sequence) else 1 - + """Get async strategy types for each configured rate.""" + num_strategies = len(self.rate) if isinstance(self.rate, list) else 1 return [self.strategy_type] * num_strategies - def next_strategy(self) -> Optional[SchedulingStrategy]: - rate = self.rate if isinstance(self.rate, Sequence) else [self.rate] - - if self.completed_strategies >= len(rate): + def next_strategy( + self, + prev_strategy: Optional[StrategyT], + prev_benchmark: Optional[BenchmarkT], + ) -> Optional[StrategyT]: + """ + Generate async strategy for the next configured rate. + + :param prev_strategy: The previously completed strategy (unused). + :param prev_benchmark: Benchmark results from the previous strategy (unused). + :return: AsyncConstantStrategy or AsyncPoissonStrategy for next rate, + or None if all rates completed. + :raises ValueError: If strategy_type is neither 'constant' nor 'poisson'. + """ + rate = self.rate if isinstance(self.rate, list) else [self.rate] + + if len(self.completed_strategies) >= len(rate): return None + current_rate = rate[len(self.completed_strategies)] + if self.strategy_type == "constant": return AsyncConstantStrategy( - rate=rate[self.completed_strategies], - initial_burst=self.initial_burst, + rate=current_rate, + startup_duration=self.startup_duration, max_concurrency=self.max_concurrency, ) elif self.strategy_type == "poisson": return AsyncPoissonStrategy( - rate=rate[self.completed_strategies], - initial_burst=self.initial_burst, + rate=current_rate, + startup_duration=self.startup_duration, max_concurrency=self.max_concurrency, random_seed=self.random_seed, ) else: raise ValueError(f"Invalid strategy type: {self.strategy_type}") - @staticmethod - def from_standard_args( # type: ignore[override] - rate_type: Union[StrategyType, ProfileType], - rate: Optional[Union[float, Sequence[float]]], - random_seed: int, - **kwargs, - ) -> "AsyncProfile": - if rate_type not in ("async", "constant", "poisson"): - raise ValueError( - "Rate type must be in ('async', 'constant', 'poisson') " - f"for async profile. Received: {rate_type}" - ) - - if not rate: - raise ValueError("Rate must be provided for async profile.") - - if not isinstance(rate, Sequence): - rate = [rate] - - if not all(isinstance(r, (float, int)) and r > 0 for r in rate): - raise ValueError( - f"All rate values must be positive numbers, received {rate}" - ) - if rate_type == "async": - rate_type = "constant" # default to constant if not specified +@Profile.register("sweep") +class SweepProfile(Profile[StrategyT, BenchmarkT]): + """ + Adaptive multi-strategy sweep execution profile with rate discovery. + """ - return AsyncProfile( - strategy_type=rate_type, # type: ignore[arg-type] - rate=rate, - random_seed=random_seed, - **kwargs, - ) - - -class SweepProfile(AsyncProfile): type_: Literal["sweep"] = "sweep" # type: ignore[assignment] sweep_size: int = Field( - description="The number of strategies to generate for the sweep.", + description="Number of strategies to generate for the sweep", + ) + strategy_type: Literal["constant", "poisson"] = "constant" + startup_duration: float = Field( + default=0.0, + description=( + "Duration in seconds for distributing startup requests " + "to converge quickly to desired rate" + ), + ge=0, + ) + max_concurrency: Optional[int] = Field( + default=None, + description="Maximum number of concurrent requests to schedule", + gt=0, + ) + random_seed: int = Field( + default=42, + description="Random seed for Poisson distribution strategy", ) - rate: float = -1 - rate_type: Literal["constant", "poisson"] = "constant" + synchronous_rate: float = Field( + default=-1.0, + description="Measured rate from synchronous strategy execution", + ) + throughput_rate: float = Field( + default=-1.0, + description="Measured rate from throughput strategy execution", + ) + async_rates: list[float] = Field( + default_factory=list, + description="Generated rates for async strategy sweep", + ) + measured_rates: list[float] = Field( + default_factory=list, + description="Calculated interpolated rates between synchronous and throughput", + ) + + @classmethod + def resolve_args( + cls, + rate_type: str, + rate: Optional[Union[float, int, list[float, int]]], + random_seed: int, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Resolve arguments for sweep profile construction. + + :param rate_type: The type/strategy for async strategies in the sweep. + :param rate: Rate parameter (ignored for sweep). + :param random_seed: Random seed for stochastic strategies. + :param kwargs: Additional arguments to pass through. + :return: Dictionary of resolved arguments. + """ + kwargs["sweep_size"] = kwargs.get("sweep_size", rate) + kwargs["random_seed"] = random_seed + if rate_type in ["constant", "poisson"]: + kwargs["strategy_type"] = rate_type + return kwargs @property def strategy_types(self) -> list[StrategyType]: - return ( - ["synchronous"] + ["throughput"] + [self.rate_type] * (self.sweep_size - 2) # type: ignore[return-value] - ) - - def next_strategy(self) -> Optional[SchedulingStrategy]: - if self.completed_strategies >= self.sweep_size: - return None - - if self.completed_strategies == 0: + """Get strategy types for the complete sweep sequence.""" + types = ["synchronous", "throughput"] + types += [self.strategy_type] * (self.sweep_size - len(types)) + return types + + def next_strategy( + self, + prev_strategy: Optional[StrategyT], + prev_benchmark: Optional[BenchmarkT], + ) -> Optional[StrategyT]: + """ + Generate the next strategy in the adaptive sweep sequence. + + Executes synchronous and throughput strategies first to measure + baseline rates, then generates interpolated rates for async strategies. + + :param prev_strategy: The previously completed strategy. + :param prev_benchmark: Benchmark results from the previous strategy. + :return: Next strategy in sweep sequence, or None if complete. + :raises ValueError: If strategy_type is neither 'constant' nor 'poisson'. + """ + if prev_strategy is None: return SynchronousStrategy() - if self.completed_strategies == 1: + if prev_strategy.type_ == "synchronous": + self.synchronous_rate = ( + prev_benchmark.metrics.requests_per_second.successful.mean + ) + return ThroughputStrategy( max_concurrency=self.max_concurrency, + startup_duration=self.startup_duration, ) - min_rate = self.measured_rates[0] - max_rate = self.measured_rates[1] - rates = np.linspace(min_rate, max_rate, self.sweep_size - 1)[1:] + if prev_strategy.type_ == "throughput": + self.throughput_rate = ( + prev_benchmark.metrics.requests_per_second.successful.mean + ) + self.measured_rates = list( + np.linspace( + self.synchronous_rate, + self.throughput_rate, + self.sweep_size - 1, + ) + )[1:] # don't rerun synchronous - if self.rate_type == "constant": + if len(self.completed_strategies) >= self.sweep_size: + return None + + next_rate_index = len( + [ + strat + for strat in self.completed_strategies + if strat.type_ == self.strategy_type + ] + ) + + if self.strategy_type == "constant": return AsyncConstantStrategy( - rate=rates[self.completed_strategies - 2], - initial_burst=self.initial_burst, + rate=self.measured_rates[next_rate_index], + startup_duration=self.startup_duration, max_concurrency=self.max_concurrency, ) - elif self.rate_type == "poisson": + elif self.strategy_type == "poisson": return AsyncPoissonStrategy( - rate=rates[self.completed_strategies - 2], - initial_burst=self.initial_burst, + rate=self.measured_rates[next_rate_index], + startup_duration=self.startup_duration, max_concurrency=self.max_concurrency, + random_seed=self.random_seed, ) else: - raise ValueError(f"Invalid strategy type: {self.rate_type}") - - @staticmethod - def from_standard_args( # type: ignore[override] - rate_type: Union[StrategyType, ProfileType], - rate: Optional[Union[float, Sequence[float]]], - random_seed: int, - **kwargs, - ) -> "SweepProfile": - if rate_type != "sweep": - raise ValueError("Rate type must be 'sweep' for sweep profile.") - - if "sweep_size" in kwargs: - raise ValueError("Sweep size must not be provided, use rate instead.") - - if isinstance(rate, Sequence): - if len(rate) != 1: - raise ValueError( - "Rate must be a single value for sweep profile, received " - f"{len(rate)} values." - ) - rate = rate[0] - - if not rate: - rate = settings.default_sweep_number - - if not rate: - raise ValueError( - "Rate (sweep_size) must be provided for concurrent profile." - ) - - if ( - not isinstance(rate, (int, float)) - or (isinstance(rate, float) and not rate.is_integer()) - or rate <= 1 - ): - raise ValueError( - f"Rate (sweep_size) must be a positive integer > 1, received {rate} " - f"with type {type(rate)}" - ) - - if not kwargs: - kwargs = {} - - if "strategy_type" not in kwargs: - kwargs["strategy_type"] = "constant" - - return SweepProfile(sweep_size=int(rate), random_seed=random_seed, **kwargs) - - -def create_profile( - rate_type: Union[StrategyType, ProfileType], - rate: Optional[Union[float, Sequence[float]]], - random_seed: int = 42, - **kwargs, -) -> "Profile": - if rate_type == "synchronous": - return SynchronousProfile.from_standard_args( - rate_type=rate_type, - rate=rate, - **kwargs, - ) - - if rate_type == "concurrent": - return ConcurrentProfile.from_standard_args( - rate_type=rate_type, - rate=rate, - **kwargs, - ) - - if rate_type == "throughput": - return ThroughputProfile.from_standard_args( - rate_type=rate_type, - rate=rate, - **kwargs, - ) - - if rate_type in ("async", "constant", "poisson"): - return AsyncProfile.from_standard_args( - rate_type=rate_type, - rate=rate, - random_seed=random_seed, - **kwargs, - ) - - if rate_type == "sweep": - return SweepProfile.from_standard_args( - rate_type=rate_type, - rate=rate, - random_seed=random_seed, - **kwargs, - ) - - raise ValueError(f"Invalid profile type: {rate_type}") + raise ValueError(f"Invalid strategy type: {self.strategy_type}") diff --git a/src/guidellm/benchmark/progress.py b/src/guidellm/benchmark/progress.py index d6f437e1..ef867a63 100644 --- a/src/guidellm/benchmark/progress.py +++ b/src/guidellm/benchmark/progress.py @@ -1,8 +1,27 @@ -import math -import time +""" +Benchmark progress tracking and console display abstractions. + +Provides progress tracking interfaces and implementations for monitoring benchmark +execution, displaying real-time statistics, and managing UI updates during +generative benchmarking operations. + +Classes: + BenchmarkerProgress: Abstract base for benchmark progress tracking. + BenchmarkerProgressGroup: Composite progress handler for multiple instances. + GenerativeConsoleBenchmarkerProgress: Console-based progress display. + +Type Variables: + BenchmarkT: Generic benchmark object type. +""" + +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from collections.abc import AsyncIterable, AsyncIterator, Iterable from dataclasses import dataclass from datetime import datetime -from typing import Generic, Optional, TypeVar, Union +from typing import Any, Generic, Literal from rich.console import Group from rich.live import Live @@ -10,7 +29,6 @@ from rich.progress import ( BarColumn, Progress, - ProgressColumn, SpinnerColumn, TaskID, TaskProgressColumn, @@ -19,123 +37,607 @@ TimeRemainingColumn, ) -from guidellm.benchmark.aggregator import ( - BenchmarkAggregator, - GenerativeBenchmarkAggregator, -) -from guidellm.benchmark.benchmark import Benchmark, GenerativeBenchmark -from guidellm.benchmark.benchmarker import BenchmarkerResult +from guidellm.benchmark.benchmark import BenchmarkT, GenerativeBenchmark +from guidellm.benchmark.profile import Profile from guidellm.scheduler import ( + SchedulerState, SchedulingStrategy, StrategyType, - strategy_display_str, ) -from guidellm.utils import Colors +from guidellm.utils import Colors, format_value_display __all__ = [ - "BenchmarkerProgressDisplay", - "BenchmarkerTaskProgressState", - "GenerativeTextBenchmarkerProgressDisplay", - "GenerativeTextBenchmarkerTaskProgressState", + "BenchmarkerProgress", + "BenchmarkerProgressGroup", + "GenerativeConsoleBenchmarkerProgress", ] -@dataclass -class BenchmarkerTaskProgressState: - display_scheduler_stats: bool - - task_id: TaskID - strategy: Union[StrategyType, SchedulingStrategy] - started: bool = False - compiling: bool = False - ended: bool = False - - start_time: Optional[float] = None - max_number: Optional[float] = None - max_duration: Optional[float] = None - in_warmup: bool = False - in_cooldown: bool = False - - requests_rate: float = 0 - request_latency: float = 0 - requests_processing: float = 0 - requests_successful: float = 0 - requests_incomplete: float = 0 - requests_errored: float = 0 - - worker_overheads_time_ms: float = 0.0 - backend_overheads_time_ms: float = 0.0 - requests_sleep_time_ms: float = 0.0 - requests_targeted_start_time_delay_ms: float = 0.0 +class BenchmarkerProgress(Generic[BenchmarkT], ABC): + """ + Abstract base class for tracking and displaying benchmark progress. + + Provides lifecycle hooks for monitoring benchmark execution stages including + initialization, start, updates, completion, and finalization. Supports + enable/disable functionality for conditional progress tracking. + """ + + def __init__(self, enabled: bool = True): + """ + Initialize progress tracker. + + :param enabled: Whether to enable progress tracking and display. + """ + self.enabled = enabled + self.profile: Profile = None + self.current_strategy: SchedulingStrategy = None @property - def description(self) -> str: - return strategy_display_str(self.strategy) + def enabled(self) -> bool: + """ + :return: Whether progress tracking is currently enabled. + """ + return self._enabled + + @enabled.setter + def enabled(self, value: bool) -> None: + """ + :param value: True to enable progress tracking, False to disable. + :raises RuntimeError: If called after progress run has started. + """ + if self.profile is not None: + raise RuntimeError( + "Cannot change enabled state after __call__ for progress run" + ) + + self._enabled = value + + def __call__( + self, + profile: Profile, + agen: AsyncIterable[ + tuple[ + dict[str, Any], + BenchmarkT | None, + SchedulingStrategy, + SchedulerState | None, + ] + ], + ) -> AsyncIterator[ + tuple[ + dict[str, Any], + BenchmarkT | None, + SchedulingStrategy, + SchedulerState | None, + ] + ]: + """ + Track progress through benchmark execution pipeline. + + Wraps the provided async generator to monitor benchmark progress, + calling appropriate lifecycle hooks based on execution state. + + :param profile: Benchmark profile configuration. + :param agen: Async generator yielding benchmark execution updates. + :return: Async iterator forwarding original updates with progress tracking. + """ + + async def aiterator() -> AsyncIterator[ + tuple[ + dict[str, Any], + BenchmarkT | None, + SchedulingStrategy, + SchedulerState | None, + ] + ]: + self.profile = profile + if self.enabled: + await self.on_initialize(profile) + + async for aggregator_update, benchmark, strategy, scheduler_state in agen: + if self.enabled: + await self.on_raw_update( + profile, + aggregator_update, + benchmark, + strategy, + scheduler_state, + ) + + if self.current_strategy != strategy: + self.current_strategy = strategy + await self.on_benchmark_start(strategy) + elif benchmark is not None: + await self.on_benchmark_complete(benchmark) + self.current_strategy = None + else: + await self.on_benchmark_update( + aggregator_update, scheduler_state + ) + + yield aggregator_update, benchmark, strategy, scheduler_state + + if self.enabled: + await self.on_finalize() + + return aiterator() + + @abstractmethod + async def on_initialize(self, profile: Profile): + """ + Initialize progress tracking for benchmark profile. + + :param profile: Benchmark profile configuration. + """ + + @abstractmethod + async def on_benchmark_start(self, strategy: SchedulingStrategy): + """ + Handle start of new benchmark strategy execution. + + :param strategy: Scheduling strategy being executed. + """ + + @abstractmethod + async def on_benchmark_update( + self, aggregator_update: dict[str, Any], scheduler_state: SchedulerState + ): + """ + Handle benchmark execution progress update. + + :param aggregator_update: Current benchmark metrics and statistics. + :param scheduler_state: Current scheduler execution state. + """ + + @abstractmethod + async def on_benchmark_complete(self, benchmark: BenchmarkT): + """ + Handle completion of benchmark strategy execution. + + :param benchmark: Completed benchmark results. + """ + + @abstractmethod + async def on_finalize(self): + """Finalize progress tracking and cleanup resources.""" + + async def on_raw_update( + self, + profile: Profile, + aggregator_update: dict[str, Any], + benchmark: BenchmarkT | None, + strategy: SchedulingStrategy, + scheduler_state: SchedulerState | None, + ): + """ + Handle raw benchmark execution update. + + Optional hook for accessing all execution state updates. Default + implementation does nothing. + + :param profile: Benchmark profile configuration. + :param aggregator_update: Current benchmark metrics and statistics. + :param benchmark: Completed benchmark if available. + :param strategy: Current scheduling strategy. + :param scheduler_state: Current scheduler execution state. + """ + + +class BenchmarkerProgressGroup(BenchmarkerProgress[BenchmarkT]): + """ + Composite progress handler that manages multiple progress instances. + + Distributes progress events to all contained progress instances, enabling + parallel progress tracking through multiple channels (e.g., console display + and file logging). + + :param instances: Collection of progress handlers to manage. + :param enabled: Whether the group is active. + """ + + def __init__( + self, + instances: ( + Iterable[BenchmarkerProgress[BenchmarkT]] + | list[BenchmarkerProgress[BenchmarkT]] + ), + enabled: bool = True, + ): + """ + Initialize progress group with handler instances. + + :param instances: Progress handler instances to coordinate. + :param enabled: Whether to enable the progress group. + """ + self.instances: list[BenchmarkerProgress[BenchmarkT]] = list(instances) + super().__init__(enabled=enabled) @property - def total(self) -> Optional[float]: - if self.max_number is None and self.max_duration is None: - return None + def enabled(self) -> bool: + """Whether the progress group is currently enabled.""" + return self._enabled + + @enabled.setter + def enabled(self, value: bool): + """ + Set enabled state for group and all contained instances. + + :param value: New enabled state. + """ + self._enabled = value + for instance in self.instances: + instance.enabled = value - return 1000 + async def on_initialize(self, profile: Profile): + """ + Initialize all progress handler instances. + + :param profile: Benchmark profile configuration. + """ + await asyncio.gather( + *[child.on_initialize(profile) for child in self.instances] + ) + + async def on_benchmark_start(self, strategy: SchedulingStrategy): + """ + Notify all handlers of benchmark strategy start. + + :param strategy: Scheduling strategy being executed. + """ + await asyncio.gather( + *[child.on_benchmark_start(strategy) for child in self.instances] + ) + + async def on_benchmark_update( + self, aggregator_update: dict[str, Any], scheduler_state: SchedulerState + ): + """ + Distribute benchmark updates to all handlers. + + :param aggregator_update: Current benchmark metrics and statistics. + :param scheduler_state: Current scheduler execution state. + """ + await asyncio.gather( + *[ + child.on_benchmark_update(aggregator_update, scheduler_state) + for child in self.instances + ] + ) + + async def on_benchmark_complete(self, benchmark: BenchmarkT): + """ + Notify all handlers of benchmark completion. + + :param benchmark: Completed benchmark results. + """ + await asyncio.gather( + *[child.on_benchmark_complete(benchmark) for child in self.instances] + ) + + async def on_finalize(self): + """Finalize all progress handler instances.""" + await asyncio.gather(*[child.on_finalize() for child in self.instances]) + + async def on_raw_update( + self, + profile: Profile, + aggregator_update: dict[str, Any], + benchmark: BenchmarkT | None, + strategy: SchedulingStrategy, + scheduler_state: SchedulerState | None, + ): + """ + Distribute raw updates to all handlers. + + :param profile: Benchmark profile configuration. + :param aggregator_update: Current benchmark metrics and statistics. + :param benchmark: Completed benchmark if available. + :param strategy: Current scheduling strategy. + :param scheduler_state: Current scheduler execution state. + """ + await asyncio.gather( + *[ + child.on_raw_update( + profile, + aggregator_update, + benchmark, + strategy, + scheduler_state, + ) + for child in self.instances + ] + ) + + +class GenerativeConsoleBenchmarkerProgress( + BenchmarkerProgress[GenerativeBenchmark], Live +): + """ + Console-based progress display for generative benchmarks. + + Provides real-time visual progress tracking using Rich library components, + displaying benchmark execution statistics, timing information, and progress + bars in a structured console interface. + """ + + def __init__(self, enabled: bool = True, display_scheduler_stats: bool = False): + """ + Initialize console progress display. + + :param enabled: Whether to enable progress tracking and display. + :param display_scheduler_stats: Whether to display scheduler statistics. + """ + super(BenchmarkerProgress, self).__init__(enabled=enabled) + super(Live, self).__init__( + refresh_per_second=4, + auto_refresh=True, + redirect_stdout=True, + redirect_stderr=True, + ) + self.display_scheduler_stats: bool = display_scheduler_stats + self.run_progress: Progress = None + self.run_progress_task: TaskID = None + self.tasks_progress: _GenerativeProgressTasks = None + + async def on_initialize(self, profile: Profile): + """ + Initialize console display components and start rendering. + + :param profile: Benchmark profile configuration. + """ + self.tasks_progress = _GenerativeProgressTasks( + profile=profile, display_scheduler_stats=self.display_scheduler_stats + ) + self.run_progress = Progress( + TextColumn("Generating...", style=f"italic {Colors.PROGRESS}"), + BarColumn( + bar_width=None, + complete_style=Colors.PROGRESS, + finished_style=Colors.SUCCESS, + ), + TextColumn( + "({task.fields[completed_benchmarks]}/{task.fields[total_benchmarks]})", + style=Colors.PROGRESS, + ), + TextColumn("["), + TimeElapsedColumn(), + TextColumn("<"), + TimeRemainingColumn(), + TextColumn("]"), + ) + self.run_progress_task = self.run_progress.add_task("") + self._sync_run_progress() + self.update( + Group( + Panel( + self.tasks_progress, + title="Benchmarks", + title_align="left", + expand=True, + ), + self.run_progress, + ) + ) + self.start() + + async def on_benchmark_start(self, strategy: SchedulingStrategy): + """ + Update display for new benchmark strategy start. + + :param strategy: Scheduling strategy being executed. + """ + self.tasks_progress.start_benchmark(strategy) + self._sync_run_progress() + + async def on_benchmark_update( + self, aggregator_update: dict[str, Any], scheduler_state: SchedulerState + ): + """ + Update display with current benchmark progress. + + :param aggregator_update: Current benchmark metrics and statistics. + :param scheduler_state: Current scheduler execution state. + """ + self.tasks_progress.update_benchmark(aggregator_update, scheduler_state) + self._sync_run_progress() + + async def on_benchmark_complete(self, benchmark: GenerativeBenchmark): + """ + Update display for completed benchmark. + + :param benchmark: Completed benchmark results. + """ + self.tasks_progress.complete_benchmark(benchmark) + self._sync_run_progress() + + async def on_finalize(self): + """Stop display rendering and cleanup resources.""" + self.tasks_progress.finalize() + self._sync_run_progress() + self.run_progress.stop_task(self.run_progress_task) + self.stop() + self.run_progress = None + self.run_progress_task = None + self.tasks_progress = None + + def _sync_run_progress(self): + """Synchronize overall progress display with task progress.""" + self.run_progress.update( + self.run_progress_task, + total=self.tasks_progress.steps_total, + completed=self.tasks_progress.steps_progress, + completed_benchmarks=self.tasks_progress.tasks_progress, + total_benchmarks=self.tasks_progress.tasks_total, + ) + + +# Scaling factor for progress calculations to provide granular progress updates +_PROGRESS_SCALE = 1000 + + +class _GenerativeProgressTasks(Progress): + def __init__(self, profile: Profile, display_scheduler_stats: bool): + self.profile: Profile = profile + self.display_scheduler_stats: bool = display_scheduler_stats + self.benchmark_task_states: list[_GenerativeProgressTaskState] = [] + self.current_index: int = -1 + + summary_text = "{task.fields[requests_summary]}\n{task.fields[tokens_summary]}" + if self.display_scheduler_stats: + summary_text += "\n{task.fields[scheduler_stats]}" + super().__init__( + TextColumn("[{task.fields[start_time]}]"), + SpinnerColumn(style=Colors.PROGRESS), + TaskProgressColumn(style=Colors.PROGRESS), + TextColumn("{task.description}"), + TextColumn("({task.fields[progress_status]})"), + TextColumn(" "), + TextColumn(summary_text), + ) + + for strategy_type in profile.strategy_types: + task_state = _GenerativeProgressTaskState( + strategy_type=strategy_type, + ) + task_id = self.add_task(**task_state.current) + task_state.task_id = task_id + self.benchmark_task_states.append(task_state) + + @property + def tasks_total(self) -> int: + return len(self.benchmark_task_states) @property - def completed(self) -> int: - if self.ended: - return 1000 + def tasks_progress(self) -> int: + return self.current_index + 1 - if self.max_number is None and self.max_duration is None: - return 0 + @property + def steps_total(self) -> int: + return _PROGRESS_SCALE * len(self.benchmark_task_states) - number = self.requests_successful + self.requests_errored - number_percent = ( - number / float(self.max_number) * 1000 if self.max_number else -math.inf + @property + def steps_progress(self) -> int: + progress_current_task = ( + self.benchmark_task_states[self.current_index].progress + if self.current_index < len(self.benchmark_task_states) + else 0 + ) + progress_total = self.current_index + (progress_current_task or 0) + + return progress_total * _PROGRESS_SCALE + + def start_benchmark(self, strategy: SchedulingStrategy): + self.current_index += 1 + if self.current_index >= len(self.benchmark_task_states): + # New task past initially estimated, append it to the end + task_state = _GenerativeProgressTaskState(strategy_type=strategy.type_) + task_id = self.add_task(**task_state.current) + task_state.task_id = task_id + self.benchmark_task_states.append(task_state) + + self.benchmark_task_states[self.current_index].start(strategy) + self.update( + self.benchmark_task_states[self.current_index].task_id, + start=True, + **self.benchmark_task_states[self.current_index].current, + ) + + def update_benchmark( + self, aggregator_update: dict[str, Any], scheduler_state: SchedulerState + ): + self.benchmark_task_states[self.current_index].update( + aggregator_update, scheduler_state ) - duration_percent = ( - (time.time() - self.start_time) / self.max_duration * 1000 - if self.max_duration and self.start_time - else -math.inf + self.update( + self.benchmark_task_states[self.current_index].task_id, + **self.benchmark_task_states[self.current_index].current, ) - return min(int(max(number_percent, duration_percent)), 1000) + def complete_benchmark(self, benchmark: GenerativeBenchmark): + self.benchmark_task_states[self.current_index].complete(benchmark) + self.update( + self.benchmark_task_states[self.current_index].task_id, + **self.benchmark_task_states[self.current_index].current, + ) + + def finalize(self): + self.stop() + + +@dataclass +class _GenerativeProgressTaskState: + task_id: TaskID = None + strategy_type: StrategyType + strategy: SchedulingStrategy | None = None + benchmark_status: Literal[ + "pending", "in_warmup", "in_progress", "in_cooldown", "completed" + ] = "pending" + progress: float | None = None + start_time: float = -1.0 + successful_requests: int = -1 + cancelled_requests: int = -1 + errored_requests: int = -1 + request_concurrency: int = -1 + requests_per_second: float = -1.0 + request_latency: float = -1.0 + output_tokens: int = -1 + output_tokens_rate: float = -1.0 + prompt_tokens: int = -1 + total_tokens_rate: float = -1.0 + time_to_first_token: float = -1.0 + inter_token_latency: float = -1.0 + queued_time: float = -1.0 + request_targeted_start_delay: float = -1.0 + scheduler_overheads_time: float = -1.0 @property - def fields(self) -> dict[str, str]: - fields = { + def current(self) -> dict[str, Any]: + return { "start_time": self.formatted_start_time, + "description": str(self.strategy or self.strategy_type), "progress_status": self.formatted_progress_status, "requests_summary": self.formatted_requests_summary, + "tokens_summary": self.formatted_tokens_summary, + "scheduler_stats": self.formatted_scheduler_stats, + "completed": None, + "total": None, } - if self.display_scheduler_stats: - fields["scheduler_stats"] = self.formatted_scheduler_stats + @property + def completed(self) -> float | None: + if self.benchmark_status == "pending": + return None + + if self.benchmark_status == "completed": + return _PROGRESS_SCALE - return fields + return self.progress * _PROGRESS_SCALE if self.progress is not None else None + + @property + def total(self) -> float | None: + return _PROGRESS_SCALE if self.benchmark_status != "pending" else None @property def formatted_start_time(self) -> str: - if self.start_time is None: + if self.start_time < 0.0: return "--:--:--" return datetime.fromtimestamp(self.start_time).strftime("%H:%M:%S") @property def formatted_progress_status(self) -> str: - if self.ended: - status = "complete" - color = Colors.SUCCESS - elif self.compiling: - status = "compiling" - color = Colors.PROGRESS - elif self.started and self.in_warmup: + if self.benchmark_status == "in_warmup": status = "warmup" color = Colors.PROGRESS - elif self.started and self.in_cooldown: - status = "cooldown" - color = Colors.PROGRESS - elif self.started: + elif self.benchmark_status == "in_progress": status = "running" color = Colors.PROGRESS + elif self.benchmark_status == "in_cooldown": + status = "cooldown" + color = Colors.PROGRESS + elif self.benchmark_status == "completed": + status = "complete" + color = Colors.SUCCESS else: status = "pending" color = Colors.INFO @@ -144,20 +646,20 @@ def formatted_progress_status(self) -> str: @property def formatted_requests_summary(self) -> str: - if not self.started: + if self.benchmark_status == "pending": return " " return ( f"[{Colors.INFO}]Req:[/{Colors.INFO}] " - + BenchmarkerTaskProgressState.format_progress_display( - value=self.requests_rate, + + format_value_display( + value=self.requests_per_second, label="req/s", total_characters=12, digits_places=4, decimal_places=1, ) + ", " - + BenchmarkerTaskProgressState.format_progress_display( + + format_value_display( value=self.request_latency, label="Lat", units="s", @@ -166,32 +668,32 @@ def formatted_requests_summary(self) -> str: decimal_places=2, ) + ", " - + BenchmarkerTaskProgressState.format_progress_display( - value=self.requests_processing, + + format_value_display( + value=self.request_concurrency, label="Conc", total_characters=12, digits_places=4, decimal_places=1, ) + ", " - + BenchmarkerTaskProgressState.format_progress_display( - value=self.requests_successful, + + format_value_display( + value=self.successful_requests, label="Comp", total_characters=12, digits_places=5, decimal_places=0, ) + ", " - + BenchmarkerTaskProgressState.format_progress_display( - value=self.requests_incomplete, + + format_value_display( + value=self.cancelled_requests, label="Inc", total_characters=12, digits_places=5, decimal_places=0, ) + ", " - + BenchmarkerTaskProgressState.format_progress_display( - value=self.requests_errored, + + format_value_display( + value=self.errored_requests, label="Err", total_characters=12, digits_places=5, @@ -199,101 +701,14 @@ def formatted_requests_summary(self) -> str: ) ) - @property - def formatted_scheduler_stats(self) -> str: - if not self.started: - return " " - - return ( - f"[{Colors.INFO}]Sys:[/{Colors.INFO}] " - + BenchmarkerTaskProgressState.format_progress_display( - value=self.worker_overheads_time_ms, - label="Work OH", - units="ms", - total_characters=18, - digits_places=3, - decimal_places=1, - ) - + ", " - + BenchmarkerTaskProgressState.format_progress_display( - value=self.backend_overheads_time_ms, - label="Back OH", - units="ms", - total_characters=18, - digits_places=3, - decimal_places=1, - ) - + ", " - + BenchmarkerTaskProgressState.format_progress_display( - value=self.requests_sleep_time_ms, - label="Req Sleep", - units="ms", - total_characters=18, - digits_places=5, - decimal_places=0, - ) - + ", " - + BenchmarkerTaskProgressState.format_progress_display( - value=self.requests_targeted_start_time_delay_ms, - label="Start Del", - units="ms", - total_characters=18, - digits_places=5, - decimal_places=0, - ) - ) - - @staticmethod - def format_progress_display( - value: float, - label: str, - units: str = "", - total_characters: Optional[int] = None, - digits_places: Optional[int] = None, - decimal_places: Optional[int] = None, - ) -> str: - if decimal_places is None and digits_places is None: - formatted_number = f"{value}:.0f" - elif digits_places is None: - formatted_number = f"{value:.{decimal_places}f}" - elif decimal_places is None: - formatted_number = f"{value:>{digits_places}f}" - else: - formatted_number = f"{value:>{digits_places}.{decimal_places}f}" - - result = f"{formatted_number}{units} [{Colors.INFO}]{label}[/{Colors.INFO}]" - - if total_characters is not None: - total_characters += len(Colors.INFO) * 2 + 5 - - if len(result) < total_characters: - result = result.rjust(total_characters) - - return result - - -class GenerativeTextBenchmarkerTaskProgressState(BenchmarkerTaskProgressState): - output_tokens: float = 0 - prompt_tokens: float = 0 - output_tokens_rate: float = 0 - total_tokens_rate: float = 0 - tokens_ttft: float = 0 - tokens_itl: float = 0 - - @property - def fields(self) -> dict[str, str]: - fields = super().fields - fields["tokens_summary"] = self.formatted_tokens_summary - return fields - @property def formatted_tokens_summary(self) -> str: - if not self.started: + if self.benchmark_status == "pending": return " " return ( f"[{Colors.INFO}]Tok:[/{Colors.INFO}] " - + BenchmarkerTaskProgressState.format_progress_display( + + format_value_display( value=self.output_tokens_rate, label="gen/s", total_characters=12, @@ -301,7 +716,7 @@ def formatted_tokens_summary(self) -> str: decimal_places=1, ) + ", " - + BenchmarkerTaskProgressState.format_progress_display( + + format_value_display( value=self.total_tokens_rate, label="tot/s", total_characters=12, @@ -309,8 +724,8 @@ def formatted_tokens_summary(self) -> str: decimal_places=1, ) + ", " - + BenchmarkerTaskProgressState.format_progress_display( - value=self.tokens_ttft, + + format_value_display( + value=self.time_to_first_token, label="TTFT", units="ms", total_characters=12, @@ -318,8 +733,8 @@ def formatted_tokens_summary(self) -> str: decimal_places=1, ) + ", " - + BenchmarkerTaskProgressState.format_progress_display( - value=self.tokens_itl, + + format_value_display( + value=self.inter_token_latency, label="ITL", units="ms", total_characters=12, @@ -327,7 +742,7 @@ def formatted_tokens_summary(self) -> str: decimal_places=1, ) + ", " - + BenchmarkerTaskProgressState.format_progress_display( + + format_value_display( value=self.prompt_tokens, label="Prompt", total_characters=12, @@ -335,7 +750,7 @@ def formatted_tokens_summary(self) -> str: decimal_places=0, ) + ", " - + BenchmarkerTaskProgressState.format_progress_display( + + format_value_display( value=self.output_tokens, label="Gen", total_characters=12, @@ -344,377 +759,184 @@ def formatted_tokens_summary(self) -> str: ) ) + @property + def formatted_scheduler_stats(self) -> str: + if self.benchmark_status == "pending": + return " " -BTPS = TypeVar("BTPS", bound=BenchmarkerTaskProgressState) - - -class BenchmarkerProgressDisplay(Generic[BTPS]): - def __init__(self, display_scheduler_stats: bool): - self.display_scheduler_stats = display_scheduler_stats - self.started = False - self.benchmarker_tasks_progress = Progress(*self.create_task_progress_columns()) - self.benchmarker_tasks_panel = Panel( - self.benchmarker_tasks_progress, - title="Benchmarks", - title_align="left", - expand=True, - ) - self.benchmarker_progress = Progress( - TextColumn("Generating...", style=f"italic {Colors.PROGRESS}"), - BarColumn( - bar_width=None, - complete_style=Colors.PROGRESS, - finished_style=Colors.SUCCESS, - ), - TextColumn( - "({task.fields[completed_benchmarks]}/{task.fields[total_benchmarks]})", - style=Colors.PROGRESS, - ), - TextColumn("["), - TimeElapsedColumn(), - TextColumn("<"), - TimeRemainingColumn(), - TextColumn("]"), - ) - self.benchmarker_live = Live( - Group( - self.benchmarker_tasks_panel, - self.benchmarker_progress, - ), - redirect_stdout=True, - redirect_stderr=True, - ) - self.active_task: Optional[TaskID] = None - self.benchmarker_tasks: list[BTPS] = [] - self.progress_task: Optional[TaskID] = None - - def update(self, result: BenchmarkerResult): - if result.type_ == "run_start": - if self.started: - raise RuntimeError("Progress display already started.") - - self.handle_start(result) - self.started = True - elif result.type_ == "run_complete": - if not self.started: - raise RuntimeError("Progress display not started.") - - self.handle_end(result) - self.started = False - else: - if not self.started: - raise RuntimeError("Progress display not started.") - - self.handle_update(result) - - def handle_start(self, result: BenchmarkerResult): - self.benchmarker_live.start() - - for index, strategy_type in enumerate(result.profile.strategy_types): - task_id = self.benchmarker_tasks_progress.add_task( - description=strategy_type, - start=False, - total=None, - completed=0, - visible=False, + return ( + f"[{Colors.INFO}]Sys:[/{Colors.INFO}] , " + + format_value_display( + value=self.request_targeted_start_delay, + label="Start Del", + units="ms", + total_characters=18, + digits_places=5, + decimal_places=0, ) - task_progress_state = self.create_task_progress_state( - task_id=task_id, - index=index, - strategy_type=strategy_type, - result=result, + + format_value_display( + value=self.scheduler_overheads_time, + label="Sched OH", + units="ms", + total_characters=18, + digits_places=3, + decimal_places=1, ) - self.benchmarker_tasks.append(task_progress_state) - self.benchmarker_tasks_progress.update( - task_id, - description=task_progress_state.description, - visible=True, - **task_progress_state.fields, # type: ignore[arg-type] + + ", " + + format_value_display( + value=self.queued_time, + label="Queued", + units="ms", + total_characters=18, + digits_places=5, + decimal_places=0, ) - - self.progress_task = self.benchmarker_progress.add_task( - "", - total=len(self.benchmarker_tasks) * 1000, - completed_benchmarks=0, - total_benchmarks=len(self.benchmarker_tasks), ) - def handle_update(self, result: BenchmarkerResult): - current_state: BTPS = self.benchmarker_tasks[result.current_index] - - if result.type_ == "scheduler_start": - self.handle_update_scheduler_start(current_state, result) - self.active_task = current_state.task_id - elif result.type_ == "scheduler_update": - self.handle_update_scheduler_update(current_state, result) - elif result.type_ == "scheduler_complete": - self.handle_update_scheduler_complete(current_state, result) - elif result.type_ == "benchmark_compiled": - self.handle_update_benchmark_compiled(current_state, result) - else: - raise ValueError(f"Unknown result type: {result.type_}") - - if self.progress_task is None: - raise RuntimeError("Progress task not set.") - - self.benchmarker_tasks_progress.update( - current_state.task_id, - description=current_state.description, - completed=current_state.completed, - total=current_state.total, - **current_state.fields, # type: ignore[arg-type] - ) - self.benchmarker_progress.update( - self.progress_task, - completed=(result.current_index * 1000) + current_state.completed, - total=1000 * len(self.benchmarker_tasks), - completed_benchmarks=( - result.current_index + (1 if current_state.ended else 0) - ), - total_benchmarks=len(self.benchmarker_tasks), - ) - - if current_state.ended: - self.benchmarker_tasks_progress.stop_task(current_state.task_id) - self.active_task = None - - def handle_update_scheduler_start( - self, progress_state: BTPS, result: BenchmarkerResult - ): - if self.active_task is not None: - raise RuntimeError("Active task already set.") - - progress_state.strategy = result.current_strategy # type: ignore[assignment] - progress_state.started = True - current_aggregator: BenchmarkAggregator = result.current_aggregator # type: ignore[assignment] - progress_state.start_time = ( - current_aggregator.requests_stats.totals.total.start_time - ) - progress_state.max_number = current_aggregator.args.max_number - progress_state.max_duration = current_aggregator.args.max_duration + def start(self, strategy: SchedulingStrategy): + self.strategy = strategy + self.strategy_type = strategy.type_ - def handle_update_scheduler_update( - self, progress_state: BTPS, result: BenchmarkerResult + def update( + self, aggregator_update: dict[str, Any], scheduler_state: SchedulerState ): - if self.active_task is None: - raise RuntimeError("Active task not set.") - - if self.active_task != progress_state.task_id: - raise RuntimeError("Active task does not match current task.") - - current_aggregator: BenchmarkAggregator = result.current_aggregator # type: ignore[assignment] - progress_state.in_warmup = current_aggregator.in_warmup - progress_state.in_cooldown = current_aggregator.in_cooldown - progress_state.requests_rate = ( - current_aggregator.requests_stats.totals.successful.rate - ) - progress_state.request_latency = ( - current_aggregator.requests_stats.request_time.mean - ) - progress_state.requests_processing = ( - current_aggregator.scheduler_stats.processing_requests.last - ) - progress_state.requests_successful = ( - current_aggregator.requests_stats.totals.successful.total - ) - progress_state.requests_incomplete = ( - current_aggregator.requests_stats.totals.incomplete.total - ) - progress_state.requests_errored = ( - current_aggregator.requests_stats.totals.errored.total - ) - progress_state.worker_overheads_time_ms = ( - current_aggregator.requests_stats.scheduled_time_delay.mean_ms - + current_aggregator.requests_stats.worker_start_delay.mean_ms - ) - progress_state.backend_overheads_time_ms = ( - current_aggregator.requests_stats.request_time_delay.mean_ms - ) - progress_state.requests_sleep_time_ms = ( - current_aggregator.requests_stats.scheduled_time_sleep.mean_ms - ) - progress_state.requests_targeted_start_time_delay_ms = ( - current_aggregator.requests_stats.request_start_time_targeted_delay.mean_ms + self.progress = scheduler_state.remaining_fraction + status: Literal["in_warmup", "in_progress", "in_cooldown"] = "in_progress" + if aggregator_update.get("requests_in_warmup"): + status = "in_warmup" + elif aggregator_update.get("requests_in_cooldown"): + status = "in_cooldown" + self._update_processing_states( + benchmark_status=status, + start_time=scheduler_state.start_time, + successful_requests=scheduler_state.successful_requests, + cancelled_requests=scheduler_state.cancelled_requests, + errored_requests=scheduler_state.errored_requests, + ) + self._update_request_stats( + request_concurrency=scheduler_state.processing_requests, + requests_per_second=aggregator_update.get("requests_per_second"), + request_latency=aggregator_update.get("request_latency"), + ) + self._update_token_stats( + output_tokens=aggregator_update.get("output_tokens"), + output_tokens_rate=aggregator_update.get("output_tokens_rate"), + prompt_tokens=aggregator_update.get("prompt_tokens"), + total_tokens_rate=aggregator_update.get("total_tokens_rate"), + time_to_first_token=aggregator_update.get("time_to_first_token") + * 1000, # ms + inter_token_latency=aggregator_update.get("inter_token_latency") * 1000, + ) + self._update_system_stats( + request_targeted_start_delay=aggregator_update.get( + "request_targeted_start_delay" + ), + queued_time=aggregator_update.get("queued_time"), + scheduler_overheads_time=aggregator_update.get("scheduler_overheads_time"), + ) + + def complete(self, benchmark: GenerativeBenchmark): + self._update_processing_states( + benchmark_status="completed", + start_time=benchmark.start_time, + successful_requests=benchmark.request_totals.successful, + cancelled_requests=benchmark.request_totals.incomplete, + errored_requests=benchmark.request_totals.errored, + ) + self._update_request_stats( + request_concurrency=benchmark.metrics.request_concurrency.successful.mean, + requests_per_second=benchmark.metrics.requests_per_second.successful.mean, + request_latency=benchmark.metrics.request_latency.successful.mean, + ) + self._update_token_stats( + output_tokens=benchmark.metrics.output_token_count.successful.mean, + output_tokens_rate=benchmark.metrics.output_tokens_per_second.successful.mean, + prompt_tokens=benchmark.metrics.prompt_token_count.successful.mean, + total_tokens_rate=benchmark.metrics.tokens_per_second.successful.mean, + time_to_first_token=( + benchmark.metrics.time_to_first_token_ms.successful.mean + ), + inter_token_latency=( + benchmark.metrics.inter_token_latency_ms.successful.mean + ), + converted=True, ) - def handle_update_scheduler_complete( + def _update_processing_states( self, - progress_state: BTPS, - result: BenchmarkerResult, # noqa: ARG002 - ): - if self.active_task is None: - raise RuntimeError("Active task not set.") - - if self.active_task != progress_state.task_id: - raise RuntimeError("Active task does not match current task.") - - progress_state.in_warmup = False - progress_state.in_cooldown = False - progress_state.compiling = True - - def handle_update_benchmark_compiled( - self, progress_state: BTPS, result: BenchmarkerResult + benchmark_status: Literal[ + "pending", "in_warmup", "in_progress", "in_cooldown", "completed" + ], + start_time: float | None = None, + successful_requests: int | None = None, + cancelled_requests: int | None = None, + errored_requests: int | None = None, ): - if self.active_task is None: - raise RuntimeError("Active task not set.") - - if self.active_task != progress_state.task_id: - raise RuntimeError("Active task does not match current task.") - - current_benchmark: Benchmark = result.current_benchmark # type: ignore[assignment] - progress_state.compiling = False - progress_state.ended = True - progress_state.requests_rate = ( - current_benchmark.metrics.requests_per_second.successful.mean - ) - progress_state.requests_processing = ( - current_benchmark.metrics.request_concurrency.successful.mean - ) - - def handle_end(self, result: BenchmarkerResult): # noqa: ARG002 - if self.progress_task is None: - raise RuntimeError("Progress task not set.") - - self.benchmarker_progress.update( - self.progress_task, - completed=len(self.benchmarker_tasks) * 1000, - total=len(self.benchmarker_tasks) * 1000, - completed_benchmarks=len(self.benchmarker_tasks), - total_benchmarks=len(self.benchmarker_tasks), - ) - self.benchmarker_progress.stop_task(self.progress_task) - self.benchmarker_live.stop() - self.active_task = None - self.benchmarker_tasks = [] - self.progress_task = None - - def create_task_progress_columns(self) -> list[ProgressColumn]: - columns = [ - TextColumn("[{task.fields[start_time]}]"), - SpinnerColumn(style=Colors.PROGRESS), - TaskProgressColumn(style=Colors.PROGRESS), - TextColumn("{task.description}"), - TextColumn("({task.fields[progress_status]})"), - TextColumn(" "), - ] - - if not self.display_scheduler_stats: - columns += [ - TextColumn("{task.fields[requests_summary]}\n"), - ] - else: - columns += [ - TextColumn( - "{task.fields[requests_summary]}\n{task.fields[scheduler_stats]}\n" - ), - ] - - return columns - - def create_task_progress_state( - self, - task_id: TaskID, - index: int, # noqa: ARG002 - strategy_type: StrategyType, - result: BenchmarkerResult, # noqa: ARG002 - ) -> BTPS: - return BenchmarkerTaskProgressState( # type: ignore[return-value] - display_scheduler_stats=self.display_scheduler_stats, - task_id=task_id, - strategy=strategy_type, - ) - - -class GenerativeTextBenchmarkerProgressDisplay( - BenchmarkerProgressDisplay[GenerativeTextBenchmarkerTaskProgressState] -): - def handle_update_scheduler_update( + self.benchmark_status = benchmark_status + + if start_time is not None: + self.start_time = start_time + if successful_requests is not None: + self.successful_requests = successful_requests + if cancelled_requests is not None: + self.cancelled_requests = cancelled_requests + if errored_requests is not None: + self.errored_requests = errored_requests + + def _update_request_stats( self, - progress_state: GenerativeTextBenchmarkerTaskProgressState, - result: BenchmarkerResult, + request_concurrency: int | None = None, + requests_per_second: float | None = None, + request_latency: float | None = None, ): - super().handle_update_scheduler_update(progress_state, result) - current_aggregator: GenerativeBenchmarkAggregator = result.current_aggregator # type: ignore[assignment] - progress_state.output_tokens = ( - current_aggregator.requests_stats.output_tokens.mean - ) - progress_state.prompt_tokens = ( - current_aggregator.requests_stats.prompt_tokens.mean - ) - progress_state.output_tokens_rate = ( - current_aggregator.requests_stats.output_tokens.rate - ) - progress_state.total_tokens_rate = ( - current_aggregator.requests_stats.total_tokens.rate - ) - progress_state.tokens_ttft = ( - current_aggregator.requests_stats.time_to_first_token.mean_ms - ) - progress_state.tokens_itl = ( - current_aggregator.requests_stats.inter_token_latency.mean_ms - ) - - def handle_update_benchmark_compiled( + if request_concurrency is not None: + self.request_concurrency = request_concurrency + if requests_per_second is not None: + self.requests_per_second = requests_per_second + if request_latency is not None: + self.request_latency = request_latency + + def _update_token_stats( self, - progress_state: GenerativeTextBenchmarkerTaskProgressState, - result: BenchmarkerResult, + output_tokens: int | None = None, + output_tokens_rate: float | None = None, + prompt_tokens: int | None = None, + total_tokens_rate: float | None = None, + time_to_first_token: float | None = None, + inter_token_latency: float | None = None, + converted: bool = False, ): - super().handle_update_benchmark_compiled(progress_state, result) - - current_benchmark: GenerativeBenchmark = result.current_benchmark # type: ignore[assignment] - progress_state.request_latency = ( - current_benchmark.metrics.request_latency.successful.mean - ) - progress_state.requests_successful = current_benchmark.request_totals.successful - progress_state.requests_errored = current_benchmark.request_totals.errored - progress_state.requests_incomplete = current_benchmark.request_totals.incomplete - progress_state.output_tokens = ( - current_benchmark.metrics.output_token_count.successful.mean - ) - progress_state.prompt_tokens = ( - current_benchmark.metrics.prompt_token_count.successful.mean - ) - progress_state.output_tokens_rate = ( - current_benchmark.metrics.output_tokens_per_second.successful.mean - ) - progress_state.total_tokens_rate = ( - current_benchmark.metrics.tokens_per_second.successful.mean - ) - progress_state.tokens_ttft = ( - current_benchmark.metrics.time_to_first_token_ms.successful.mean - ) - progress_state.tokens_itl = ( - current_benchmark.metrics.inter_token_latency_ms.successful.mean - ) + if output_tokens is not None: + self.output_tokens = output_tokens + if output_tokens_rate is not None: + self.output_tokens_rate = output_tokens_rate + if prompt_tokens is not None: + self.prompt_tokens = prompt_tokens + if total_tokens_rate is not None: + self.total_tokens_rate = total_tokens_rate + if time_to_first_token is not None: + self.time_to_first_token = time_to_first_token * ( + 1000 if not converted else 1 + ) + if inter_token_latency is not None: + self.inter_token_latency = inter_token_latency * ( + 1000 if not converted else 1 + ) - def create_task_progress_state( + def _update_system_stats( self, - task_id: TaskID, - index: int, # noqa: ARG002 - strategy_type: StrategyType, - result: BenchmarkerResult, # noqa: ARG002 - ) -> GenerativeTextBenchmarkerTaskProgressState: - return GenerativeTextBenchmarkerTaskProgressState( - display_scheduler_stats=self.display_scheduler_stats, - task_id=task_id, - strategy=strategy_type, - ) - - def create_task_progress_columns(self) -> list[ProgressColumn]: - columns = super().create_task_progress_columns() - columns = columns[:-1] # remove the last display info column - - if not self.display_scheduler_stats: - columns += [ - TextColumn( - "{task.fields[requests_summary]}\n{task.fields[tokens_summary]}", - ), - ] - else: - columns += [ - TextColumn( - "{task.fields[requests_summary]}\n{task.fields[tokens_summary]}\n{task.fields[scheduler_stats]}", - ), - ] - - return columns + request_targeted_start_delay: float | None = None, + queued_time: float | None = None, + scheduler_overheads_time: float | None = None, + converted: bool = False, + ): + if request_targeted_start_delay is not None: + self.request_targeted_start_delay = request_targeted_start_delay * ( + 1000 if not converted else 1 + ) + if queued_time is not None: + self.queued_time = queued_time * (1000 if not converted else 1) + if scheduler_overheads_time is not None: + self.scheduler_overheads_time = scheduler_overheads_time * ( + 1000 if not converted else 1 + ) diff --git a/src/guidellm/config.py b/src/guidellm/config.py index beda55fc..6460821d 100644 --- a/src/guidellm/config.py +++ b/src/guidellm/config.py @@ -133,17 +133,15 @@ class Settings(BaseSettings): max_concurrency: int = 512 max_worker_processes: int = 10 max_add_requests_per_loop: int = 20 + scheduler_start_delay_non_distributed: float = 0.1 + scheduler_poll_interval: float = 0.05 # Data settings dataset: DatasetSettings = DatasetSettings() # Request/stats settings - preferred_prompt_tokens_source: Optional[ - Literal["request", "response", "local"] - ] = "response" - preferred_output_tokens_source: Optional[ - Literal["request", "response", "local"] - ] = "response" + preferred_prompt_tokens_source: Literal["request", "response"] = "response" + preferred_output_tokens_source: Literal["request", "response"] = "response" preferred_backend: Literal["openai"] = "openai" preferred_route: Literal["text_completions", "chat_completions"] = ( "text_completions" diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index 50ab3cca..6760dcb0 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -11,10 +11,10 @@ from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict from transformers import PreTrainedTokenizerBase # type: ignore[import] +from guidellm.backend import GenerationRequest from guidellm.config import settings from guidellm.dataset import ColumnInputTypes, load_dataset from guidellm.objects import StandardBaseModel -from guidellm.request.request import GenerationRequest __all__ = [ "GenerativeRequestLoader", diff --git a/src/guidellm/request/request.py b/src/guidellm/request/request.py deleted file mode 100644 index 81c8cabd..00000000 --- a/src/guidellm/request/request.py +++ /dev/null @@ -1,79 +0,0 @@ -import uuid -from typing import Any, Literal, Optional - -from pydantic import Field - -from guidellm.objects.pydantic import StandardBaseModel - -__all__ = ["GenerationRequest"] - - -class GenerationRequest(StandardBaseModel): - """ - A class representing a request for generation. - This class is used to encapsulate the details of a generation request, - including the request ID, type, content, parameters, statistics, and constraints. - It is designed to be used with the BackendRequestsWorker class to handle - the generation process. - - :param request_id: The unique identifier for the request. - :param request_type: The type of request (e.g., text, chat). - :param content: The content for the request to send to the backend. - If request_type is 'text', this should be a string or list of strings - which will be resolved by backend.text_completions. - If request_type is 'chat', this should be a string, - a list of (str, Dict[str, Union[str, Dict[str, str]], Path, Image]), - or Any raw content which will be resolved by backend.chat_completions. - If raw content, raw_content=True must be passed in the params. - :param params: Additional parameters for the request passed in as kwargs. - For an http backend, these are passed into the body of the request. - :param stats: Statistics for the request, such as the number of prompt tokens. - Used for tracking and reporting purposes. - :param constraints: Constraints for the request, such as the maximum number - of output tokens. Used for controlling the behavior of the backend. - """ - - request_id: Optional[str] = Field( - default_factory=lambda: str(uuid.uuid4()), - description="The unique identifier for the request.", - ) - request_type: Literal["text_completions", "chat_completions"] = Field( - default="text_completions", - description=( - "The type of request (e.g., text, chat). " - "If request_type='text_completions', resolved by backend.text_completions. " - "If request_typ='chat_completions', resolved by backend.chat_completions." - ), - ) - content: Any = Field( - description=( - "The content for the request to send to the backend. " - "If request_type is 'text', this should be a string or list of strings " - "which will be resolved by backend.text_completions. " - "If request_type is 'chat', this should be a string, " - "a list of (str, Dict[str, Union[str, Dict[str, str]], Path, Image]), " - "or Any raw content which will be resolved by backend.chat_completions. " - "If raw content, raw_content=True must be passed in the params." - ) - ) - params: dict[str, Any] = Field( - default_factory=dict, - description=( - "Additional parameters for the request that will be passed in as kwargs. " - "For an http backend, these are passed into the body of the request. " - ), - ) - stats: dict[Literal["prompt_tokens"], int] = Field( - default_factory=dict, - description=( - "Statistics for the request, such as the number of prompt tokens. " - "Used for tracking and reporting purposes." - ), - ) - constraints: dict[Literal["output_tokens"], int] = Field( - default_factory=dict, - description=( - "Constraints for the request, such as the maximum number of output tokens. " - "Used for controlling the behavior of the backend." - ), - ) diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index 37bf1fd5..b957c622 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -1,52 +1,94 @@ -from .result import ( - SchedulerRequestInfo, - SchedulerRequestResult, - SchedulerResult, - SchedulerRunInfo, +from .constraints import ( + Constraint, + ConstraintInitializer, + ConstraintsInitializerFactory, + MaxDurationConstraint, + MaxDurationConstraintInitializer, + MaxErrorRateConstraint, + MaxErrorRateConstraintInitializer, + MaxErrorsConstraint, + MaxErrorsConstraintInitializer, + MaxGlobalErrorRateConstraint, + MaxGlobalErrorRateConstraintInitializer, + MaxNumberConstraint, + MaxNumberConstraintInitializer, +) +from .environment import Environment, NonDistributedEnvironment +from .objects import ( + BackendInterface, + BackendT, + MeasuredRequestTimings, + MeasuredRequestTimingsT, + MultiTurnRequestT, + RequestSchedulerTimings, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerState, + SchedulerUpdateAction, + SchedulerUpdateActionProgress, ) from .scheduler import Scheduler from .strategy import ( AsyncConstantStrategy, AsyncPoissonStrategy, ConcurrentStrategy, + ConstantRateRequestTimings, + LastCompletionRequestTimings, + NoDelayRequestTimings, + PoissonRateRequestTimings, + ScheduledRequestTimings, SchedulingStrategy, + StrategyT, StrategyType, SynchronousStrategy, ThroughputStrategy, - strategy_display_str, -) -from .types import RequestT, ResponseT -from .worker import ( - GenerativeRequestsWorker, - GenerativeRequestsWorkerDescription, - RequestsWorker, - ResolveStatus, - WorkerDescription, - WorkerProcessRequest, - WorkerProcessResult, ) +from .worker import WorkerProcess +from .worker_group import WorkerProcessGroup __all__ = [ "AsyncConstantStrategy", "AsyncPoissonStrategy", + "BackendInterface", + "BackendT", "ConcurrentStrategy", - "GenerativeRequestsWorker", - "GenerativeRequestsWorkerDescription", + "ConstantRateRequestTimings", + "Constraint", + "ConstraintInitializer", + "ConstraintsInitializerFactory", + "Environment", + "LastCompletionRequestTimings", + "MaxDurationConstraint", + "MaxDurationConstraintInitializer", + "MaxErrorRateConstraint", + "MaxErrorRateConstraintInitializer", + "MaxErrorsConstraint", + "MaxErrorsConstraintInitializer", + "MaxGlobalErrorRateConstraint", + "MaxGlobalErrorRateConstraintInitializer", + "MaxNumberConstraint", + "MaxNumberConstraintInitializer", + "MeasuredRequestTimings", + "MeasuredRequestTimingsT", + "MultiTurnRequestT", + "NoDelayRequestTimings", + "NonDistributedEnvironment", + "PoissonRateRequestTimings", + "RequestSchedulerTimings", "RequestT", - "RequestsWorker", - "ResolveStatus", "ResponseT", + "ScheduledRequestInfo", + "ScheduledRequestTimings", "Scheduler", - "SchedulerRequestInfo", - "SchedulerRequestResult", - "SchedulerResult", - "SchedulerRunInfo", + "SchedulerState", + "SchedulerUpdateAction", + "SchedulerUpdateActionProgress", "SchedulingStrategy", + "StrategyT", "StrategyType", "SynchronousStrategy", "ThroughputStrategy", - "WorkerDescription", - "WorkerProcessRequest", - "WorkerProcessResult", - "strategy_display_str", + "WorkerProcess", + "WorkerProcessGroup", ] diff --git a/src/guidellm/scheduler/constraints.py b/src/guidellm/scheduler/constraints.py new file mode 100644 index 00000000..5e11b71d --- /dev/null +++ b/src/guidellm/scheduler/constraints.py @@ -0,0 +1,536 @@ +""" +Constraint system for scheduler behavior control and request processing limits. + +Provides flexible constraints for managing scheduler behavior with configurable +thresholds based on time, error rates, and request counts. + +Classes: + ConstraintsInitializerFactory: Registry for constraint initializer functions. + MaxNumberConstraint: Limits execution by maximum request count. + MaxNumberConstraintInitializer: Factory for MaxNumberConstraint instances. + MaxDurationConstraint: Limits execution by maximum time duration. + MaxDurationConstraintInitializer: Factory for MaxDurationConstraint instances. + MaxErrorsConstraint: Limits execution by maximum absolute error count. + MaxErrorsConstraintInitializer: Factory for MaxErrorsConstraint instances. + MaxErrorRateConstraint: Limits execution by sliding window error rate. + MaxErrorRateConstraintInitializer: Factory for MaxErrorRateConstraint instances. + MaxGlobalErrorRateConstraint: Limits execution by global error rate. + MaxGlobalErrorRateConstraintInitializer: Factory for MaxGlobalErrorRateConstraint. + +Type Aliases: + Constraint: Function signature for constraint evaluation. + ConstraintInitializer: Function signature for constraint factory. +""" + +from __future__ import annotations + +import time +from typing import Any, Protocol, runtime_checkable + +from pydantic import Field + +from guidellm.objects import StandardBaseModel +from guidellm.scheduler.objects import ( + ScheduledRequestInfo, + SchedulerState, + SchedulerUpdateAction, + SchedulerUpdateActionProgress, +) +from guidellm.utils import RegistryMixin + +__all__ = [ + "Constraint", + "ConstraintInitializer", + "ConstraintsInitializerFactory", + "MaxDurationConstraint", + "MaxDurationConstraintInitializer", + "MaxErrorRateConstraint", + "MaxErrorRateConstraintInitializer", + "MaxErrorsConstraint", + "MaxErrorsConstraintInitializer", + "MaxGlobalErrorRateConstraint", + "MaxGlobalErrorRateConstraintInitializer", + "MaxNumberConstraint", + "MaxNumberConstraintInitializer", +] + + +@runtime_checkable +class Constraint(Protocol): + """Protocol for constraint evaluation functions.""" + + def __call__( + self, state: SchedulerState, request: ScheduledRequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against scheduler state and request information. + + :param state: Current scheduler state with metrics and timing. + :param request: Individual request information and metadata. + :return: Action indicating whether to continue or stop operations. + """ + + +@runtime_checkable +class ConstraintInitializer(Protocol): + """Protocol for constraint initializer factory functions.""" + + def create_constraint(self, **kwargs) -> Constraint: + """ + Create a constraint instance from configuration parameters. + + :param kwargs: Configuration parameters for constraint creation. + :return: Configured constraint evaluation function. + """ + + +class ConstraintsInitializerFactory(RegistryMixin[ConstraintInitializer]): + """Registry factory for creating and managing constraint initializers.""" + + @classmethod + def create(cls, key: str, *args, **kwargs) -> ConstraintInitializer: + """ + Create a constraint initializer for the specified key. + + :param key: Registered constraint initializer key. + :param args: Positional arguments for initializer creation. + :param kwargs: Keyword arguments for initializer creation. + :return: Configured constraint initializer function. + :raises ValueError: If the key is not registered in the factory. + """ + if not cls.is_registered(key): + raise ValueError(f"Unknown constraint initializer key: {key}") + + initializer_class = cls.get_registered_object(key) + + # Handle simple scalar values by delegating to the initializer class + if ( + len(args) == 1 + and not kwargs + and hasattr(initializer_class, "from_simple_value") + ): + return initializer_class.from_simple_value(args[0]) + + return initializer_class(*args, **kwargs) + + @classmethod + def create_constraint(cls, key: str, *args, **kwargs) -> Constraint: + """ + Create a constraint instance for the specified key. + + :param key: Registered constraint initializer key. + :param kwargs: Keyword arguments for constraint creation. + :return: Configured constraint function ready for evaluation. + :raises ValueError: If the key is not registered in the factory. + """ + return cls.create(key, *args, **kwargs).create_constraint() + + @classmethod + def resolve( + cls, + initializers: dict[ + str, + Any | dict[str, Any] | Constraint | ConstraintInitializer, + ], + ) -> dict[str, Constraint]: + """ + Resolve mixed constraint specifications to callable constraints. + + :param initializers: Dictionary mapping constraint keys to specifications. + :return: Dictionary mapping constraint keys to callable functions. + :raises ValueError: If any key is not registered in the factory. + """ + constraints = {} + + for key, val in initializers.items(): + if isinstance(val, Constraint): + constraints[key] = val + elif isinstance(val, ConstraintInitializer): + constraints[key] = val.create_constraint() + elif isinstance(val, dict): + constraints[key] = cls.create_constraint(key, **val) + else: + constraints[key] = cls.create_constraint(key, val) + + return constraints + + @classmethod + def resolve_constraints( + cls, + constraints: dict[str, Any | dict[str, Any] | Constraint], + ) -> dict[str, Constraint]: + """ + Resolve constraints from mixed constraint specifications. + + :param constraints: Dictionary mapping constraint keys to specifications. + :return: Dictionary mapping constraint keys to callable functions. + :raises ValueError: If any constraint key is not registered. + """ + resolved_constraints = {} + + for key, val in constraints.items(): + if isinstance(val, Constraint): + resolved_constraints[key] = val + elif isinstance(val, dict): + resolved_constraints[key] = cls.create_constraint(key, **val) + else: + resolved_constraints[key] = cls.create_constraint(key, val) + + return resolved_constraints + + +class _MaxNumberBase(StandardBaseModel): + max_num: int | float = Field(gt=0, description="Maximum number of requests allowed") + + +class MaxNumberConstraint(_MaxNumberBase): + """Constraint that limits execution based on maximum request counts.""" + + def __call__( + self, state: SchedulerState, _request_info: ScheduledRequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current scheduler state. + + :param state: Current scheduler state with request counts. + :param _request_info: Individual request information (unused). + :return: Action indicating whether to continue or stop operations. + """ + create_exceeded = state.created_requests >= self.max_num + processed_exceeded = state.processed_requests >= self.max_num + + return SchedulerUpdateAction( + request_queuing="stop" if create_exceeded else "continue", + request_processing="stop_local" if processed_exceeded else "continue", + metadata={ + "max_number": self.max_num, + "create_exceeded": create_exceeded, + "processed_exceeded": processed_exceeded, + "created_requests": state.created_requests, + "processed_requests": state.processed_requests, + }, + progress=SchedulerUpdateActionProgress( + remaining_fraction=max( + 0.0, 1.0 - state.processed_requests / float(self.max_num) + ), + remaining_requests=max(0, self.max_num - state.processed_requests), + ), + ) + + +@ConstraintsInitializerFactory.register("max_number") +class MaxNumberConstraintInitializer(_MaxNumberBase): + """Factory for creating MaxNumberConstraint instances.""" + + @classmethod + def from_simple_value(cls, value: int | float) -> MaxNumberConstraintInitializer: + """ + Create a MaxNumberConstraintInitializer from a simple scalar value. + + :param value: Maximum number of requests allowed. + :return: Configured MaxNumberConstraintInitializer instance. + """ + return cls(max_num=value) + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Create a MaxNumberConstraint instance. + + :param _kwargs: Additional keyword arguments (unused). + :return: Configured MaxNumberConstraint instance. + """ + return MaxNumberConstraint( + max_num=self.max_num, + ) + + +class _MaxDurationBase(StandardBaseModel): + max_duration: int | float = Field(gt=0, description="Maximum duration in seconds") + + +class MaxDurationConstraint(_MaxDurationBase): + """Constraint that limits execution based on maximum time duration.""" + + def __call__( + self, state: SchedulerState, _request_info: ScheduledRequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current scheduler state and elapsed time. + + :param state: Current scheduler state with start time. + :param _request_info: Individual request information (unused). + :return: Action indicating whether to continue or stop operations. + """ + current_time = time.time() + elapsed = current_time - state.start_time + duration_exceeded = elapsed >= self.max_duration + + return SchedulerUpdateAction( + request_queuing="stop" if duration_exceeded else "continue", + request_processing="stop_local" if duration_exceeded else "continue", + metadata={ + "max_duration": self.max_duration, + "elapsed_time": elapsed, + "duration_exceeded": duration_exceeded, + "start_time": state.start_time, + "current_time": current_time, + }, + progress=SchedulerUpdateActionProgress( + remaining_fraction=max(0.0, 1.0 - elapsed / float(self.max_duration)), + remaining_duration=max(0.0, self.max_duration - elapsed), + ), + ) + + +@ConstraintsInitializerFactory.register("max_duration") +class MaxDurationConstraintInitializer(_MaxDurationBase): + """Factory for creating MaxDurationConstraint instances.""" + + @classmethod + def from_simple_value(cls, value: int | float) -> MaxDurationConstraintInitializer: + """ + Create a MaxDurationConstraintInitializer from a simple scalar value. + + :param value: Maximum duration in seconds. + :return: Configured MaxDurationConstraintInitializer instance. + """ + return cls(max_duration=value) + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Create a MaxDurationConstraint instance. + + :param _kwargs: Additional keyword arguments (unused). + :return: Configured MaxDurationConstraint instance. + """ + return MaxDurationConstraint( + max_duration=self.max_duration, + ) + + +class _MaxErrorsBase(StandardBaseModel): + max_errors: int | float = Field( + gt=0, description="Maximum number of errors allowed" + ) + + +class MaxErrorsConstraint(_MaxErrorsBase): + """Constraint that limits execution based on absolute error count.""" + + def __call__( + self, state: SchedulerState, _request_info: ScheduledRequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current error count. + + :param state: Current scheduler state with error counts. + :param _request_info: Individual request information (unused). + :return: Action indicating whether to continue or stop operations. + """ + errors_exceeded = state.errored_requests >= self.max_errors + + return SchedulerUpdateAction( + request_queuing="stop" if errors_exceeded else "continue", + request_processing="stop_all" if errors_exceeded else "continue", + metadata={ + "max_errors": self.max_errors, + "errors_exceeded": errors_exceeded, + "current_errors": state.errored_requests, + }, + ) + + +@ConstraintsInitializerFactory.register("max_errors") +class MaxErrorsConstraintInitializer(_MaxErrorsBase): + """Factory for creating MaxErrorsConstraint instances.""" + + @classmethod + def from_simple_value(cls, value: int | float) -> MaxErrorsConstraintInitializer: + """ + Create a MaxErrorsConstraintInitializer from a simple scalar value. + + :param value: Maximum number of errors allowed. + :return: Configured MaxErrorsConstraintInitializer instance. + """ + return cls(max_errors=value) + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Create a MaxErrorsConstraint instance. + + :param _kwargs: Additional keyword arguments (unused). + :return: Configured MaxErrorsConstraint instance. + """ + return MaxErrorsConstraint( + max_errors=self.max_errors, + ) + + +class _MaxErrorRateBase(StandardBaseModel): + max_error_rate: int | float = Field( + gt=0, le=1, description="Maximum error rate allowed (0.0 to 1.0)" + ) + window_size: int | float = Field( + default=50, + gt=0, + description="Size of sliding window for calculating error rate", + ) + + +class MaxErrorRateConstraint(_MaxErrorRateBase): + """Constraint that limits execution based on sliding window error rate.""" + + error_window: list[bool] = Field( + default_factory=list, + description="Sliding window tracking error status of recent requests", + ) + + def __call__( + self, state: SchedulerState, request_info: ScheduledRequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against sliding window error rate. + + :param state: Current scheduler state with request counts. + :param request_info: Individual request with completion status. + :return: Action indicating whether to continue or stop operations. + """ + if request_info.status in ["completed", "errored", "cancelled"]: + self.error_window.append(request_info.status == "errored") + if len(self.error_window) > self.window_size: + self.error_window.pop(0) + + error_count = sum(self.error_window) + window_requests = len(self.error_window) + error_rate = ( + error_count / float(window_requests) if window_requests > 0 else 0.0 + ) + exceeded_min_processed = state.processed_requests >= self.window_size + exceeded_error_rate = error_rate >= self.max_error_rate + + return SchedulerUpdateAction( + request_queuing=( + "stop" if exceeded_min_processed and exceeded_error_rate else "continue" + ), + request_processing=( + "stop_all" + if exceeded_min_processed and exceeded_error_rate + else "continue" + ), + metadata={ + "max_error_rate": self.max_error_rate, + "window_size": self.window_size, + "error_count": error_count, + "processed_count": state.processed_requests, + "current_window_size": len(self.error_window), + "current_error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + }, + ) + + +@ConstraintsInitializerFactory.register("max_error_rate") +class MaxErrorRateConstraintInitializer(_MaxErrorRateBase): + """Factory for creating MaxErrorRateConstraint instances.""" + + @classmethod + def from_simple_value(cls, value: int | float) -> MaxErrorRateConstraintInitializer: + """ + Create a MaxErrorRateConstraintInitializer from a simple scalar value. + + :param value: Maximum error rate allowed (0.0 to 1.0). + :return: Configured MaxErrorRateConstraintInitializer instance. + """ + return cls(max_error_rate=value) + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Create a MaxErrorRateConstraint instance. + + :param _kwargs: Additional keyword arguments (unused). + :return: Configured MaxErrorRateConstraint instance. + """ + return MaxErrorRateConstraint( + max_error_rate=self.max_error_rate, + window_size=self.window_size, + ) + + +class _MaxGlobalErrorRateBase(StandardBaseModel): + max_error_rate: int | float = Field( + gt=0, le=1, description="Maximum error rate allowed (0.0 to 1.0)" + ) + min_processed: int | float | None = Field( + default=50, + gt=30, + description=( + "Minimum number of processed requests before applying error rate constraint" + ), + ) + + +class MaxGlobalErrorRateConstraint(_MaxGlobalErrorRateBase): + """Constraint that limits execution based on global error rate.""" + + def __call__( + self, state: SchedulerState, _request_info: ScheduledRequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against global error rate. + + :param state: Current scheduler state with global request and error counts. + :param _request_info: Individual request information (unused). + :return: Action indicating whether to continue or stop operations. + """ + exceeded_min_processed = state.processed_requests >= self.min_processed + error_rate = ( + state.errored_requests / float(state.processed_requests) + if state.processed_requests > 0 + else 0.0 + ) + exceeded_error_rate = error_rate >= self.max_error_rate + should_stop = exceeded_min_processed and exceeded_error_rate + + return SchedulerUpdateAction( + request_queuing="stop" if should_stop else "continue", + request_processing="stop_all" if should_stop else "continue", + metadata={ + "max_error_rate": self.max_error_rate, + "min_processed": self.min_processed, + "processed_requests": state.processed_requests, + "errored_requests": state.errored_requests, + "error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + }, + ) + + +@ConstraintsInitializerFactory.register("max_global_error_rate") +class MaxGlobalErrorRateConstraintInitializer(_MaxGlobalErrorRateBase): + """Factory for creating MaxGlobalErrorRateConstraint instances.""" + + @classmethod + def from_simple_value( + cls, value: int | float + ) -> MaxGlobalErrorRateConstraintInitializer: + """ + Create a MaxGlobalErrorRateConstraintInitializer from a simple scalar value. + + :param value: Maximum error rate allowed (0.0 to 1.0). + :return: Configured MaxGlobalErrorRateConstraintInitializer instance. + """ + return cls(max_error_rate=value) + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Create a MaxGlobalErrorRateConstraint instance. + + :param _kwargs: Additional keyword arguments (unused). + :return: Configured MaxGlobalErrorRateConstraint instance. + """ + return MaxGlobalErrorRateConstraint( + max_error_rate=self.max_error_rate, + min_processed=self.min_processed, + ) diff --git a/src/guidellm/scheduler/environment.py b/src/guidellm/scheduler/environment.py new file mode 100644 index 00000000..46be70aa --- /dev/null +++ b/src/guidellm/scheduler/environment.py @@ -0,0 +1,228 @@ +""" +Scheduler environment abstractions for distributed and non-distributed execution. + +Provides environment abstractions for coordinating scheduler execution across +single or multiple nodes, handling synchronization, error propagation, and lifecycle +management. + +Classes: + Environment: Abstract base for scheduler coordination across nodes. + NonDistributedEnvironment: Single-node implementation with minimal overhead. +""" + +from __future__ import annotations + +import time +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Iterable +from typing import ( + Generic, +) + +from guidellm.config import settings +from guidellm.scheduler.constraints import Constraint +from guidellm.scheduler.objects import ( + MeasuredRequestTimingsT, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerState, +) +from guidellm.scheduler.strategy import SchedulingStrategy +from guidellm.utils import InfoMixin + +__all__ = ["Environment", "NonDistributedEnvironment"] + + +class Environment(ABC, Generic[RequestT, ResponseT], InfoMixin): + """ + Abstract base for scheduler execution environments. + + Defines the interface for coordinating scheduler execution across single or + multiple nodes, handling parameter synchronization, timing, state updates, + error propagation, and cleanup. + """ + + @abstractmethod + async def sync_run_params( + self, + requests: Iterable[RequestT], + strategy: SchedulingStrategy, + constraints: dict[str, Constraint], + ) -> tuple[ + Iterable[RequestT], + SchedulingStrategy, + dict[str, Constraint], + ]: + """ + Synchronize run parameters across nodes and resolve local execution scope. + + Coordinates parameter distribution across active nodes. For distributed + environments, handles validation, node assignment, and workload partitioning. + For non-distributed environments, typically returns parameters unchanged. + + :param requests: Complete set of requests to process across all nodes. + :param strategy: Scheduling strategy to apply during execution. + :param constraints: Runtime constraints to enforce during execution. + :return: Tuple of (local_requests, strategy, constraints) for this node. + :raises Exception: If parameter synchronization fails or nodes inconsistent. + """ + ... + + @abstractmethod + async def sync_run_start(self) -> float: + """ + Coordinate global start time across nodes for synchronized execution. + + Ensures all nodes begin processing simultaneously for accurate benchmarking. + + :return: Unix timestamp when all nodes should begin processing. + :raises Exception: If startup synchronization fails across nodes. + """ + ... + + @abstractmethod + async def update_run_iteration( + self, + response: ResponseT | None, + request: RequestT, + request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + state: SchedulerState, + ): + """ + Update environment state with completed request iteration. + + Called after each request is processed to update execution progress. + Enables state synchronization across nodes in distributed environments. + + :param response: Response generated for the request, if successful. + :param request: The processed request. + :param request_info: Metadata about request processing including timings. + :raises Exception: If state update fails or indicates critical errors. + """ + ... + + @abstractmethod + async def sync_run_error(self, err: list[Exception] | Exception): + """ + Handle and propagate errors across all nodes. + + Coordinates error handling when failures occur, ensuring all nodes are + notified and can perform appropriate cleanup or shutdown. + + :param err: The exception that occurred during execution. + """ + ... + + @abstractmethod + async def sync_run_end( + self, + ) -> AsyncIterator[ + tuple[ + ResponseT, + RequestT, + ScheduledRequestInfo[MeasuredRequestTimingsT], + SchedulerState, + ] + ]: + """ + Finalize execution and aggregate results from all nodes. + + Handles cleanup, result synchronization, and error propagation at run end. + Collects results from worker nodes in distributed environments. + + :return: Iterator of (response, request, request_info, state) tuples from + remote nodes in distributed environments, empty for non-distributed. + :raises Exception: Any errors that occurred during the run. + """ + ... + + +class NonDistributedEnvironment(Environment): + """ + Single-node scheduler execution environment. + + Simplified environment for running schedulers on a single node without + distributed coordination. Implements the Environment interface with minimal + synchronization overhead for local testing, development, and single-machine + benchmarking. + """ + + def __init__(self): + """Initialize with no stored errors.""" + self.run_errors: list[Exception] = [] + + async def sync_run_params( + self, + requests: Iterable[RequestT], + strategy: SchedulingStrategy, + constraints: dict[str, Constraint], + ) -> tuple[Iterable[RequestT], SchedulingStrategy, dict[str, Constraint]]: + """ + Return parameters unchanged for single-node execution. + + :param requests: Iterable of requests to process. + :param strategy: Scheduling strategy to apply during execution. + :param constraints: Runtime constraints to enforce during execution. + :return: Tuple containing the original (requests, strategy, constraints). + """ + return requests, strategy, constraints + + async def sync_run_start(self) -> float: + """ + Return current time plus configuration delay. + + :return: Unix timestamp for when the run should start. + """ + return time.time() + settings.scheduler_start_delay_non_distributed + + async def update_run_iteration( + self, + response: ResponseT | None, + request: RequestT, + request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + state: SchedulerState, + ): + """ + No-op for single-node execution. + + :param response: Response generated for the request, if successful. + :param request: The request that was processed. + :param request_info: Metadata about request processing including timings. + """ + + async def sync_run_error(self, err: Exception): + """ + Store error for later propagation during run finalization. + + :param err: The exception that occurred during execution. + """ + err = [err] if not isinstance(err, list) else err + self.run_errors.extend(err) + + async def sync_run_end( + self, + ) -> AsyncIterator[ + tuple[ + ResponseT, + RequestT, + ScheduledRequestInfo[MeasuredRequestTimingsT], + SchedulerState, + ] + ]: + """ + Finalize single-node execution and propagate any stored errors. + + :return: Empty iterator since there are no remote nodes. + :raises Exception: Any error stored during execution via sync_run_error. + """ + if self.run_errors: + if len(self.run_errors) == 1: + raise self.run_errors[0] + else: + raise RuntimeError( + f"Errors occurred during execution: {self.run_errors}" + ) + + return + yield # needed to force generator compilation diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/objects.py new file mode 100644 index 00000000..2d8ce4e3 --- /dev/null +++ b/src/guidellm/scheduler/objects.py @@ -0,0 +1,331 @@ +""" +Core data structures and interfaces for the GuideLLM scheduler system. + +Provides type-safe abstractions for distributed request processing, timing +measurements, and backend interfaces for benchmarking operations. + +Classes: + RequestSchedulerTimings: Scheduler-level request timing measurements. + RequestTimings: Base backend request timing measurements. + ScheduledRequestInfo: Complete request lifecycle information. + BackendInterface: Abstract backend processing interface. + SchedulerState: Scheduler operation state tracking. + SchedulerUpdateAction: Scheduler behavior control directives. + +Type Variables: + RequestT: Generic request object type. + ResponseT: Generic response object type. + RequestTimingsT: Generic request timing object type. + BackendT: Generic backend interface type. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from typing import ( + Any, + Generic, + Literal, + TypeVar, + Union, +) + +from pydantic import Field, computed_field +from typing_extensions import TypeAliasType, TypedDict + +from guidellm.objects import StandardBaseModel + +__all__ = [ + "BackendInterface", + "BackendT", + "MeasuredRequestTimings", + "MeasuredRequestTimingsT", + "MultiTurnRequestT", + "RequestSchedulerTimings", + "RequestT", + "ResponseT", + "ScheduledRequestInfo", + "SchedulerState", + "SchedulerUpdateAction", + "SchedulerUpdateActionProgress", +] + +RequestT = TypeVar("RequestT") +MultiTurnRequestT = TypeAliasType( + "MultiTurnRequestT", + Union[ + list[Union[RequestT, tuple[RequestT, float]]], + tuple[Union[RequestT, tuple[RequestT, float]]], + ], + type_params=(RequestT,), +) +ResponseT = TypeVar("ResponseT") + + +class RequestSchedulerTimings(StandardBaseModel): + """Scheduler-level timing measurements for request lifecycle tracking.""" + + targeted_start: float | None = Field( + default=None, + description="When the request was initially targeted for execution", + ) + queued: float | None = Field( + default=None, + description="When the request was placed into the processing queue", + ) + dequeued: float | None = Field( + default=None, + description="When the request was removed from the queue for processing", + ) + scheduled_at: float | None = Field( + default=None, description="When the request was scheduled for processing" + ) + resolve_start: float | None = Field( + default=None, description="When backend resolution of the request began" + ) + resolve_end: float | None = Field( + default=None, description="When backend resolution of the request completed" + ) + finalized: float | None = Field( + default=None, + description="When the request was processed/acknowledged by the scheduler", + ) + + +class MeasuredRequestTimings(StandardBaseModel): + """Base timing measurements for backend request processing.""" + + request_start: float | None = Field( + default=None, description="When the backend began processing the request" + ) + request_end: float | None = Field( + default=None, description="When the backend completed processing the request" + ) + + +MeasuredRequestTimingsT = TypeVar( + "MeasuredRequestTimingsT", bound=MeasuredRequestTimings +) + + +class ScheduledRequestInfo(StandardBaseModel, Generic[MeasuredRequestTimingsT]): + """Complete request information including status, timings, and metadata.""" + + request_id: str = Field(description="Unique identifier for the request") + status: Literal[ + "queued", "pending", "in_progress", "completed", "errored", "cancelled" + ] = Field(description="Current processing status of the request") + scheduler_node_id: int = Field( + description="ID/rank of the scheduler node handling the request" + ) + scheduler_process_id: int = Field( + description="ID/rank of the node's scheduler process handling the request" + ) + scheduler_start_time: float = Field( + description="Unix timestamp for the local time when scheduler processing began" + ) + + error: str | None = Field( + default=None, description="Error message if the request.status is 'errored'" + ) + scheduler_timings: RequestSchedulerTimings = Field( + default_factory=RequestSchedulerTimings, + description="Scheduler-level timing measurements for request lifecycle", + ) + request_timings: MeasuredRequestTimingsT | None = Field( + default=None, + description="Backend-specific timing measurements for request processing", + ) + + @computed_field + @property + def started_at(self) -> float | None: + """ + Get the effective request processing start time. + + :return: Unix timestamp when processing began, or None if not started. + """ + request_start = ( + self.request_timings.request_start if self.request_timings else None + ) + + return request_start or self.scheduler_timings.resolve_start + + @computed_field + @property + def completed_at(self) -> float | None: + """ + Get the effective request processing completion time. + + :return: Unix timestamp when processing completed, or None if not completed. + """ + request_end = self.request_timings.request_end if self.request_timings else None + + return request_end or self.scheduler_timings.resolve_end + + +class BackendInterface(ABC, Generic[RequestT, MeasuredRequestTimingsT, ResponseT]): + """ + Abstract interface for request processing backends. Note: before process_startup + is invoked, the implementation must ensure all properties are pickleable. + """ + + @property + @abstractmethod + def processes_limit(self) -> int | None: + """Maximum worker processes supported, or None if unlimited.""" + + @property + @abstractmethod + def requests_limit(self) -> int | None: + """Maximum concurrent requests supported, or None if unlimited.""" + + @abstractmethod + def info(self) -> dict[str, Any]: + """ + :return: Backend metadata including model any initializaiton and + configuration information. + """ + ... + + @abstractmethod + async def process_startup(self) -> None: + """ + Perform backend initialization and startup procedures. + + :raises: Implementation-specific exceptions for startup failures. + """ + + @abstractmethod + async def validate(self) -> None: + """ + Validate backend configuration and operational status. + + :raises: Implementation-specific exceptions for validation failures. + """ + + @abstractmethod + async def process_shutdown(self) -> None: + """ + Perform backend cleanup and shutdown procedures. + + :raises: Implementation-specific exceptions for shutdown failures. + """ + + @abstractmethod + async def resolve( + self, + request: RequestT, + request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + history: list[tuple[RequestT, ResponseT]] | None = None, + ) -> AsyncIterator[tuple[ResponseT, ScheduledRequestInfo[MeasuredRequestTimingsT]]]: + """ + Process a request and yield incremental response updates. + + :param request: The request object to process. + :param request_info: Scheduling metadata and timing information. + :param history: Optional conversation history for multi-turn requests. + :yield: Tuples of (response, updated_request_info) for each response chunk. + :raises: Implementation-specific exceptions for processing failures. + """ + + +BackendT = TypeVar("BackendT", bound=BackendInterface) + + +class SchedulerUpdateActionProgress(TypedDict, total=False): + """Progress information for a scheduler update action.""" + + remaining_fraction: float | None = None + remaining_requests: float | None = None + remaining_duration: float | None = None + + +class SchedulerUpdateAction(StandardBaseModel): + """Scheduler behavior control directives and actions.""" + + request_queuing: Literal["continue", "stop"] = Field( + default="continue", description="Action to take for request queuing operations" + ) + request_processing: Literal["continue", "stop_local", "stop_all"] = Field( + default="continue", + description="Action to take for request processing operations", + ) + metadata: dict[str, Any] = Field( + default_factory=dict, + description="Additional context and data for the scheduler action", + ) + progress: SchedulerUpdateActionProgress = Field( + default_factory=SchedulerUpdateActionProgress, + description="Progress information for the scheduler action", + ) + + +class SchedulerState(StandardBaseModel): + """Scheduler operation state tracking and statistics.""" + + node_id: int = Field(description="Unique identifier for this scheduler node") + num_processes: int = Field( + description="Number of worker processes in this scheduler" + ) + start_time: float = Field(description="Unix timestamp when the scheduler started") + end_time: float | None = Field( + default=None, description="Unix timestamp when the scheduler stopped" + ) + end_queuing_time: float | None = Field( + default=None, description="When request queuing stopped, if applicable" + ) + end_queuing_constraints: dict[str, SchedulerUpdateAction] = Field( + default_factory=dict, + description="Constraints that triggered queuing termination", + ) + end_processing_time: float | None = Field( + default=None, description="When request processing stopped, if applicable" + ) + end_processing_constraints: dict[str, SchedulerUpdateAction] = Field( + default_factory=dict, + description="Constraints that triggered processing termination", + ) + scheduler_constraints: dict[str, SchedulerUpdateAction] = Field( + default_factory=dict, + description="The latest state from all constraints applied during the scheduler run", + ) + + remaining_fraction: float | None = Field( + default=None, + description="Estimated fraction for the remaining progress of the scheduler run, if known", + ) + remaining_requests: int | None = Field( + default=None, + description="Estimated number of requests remaining to be processed, if known", + ) + remaining_duration: float | None = Field( + default=None, + description="Estimated time remaining in seconds for the scheduler run, if known", + ) + + created_requests: int = Field( + default=0, description="Total number of requests created" + ) + queued_requests: int = Field( + default=0, description="Total number of requests queued for processing" + ) + pending_requests: int = Field( + default=0, description="Number of requests currently pending processing" + ) + processing_requests: int = Field( + default=0, description="Number of requests currently being processed" + ) + processed_requests: int = Field( + default=0, description="Total number of requests that completed processing" + ) + successful_requests: int = Field( + default=0, description="Number of requests that completed successfully" + ) + errored_requests: int = Field( + default=0, description="Number of requests that failed with errors" + ) + cancelled_requests: int = Field( + default=0, description="Number of requests that were cancelled" + ) diff --git a/src/guidellm/scheduler/result.py b/src/guidellm/scheduler/result.py deleted file mode 100644 index 0f12687f..00000000 --- a/src/guidellm/scheduler/result.py +++ /dev/null @@ -1,137 +0,0 @@ -from typing import ( - Generic, - Literal, - Optional, -) - -from guidellm.objects import StandardBaseModel -from guidellm.scheduler.strategy import SchedulingStrategy -from guidellm.scheduler.types import RequestT, ResponseT - -__all__ = [ - "SchedulerRequestInfo", - "SchedulerRequestResult", - "SchedulerResult", - "SchedulerRunInfo", -] - - -class SchedulerRunInfo(StandardBaseModel): - """ - Information about the current run of the scheduler. - This class holds metadata about the scheduling run, - including the start and end times, the number of processes, - and the scheduling strategy used. - It also tracks the number of requests created, queued, pending, - and completed during the run. - - :param start_time: The start time of the scheduling run. - :param end_time: The end time of the scheduling run; - if None, then this will be math.inf. - :param end_number: The maximum number of requests to be processed; - if None, then this will be math.inf. - :param processes: The number of processes used in the scheduling run. - :param strategy: The scheduling strategy used in the run. - This should be an instance of SchedulingStrategy. - :param created_requests: The number of requests created during the run. - :param queued_requests: The number of requests queued during the run. - :param scheduled_requests: The number of requests scheduled during the run. - (requests pending being sent to the worker but recieved by a process) - :param processing_requests: The number of requests actively being run. - :param completed_requests: The number of requests completed during the run. - """ - - start_time: float - end_time: float - end_number: float - processes: int - strategy: SchedulingStrategy - - created_requests: int = 0 - queued_requests: int = 0 - scheduled_requests: int = 0 - processing_requests: int = 0 - completed_requests: int = 0 - - -class SchedulerRequestInfo(StandardBaseModel): - """ - Information about a specific request run through the scheduler. - This class holds metadata about the request, including - the targeted start time, queued time, start time, end time, - and the process ID that handled the request. - - :param targeted_start_time: The targeted start time for the request (time.time()). - :param queued_time: The time the request was queued (time.time()). - :param scheduled_time: The time the request was scheduled (time.time()) - (any sleep time before the request was sent to the worker). - :param worker_start: The time the worker started processing request (time.time()). - :param worker_end: The time the worker finished processing request. (time.time()). - :param process_id: The ID of the underlying process that handled the request. - """ - - requested: bool = False - completed: bool = False - errored: bool = False - canceled: bool = False - - targeted_start_time: float = -1 - queued_time: float = -1 - dequeued_time: float = -1 - scheduled_time: float = -1 - worker_start: float = -1 - request_start: float = -1 - request_end: float = -1 - worker_end: float = -1 - process_id: int = -1 - - -class SchedulerResult(StandardBaseModel): - """ - The yielded, iterative result for a scheduler run. - These are triggered on the start and end of the run, - as well as on the start and end of each request. - Depending on the type, it will hold the request and response - along with information and statistics about the request and general run. - - :param type_: The type of the result, which can be one of: - - "run_start": Indicates the start of the run. - - "run_complete": Indicates the completion of the run (teardown happens after). - - "request_start": Indicates the start of a request. - - "request_complete": Indicates the completion of a request. - :param request: The request that was processed. - :param response: The response from the worker for the request. - :param request_info: Information about the request, including - the targeted start time, queued time, start time, end time, - and the process ID that handled the request. - :param run_info: Information about the current run of the scheduler, - including the start and end times, the number of processes, - and the scheduling strategy used. - It also tracks the number of requests created, queued, pending, - and completed during the run. - """ - - pydantic_type: Literal["scheduler_result"] = "scheduler_result" - type_: Literal[ - "run_start", - "run_complete", - "request_scheduled", - "request_start", - "request_complete", - ] - run_info: SchedulerRunInfo - - -class SchedulerRequestResult( - SchedulerResult, - Generic[RequestT, ResponseT], -): - pydantic_type: Literal["scheduler_request_result"] = "scheduler_request_result" # type: ignore[assignment] - type_: Literal[ - "request_scheduled", - "request_start", - "request_complete", - ] - request: RequestT - request_info: SchedulerRequestInfo - response: Optional[ResponseT] = None diff --git a/src/guidellm/scheduler/scheduler.py b/src/guidellm/scheduler/scheduler.py index 06203827..d61b0ecb 100644 --- a/src/guidellm/scheduler/scheduler.py +++ b/src/guidellm/scheduler/scheduler.py @@ -1,382 +1,170 @@ -import asyncio -import math -import multiprocessing -import multiprocessing.queues -import time -from collections.abc import AsyncGenerator, Iterable, Iterator -from concurrent.futures import ProcessPoolExecutor -from typing import ( - Any, - Generic, - Optional, - Union, -) +""" +Scheduler for coordinating distributed load testing and benchmarking workloads. + +This module provides a thread-safe singleton scheduler for orchestrating +benchmarking operations across worker processes and distributed environments. + +Classes: + Scheduler: Generic singleton scheduler for distributed request processing. +""" -from loguru import logger +from __future__ import annotations -from guidellm.config import settings -from guidellm.scheduler.result import ( - SchedulerRequestResult, - SchedulerResult, - SchedulerRunInfo, +from collections.abc import AsyncIterator, Iterable +from typing import Any, Generic + +from guidellm.scheduler.constraints import ( + Constraint, + ConstraintsInitializerFactory, ) -from guidellm.scheduler.strategy import SchedulingStrategy -from guidellm.scheduler.types import RequestT, ResponseT -from guidellm.scheduler.worker import ( - RequestsWorker, - WorkerProcessRequest, - WorkerProcessResult, +from guidellm.scheduler.environment import Environment +from guidellm.scheduler.objects import ( + BackendInterface, + MeasuredRequestTimingsT, + MultiTurnRequestT, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerState, ) +from guidellm.scheduler.strategy import SchedulingStrategy +from guidellm.scheduler.worker_group import WorkerProcessGroup +from guidellm.utils.singleton import ThreadSafeSingletonMixin __all__ = ["Scheduler"] -class Scheduler(Generic[RequestT, ResponseT]): +class Scheduler( + Generic[RequestT, MeasuredRequestTimingsT, ResponseT], + ThreadSafeSingletonMixin, +): """ - A class that handles the scheduling of requests to a worker. - This class is responsible for managing the lifecycle of the requests, - including their creation, queuing, and processing. - It uses a multiprocessing approach to handle requests concurrently - and efficiently, based on the specified scheduling strategy. - The Scheduler class is designed to work with a RequestsWorker, - which is an abstract base class that defines the interface for a worker - that can resolve requests asynchronously or synchronously. - The Scheduler class also supports different scheduling strategies, - including synchronous, throughput, and concurrent strategies. + Generic singleton scheduler for coordinating distributed load testing workloads. + + Orchestrates benchmarking operations by managing request distribution across + worker processes, coordinating timing with distributed environments, and + aggregating results. Supports generic backend types for adaptability to + various testing scenarios including LLM inference and API testing. + + Example: + :: + from guidellm.scheduler import Scheduler + from guidellm.backend import ( + OpenAIBackend, + GenerationRequest, + GenerationResponse, + GenerationRequestTimings + ) - :param worker: The worker that will process the requests. - This should be an instance of RequestsWorker. - :param request_loader: An iterable that generates requests. - This can be a list, generator, or any other iterable. - The requests will be processed by the worker. + scheduler = Scheduler[ + OpenAIBackend, + GenerationRequest, + GenerationRequestTimings, + GenerationResponse + ]() + async for response, request, info, state in scheduler.run( + requests=request_list, + backend=backend, + strategy=strategy, + env=environment, + max_requests=1000 + ): + print(f"Response: {response}") """ - def __init__( - self, - worker: RequestsWorker[RequestT, ResponseT], - request_loader: Iterable[RequestT], - ): - if not isinstance(worker, RequestsWorker): - raise ValueError(f"Invalid worker: {worker}") - - if not isinstance(request_loader, Iterable): - raise ValueError(f"Invalid request_loader: {request_loader}") - - self.worker = worker - self.request_loader = request_loader - async def run( self, - scheduling_strategy: SchedulingStrategy, - max_number: Optional[int] = None, - max_duration: Optional[float] = None, - ) -> AsyncGenerator[ - Union[SchedulerResult, SchedulerRequestResult[RequestT, ResponseT]], None + requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + backend: BackendInterface[RequestT, MeasuredRequestTimingsT, ResponseT], + strategy: SchedulingStrategy, + env: Environment, + **constraints: dict[str, Any | dict[str, Any] | Constraint], + ) -> AsyncIterator[ + tuple[ + ResponseT | None, + RequestT, + ScheduledRequestInfo[MeasuredRequestTimingsT], + SchedulerState, + ] ]: """ - The main method that runs the scheduler. - This method is a generator that yields SchedulerResult objects - at the start and end of the run, as well as at the start and end - of each request. - It uses multiprocessing to handle requests concurrently - and efficiently, based on the specified scheduling strategy. - The method also handles the lifecycle of the requests, - including their creation, queuing, and processing. - The method is designed to be used as an asynchronous generator, - allowing it to be used with asyncio and other asynchronous frameworks. - - :param scheduling_strategy: The scheduling strategy to use. - Specifies the times at which requests will be sent as well how many - worker processes are used and if requests are scheduled sync or async. - This can be one of the following: - - "synchronous": Requests are sent synchronously. - - "throughput": Requests are sent at the maximum rate possible. - - An instance of SchedulingStrategy. - :param max_number: The maximum number of requests to process. - If None, then no limit is set and either the iterator must be exhaustible - or the max_duration must be set. - :param max_duration: The maximum duration for the scheduling run. - If None, then no limit is set and either the iterator must be exhaustible - or the max_number must be set. - :return: An asynchronous generator that yields SchedulerResult objects. - Each SchedulerResult object contains information about the request, - the response, and the run information. + Execute request processing with the provided configuration. + + Coordinates execution across worker processes with the specified backend + and scheduling strategy. Manages timing, synchronization, and resource + cleanup while yielding real-time updates. + + :param requests: Requests to process. Supports single requests + (Iterable[RequestT]) or multi-turn sequences (Iterable[Iterable]) where + each item is either a RequestT or tuple of (RequestT, delay_seconds). + :param backend: Backend instance for processing requests. + :param strategy: Scheduling strategy for request timing. + :param env: Environment for distributed execution coordination. + :param constraints: Execution control constraints (max_requests, duration, + etc.). Values can be primitives or callable functions. + :yields: Tuples of (response, request, scheduling_info, scheduler_state). + Response may be None for failed requests. + :raises Exception: Worker process, environment, or constraint evaluation errors + are propagated after cleanup. """ - if scheduling_strategy is None or not isinstance( - scheduling_strategy, SchedulingStrategy - ): - raise ValueError(f"Invalid scheduling strategy: {scheduling_strategy}") - - if max_number is not None and max_number < 1: - raise ValueError(f"Invalid max_number: {max_number}") - - if max_duration is not None and max_duration < 0: - raise ValueError(f"Invalid max_duration: {max_duration}") - - with ( - multiprocessing.Manager() as manager, - ProcessPoolExecutor( - max_workers=scheduling_strategy.processes_limit - ) as executor, - ): - requests_iter: Optional[Iterator[Any]] = None - futures, requests_queue, responses_queue = await self._start_processes( - manager, executor, scheduling_strategy - ) - run_info, requests_iter, times_iter = self._run_setup( - futures, scheduling_strategy, max_number, max_duration - ) - yield SchedulerResult( - type_="run_start", - run_info=run_info, - ) - + with self.thread_lock: + worker_group: ( + WorkerProcessGroup[RequestT, MeasuredRequestTimingsT, ResponseT] | None + ) = None + + # Any issues during the run will raise an error (local or remote), + # be caught and passed to the environment, + # and will ensure clean up before raising the error. try: - while True: - # check errors and raise them - for future in futures: - if future.done() and (err := future.exception()) is not None: - raise err - - if ( - requests_iter is None - and run_info.completed_requests >= run_info.created_requests - ): - # we've exhausted all requests we've wanted to run - # and yielded all responses - break - - requests_iter = self._add_requests( - requests_iter, - times_iter, - requests_queue, - run_info, - ) - await asyncio.sleep(0) # enable requests to start - - iter_result = self._check_result_ready( - responses_queue, - run_info, - ) - if iter_result is not None: - yield iter_result - - # yield control to the event loop - await asyncio.sleep(settings.default_async_loop_sleep) - except Exception as err: - raise RuntimeError(f"Scheduler run failed: {err}") from err - - yield SchedulerResult( - type_="run_complete", - run_info=run_info, - ) - - await self._stop_processes(futures, requests_queue) - - async def _start_processes( - self, - manager, - executor: ProcessPoolExecutor, - scheduling_strategy: SchedulingStrategy, - ) -> tuple[ - list[asyncio.Future], - multiprocessing.Queue, - multiprocessing.Queue, - ]: - await self.worker.prepare_multiprocessing() - requests_queue = manager.Queue( - maxsize=scheduling_strategy.queued_requests_limit - ) - responses_queue = manager.Queue() - - num_processes = min( - scheduling_strategy.processes_limit, - scheduling_strategy.processing_requests_limit, - ) - requests_limit_split = ( - scheduling_strategy.processing_requests_limit - // scheduling_strategy.processes_limit - ) - requests_limit_remain = ( - scheduling_strategy.processing_requests_limit - % scheduling_strategy.processes_limit - ) - process_ids = (id_ for id_ in range(num_processes)) - process_requests_limits = ( - requests_limit_split + 1 - if i < requests_limit_remain - else requests_limit_split - for i in range(num_processes) - ) - - futures = [] - loop = asyncio.get_event_loop() - for id_, requests_limit in zip(process_ids, process_requests_limits): - if scheduling_strategy.processing_mode == "sync": - futures.append( - loop.run_in_executor( - executor, - self.worker.process_loop_synchronous, - requests_queue, - responses_queue, - id_, - ) + # Setup local run parameters, sync with the environment + constraints = ConstraintsInitializerFactory.resolve_constraints( + constraints ) - elif scheduling_strategy.processing_mode == "async": - futures.append( - loop.run_in_executor( - executor, - self.worker.process_loop_asynchronous, - requests_queue, - responses_queue, - requests_limit, - id_, - ) + ( + local_requests, + local_strategy, + local_constraints, + ) = await env.sync_run_params(requests, strategy, constraints) + + # Setup the worker group, sync start with the environment + worker_group = WorkerProcessGroup[ + RequestT, MeasuredRequestTimingsT, ResponseT + ]( + backend=backend, + requests=local_requests, + strategy=local_strategy, + constraints=local_constraints, ) - else: - raise ValueError( - f"Invalid processing mode: {scheduling_strategy.processing_mode} " - f"for strategy: {scheduling_strategy}" - ) - - await asyncio.sleep(0.1) # give time for processes to start - - return futures, requests_queue, responses_queue - - def _run_setup( - self, - processes: list[asyncio.Future], - scheduling_strategy: SchedulingStrategy, - max_number: Optional[int], - max_duration: Optional[float], - ) -> tuple[SchedulerRunInfo, Iterator[Any], Iterator[float]]: - requests_iter = iter(self.request_loader) - start_time = time.time() - times_iter = iter(scheduling_strategy.request_times()) - end_time = time.time() + (max_duration or math.inf) - end_number = max_number or math.inf - - try: - # update end number if the request loader is finite and less than max - iter_length = len(self.request_loader) # type: ignore[arg-type] - if 0 < iter_length < end_number: - end_number = iter_length - except Exception: # noqa: BLE001, S110 - pass - - if end_number == math.inf and end_time is None: - logger.warning( - "No end number or end time set, " - "scheduler will run indefinitely until the request loader is exhausted." - ) - - info = SchedulerRunInfo( - start_time=start_time, - end_time=end_time, - end_number=end_number, - processes=len(processes), - strategy=scheduling_strategy, - ) - - return info, requests_iter, times_iter - - def _add_requests( - self, - requests_iter: Optional[Iterator[Any]], - times_iter: Iterator[float], - requests_queue: multiprocessing.Queue, - run_info: SchedulerRunInfo, - ) -> Optional[Iterator[Any]]: - if requests_iter is not None: - try: - added_count = 0 - - while ( - not requests_queue.full() - and added_count < settings.max_add_requests_per_loop - ): - if run_info.created_requests >= run_info.end_number: - raise StopIteration - - if ( - request_time := next(times_iter) - ) >= run_info.end_time or time.time() >= run_info.end_time: - raise StopIteration - - request = next(requests_iter) - work_req: WorkerProcessRequest[RequestT] = WorkerProcessRequest( - request=request, - start_time=request_time, - timeout_time=run_info.end_time, - queued_time=time.time(), + await worker_group.create_processes() + local_start_time = await env.sync_run_start() + await worker_group.start(local_start_time) + + # Yield any updates and sync with the environment for non-local updates + async for ( + response, + request, + request_info, + state, + ) in worker_group.request_updates(): + await env.update_run_iteration( + response, request, request_info, state ) - requests_queue.put(work_req) - - run_info.created_requests += 1 - run_info.queued_requests += 1 - added_count += 1 - except StopIteration: - # we've reached the limit number, limit time, or exhausted the requests - # set to None to stop adding more and tell the loop no more requests - requests_iter = None - - return requests_iter - - def _check_result_ready( - self, - responses_queue: multiprocessing.Queue, - run_info: SchedulerRunInfo, - ) -> Optional[SchedulerRequestResult[RequestT, ResponseT]]: - try: - process_response: WorkerProcessResult[RequestT, ResponseT] = ( - responses_queue.get_nowait() - ) - except multiprocessing.queues.Empty: # type: ignore[attr-defined] - return None - - if process_response.type_ == "request_scheduled": - run_info.queued_requests -= 1 - run_info.scheduled_requests += 1 - - return SchedulerRequestResult( - type_="request_scheduled", - run_info=run_info, - request=process_response.request, - request_info=process_response.info, - response=None, - ) - - if process_response.type_ == "request_start": - run_info.scheduled_requests -= 1 - run_info.processing_requests += 1 - - return SchedulerRequestResult( - type_="request_start", - run_info=run_info, - request=process_response.request, - request_info=process_response.info, - response=None, - ) - - if process_response.type_ == "request_complete": - run_info.processing_requests -= 1 - run_info.completed_requests += 1 - - return SchedulerRequestResult( - type_="request_complete", - run_info=run_info, - request=process_response.request, - request_info=process_response.info, - response=process_response.response, - ) - raise ValueError(f"Invalid process response type: {process_response}") - - async def _stop_processes( - self, - futures: list[asyncio.Future], - requests_queue: multiprocessing.Queue, - ): - for _ in futures: - requests_queue.put(None) - - await asyncio.gather(*futures) + yield response, request, request_info, state + except Exception as err: # noqa: BLE001 + await env.sync_run_error(err) + finally: + # Ensure all worker processes are cleaned up for error or completion + if worker_group is not None: + err = await worker_group.shutdown() + if err is not None: + await env.sync_run_error(err) + + # Ensure any errors are raised and all responses + # are yielded for aggregation on the primary node + async for ( + response, + request, + request_info, + state, + ) in env.sync_run_end(): + yield response, request, request_info, state diff --git a/src/guidellm/scheduler/strategy.py b/src/guidellm/scheduler/strategy.py index 200c799e..60b1e03c 100644 --- a/src/guidellm/scheduler/strategy.py +++ b/src/guidellm/scheduler/strategy.py @@ -1,364 +1,650 @@ +""" +Request scheduling strategies for the GuideLLM toolkit. + +This module provides a comprehensive set of scheduling strategies that control how +requests are processed and timed within the GuideLLM benchmarking system. These +strategies enable fine-grained control over request concurrency, timing patterns, +and throughput characteristics to simulate various real-world usage scenarios. + +The scheduling system is built around abstract timing implementations that define +when requests should be executed, and concrete strategy classes that combine +timing behaviors with process and concurrency limits. + +Classes: + ScheduledRequestTimings: Abstract base class for request timing implementations + LastCompletionRequestTimings: Timing implementation for synchronous/concurrent + strategies + NoDelayRequestTimings: Timing implementation for throughput-maximizing strategies + ConstantRateRequestTimings: Timing implementation for constant-rate request + scheduling + PoissonRateRequestTimings: Timing implementation for Poisson-distributed request + scheduling + SchedulingStrategy: Abstract base class for all scheduling strategies + SynchronousStrategy: Sequential request processing with maximum throughput + ConcurrentStrategy: Parallel request processing with limited concurrency + ThroughputStrategy: Unrestricted request processing for maximum system throughput + AsyncConstantStrategy: Asynchronous request scheduling at a constant rate + AsyncPoissonStrategy: Asynchronous request scheduling with Poisson distribution +""" + +from __future__ import annotations + import math -import os import random import time -from collections.abc import Generator -from typing import ( - Literal, - Optional, - Union, -) +from abc import ABC, abstractmethod +from typing import Literal, TypeVar -from pydantic import Field +from pydantic import Field, PrivateAttr -from guidellm.config import settings from guidellm.objects import StandardBaseModel +from guidellm.scheduler.objects import ScheduledRequestInfo __all__ = [ "AsyncConstantStrategy", "AsyncPoissonStrategy", "ConcurrentStrategy", + "ConstantRateRequestTimings", + "LastCompletionRequestTimings", + "NoDelayRequestTimings", + "PoissonRateRequestTimings", + "ScheduledRequestTimings", "SchedulingStrategy", + "StrategyT", "StrategyType", "SynchronousStrategy", "ThroughputStrategy", - "strategy_display_str", ] StrategyType = Literal["synchronous", "concurrent", "throughput", "constant", "poisson"] -class SchedulingStrategy(StandardBaseModel): +def _exponential_decay_tau(max_progress: float, convergence: float = 0.99) -> float: """ - An abstract base class for scheduling strategies. - This class defines the interface for scheduling requests and provides - a common structure for all scheduling strategies. - Subclasses should implement the `request_times` method to provide - specific scheduling behavior. - - :param type_: The type of scheduling strategy to use. - This should be one of the predefined strategy types. + :param max_progress: The max progress value to reach + :param convergence: The target convergence level for reaching max_progress. + Default 0.99 represents at 99% exponential decay reach max_progress. + :return: The calculated tau value for the given max_progress and convergence. """ + return max_progress / (-math.log(1 - convergence)) - type_: Literal["strategy"] = Field( - description="The type of scheduling strategy schedule requests with.", + +def _exponential_decay_fraction(progress: float, tau: float = 1.0) -> float: + """ + :param progress: The current progress value (>=0) + :param tau: The scale factor for the exponential decay (default: 1.0) + :return: The fraction of completion based on exponential decay (0 -> 1) + """ + return 1 - math.exp(-progress / tau) + + +class ScheduledRequestTimings(StandardBaseModel, ABC): + """ + Abstract base class for request timing implementations in scheduling strategies. + + This class defines the interface for controlling when requests are scheduled + and how timing offsets are calculated. Different implementations provide + various timing behaviors such as synchronous, constant-rate, or stochastic + request scheduling patterns. + + Implementations must provide logic for calculating the next request offset + and handling request completion events that may affect future timing decisions. + """ + + @abstractmethod + def next_offset(self) -> float: + """ + Calculate the time offset for the next request to be scheduled. + + :return: The offset in seconds from the scheduler start time when the + next request should be scheduled. + """ + + @abstractmethod + def request_completed(self, request_info: ScheduledRequestInfo): + """ + Handle the completion of a request and update internal timing state. + + This method is called when a request completes (successfully or with error) + and allows the timing implementation to update its internal state based on + the completion information. + + :param request_info: Information about the completed request including + timing details and completion status. + """ + + +class LastCompletionRequestTimings(ScheduledRequestTimings): + """ + Timing implementation for synchronous and concurrent scheduling strategies. + + This implementation schedules the next request immediately after the last + request has completed, enabling sequential or limited concurrent processing. + It maintains an internal offset based on completion times to ensure proper + scheduling behavior. + """ + + offset: float = Field( + default=0.0, + description="The current time offset in seconds from scheduler start time.", + ) + startup_requests: int = Field( + default=0, + description=( + "Number of initial requests to schedule during startup phase with equal " + "spacing of startup_requests_delay before going to last request times." + ), + ge=0, + ) + startup_requests_delay: float = Field( + default=0.0, + description=( + "Delay in seconds used to add to the offset for each request " + "within the startup phase (_requests_count <= startup_requests)." + ), + ge=0, ) + _requests_count: int = PrivateAttr(0) - @property - def processing_mode(self) -> Literal["sync", "async"]: + def next_offset(self) -> float: + """ + :return: The current offset value in seconds from scheduler start time. """ - The processing mode for the scheduling strategy, either 'sync' or 'async'. - This property determines how the worker processes are setup: - either to run synchronously with one request at a time or asynchronously. - This property should be implemented by subclasses to return - the appropriate processing mode. + self._requests_count += 1 + + if self._requests_count <= self.startup_requests: + self.offset += self.startup_requests_delay - :return: The processing mode for the scheduling strategy, - either 'sync' or 'async'. + return self.offset + + def request_completed(self, request_info: ScheduledRequestInfo): """ - return "async" + Update timing state and offset based on the completed request. - @property - def processes_limit(self) -> int: + :param request_info: Information about the completed request including + timing details and completion status. """ - The limit on the number of worker processes for the scheduling strategy. - It determines how many worker processes are created - for the scheduling strategy and must be implemented by subclasses. + if ( + self._requests_count > self.startup_requests + and request_info.completed_at is not None + ): + # set the next sync offset to the time when the previous request completed + self.offset = request_info.completed_at - request_info.scheduler_start_time + - :return: The number of processes for the scheduling strategy. +class NoDelayRequestTimings(ScheduledRequestTimings): + """ + Timing implementation for throughput-maximizing scheduling strategies. + + This implementation schedules requests with no delay, allowing the system + to process requests as quickly as possible. It always returns a zero offset, + enabling maximum throughput by scheduling requests immediately without + waiting for previous requests to complete. + """ + + offset: float = Field( + default=0.0, + description="The time offset to apply in seconds from scheduler start time.", + ge=0, + ) + startup_duration: float = Field( + default=0.0, + description=( + "The duration of the startup phase in seconds to gradually ramp up " + "request processing." + ), + ge=0, + ) + startup_target_requests: int = Field( + default=1.0, + description=( + "The target number of requests to converge to in the startup phase." + ), + gt=0, + ) + startup_convergence: float = Field( + default=0.99, + description=("The target convergence rate during the startup phase."), + ) + _start_time: float | None = PrivateAttr(None) + _requests_count: int = PrivateAttr(0) + + def next_offset(self) -> float: """ - cpu_cores = os.cpu_count() or 1 + :return: Static offset plus any startup adjustment. + """ + if self._start_time is None: + self._start_time = time.time() - return min(max(1, cpu_cores - 1), settings.max_worker_processes) + self._requests_count += 1 + elapsed = time.time() - self._start_time - @property - def queued_requests_limit(self) -> Optional[int]: + if self.startup_duration > 0 and elapsed < self.startup_duration: + startup_percent = _exponential_decay_fraction( + self._requests_count, + _exponential_decay_tau( + self.startup_target_requests, self.startup_convergence + ), + ) + else: + startup_percent = 1.0 + + return self.offset + startup_percent * self.startup_duration + + def request_completed(self, request_info: ScheduledRequestInfo): """ - The maximum number of queued requests for the scheduling strategy. - It determines how many requests can be queued at one time - for the scheduling strategy and must be implemented by subclasses. + Handle request completion (no action needed for throughput strategy). - :return: The maximum number of queued requests for the scheduling strategy. + :param request_info: Information about the completed request (unused). """ - return settings.max_concurrency - @property - def processing_requests_limit(self) -> int: + +class ConstantRateRequestTimings(ScheduledRequestTimings): + """ + Timing implementation for constant-rate scheduling strategies. + + This implementation schedules requests at a constant rate defined in requests + per second. The offset for each subsequent request is calculated as a multiple + of the interval between requests, ensuring evenly spaced request scheduling. + """ + + rate: float = Field( + description="The target rate in requests per second. Must be positive.", + gt=0, + ) + offset: float = Field( + default=0.0, + description="The time offset to apply in seconds from scheduler start time.", + ge=0, + ) + _requests_count: int = PrivateAttr(0) + + def next_offset(self) -> float: """ - The maximum number of processing requests for the scheduling strategy. - It determines how many requests can be processed at one time - for the scheduling strategy and must be implemented by subclasses. + Calculate the offset for the next request at a constant rate. - :return: The maximum number of processing requests for the scheduling strategy. + Each request is scheduled at a fixed interval based on the target rate, + with offsets increasing linearly: 0, 1/rate, 2/rate, 3/rate, etc. + + :return: The offset in seconds for the next request. """ - return settings.max_concurrency + num_requests = self._requests_count + self._requests_count += 1 + interval = 1.0 / self.rate + + return self.offset + interval * num_requests - def request_times(self) -> Generator[float, None, None]: + def request_completed(self, request_info: ScheduledRequestInfo): """ - A generator that yields timestamps for when requests should be sent. - This method should be implemented by subclasses to provide specific - scheduling behavior. + Handle request completion (no action needed for constant rate strategy). - :return: A generator that yields timestamps for request scheduling - or -1 for requests that should be sent immediately. + :param request_info: Information about the completed request (unused). """ - raise NotImplementedError("Subclasses must implement request_times() method.") -class SynchronousStrategy(SchedulingStrategy): +class PoissonRateRequestTimings(ScheduledRequestTimings): """ - A class representing a synchronous scheduling strategy. - This strategy schedules requests synchronously, one at a time, - with the maximum rate possible. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for synchronous scheduling. - - :param type_: The synchronous StrategyType to schedule requests synchronously. + Timing implementation for Poisson-distributed scheduling strategies. + + This implementation schedules requests following a Poisson process with + exponentially distributed inter-arrival times. The average rate is specified + in requests per second, but individual intervals vary randomly according to + the exponential distribution, simulating realistic traffic patterns. """ - type_: Literal["synchronous"] = "synchronous" # type: ignore[assignment] + rate: float = Field( + description="The target average rate in requests per second. Must be positive.", + gt=0, + ) + random_seed: int = Field( + default=42, + description=( + "Seed for the random number generator to ensure reproducible behavior." + ), + ) + offset: float = Field( + default=0.0, + description="The time offset to apply in seconds from scheduler start time.", + ) + _requests_count: int = PrivateAttr(0) + _random: random.Random | None = PrivateAttr(None) - @property - def processing_mode(self) -> Literal["sync"]: + def next_offset(self) -> float: """ - The processing mode for the scheduling strategy, either 'sync' or 'async'. - This property determines how the worker processes are setup: - either to run synchronously with one request at a time or asynchronously. + Calculate the offset for the next request using Poisson distribution. - :return: 'sync' for synchronous scheduling strategy - for the single worker process. + Uses exponential distribution to generate inter-arrival times that + follow a Poisson process. Each call advances the cumulative offset + by a randomly generated delay. + + :return: The cumulative offset in seconds for the next request. + """ + self._requests_count += 1 + + if self._random is None: + self._random = random.Random(self.random_seed) + else: + next_delay = self._random.expovariate(self.rate) + self.offset += next_delay + + return self.offset + + def request_completed(self, request_info: ScheduledRequestInfo): """ - return "sync" + Handle request completion (no action needed for Poisson rate strategy). + + :param request_info: Information about the completed request (unused). + """ + + +class SchedulingStrategy(StandardBaseModel): + """ + An abstract base class for scheduling strategies enabling control over how + requests are processed by the scheduler. + """ + + type_: Literal["strategy"] = Field( + description="The type of scheduling strategy to schedule requests with.", + ) @property - def processes_limit(self) -> int: + def processes_limit(self) -> int | None: + """ + :return: The maximum number of worker processes supported by the + scheduling strategy. None if not limited. """ - The limit on the number of worker processes for the scheduling strategy. - It determines how many worker processes are created - for the scheduling strategy and must be implemented by subclasses. + return None - :return: 1 for the synchronous scheduling strategy to limit - the worker processes to one. + @property + def requests_limit(self) -> int | None: """ - return 1 + :return: The maximum number of concurrent requests that can be processed + at once by the scheduling strategy. None if not limited. + """ + return None + + def create_request_timings( + self, local_rank: int, local_world_size: int, local_max_concurrency: int + ) -> ScheduledRequestTimings: + """ + Create a ScheduledRequestTimings instance to define the timing behavior + for the worker process to schedule requests. + + :param local_rank: The rank of the worker process within the local world size. + :param local_world_size: The total num of worker processes in the local world. + :param local_max_concurrency: The maximum number of concurrent requests + for the worker process. + :return: A ScheduledRequestTimings instance for the worker process. + """ + raise NotImplementedError( + "create_worker_timings method must be implemented by subclasses." + ) + + +StrategyT = TypeVar("StrategyT", bound=SchedulingStrategy) + + +class SynchronousStrategy(SchedulingStrategy): + """ + Sequential request processing strategy with maximum throughput constraints. + + This strategy processes requests one at a time in strict sequential order, + waiting for each request to complete before starting the next. It provides + the most predictable timing behavior and is useful for measuring maximum + achievable throughput under sequential processing constraints. + + The strategy enforces a limit of one worker process and one concurrent request, + making it ideal for scenarios where request ordering and isolation are critical. + """ + + type_: Literal["synchronous"] = "synchronous" # type: ignore[assignment] + + def __str__(self) -> str: + """Return string representation of the strategy.""" + return "synchronous" @property - def queued_requests_limit(self) -> int: + def processes_limit(self) -> int | None: """ - The maximum number of queued requests for the scheduling strategy. - It determines how many requests can be queued at one time - for the scheduling strategy and must be implemented by subclasses. + Get the maximum number of worker processes for synchronous scheduling. - :return: 1 for the synchronous scheduling strategy to limit - the queued requests to one that is ready to be processed. + :return: Always returns 1 to enforce single-process constraint. """ return 1 @property - def processing_requests_limit(self) -> int: + def requests_limit(self) -> int | None: """ - The maximum number of processing requests for the scheduling strategy. - It determines how many requests can be processed at one time - for the scheduling strategy and must be implemented by subclasses. + Get the maximum number of concurrent requests for synchronous scheduling. - :return: 1 for the synchronous scheduling strategy to limit - the processing requests to one that is ready to be processed. + :return: Always returns 1 to enforce single-request constraint. """ return 1 - def request_times(self) -> Generator[float, None, None]: + def create_request_timings( + self, local_rank: int, local_world_size: int, local_max_concurrency: int + ) -> ScheduledRequestTimings: """ - A generator that yields time.time() so requests are sent immediately, - while scheduling them synchronously. + Create timing implementation for synchronous request scheduling. - :return: A generator that yields time.time() for immediate request scheduling. + :param local_rank: The rank of the worker process. Must be 0. + :param local_world_size: Total number of worker processes. Must be 1. + :param local_max_concurrency: The maximum number of concurrent requests + for the worker process. Unused in this strategy. + :return: LastCompletionRequestTimings instance for sequential processing. + :raises ValueError: If multiple workers or non-zero rank is specified. """ - while True: - yield time.time() + if local_world_size > 1 or local_rank != 0: + raise ValueError( + "SynchronousStrategy can only be used with a single worker process." + ) + + return LastCompletionRequestTimings() class ConcurrentStrategy(SchedulingStrategy): """ - A class representing a concurrent scheduling strategy. - This strategy schedules requests concurrently with the specified - number of streams. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for concurrent scheduling. - - :param type_: The concurrent StrategyType to schedule requests concurrently. - :param streams: The number of concurrent streams to use for scheduling requests. - Each stream runs synchronously with the maximum rate possible. - This must be a positive integer. + Parallel request processing strategy with controlled concurrency limits. + + This strategy enables concurrent request processing up to a specified number + of streams, allowing multiple requests to be processed simultaneously while + maintaining predictable resource usage. It provides a balance between + throughput and resource control. + + The number of concurrent streams determines both the maximum number of worker + processes and the maximum number of requests that can be processed in parallel. + Each worker process handles one stream and waits for request completion before + processing the next request in that stream. """ type_: Literal["concurrent"] = "concurrent" # type: ignore[assignment] streams: int = Field( description=( "The number of concurrent streams to use for scheduling requests. " - "Each stream runs sychronously with the maximum rate possible. " "This must be a positive integer." ), gt=0, ) + startup_duration: float = Field( + default=0.0, + description=( + "Duration in seconds over which startup requests are distributed " + "before switching to completion-based timing." + ), + ge=0, + ) - @property - def processing_mode(self) -> Literal["sync"]: - """ - The processing mode for the scheduling strategy, either 'sync' or 'async'. - This property determines how the worker processes are setup: - either to run synchronously with one request at a time or asynchronously. - - :return: 'sync' for synchronous scheduling strategy - for the multiple worker processes equal to streams. - """ - return "sync" + def __str__(self) -> str: + """Return string representation of the strategy.""" + return f"concurrent@{self.streams}" @property def processes_limit(self) -> int: """ - The limit on the number of worker processes for the scheduling strategy. - It determines how many worker processes are created - for the scheduling strategy and must be implemented by subclasses. - - :return: {self.streams} for the concurrent scheduling strategy to limit - the worker processes to the number of streams. - """ - return self.streams - - @property - def queued_requests_limit(self) -> int: - """ - The maximum number of queued requests for the scheduling strategy. - It determines how many requests can be queued at one time - for the scheduling strategy and must be implemented by subclasses. + Get the maximum number of worker processes for concurrent scheduling. - :return: {self.streams} for the concurrent scheduling strategy to limit - the queued requests to the number of streams that are ready to be processed. + :return: The number of streams, which equals the maximum worker processes. """ return self.streams @property - def processing_requests_limit(self) -> int: + def requests_limit(self) -> int: """ - The maximum number of processing requests for the scheduling strategy. - It determines how many requests can be processed at one time - for the scheduling strategy and must be implemented by subclasses. + Get the maximum number of concurrent requests for concurrent scheduling. - :return: {self.streams} for the concurrent scheduling strategy to limit - the processing requests to the number of streams that ready to be processed. + :return: The number of streams, which equals the maximum concurrent requests. """ return self.streams - def request_times(self) -> Generator[float, None, None]: - """ - A generator that yields time.time() so requests are sent - immediately, while scheduling them concurrently with the specified - number of streams. + def create_request_timings( + self, local_rank: int, local_world_size: int, local_max_concurrency: int + ) -> LastCompletionRequestTimings: + """ + Create timing implementation for concurrent request scheduling. + + :param local_rank: The rank of the worker process. Must be less than streams. + :param local_world_size: Total number of worker processes. Must not exceed + streams. + :param local_max_concurrency: The maximum number of concurrent requests + for the worker process. Unused in this strategy. + :return: LastCompletionRequestTimings instance for stream-based processing. + :raises ValueError: If worker configuration exceeds stream limits. + """ + if local_world_size > self.streams: + raise ValueError( + "ConcurrentStrategy can only be used with up to " + f"{self.streams} worker processes." + ) + + if local_rank >= self.streams: + raise ValueError( + f"Local rank {local_rank} exceeds the number of streams {self.streams}." + ) + + if self.startup_duration > 0: + # Ensure equal global distribution of the start up for concurrent streams + # Ex: for 10 streams, 2 workers, and 8 seconds start up duration, + # the first worker should start at 0.0, 1.6, 3.2, 4.8, 6.4 + # and the second worker should start at 0.8, 2.4, 4.0, 5.6, 7.2 + delay_per_stream = self.startup_duration / self.streams + streams_per_worker = self.streams // local_world_size + + offset = local_rank * streams_per_worker * delay_per_stream + startup_requests = streams_per_worker + ( + 1 + if local_world_size > 1 and local_rank < self.streams % local_world_size + else 0 + ) + startup_requests_delay = delay_per_stream * local_world_size + else: + offset = 0.0 + startup_requests = 0 + startup_requests_delay = 0.0 - :return: A generator that yields time.time() for immediate request scheduling. - """ - while True: - yield time.time() + return LastCompletionRequestTimings( + offset=offset, + startup_requests=startup_requests, + startup_requests_delay=startup_requests_delay, + ) class ThroughputStrategy(SchedulingStrategy): """ - A class representing a throughput scheduling strategy. - This strategy schedules as many requests asynchronously as possible, - with the maximum rate possible. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for throughput scheduling. - - :param type_: The throughput StrategyType to schedule requests asynchronously. + Maximum throughput strategy with optional concurrency limits. + + This strategy schedules requests to maximize system throughput by allowing + unlimited concurrent request processing. Requests are scheduled immediately + without waiting for previous requests to complete, enabling the system to + achieve its maximum processing capacity. + + An optional maximum concurrency limit can be set to prevent resource + exhaustion while still allowing high-throughput processing patterns. """ type_: Literal["throughput"] = "throughput" # type: ignore[assignment] - max_concurrency: Optional[int] = Field( + max_concurrency: int | None = Field( default=None, description=( "The maximum number of concurrent requests to schedule. " - "If set to None, the concurrency value from settings will be used. " "This must be a positive integer greater than 0." ), gt=0, ) + startup_duration: float = Field( + default=0.0, + description=( + "Duration in seconds over which startup requests are distributed " + "before switching to full throughput scheduling." + ), + ge=0, + ) - @property - def processing_mode(self) -> Literal["async"]: - """ - The processing mode for the scheduling strategy, either 'sync' or 'async'. - This property determines how the worker processes are setup: - either to run synchronously with one request at a time or asynchronously. - - :return: 'async' for asynchronous scheduling strategy - for the multiple worker processes handling requests. - """ - return "async" + def __str__(self) -> str: + """Return string representation of the strategy.""" + return "throughput" @property - def queued_requests_limit(self) -> int: + def processes_limit(self) -> int | None: """ - The maximum number of queued requests for the scheduling strategy. - It determines how many requests can be queued at one time - for the scheduling strategy and must be implemented by subclasses. + Get the maximum number of worker processes for throughput scheduling. - :return: The processing requests limit to ensure that there are enough - requests even for the worst case scenario where the max concurrent - requests are pulled at once for processing. + :return: The max_concurrency value if set, otherwise None for unlimited + worker processes. """ - return self.processing_requests_limit + return self.max_concurrency @property - def processing_requests_limit(self) -> int: + def requests_limit(self) -> int | None: """ - The maximum number of processing requests for the scheduling strategy. - It determines how many requests can be processed at one time - for the scheduling strategy and must be implemented by subclasses. + Get the maximum number of concurrent requests for throughput scheduling. - :return: {self.max_concurrency} for the throughput scheduling strategy to limit - the processing requests to the maximum concurrency. - If max_concurrency is None, then the default processing requests limit - will be used. + :return: The max_concurrency value if set, otherwise None for unlimited + concurrent requests. """ - return self.max_concurrency or super().processing_requests_limit + return self.max_concurrency - def request_times(self) -> Generator[float, None, None]: + def create_request_timings( + self, local_rank: int, local_world_size: int, local_max_concurrency: int + ) -> ScheduledRequestTimings: """ - A generator that yields the start time.time() so requests are sent - immediately, while scheduling as many asynchronously as possible. + Create timing implementation for throughput request scheduling. - :return: A generator that yields the start time.time() - for immediate request scheduling. + :param local_rank: The rank of the worker process (unused for throughput). + :param local_world_size: Total number of worker processes (unused for + throughput). + :param local_max_concurrency: The maximum number of concurrent requests + for the worker process. + :return: NoDelayRequestTimings instance for immediate request scheduling. """ - start_time = time.time() + if self.startup_duration > 0: + # Vary offset by up to 5% of the startup duration for a bit of variance + offset = 0.05 * self.startup_duration * (local_rank / local_world_size) + # Use local_max_concurrency as the target requests for startup convergence + startup_target_requests = local_max_concurrency + else: + offset = 0.0 + startup_target_requests = 1 - while True: - yield start_time + return NoDelayRequestTimings( + startup_duration=self.startup_duration, + startup_target_requests=startup_target_requests, + offset=offset, + ) class AsyncConstantStrategy(ThroughputStrategy): """ - A class representing an asynchronous constant scheduling strategy. - This strategy schedules requests asynchronously at a constant request rate - in requests per second. - If initial_burst is set, it will send an initial burst of math.floor(rate) - requests to reach the target rate. - This is useful to ensure that the target rate is reached quickly - and then maintained. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for asynchronous constant scheduling. - - :param type_: The constant StrategyType to schedule requests asynchronously. - :param rate: The rate at which to schedule requests asynchronously in - requests per second. This must be a positive float. - :param initial_burst: True to send an initial burst of requests - (math.floor(self.rate)) to reach target rate. - False to not send an initial burst. + Asynchronous constant-rate scheduling strategy for predictable load patterns. + + This strategy schedules requests at a fixed rate specified in requests per + second, distributed evenly across all worker processes. It provides predictable + timing behavior while allowing asynchronous processing, making it ideal for + simulating steady-state load conditions and measuring system performance + under consistent request rates. + + The total rate is divided equally among all worker processes, ensuring the + aggregate rate matches the specified value regardless of the number of workers. """ type_: Literal["constant"] = "constant" # type: ignore[assignment] @@ -369,64 +655,55 @@ class AsyncConstantStrategy(ThroughputStrategy): ), gt=0, ) - initial_burst: bool = Field( - default=True, + startup_duration: float = Field( + default=0.0, description=( - "True to send an initial burst of requests (math.floor(self.rate)) " - "to reach target rate. False to not send an initial burst." + "Duration in seconds over which startup requests are distributed " + "to converge quickly to the desired rate before switching to " + "constant-rate scheduling." ), + ge=0, ) - def request_times(self) -> Generator[float, None, None]: - """ - A generator that yields timestamps for when requests should be sent. - This method schedules requests asynchronously at a constant rate - in requests per second. - If burst_time is set, it will send an initial burst of requests - to reach the target rate. - This is useful to ensure that the target rate is reached quickly - and then maintained. + def __str__(self) -> str: + """Return string representation of the strategy.""" + return f"constant@{self.rate:.2f}" - :return: A generator that yields timestamps for request scheduling. + def create_request_timings( + self, local_rank: int, local_world_size: int, local_max_concurrency: int + ) -> ScheduledRequestTimings: """ - start_time = time.time() - constant_increment = 1.0 / self.rate + Create timing implementation for constant-rate request scheduling. - # handle bursts first to get to the desired rate - if self.initial_burst is not None: - # send an initial burst equal to the rate - # to reach the target rate - burst_count = math.floor(self.rate) - for _ in range(burst_count): - yield start_time + Divides the total rate evenly across all worker processes to maintain + the specified aggregate rate. - start_time += constant_increment - - counter = 0 + :param local_rank: The rank of the worker process (unused). + :param local_world_size: Total number of worker processes for rate division. + :param local_max_concurrency: The maximum number of concurrent requests + for the worker process. + :return: ConstantRateRequestTimings instance with per-worker rate. + """ + # Divide the rate evenly across all worker processes + worker_rate = self.rate / local_world_size - # continue with constant rate after bursting - while True: - yield start_time + constant_increment * counter - counter += 1 + return ConstantRateRequestTimings( + rate=worker_rate, + ) class AsyncPoissonStrategy(ThroughputStrategy): """ - A class representing an asynchronous Poisson scheduling strategy. - This strategy schedules requests asynchronously at a Poisson request rate - in requests per second. - If initial_burst is set, it will send an initial burst of math.floor(rate) - requests to reach the target rate. - It inherits from the `SchedulingStrategy` base class and - implements the `request_times` method to provide the specific - behavior for asynchronous Poisson scheduling. - - :param type_: The Poisson StrategyType to schedule requests asynchronously. - :param rate: The rate at which to schedule requests asynchronously in - requests per second. This must be a positive float. - :param initial_burst: True to send an initial burst of requests - (math.floor(self.rate)) to reach target rate. - False to not send an initial burst. + Asynchronous Poisson-distributed scheduling strategy for realistic load simulation. + + This strategy schedules requests following a Poisson process with exponentially + distributed inter-arrival times. The average rate is specified in requests per + second, but individual intervals vary randomly, providing a more realistic + simulation of user behavior and network traffic patterns. + + The total rate is divided equally among all worker processes, with each worker + using a different random seed to ensure independent request streams that + collectively achieve the target rate. """ type_: Literal["poisson"] = "poisson" # type: ignore[assignment] @@ -437,57 +714,45 @@ class AsyncPoissonStrategy(ThroughputStrategy): ), gt=0, ) - initial_burst: bool = Field( - default=True, + startup_duration: float = Field( + default=0.0, description=( - "True to send an initial burst of requests (math.floor(self.rate)) " - "to reach target rate. False to not send an initial burst." + "Duration in seconds over which startup requests are distributed " + "to converge quickly to the desired rate before switching to " + "constant-rate scheduling." ), + ge=0, ) random_seed: int = Field( default=42, - description=("The random seed to use for the Poisson distribution. "), + description=("The random seed to use for the Poisson distribution."), ) - def request_times(self) -> Generator[float, None, None]: - """ - A generator that yields timestamps for when requests should be sent. - This method schedules requests asynchronously at a Poisson rate - in requests per second. - The inter arrival time between requests is exponentially distributed - based on the rate. - - :return: A generator that yields timestamps for request scheduling. - """ - start_time = time.time() - - if self.initial_burst is not None: - # send an initial burst equal to the rate - # to reach the target rate - burst_count = math.floor(self.rate) - for _ in range(burst_count): - yield start_time - else: - yield start_time - - # set the random seed for reproducibility - rand = random.Random(self.random_seed) # noqa: S311 - - while True: - inter_arrival_time = rand.expovariate(self.rate) - start_time += inter_arrival_time - yield start_time - - -def strategy_display_str(strategy: Union[StrategyType, SchedulingStrategy]) -> str: - strategy_type = strategy if isinstance(strategy, str) else strategy.type_ - strategy_instance = strategy if isinstance(strategy, SchedulingStrategy) else None - - if strategy_type == "concurrent": - rate = f"@{strategy_instance.streams}" if strategy_instance else "@##" # type: ignore[attr-defined] - elif strategy_type in ("constant", "poisson"): - rate = f"@{strategy_instance.rate:.2f}" if strategy_instance else "@#.##" # type: ignore[attr-defined] - else: - rate = "" - - return f"{strategy_type}{rate}" + def __str__(self) -> str: + """Return string representation of the strategy.""" + return f"poisson@{self.rate:.2f}" + + def create_request_timings( + self, local_rank: int, local_world_size: int, local_max_concurrency: int + ) -> ScheduledRequestTimings: + """ + Create timing implementation for Poisson-distributed request scheduling. + + Divides the total rate evenly across all worker processes and assigns + unique random seeds to ensure independent but coordinated request streams. + + :param local_rank: The rank of the worker process for seed generation. + :param local_world_size: Total number of worker processes for rate division. + :param local_max_concurrency: The maximum number of concurrent requests + for the worker process. + :return: PoissonRateRequestTimings instance with per-worker rate and + unique seed. + """ + # Divide the rate evenly across all worker processes + worker_rate = self.rate / local_world_size + # Use a different seed for each worker to ensure different sequences + worker_seed = self.random_seed + local_rank + return PoissonRateRequestTimings( + rate=worker_rate, + random_seed=worker_seed, + ) diff --git a/src/guidellm/scheduler/types.py b/src/guidellm/scheduler/types.py deleted file mode 100644 index 42535d71..00000000 --- a/src/guidellm/scheduler/types.py +++ /dev/null @@ -1,7 +0,0 @@ -from typing import TypeVar - -__all__ = ["RequestT", "ResponseT"] - - -RequestT = TypeVar("RequestT") -ResponseT = TypeVar("ResponseT") diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index a53b14c2..ac9c837f 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -1,513 +1,527 @@ -import asyncio -import math -import multiprocessing -import multiprocessing.queues -import time -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator -from dataclasses import dataclass -from typing import ( - Any, - Generic, - Literal, - Optional, - Union, -) - -from loguru import logger -from pydantic import Field - -from guidellm.backend import ( - Backend, - BackendType, - RequestArgs, - ResponseSummary, - StreamingTextResponse, -) -from guidellm.objects import StandardBaseModel -from guidellm.request import GenerationRequest -from guidellm.scheduler.result import SchedulerRequestInfo -from guidellm.scheduler.types import RequestT, ResponseT - -__all__ = [ - "GenerativeRequestsWorker", - "GenerativeRequestsWorkerDescription", - "RequestsWorker", - "ResolveStatus", - "WorkerDescription", - "WorkerProcessRequest", - "WorkerProcessResult", -] +""" +Worker process management for multi-process request scheduling and execution. +Provides infrastructure for managing individual worker processes that handle +request scheduling, processing, and coordination in multi-process environments. -@dataclass -class WorkerProcessRequest(Generic[RequestT]): - request: RequestT - start_time: float - timeout_time: float - queued_time: float +Classes: + WorkerProcess: Individual worker process for request processing and coordination. +""" +from __future__ import annotations -@dataclass -class WorkerProcessResult(Generic[RequestT, ResponseT]): - type_: Literal["request_scheduled", "request_start", "request_complete"] - request: RequestT - response: Optional[ResponseT] - info: SchedulerRequestInfo - - -@dataclass -class ResolveStatus: - requested: bool - completed: bool - errored: bool - canceled: bool - - request_start: float - request_end: float - +import asyncio +import time +from collections.abc import Generator +from multiprocessing import Queue +from multiprocessing.synchronize import Barrier as ProcessingBarrier +from multiprocessing.synchronize import Event as ProcessingEvent +from queue import Empty as QueueEmpty +from threading import Event as ThreadingEvent +from typing import Generic, Literal + +import culsans + +from guidellm.scheduler.objects import ( + BackendInterface, + MeasuredRequestTimingsT, + MultiTurnRequestT, + RequestT, + ResponseT, + ScheduledRequestInfo, +) +from guidellm.scheduler.strategy import ScheduledRequestTimings +from guidellm.utils import MsgpackEncoding, synchronous_to_exitable_async -class WorkerDescription(StandardBaseModel): - type_: Literal["worker"] = "worker" +__all__ = ["WorkerProcess"] -class RequestsWorker(ABC, Generic[RequestT, ResponseT]): +class WorkerProcess(Generic[RequestT, MeasuredRequestTimingsT, ResponseT]): """ - An abstract base class for a worker that processes requests. - This class defines the interface for a worker that can resolve requests - asynchronously or synchronously within the Scheduler class. - Subclasses must implement the `resolve` method, - which takes a request directly given from the load generator, - along with the desired start_time for the request and a timeout_time. - The `resolve` method should return the response from the backend. + Individual worker process for request processing and coordination. + + Manages the complete lifecycle of requests from queue consumption through backend + processing and updates publication, maintaining synchronization with other + processes in the group. """ - @property - @abstractmethod - def description(self) -> WorkerDescription: + def __init__( + self, + local_rank: int, + local_world_size: int, + async_limit: int, + startup_barrier: ProcessingBarrier, + shutdown_event: ProcessingEvent, + error_event: ProcessingEvent, + requests_queue: Queue[ + tuple[ + RequestT | MultiTurnRequestT[RequestT], + ScheduledRequestInfo[MeasuredRequestTimingsT], + ] + ], + updates_queue: Queue[ + tuple[ + ResponseT | None, + RequestT | MultiTurnRequestT[RequestT], + ScheduledRequestInfo[MeasuredRequestTimingsT], + ] + ], + backend: BackendInterface[RequestT, MeasuredRequestTimingsT, ResponseT], + request_timings: ScheduledRequestTimings, + poll_intervals: float = 0.1, + max_requests_queue_buffer: int = 2, + ): """ - An abstract property that must be implemented by subclasses. - This property should return a Serializable class representing the information - about the worker instance. + Initialize worker process instance. + + :param local_rank: Process rank within the worker group. + :param local_world_size: Total number of worker processes in the group. + :param async_limit: Maximum concurrent requests this worker can handle. + :param startup_barrier: Multiprocessing barrier for coordinated startup. + :param shutdown_event: Event for signaling graceful shutdown. + :param error_event: Event for signaling error conditions across processes. + :param requests_queue: Queue for receiving requests to process. + :param updates_queue: Queue for publishing processing updates. + :param backend: Backend instance for processing requests. + :param request_timings: Timing strategy for request scheduling. + :param poll_intervals: Time interval for polling operations. """ - ... + # Worker info + self.local_rank = local_rank + self.local_world_size = local_world_size + self.async_limit = async_limit + + # Process synchronization + self.startup_barrier = startup_barrier + self.shutdown_event = shutdown_event + self.error_event = error_event + self.requests_queue = requests_queue + self.updates_queue = updates_queue + + # Local synchronization (initialized during start up) + self.pending_requests_queue: culsans.Queue[ + tuple[ + RequestT | MultiTurnRequestT[RequestT], + ScheduledRequestInfo[MeasuredRequestTimingsT], + ] + ] = None + self.pending_updates_queue: culsans.Queue[ + tuple[ + RequestT | MultiTurnRequestT[RequestT], + ScheduledRequestInfo[MeasuredRequestTimingsT], + ] + ] = None + self.requests_canceled: ThreadingEvent = None + self.pull_requests_stopped: ThreadingEvent = None + self.pull_task: asyncio.Task = None + self.push_task: asyncio.Task = None + + # Request processing + self.backend = backend + self.request_timings = request_timings + self.poll_intervals = poll_intervals + self.max_requests_queue_buffer = max_requests_queue_buffer + self.startup_completed: bool = False - @abstractmethod - async def prepare_multiprocessing(self): - """ - An abstract method that must be implemented by subclasses. - This is useful for workers that have instance state that can not - be shared across processes and should be cleared out and re-initialized - for each new process. + def run(self): """ - ... + Main entry point for worker process execution. - @abstractmethod - async def resolve( - self, - request: RequestT, - timeout_time: float, - ) -> tuple[ResolveStatus, ResponseT]: + Initializes asyncio event loop and starts worker async operations. + + :raises RuntimeError: If worker encounters unrecoverable error during execution. """ - An abstract method that must be implemented by subclasses. - This method should handle the resolution of a request through asyncio, - including any necessary backend processing and response handling. - - :param request: The request to be resolved generated by the load generator. - :param timeout_time: The timeout time for the request, if there is no timeout - given, then this will be math.inf. - :return: The response from the worker. + try: + asyncio.run(self.run_async()) + except Exception as exc: + self.error_event.set() + raise RuntimeError( + f"Worker process {self.local_rank} encountered an error: {exc}" + ) from exc + + async def run_async(self): """ - ... - - async def get_request( - self, requests_queue: multiprocessing.Queue - ) -> Optional[WorkerProcessRequest[RequestT]]: - return await asyncio.to_thread(requests_queue.get) # type: ignore[attr-defined] + Execute main asynchronous worker process logic. - async def send_result( - self, - results_queue: multiprocessing.Queue, - result: WorkerProcessResult[RequestT, ResponseT], - ): - await asyncio.to_thread(results_queue.put, result) # type: ignore[attr-defined] + Orchestrates concurrent execution of request processing and shutdown monitoring + tasks, handling cleanup and error propagation when tasks complete. - async def resolve_scheduler_request( - self, - request: Any, - queued_time: float, - dequeued_time: float, - start_time: float, - timeout_time: float, - results_queue: multiprocessing.Queue, - process_id: int, - ): - info = SchedulerRequestInfo( - targeted_start_time=start_time, - queued_time=queued_time, - dequeued_time=dequeued_time, - scheduled_time=time.time(), - process_id=process_id, - ) - result: WorkerProcessResult[RequestT, ResponseT] = WorkerProcessResult( - type_="request_scheduled", - request=request, - response=None, - info=info, - ) - asyncio.create_task(self.send_result(results_queue, result)) - - if (wait_time := start_time - time.time()) > 0: - await asyncio.sleep(wait_time) - - info.worker_start = time.time() - result = WorkerProcessResult( - type_="request_start", - request=request, - response=None, - info=info, - ) - asyncio.create_task(self.send_result(results_queue, result)) - - status, response = await self.resolve(request, timeout_time) - info.worker_end = time.time() - info.requested = status.requested - info.completed = status.completed - info.errored = status.errored - info.canceled = status.canceled - info.request_start = status.request_start - info.request_end = status.request_end - result = WorkerProcessResult( - type_="request_complete", - request=request, - response=response, - info=info, - ) - asyncio.create_task(self.send_result(results_queue, result)) - - def process_loop_synchronous( - self, - requests_queue: multiprocessing.Queue, - results_queue: multiprocessing.Queue, - process_id: int, - ): - async def _process_runner(): - while ( - process_request := await self.get_request(requests_queue) - ) is not None: - dequeued_time = time.time() - - await self.resolve_scheduler_request( - request=process_request.request, - queued_time=process_request.queued_time, - dequeued_time=dequeued_time, - start_time=process_request.start_time, - timeout_time=process_request.timeout_time, - results_queue=results_queue, - process_id=process_id, - ) - - try: - asyncio.run(_process_runner()) - except Exception as exc: # noqa: BLE001 - logger.error( - f"Error in worker process {process_id}: {exc}", - exc_info=True, - stack_info=True, - ) - - def process_loop_asynchronous( - self, - requests_queue: multiprocessing.Queue, - results_queue: multiprocessing.Queue, - max_concurrency: int, - process_id: int, - ): - async def _process_runner(): - pending = asyncio.Semaphore(max_concurrency) - - if pending.locked(): - raise ValueError("Async worker called with max_concurrency < 1") - - while ( - process_request := await self.get_request(requests_queue) - ) is not None: - dequeued_time = time.time() - - await pending.acquire() - - def _task_done(_: asyncio.Task): - nonlocal pending - pending.release() - - task = asyncio.create_task( - self.resolve_scheduler_request( - request=process_request.request, - queued_time=process_request.queued_time, - dequeued_time=dequeued_time, - start_time=process_request.start_time, - timeout_time=process_request.timeout_time, - results_queue=results_queue, - process_id=process_id, - ) - ) - task.add_done_callback(_task_done) - await asyncio.sleep(0) # enable start task immediately + :raises RuntimeError: If worker tasks encounter unrecoverable errors. + """ + # Start both shutdown monitoring and request processing concurrently + tasks = [ + asyncio.create_task(self.run_async_stop_processing()), + asyncio.create_task(self.run_async_requests_processing()), + ] try: - asyncio.run(_process_runner()) - except Exception as exc: # noqa: BLE001 - logger.error( - f"Error in worker process {process_id}: {exc}", - exc_info=True, - stack_info=True, + # Wait for the first task to complete (shut down or error) + completed, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED ) + # Cancel remaining tasks + if pending: + for task in pending: + task.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + # Check for exceptions in completed tasks + for task in completed: + if not task.cancelled() and (exception := task.exception()): + raise exception + except asyncio.CancelledError: + # Ensure all tasks are canceled before re-raising + for task in tasks: + if not task.done(): + task.cancel() + if any(not task.done() for task in tasks): + await asyncio.gather(*tasks, return_exceptions=True) + raise + + async def run_async_stop_processing(self): + """ + Monitor for shutdown and error signals. -class GenerativeRequestsWorkerDescription(WorkerDescription): - type_: Literal["generative_requests_worker"] = "generative_requests_worker" # type: ignore[assignment] - backend_type: BackendType - backend_target: str - backend_model: str - backend_info: dict[str, Any] = Field( - default_factory=dict, - ) + Runs in parallel with request processing to monitor for shutdown or error + events and trigger appropriate cleanup procedures. + :raises RuntimeError: If error event is signaled or unexpected exit occurs. + :raises asyncio.CancelledError: If shutdown event is signaled. + """ + exit_reason, _ = await synchronous_to_exitable_async( + synchronous=None, + exit_events={ + "error_event": self.error_event, + "shutdown_event": self.shutdown_event, + }, + poll_interval=self.poll_intervals, + ) -class GenerativeRequestsWorker(RequestsWorker[GenerationRequest, ResponseSummary]): - """ - A class that handles the execution of requests using a backend. - This class is responsible for sending requests to the backend, - handling responses, and managing errors. + if exit_reason == "error_event": + raise RuntimeError( + f"Worker process {self.local_rank} received error signal." + ) + elif exit_reason == "shutdown_event": + raise asyncio.CancelledError( + f"Worker process {self.local_rank} received shutdown signal." + ) + else: + raise RuntimeError( + f"Worker process {self.local_rank} received unexpected exit reason: " + f"{exit_reason}" + ) - :param backend: The backend to use for handling requests. - This should be an instance of Backend such as an OpenAIHTTPBackend. - """ + async def run_async_requests_processing(self): + """ + Process incoming requests from the queue. - def __init__(self, backend: Backend): - self.backend = backend + Handles backend initialization, process synchronization, concurrent request + processing with semaphore limiting, and graceful shutdown with task cleanup. - @property - def description(self) -> GenerativeRequestsWorkerDescription: - """ - Get the description of the worker. - :return: The description of the worker. + :raises RuntimeError: If backend initialization or startup synchronization + fails. + :raises asyncio.CancelledError: If shutdown is requested during processing. + :raises NotImplementedError: If multi-turn requests are encountered. """ - return GenerativeRequestsWorkerDescription( - backend_type=self.backend.type_, - backend_target=self.backend.target, - backend_model=self.backend.model or "None", - backend_info=self.backend.info, + try: + await self._initialize_requests_processing() + await self._start_ready_requests_processing() + await self._loop_requests_processing() + except asyncio.CancelledError: + await self._shutdown_requests_processing() + + raise + + async def _initialize_requests_processing(self): + # Ensure backend is ready on this worker + await self.backend.process_startup() + await self.backend.validate() + + # Setup local queues + self.pending_requests_queue = culsans.Queue( + maxsize=self.max_requests_queue_buffer ) - - async def prepare_multiprocessing(self): - """ - Prepare the worker for multiprocessing. - This is useful for workers that have instance state that can not - be shared across processes and should be cleared out and re-initialized - for each new process. - """ - await self.backend.prepare_multiprocessing() - - def process_loop_synchronous( - self, - requests_queue: multiprocessing.Queue, - results_queue: multiprocessing.Queue, - process_id: int, - ): - asyncio.run(self.backend.validate()) - super().process_loop_synchronous( - requests_queue=requests_queue, - results_queue=results_queue, - process_id=process_id, + self.pending_updates_queue = culsans.Queue() + self.requests_canceled = ThreadingEvent() + self.pull_requests_stopped = ThreadingEvent() + + # Start background tasks for queue management + self.pull_task = asyncio.create_task( + synchronous_to_exitable_async( + self._pull_requests_generator(), + poll_interval=0, # no delays on thread for checking queue + ) ) - - def process_loop_asynchronous( - self, - requests_queue: multiprocessing.Queue, - results_queue: multiprocessing.Queue, - max_concurrency: int, - process_id: int, - ): - asyncio.run(self.backend.validate()) - super().process_loop_asynchronous( - requests_queue=requests_queue, - results_queue=results_queue, - max_concurrency=max_concurrency, - process_id=process_id, + self.push_task = asyncio.create_task( + synchronous_to_exitable_async( + self._push_updates_generator(), + poll_interval=0, # no delays on thread for checking queue + ) ) - async def resolve( - self, - request: GenerationRequest, - timeout_time: float, - ) -> tuple[ResolveStatus, ResponseSummary]: - """ - Resolve a request by sending it to the backend and handling the response. - This method sends the request to the backend, waits for a response, - and handles any errors that may occur during the process. - - :param request: The request to resolve. - :param timeout_time: The time to wait for a response before timing out. - If timeout_time is math.inf, the request will not timeout. - :return: A ResponseSummary object containing the response from the backend. - If an error occurs, the ResponseSummary will contain the error message. - """ - resolve_start_time = time.time() - response = None - error: Optional[str] = None - status = ResolveStatus( - requested=False, - completed=False, - errored=False, - canceled=False, - request_start=-1, - request_end=-1, + async def _start_ready_requests_processing(self): + # Wait for all processes to be ready + barrier_exit_reason, _ = await synchronous_to_exitable_async( + synchronous=None, + exit_barrier=self.startup_barrier, + poll_interval=self.poll_intervals, ) - try: - if timeout_time < time.time(): - raise asyncio.TimeoutError( - "The timeout time has already passed." - ) # exit early - - status.requested = True - request_func, request_kwargs = self._create_request_func_kwargs(request) - - async def _runner(): - # wrap function so we can enforce timeout and - # still return the latest state from the backend - async for resp in request_func(**request_kwargs): # type: ignore[operator] - nonlocal response - response = resp - - await asyncio.wait_for( - _runner(), - timeout=timeout_time - time.time() if timeout_time < math.inf else None, + if barrier_exit_reason not in ["barrier", "canceled"]: + raise RuntimeError( + f"Worker process {self.local_rank} failed to synchronize at " + f"startup: {barrier_exit_reason}" ) - if not response: - raise ValueError( - f"No response received for request: {request} " - f"and backend: {self.backend}" - ) - if not isinstance(response, ResponseSummary): - raise ValueError( - f"Received no ResponseSummary for request: {request} " - f"and backend: {self.backend}, received: {response}" - ) + self.startup_completed = True - status.completed = True - except asyncio.TimeoutError: - error = "TimeoutError: The request timed out before completing." - status.errored = True - status.canceled = True - except Exception as exc: # noqa: BLE001 - error = str(exc) - status.errored = True - - return self._handle_response( - status=status, - request=request, - response=response, - error=error, - resolve_start_time=resolve_start_time, - ) + async def _loop_requests_processing(self): + async_semaphore = asyncio.Semaphore(self.async_limit) + pending_tasks = set() - def _create_request_func_kwargs( - self, - request: GenerationRequest, - ) -> tuple[ - AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None], - dict[str, Any], - ]: - request_func: AsyncGenerator[ - Union[StreamingTextResponse, ResponseSummary], None - ] - request_kwargs: dict[str, Any] - - if request.request_type == "text_completions": - request_func = self.backend.text_completions # type: ignore[assignment] - request_kwargs = { - "prompt": request.content, - "request_id": request.request_id, - "prompt_token_count": request.stats.get("prompt_tokens", None), - "output_token_count": request.constraints.get("output_tokens", None), - **request.params, - } - elif request.request_type == "chat_completions": - request_func = self.backend.chat_completions # type: ignore[assignment] - request_kwargs = { - "content": request.content, - "request_id": request.request_id, - "prompt_token_count": request.stats.get("prompt_tokens", None), - "output_token_count": request.constraints.get("output_tokens", None), - **request.params, - } - else: - raise ValueError( - f"Invalid request type: {request.request_type} for {request}" - ) + def _task_done(task): + pending_tasks.discard(task) + async_semaphore.release() - return request_func, request_kwargs + if not task.cancelled() and (exception := task.exception()): + raise exception - def _handle_response( - self, - status: ResolveStatus, - request: GenerationRequest, - response: Any, - error: Optional[str], - resolve_start_time: float, - ) -> tuple[ResolveStatus, ResponseSummary]: - if response is None or not isinstance( - response, (ResponseSummary, StreamingTextResponse) - ): - # nothing received or invalid response, fill in defaults for error - if response: - error = str( - ValueError( - f"Invalid response: {type(response)} for request: {request}; " - ) - ) + (error or "") - - response = ResponseSummary( - value="", - request_args=RequestArgs( - target=self.backend.target, - headers={}, - params={}, - payload={}, - ), - start_time=resolve_start_time, - end_time=status.request_end, - first_iter_time=None, - last_iter_time=None, - request_id=request.request_id, - error=error or "Unknown error", - ) - elif isinstance(response, StreamingTextResponse): - response = ResponseSummary( - value=response.value, - request_args=RequestArgs( - target=self.backend.target, - headers={}, - params={}, - payload={}, - ), - start_time=response.start_time, - end_time=time.time(), - first_iter_time=response.first_iter_time, - last_iter_time=response.time if response.iter_count > 0 else None, - request_prompt_tokens=request.stats.get("prompt_tokens", None), - request_output_tokens=request.constraints.get("output_tokens", None), - response_prompt_tokens=None, - response_output_tokens=response.iter_count, - request_id=request.request_id, - error=error or "Unknown error", + try: + # Main loop; loop until canceled + while True: + await async_semaphore.acquire() + request_task = asyncio.create_task(self._process_next_request()) + pending_tasks.add(request_task) + request_task.add_done_callback(_task_done) + await asyncio.sleep(0) + except asyncio.CancelledError: + # Shut down requests queuing + self.requests_canceled.set() + + # Cancel pending requests + if pending_tasks: + for task in list(pending_tasks): + task.cancel() + await asyncio.gather(*pending_tasks, return_exceptions=True) + raise + + async def _shutdown_requests_processing(self): + if self.requests_canceled is not None: + # Queues have been constructed, cancel pending and ensure updates + self.requests_canceled.set() + await self._cancel_pending_requests() + await self.pending_updates_queue.async_join() + await self.pending_requests_queue.aclose() + await self.pending_updates_queue.aclose() + + # Cancel background tasks + tasks = [] + if self.push_task is not None and not self.push_task.done(): + self.push_task.cancel() + tasks.append(self.push_task) + if self.pull_task is not None and not self.pull_task.done(): + self.pull_task.cancel() + tasks.append(self.pull_task) + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + # Shut down backend + await self.backend.process_shutdown() + + # Reset state + self.pending_requests_queue = None + self.pending_updates_queue = None + self.pull_task = None + self.push_task = None + self.requests_canceled = None + + async def _process_next_request(self): + request: RequestT | MultiTurnRequestT[RequestT] | None = None + request_info: ScheduledRequestInfo[MeasuredRequestTimingsT] | None = None + response: ResponseT | None = None + + try: + # get next request to send + request, request_info = await self.pending_requests_queue.async_get() + current_time = time.time() + request_info.scheduler_timings.dequeued = current_time + await self._handle_request_update( + new_status="pending", + response=response, + request=request, + request_info=request_info, ) - response.error = error - status.request_start = response.start_time - status.request_end = response.end_time + if isinstance(request, (list, tuple)): + raise NotImplementedError("Multi-turn requests are not yet supported") + + # Calculate when to start processing request + timings_offset = self.request_timings.next_offset() + target_start = request_info.scheduler_start_time + timings_offset + request_info.scheduler_timings.targeted_start = target_start + + if target_start > current_time: + await asyncio.sleep(target_start - current_time) + request_info.scheduler_timings.scheduled_at = target_start + else: + request_info.scheduler_timings.scheduled_at = current_time + + # Process the request + request_info.scheduler_timings.resolve_start = time.time() + await self._handle_request_update( + new_status="in_progress", + response=response, + request=request, + request_info=request_info, + ) + async for resp in self.backend.resolve(request, request_info, None): + response = resp + + # Complete + request_info.scheduler_timings.resolve_end = time.time() + await self._handle_request_update( + new_status="completed", + response=response, + request=request, + request_info=request_info, + ) + except asyncio.CancelledError: + # Handle cancellation + if request is not None and request_info is not None: + request_info.error = "Request was cancelled" + request_info.scheduler_timings.resolve_end = time.time() + await self._handle_request_update( + new_status="cancelled", + response=response, + request=request, + request_info=request_info, + ) + raise + except Exception as exc: # noqa: BLE001 + if request is not None and request_info is not None: + request_info.error = str(exc) + request_info.scheduler_timings.resolve_end = time.time() + await self._handle_request_update( + new_status="errored", + response=response, + request=request, + request_info=request_info, + ) - return status, response + async def _handle_request_update( + self, + new_status: Literal[ + "pending", "in_progress", "completed", "errored", "cancelled" + ], + response: ResponseT | None, + request: RequestT | MultiTurnRequestT[RequestT], + request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + ): + status_orders = { + "queued": -2, # does not send event + "pending": -1, # does not send event + "in_progress": 1, + "completed": 2, + "errored": 2, + "cancelled": 2, + } + prev_status = request_info.status + try: + if ( + status_orders[new_status] >= status_orders["in_progress"] + and status_orders[prev_status] < status_orders["in_progress"] + ): + # Haven't sent start update yet + request_info.status = "in_progress" + await self.pending_updates_queue.async_put( + (None, request, request_info.model_copy()) + ) + prev_status = "in_progress" + + if ( + status_orders[new_status] > status_orders["in_progress"] + and status_orders[new_status] > status_orders[prev_status] + ): + # Haven't sent resolved update yet + request_info.status = new_status + await self.pending_updates_queue.async_put( + (response, request, request_info.model_copy()) + ) + prev_status = new_status + + # Notify instance states + self.request_timings.request_completed(request_info) + self.pending_requests_queue.task_done() + except Exception as exc: + # Reset status to last one that succeeded or started function with + # Calling logic can retry after handling error, if possible + request_info.status = prev_status + raise exc + + async def _cancel_pending_requests(self): + while True: + try: + request, request_info = await asyncio.wait_for( + self.pending_requests_queue.async_get(), timeout=self.poll_intervals + ) + request_info.error = "Request was cancelled" + request_info.scheduler_timings.resolve_end = time.time() + await self._handle_request_update( + new_status="cancelled", + response=None, + request=request, + request_info=request_info, + ) + except (culsans.QueueEmpty, asyncio.TimeoutError): + if self.pull_requests_stopped.is_set(): + # No more requests will be put on the Queue + break + + def _pull_requests_generator(self) -> Generator: + last_check = time.time() + + while True: + if self.requests_canceled.is_set(): + break + + try: + message = self.requests_queue.get(timeout=self.poll_intervals) + request_tuple = MsgpackEncoding.decode(message) + self.pending_requests_queue.sync_put(request_tuple) + except QueueEmpty: + pass # No update available, continue polling + except culsans.QueueShutDown: + break + except Exception: # noqa: BLE001, S110 + pass + + if time.time() - last_check > self.poll_intervals: + # Yield to allow cancel/error/stop checks in wrapper + last_check = time.time() + yield None + + self.pull_requests_stopped.set() + + def _push_updates_generator(self) -> Generator: + last_check = time.time() + + while True: + try: + update_tuple = self.pending_updates_queue.sync_get( + timeout=self.poll_intervals + ) + message = MsgpackEncoding.encode(update_tuple) + self.updates_queue.put(message) + self.pending_updates_queue.task_done() + except culsans.QueueEmpty: + pass # No update available, continue polling + except culsans.QueueShutDown: + break + except Exception: # noqa: BLE001, S110 + pass + + if time.time() - last_check > self.poll_intervals: + # Yield to allow cancel/error/stop checks in wrapper + last_check = time.time() + yield None diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py new file mode 100644 index 00000000..d4a3fc2d --- /dev/null +++ b/src/guidellm/scheduler/worker_group.py @@ -0,0 +1,616 @@ +""" +Multi-process worker group orchestration for distributed request scheduling. + +Provides infrastructure for coordinating worker processes with shared state +management, inter-process communication, and lifecycle coordination. + +Classes: + WorkerProcessGroup: Orchestrates multiple worker processes for distributed + request processing with centralized coordination. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import math +import queue +import threading +import time +from asyncio import Task +from collections.abc import AsyncIterator, Iterable, Iterator +from multiprocessing import Queue, get_context +from multiprocessing.process import BaseProcess +from multiprocessing.synchronize import Barrier, Event +from threading import Event as ThreadingEvent +from typing import Generic + +import culsans + +from guidellm.config import settings +from guidellm.scheduler.constraints import Constraint +from guidellm.scheduler.objects import ( + BackendInterface, + MeasuredRequestTimingsT, + MultiTurnRequestT, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerState, +) +from guidellm.scheduler.strategy import SchedulingStrategy +from guidellm.scheduler.worker import WorkerProcess +from guidellm.utils import MsgpackEncoding, synchronous_to_exitable_async + +__all__ = ["WorkerProcessGroup"] + + +class WorkerProcessGroup(Generic[RequestT, MeasuredRequestTimingsT, ResponseT]): + """ + Orchestrates multiple worker processes for distributed request processing. + + Manages process lifecycle, request distribution, response collection, and state + synchronization across workers. Handles dynamic scaling, load balancing, and + constraint evaluation with graceful shutdown coordination. + """ + + def __init__( + self, + requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + backend: BackendInterface[RequestT, MeasuredRequestTimingsT, ResponseT], + strategy: SchedulingStrategy, + constraints: dict[str, Constraint], + infinite_requests: bool | None = None, + ): + self.requests = requests + self.backend = backend + self.strategy = strategy + self.constraints = constraints + self.infinite_requests = infinite_requests + + # Multiprocessing contexts and primitives, created in create_processes + self.mp_context = None + self.processes: list[BaseProcess] = None + self.startup_barrier: Barrier = None + self.shutdown_event: Event = None + self.error_event: Event = None + self.requests_queue: Queue[ + tuple[ + RequestT | MultiTurnRequestT[RequestT], + ScheduledRequestInfo[MeasuredRequestTimingsT], + ] + ] = None + self.updates_queue: Queue[ + tuple[ + ResponseT | None, + RequestT, + ScheduledRequestInfo[MeasuredRequestTimingsT], + ] + ] = None + + # Local process async/threading bridges + signals + self.pending_updates_queue: culsans.Queue[ + tuple[ + ResponseT | None, + RequestT | MultiTurnRequestT[RequestT], + ScheduledRequestInfo[MeasuredRequestTimingsT], + ] + ] = None + self.pending_requests_complete: ThreadingEvent = None + self.pending_updates_complete: ThreadingEvent = None + self.populate_requests_task: Task = None + self.populate_updates_task: Task = None + + # Scheduler state + self.state_update_lock: threading.Lock = None + self.scheduler_state: SchedulerState = None + + async def create_processes(self): + """ + Initialize and start the worker process group. + + Sets up multiprocessing infrastructure and worker processes based on + strategy constraints, backend capabilities, and system configuration. + + :param backend: Backend instance for processing requests. + :param requests: Iterable of requests to process. + :param strategy: Scheduling strategy configuration. + :param constraints: Dictionary of named constraints for controlling execution. + :raises RuntimeError: If process initialization or startup fails. + """ + # Processes limits and params + num_processes = int( + min( + self.strategy.processes_limit or math.inf, + self.backend.processes_limit or math.inf, + settings.max_worker_processes, + ) + ) + if num_processes <= 0: + raise RuntimeError("num_processes resolved to 0; increase limits/config") + + max_conc = int( + min( + self.strategy.requests_limit or math.inf, + self.backend.requests_limit or math.inf, + settings.max_concurrency, + ) + ) + if max_conc <= 0: + raise RuntimeError("max_concurrency resolved to 0; increase limits/config") + + per_proc_max_conc = math.ceil(max_conc / num_processes) + per_proc_max_queue = min(2, per_proc_max_conc) + max_queued_requests = ( # Add queue buffer for each process + max_conc + (num_processes * per_proc_max_queue) + ) + + # Initialize multiprocessing components + self.mp_context = get_context("fork") + self.startup_barrier = self.mp_context.Barrier(num_processes + 1) + self.shutdown_event = self.mp_context.Event() + self.error_event = self.mp_context.Event() + self.requests_queue = self.mp_context.Queue(maxsize=max_queued_requests) + self.updates_queue = self.mp_context.Queue() + + # Initialize worker processes + self.processes = [] + for rank in range(num_processes): + async_limit = per_proc_max_conc + ( + 1 if rank < (max_conc % num_processes) else 0 + ) + worker = WorkerProcess[RequestT, MeasuredRequestTimingsT, ResponseT]( + local_rank=rank, + local_world_size=num_processes, + async_limit=async_limit, + startup_barrier=self.startup_barrier, + shutdown_event=self.shutdown_event, + error_event=self.error_event, + requests_queue=self.requests_queue, + updates_queue=self.updates_queue, + backend=self.backend, + request_timings=self.strategy.create_request_timings( + local_rank=rank, + local_world_size=num_processes, + local_max_concurrency=async_limit, + ), + poll_intervals=settings.scheduler_poll_interval, + ) + proc = self.mp_context.Process(target=worker.run, daemon=False) + proc.start() + self.processes.append(proc) + + reason, _ = await synchronous_to_exitable_async( + synchronous=None, + exit_events={ + "error_event": self.error_event, + "shutdown_event": self.shutdown_event, + }, + exit_barrier=self.startup_barrier, + poll_interval=settings.scheduler_poll_interval, + ) + if reason != "barrier": + raise RuntimeError( + f"Worker process group startup failed with exit reason: {reason}" + ) + + async def start(self, start_time: float): + """ + Begin request processing at the specified start time. + + Initializes scheduler state and background tasks, then waits until the + specified start time before beginning operations. + + :param start_time: Unix timestamp when processing should begin. + :raises RuntimeError: If workers encounter errors during startup. + """ + if self.processes is None: + raise RuntimeError("create_processes() must be called before start()") + + self.state_update_lock = threading.Lock() + self.scheduler_state = SchedulerState( + node_id=0, # Process group node identifier + num_processes=len(self.processes), + start_time=start_time, + ) + self.pending_updates_queue = culsans.Queue() + self.pending_requests_complete = ThreadingEvent() + self.pending_updates_complete = ThreadingEvent() + + self.populate_requests_task = asyncio.create_task( + synchronous_to_exitable_async( + self._populate_requests_generator(start_time), + exit_events={"error_event": self.error_event}, + poll_interval=0.0, + ) + ) + self.populate_updates_task = asyncio.create_task( + synchronous_to_exitable_async( + self._populate_updates_generator(), + exit_events={"error_event": self.error_event}, + poll_interval=0.0, + ) + ) + + await asyncio.sleep(max(0, start_time - time.time())) + if self.error_event.is_set(): + raise RuntimeError( + "error_event is set in WorkerProcessGroup, " + "indicating an error occurred in one of the worker processes." + ) + + async def request_updates( + self, + ) -> AsyncIterator[ + tuple[ + ResponseT | None, + RequestT, + ScheduledRequestInfo[MeasuredRequestTimingsT], + SchedulerState, + ] + ]: + """ + Yield request processing updates as they become available. + + Returns an async iterator of request updates including response, request, + scheduling metadata, and scheduler state. Updates occur on request queued, + processing start, and completion. + + :return: Async iterator yielding (response, request, request_info, state) + tuples; response is None until processing is complete. + :raises RuntimeError: If workers encounter unrecoverable errors. + """ + last_check_time = -1 * math.inf + + while ( + not self.pending_updates_complete.is_set() + or not self.pending_updates_queue.empty() + ): + try: + ( + response, + request, + request_info, + scheduler_state, + ) = await asyncio.wait_for( + self.pending_updates_queue.async_get(), + timeout=settings.scheduler_poll_interval, + ) + + yield response, request, request_info, scheduler_state + except asyncio.TimeoutError: + pass + + if (time.time() - last_check_time) >= settings.scheduler_poll_interval: + if self.error_event.is_set(): + raise RuntimeError( + "error_event is set in WorkerProcessGroup, " + "indicating an error occurred in one of the worker processes." + ) + last_check_time = time.time() + + async def shutdown(self) -> list[Exception]: # noqa: C901 + """ + Gracefully shut down the worker process group and clean up resources. + + Performs safe shutdown of worker processes, background tasks, and + multiprocessing resources. + + :return: List of exceptions encountered during shutdown; empty if no errors. + """ + exceptions: list[Exception] = [] + + if self.shutdown_event is not None: + self.shutdown_event.set() + self.shutdown_event = None + + cancel_tasks = [ + task + for task in (self.populate_requests_task, self.populate_updates_task) + if task and not task.done() + ] + for task in cancel_tasks: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + if cancel_tasks: + try: + await asyncio.gather(*cancel_tasks, return_exceptions=True) + except Exception as err: # noqa: BLE001 + exceptions.append(err) + self.populate_requests_task = None + self.populate_updates_task = None + + if self.processes: + for proc in self.processes: + await asyncio.to_thread(proc.join, 5) + if proc.exitcode not in (0, None): + exceptions.append( + RuntimeError( + f"Worker {proc.pid} exited with code {proc.exitcode}" + ) + ) + self.processes = None + self.mp_context = None + + self.startup_barrier = None + self.error_event = None + self.requests_queue = None + self.updates_queue = None + self.pending_updates_queue = None + + return exceptions + + def _update_state( + self, info: ScheduledRequestInfo[MeasuredRequestTimingsT] + ) -> tuple[SchedulerState, bool, bool]: + if not self.scheduler_state or not self.state_update_lock: + raise RuntimeError("workerProcessGroup not started") + + with self.state_update_lock: + state = self.scheduler_state + if info.status == "queued": + state.created_requests += 1 + state.queued_requests += 1 + elif info.status == "in_progress": + state.queued_requests -= 1 + state.processing_requests += 1 + elif info.status in ("completed", "errored", "cancelled"): + state.processing_requests -= 1 + state.processed_requests += 1 + state.successful_requests += 1 if info.status == "completed" else 0 + state.errored_requests += 1 if info.status == "errored" else 0 + state.cancelled_requests += 1 if info.status == "cancelled" else 0 + else: + raise ValueError( + f"Unknown request status: {info.status}. " + "Supported statuses are: queued, pending, in_progress, " + "completed, errored, cancelled." + ) + + state.end_time = time.time() # Always update for last time update received + actions = { + name: const(state, info) for name, const in self.constraints.items() + } + state.scheduler_constraints = actions + + if state.end_queuing_time is None and ( + stop_queueing_actions := { + key: action + for key, action in actions.items() + if action.request_queuing == "stop" + } + ): + # Queuing not stopped and actions returned to stop it + state.end_queuing_constraints.update(stop_queueing_actions) + state.end_queuing_time = time.time() + + if state.end_processing_time is None and ( + stop_processing_actions := { + key: action + for key, action in actions.items() + if action.request_processing in ("stop_local", "stop_all") + } + ): + # Processing not stopped and actions returned to stop it + state.end_processing_constraints.update(stop_processing_actions) + state.end_processing_time = time.time() + + state_copy: SchedulerState = state.model_copy() + + return ( + state_copy, + state_copy.end_queuing_time is None, + state_copy.end_processing_time is None, + ) + + def _populate_requests_generator(self, scheduler_start_time: float): + last_check_time: float = time.time() + continue_requests: bool = True + message: bytes | None = None + request_iter: Iterator[RequestT] | None = ( + self._populate_requests_create_iterator(first=True) + ) + + try: + while continue_requests or message is not None: + if request_iter is None: + request_iter = self._populate_requests_create_iterator(first=False) + + if request_iter is None and continue_requests: + # Out of requests so stop + continue_requests = False + # Update scheduler state that requests were exhausted + with self.state_update_lock: + self.scheduler_state.end_queuing_constraints["request_iter"] = { + "status": "exhausted", + "time": time.time(), + } + self.scheduler_state.end_queuing_time = time.time() + + if continue_requests and message is None: + message, continue_requests = self._populate_requests_next_message( + request_iter, scheduler_start_time + ) + if message is None: + # No message returned because request_iter is exhausted + request_iter = None + + if message is not None: + with contextlib.suppress(queue.Full): + self.requests_queue.put( + message[0], timeout=settings.scheduler_poll_interval + ) + self.pending_updates_queue.sync_put(message[1]) + message = None + + if (time.time() - last_check_time) >= settings.scheduler_poll_interval: + last_check_time = time.time() + continue_requests = ( + continue_requests and not self.shutdown_event.is_set() + ) + yield None # Yield to check for error in wrapper to stop + except Exception as err: # noqa: BLE001 + self.error_event.set() + raise err + finally: + self.pending_requests_complete.set() + + def _populate_requests_create_iterator( + self, first: bool = False + ) -> Iterator[RequestT] | None: + if first: + # First invocation, get a new iterator if not already one + return ( + iter(self.requests) + if not isinstance(self.requests, Iterator) + else self.requests + ) + + if self.infinite_requests is True and isinstance(self.requests, Iterator): + # Out of requests and infinite set to True, but request_iter is Iterator + # Cannot create new, raise RuntimeError + raise RuntimeError( + f"Requests iterator {self.requests} exhausted and " + "infinite_requests is set to True" + ) + + if self.infinite_requests is not False and isinstance(self.requests, Iterable): + # Out of requests and infinite set to True or set to default + # Create new iterator out of the Iterable + return iter(self.requests) + + # Either infinite is False for Iterable or Iterator + # or infinite is None (default) for Iterator + # So, return None to stop + return None + + def _populate_requests_next_message( + self, request_iter: Iterator[RequestT], scheduler_start_time: float + ) -> tuple[tuple[bytes, bytes] | None, bool]: + try: + request = next(request_iter) + request_info = ScheduledRequestInfo[MeasuredRequestTimingsT]( + request_id=( + request + if isinstance(request, str) + else getattr(request, "id_", getattr(request, "id", id(request))) + ), + status="queued", + scheduler_node_id=-1, + scheduler_process_id=0, + scheduler_start_time=scheduler_start_time, + ) + state, continue_requests, _ = self._update_state(request_info) + + request_msg = MsgpackEncoding.encode((request, request_info)) + update_msg = (None, request, request_info, state) + + return (request_msg, update_msg), continue_requests + except StopIteration: + return None, True + + def _populate_updates_generator(self): + """Generator for populating updates from workers.""" + last_check_time = time.time() + last_state: SchedulerState = None + continue_processing = True + shutdown_set = False + canceled_remaining = False + + try: + while ( + continue_processing + or last_state is None + or (last_state.processed_requests < last_state.created_requests) + ): + next_state, continue_updates = self._populate_updates_process_next() + if next_state is not None: + last_state = next_state + continue_processing = continue_processing and continue_updates + + if not continue_processing and not shutdown_set: + self.shutdown_event.set() + shutdown_set = True + time.sleep( + settings.scheduler_poll_interval + ) # Ensure shut down propagates + + if not continue_processing and not canceled_remaining: + # We've shut down, no more requests will be added, cancel remaining + next_state = self._populate_updates_cancel_remaining() + if next_state is not None: + last_state = next_state + canceled_remaining = True + + if (time.time() - last_check_time) >= settings.scheduler_poll_interval: + last_check_time = time.time() + if not shutdown_set and self.shutdown_event.is_set(): + shutdown_set = True + continue_processing = False + with self.state_update_lock: + self.scheduler_state.end_queuing_constraints[ + "shutdown_event" + ] = { + "status": "set", + "time": time.time(), + } + self.scheduler_state.end_processing_time = time.time() + + yield None # Yield to check for error in wrapper to stop + except Exception as err: # noqa: BLE001 + self.error_event.set() + raise err + finally: + self.pending_updates_complete.set() + + def _populate_updates_process_next( + self, + ) -> tuple[SchedulerState | None, bool]: + try: + message = self.updates_queue.get(timeout=settings.scheduler_poll_interval) + response, request, request_info = MsgpackEncoding.decode(message) + + scheduler_state, _, continue_updates = self._update_state(request_info) + self.pending_updates_queue.sync_put( + (response, request, request_info, scheduler_state) + ) + + return scheduler_state, continue_updates + except queue.Empty: + return None, True + + def _populate_updates_cancel_remaining( + self, + ) -> SchedulerState | None: + last_state = None + + while True: + try: + message = self.requests_queue.get( + timeout=settings.scheduler_poll_interval + ) + request, request_info = MsgpackEncoding.decode(message) + + # Send start first + request_info.status = "in_progress" + scheduler_state, _, _ = self._update_state(request_info) + self.pending_updates_queue.sync_put( + (None, request, request_info.model_copy(), scheduler_state) + ) + + # Send canceled + request_info.status = "cancelled" + request_info.error = "Request was cancelled" + request_info.scheduler_timings.resolve_end = time.time() + scheduler_state, _, _ = self._update_state(request_info) + self.pending_updates_queue.sync_put( + (None, request, request_info, scheduler_state) + ) + + last_state = scheduler_state + except queue.Empty: + if self.pending_requests_complete.is_set(): + # no more requests being pushed to queue, safe to exit + break + + return last_state diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index fb9262c3..0df3ad1b 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -1,5 +1,6 @@ from .colors import Colors from .default_group import DefaultGroupHandler +from .encoding import MsgpackEncoding from .hf_datasets import ( SUPPORTED_TYPES, save_dataset_to_file, @@ -7,29 +8,41 @@ from .hf_transformers import ( check_load_processor, ) +from .mixins import InfoMixin from .random import IntegerRangeSampler +from .registry import RegistryMixin +from .singleton import SingletonMixin, ThreadSafeSingletonMixin from .text import ( EndlessTextCreator, clean_text, filter_text, + format_value_display, is_puncutation, load_text, split_text, split_text_list_by_length, ) +from .threading import synchronous_to_exitable_async __all__ = [ "SUPPORTED_TYPES", "Colors", "DefaultGroupHandler", "EndlessTextCreator", + "InfoMixin", "IntegerRangeSampler", + "MsgpackEncoding", + "RegistryMixin", + "SingletonMixin", + "ThreadSafeSingletonMixin", "check_load_processor", "clean_text", "filter_text", + "format_value_display", "is_puncutation", "load_text", "save_dataset_to_file", "split_text", "split_text_list_by_length", + "synchronous_to_exitable_async", ] diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py new file mode 100644 index 00000000..42a94822 --- /dev/null +++ b/src/guidellm/utils/encoding.py @@ -0,0 +1,141 @@ +""" +MessagePack encoding utilities with Pydantic model support. + +Provides binary serialization and deserialization of Python objects using MessagePack, +with special handling for Pydantic models to preserve type information and generic +parameters for accurate reconstruction. + +Classes: + MsgpackEncoding: MessagePack encoder/decoder with Pydantic support. +""" + +import importlib +from typing import Any + +import msgpack +from pydantic import BaseModel + +__all__ = ["MsgpackEncoding"] + + +class MsgpackEncoding: + """ + MessagePack encoder/decoder with Pydantic model support. + + Provides binary serialization of Python objects with special handling + for Pydantic models to preserve type information and generic parameters. + """ + + PYDANTIC_TAG = "__pydantic__" + PYDANTIC_DATA = "data" + PYDANTIC_ARGS = "args" + + @classmethod + def encode(cls, obj: Any) -> bytes: + """ + Encode a Python object to MessagePack binary format. + + :param obj: The object to encode (supports Pydantic models, dicts, lists, etc.). + :return: Binary MessagePack representation. + """ + return msgpack.packb(cls.to_primitive(obj), use_bin_type=True) + + @classmethod + def decode(cls, data: bytes) -> Any: + """ + Decode MessagePack binary data back to Python objects. + + :param data: Binary MessagePack data to decode. + :return: Reconstructed Python object with original types preserved. + """ + return cls.from_primitive(msgpack.unpackb(data, raw=False)) + + @classmethod + def to_primitive(cls, obj: Any) -> Any: + """ + Convert objects to primitive types for MessagePack serialization. + + Recursively converts complex objects to primitives. Pydantic models are + converted to tagged dictionaries with type metadata for reconstruction. + + :param obj: The object to convert. + :return: Primitive representation suitable for MessagePack. + """ + if isinstance(obj, BaseModel): + model_cls = obj.__class__ + + origin = getattr(model_cls, "__origin__", None) + if origin is None and hasattr(model_cls, "__pydantic_generic_metadata__"): + origin = model_cls.__pydantic_generic_metadata__.get("origin", None) + if origin is None: + origin = model_cls + + args = getattr(model_cls, "__args__", ()) + if not args and hasattr(model_cls, "__pydantic_generic_metadata__"): + args = model_cls.__pydantic_generic_metadata__.get("args", ()) + + encoded = { + cls.PYDANTIC_TAG: f"{origin.__module__}.{origin.__name__}", + cls.PYDANTIC_DATA: obj.model_dump(), + } + + if args: + encoded[cls.PYDANTIC_ARGS] = [ + f"{arg.__module__}.{arg.__name__}" for arg in args + ] + + return encoded + + if isinstance(obj, dict): + return { + cls.to_primitive(key): cls.to_primitive(val) for key, val in obj.items() + } + + if isinstance(obj, list): + return [cls.to_primitive(val) for val in obj] + + if isinstance(obj, tuple): + return tuple(cls.to_primitive(val) for val in obj) + + return obj + + @classmethod + def from_primitive(cls, obj: Any) -> Any: + """ + Reconstruct objects from their primitive MessagePack representation. + + Recursively converts primitives back to original objects. Tagged dictionaries + are restored to Pydantic models with proper types and generic parameters. + + :param obj: The primitive representation to convert. + :return: Reconstructed object with original types. + :raises ImportError: If a Pydantic model's module cannot be imported. + :raises AttributeError: If a class reference cannot be found. + """ + if isinstance(obj, dict) and cls.PYDANTIC_TAG in obj: + origin_path = obj[cls.PYDANTIC_TAG] + module_name, class_name = origin_path.rsplit(".", 1) + origin_cls = getattr(importlib.import_module(module_name), class_name) + + type_args = [] + if cls.PYDANTIC_ARGS in obj: + for arg_path in obj[cls.PYDANTIC_ARGS]: + mod, clazz = arg_path.rsplit(".", 1) + type_args.append(getattr(importlib.import_module(mod), clazz)) + + model_cls = origin_cls[tuple(type_args)] if type_args else origin_cls + + return model_cls.model_validate(obj[cls.PYDANTIC_DATA]) + + if isinstance(obj, dict): + return { + cls.from_primitive(k): cls.from_primitive(v) for k, v in obj.items() + } + + if isinstance(obj, list): + return [cls.from_primitive(v) for v in obj] + + if isinstance(obj, tuple): + return tuple(cls.from_primitive(v) for v in obj) + + return obj diff --git a/src/guidellm/utils/mixins.py b/src/guidellm/utils/mixins.py new file mode 100644 index 00000000..c9aa867e --- /dev/null +++ b/src/guidellm/utils/mixins.py @@ -0,0 +1,59 @@ +""" +Mixin classes for common metadata extraction and object introspection. + +Provides reusable mixins for extracting structured metadata from objects, +enabling consistent information exposure across different class hierarchies. + +Classes: + InfoMixin: Mixin providing standardized metadata extraction capabilities. +""" + +from typing import Any + +__all__ = ["InfoMixin"] + + +class InfoMixin: + """Mixin class providing standardized metadata extraction for introspection.""" + + @classmethod + def extract_from_obj(cls, obj: Any) -> dict[str, Any]: + """ + Extract structured metadata from any object. + + Attempts to use the object's own `info` method or property if available, + otherwise constructs metadata from object attributes and type information. + + :param obj: Object to extract metadata from. + :return: Dictionary containing object metadata including type, class, + module, and public attributes. + """ + if hasattr(obj, "info"): + return obj.info() if callable(obj.info) else obj.info + + return { + "str": str(obj), + "type": type(obj).__name__, + "class": obj.__class__.__name__ if hasattr(obj, "__class__") else None, + "module": obj.__class__.__module__ if hasattr(obj, "__class__") else None, + "attributes": ( + { + key: val + if isinstance(val, (str, int, float, bool, list, dict)) + else str(val) + for key, val in obj.__dict__.items() + if not key.startswith("_") + } + if hasattr(obj, "__dict__") + else {} + ), + } + + @property + def info(self) -> dict[str, Any]: + """ + Return structured metadata about this instance. + + :return: Dictionary containing class name, module, and public attributes. + """ + return self.extract_from_obj(self) diff --git a/src/guidellm/utils/registry.py b/src/guidellm/utils/registry.py new file mode 100644 index 00000000..86342fc4 --- /dev/null +++ b/src/guidellm/utils/registry.py @@ -0,0 +1,184 @@ +""" +Registry system for objects in the GuideLLM toolkit. + +This module provides a flexible object registration and discovery system used +throughout the GuideLLM toolkit. It enables automatic registration of objects +and discovery of implementations through decorators. + +Classes: + RegistryMixin: Base mixin for creating object registries with decorators. +""" + +from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar + +__all__ = ["RegistryMixin"] + + +RegistryObjT = TypeVar("RegistryObjT", bound=Any) + + +class RegistryMixin(Generic[RegistryObjT]): + """ + A mixin class that provides a registration system for the specified object type. + + This mixin allows classes to maintain a registry of objects that can be + dynamically discovered and instantiated. Classes that inherit from this mixin + can use the @register decorator to add objects to the registry. + + The registry is class-specific, meaning each class that inherits from this mixin + will have its own separate registry of implementations. + + Example: + :: + class BaseAlgorithm(RegistryMixin): + pass + + @BaseAlgorithm.register() + class ConcreteAlgorithm(BaseAlgorithm): + pass + + @BaseAlgorithm.register("custom_name") + class AnotherAlgorithm(BaseAlgorithm): + pass + + # Get all registered algorithm implementations + algorithms = BaseAlgorithm.registered_objects() + + :cvar registry: A dictionary mapping object names to objects that have been + registered to the extending subclass through the @subclass.register() decorator + :cvar registry_populated: A flag that tracks whether the registry has been + populated with objects from the specified package(s). + """ + + registry: ClassVar[Optional[dict[str, RegistryObjT]]] = None + registry_populated: ClassVar[bool] = False + + @classmethod + def register( + cls, name: Optional[str] = None + ) -> Callable[[RegistryObjT], RegistryObjT]: + """ + An invoked decorator that registers an object with the registry under + either the provided name or the object name if no name is provided. + + Example: + ```python + @RegistryMixin.register() + class ExampleClass: + ... + + @RegistryMixin.register("custom_name") + class AnotherExampleClass: + ... + ``` + + :param name: Optional name to register the object under. If None, the object + name is used as the registry key. + :return: A decorator function that registers the decorated object. + :raises ValueError: If name is provided but is not a string. + """ + if name is not None and not isinstance(name, str): + raise ValueError( + f"RegistryMixin.register() name must be a string or None. Got {name}." + ) + + return lambda obj: cls.register_decorator(obj, name=name) + + @classmethod + def register_decorator( + cls, obj: RegistryObjT, name: Optional[str] = None + ) -> RegistryObjT: + """ + A non-invoked decorator that registers the object with the registry. + If passed through a lambda, then name can be passed in as well. + Otherwise, the only argument is the decorated object. + + Example: + ```python + @RegistryMixin.register_decorator + class ExampleClass: + ... + ``` + + :param obj: The object to register + :param name: Optional name to register the object under. If None, the object + name is used as the registry key. + :return: The registered object. + :raises TypeError: If the decorator is used incorrectly. + :raises ValueError: If the object is already registered or if name is provided + but is not a string. + """ + + if not name: + name = getattr(obj, "__name__", str(obj)) + elif not isinstance(name, str): + raise ValueError( + "RegistryMixin.register_decorator must be used as a decorator " + "and without invocation. " + f"Got improper name arg {name}." + ) + + if cls.registry is None: + cls.registry = {} + + if name in cls.registry: + raise ValueError( + f"RegistryMixin.register_decorator cannot register an object " + f"{obj} with the name {name} because it is already registered." + ) + + cls.registry[name] = obj + + return obj + + @classmethod + def registered_objects(cls) -> dict[str, RegistryObjT]: + """ + :return: A dictionary mapping names to all registered objects. + """ + if cls.registry is None: + return {} + return dict(cls.registry) + + @classmethod + def get_registered_object(cls, name: str) -> RegistryObjT: + """ + :param name: The name of the registered object. + :return: The registred object + """ + if cls.registry is None or name not in cls.registry: + raise ValueError(f"Object with name {name} is not registered.") + return cls.registry[name] + + @classmethod + def is_registered(cls, name: str) -> bool: + """ + :param name: The name to check for registration. + :return: True if an object is registered with that name, False otherwise. + """ + if cls.registry is None: + return False + return name in cls.registry + + @classmethod + def unregister(cls, name: str) -> bool: + """ + :param name: The name of the object to unregister. + :return: True if the object was successfully unregistered, False if it + wasn't registered. + """ + if cls.registry is None: + return False + + if name in cls.registry: + del cls.registry[name] + return True + return False + + @classmethod + def clear_registry(cls) -> None: + """ + Clear all registered objects from the registry. + """ + if cls.registry is not None: + cls.registry.clear() diff --git a/src/guidellm/utils/singleton.py b/src/guidellm/utils/singleton.py new file mode 100644 index 00000000..48f039cf --- /dev/null +++ b/src/guidellm/utils/singleton.py @@ -0,0 +1,78 @@ +""" +Singleton pattern implementations for ensuring single instance classes. + +Provides singleton mixins for creating classes that maintain a single instance +throughout the application lifecycle, with support for both basic and thread-safe +implementations. + +Classes: + SingletonMixin: Basic singleton implementation using class variables. + ThreadSafeSingletonMixin: Thread-safe singleton using locking mechanisms. +""" + +import threading +from typing import ClassVar + +__all__ = ["SingletonMixin", "ThreadSafeSingletonMixin"] + + +class SingletonMixin: + """ + Basic singleton mixin ensuring single instance per class. + + Implements the singleton pattern using class variables to control instance + creation. Subclasses must call super().__init__() for proper initialization + state management. + """ + + singleton_instance: ClassVar["SingletonMixin"] = None + + def __new__(cls, *args, **kwargs): + """ + Create or return the singleton instance. + + :param args: Positional arguments passed to the constructor. + :param kwargs: Keyword arguments passed to the constructor. + :return: The singleton instance of the class. + """ + if cls.singleton_instance is None: + cls.singleton_instance = super().__new__(cls, *args, **kwargs) + cls.singleton_instance.initialized = False + return cls.singleton_instance + + def __init__(self): + """Initialize the singleton instance exactly once.""" + if self.initialized: + return + self.initialized = True + + +class ThreadSafeSingletonMixin(SingletonMixin): + """ + Thread-safe singleton mixin with locking mechanisms. + + Extends SingletonMixin with thread safety using locks to prevent race + conditions during instance creation in multi-threaded environments. + """ + + singleton_lock: ClassVar[threading.Lock] = threading.Lock() + + def __new__(cls, *args, **kwargs): + """ + Create or return the singleton instance with thread safety. + + :param args: Positional arguments passed to the constructor. + :param kwargs: Keyword arguments passed to the constructor. + :return: The singleton instance of the class. + """ + with cls.singleton_lock: + if cls.singleton_instance is None: + cls.singleton_instance = super().__new__(cls, *args, **kwargs) + cls.singleton_instance.initialized = False + return cls.singleton_instance + + def __init__(self): + """Initialize the singleton instance with thread-local lock.""" + if not self.initialized: + self.thread_lock = threading.Lock() + super().__init__() diff --git a/src/guidellm/utils/text.py b/src/guidellm/utils/text.py index cdefaa14..d2e84778 100644 --- a/src/guidellm/utils/text.py +++ b/src/guidellm/utils/text.py @@ -11,11 +11,13 @@ from guidellm import data as package_data from guidellm.config import settings +from guidellm.utils.colors import Colors __all__ = [ "EndlessTextCreator", "clean_text", "filter_text", + "format_value_display", "is_puncutation", "load_text", "split_text", @@ -25,6 +27,34 @@ MAX_PATH_LENGTH = 4096 +def format_value_display( + value: float, + label: str, + units: str = "", + total_characters: Optional[int] = None, + digits_places: Optional[int] = None, + decimal_places: Optional[int] = None, +) -> str: + if decimal_places is None and digits_places is None: + formatted_number = f"{value}:.0f" + elif digits_places is None: + formatted_number = f"{value:.{decimal_places}f}" + elif decimal_places is None: + formatted_number = f"{value:>{digits_places}f}" + else: + formatted_number = f"{value:>{digits_places}.{decimal_places}f}" + + result = f"{formatted_number}{units} [{Colors.INFO}]{label}[/{Colors.INFO}]" + + if total_characters is not None: + total_characters += len(Colors.INFO) * 2 + 5 + + if len(result) < total_characters: + result = result.rjust(total_characters) + + return result + + def split_text_list_by_length( text_list: list[Any], max_characters: Union[int, list[int]], diff --git a/src/guidellm/utils/threading.py b/src/guidellm/utils/threading.py new file mode 100644 index 00000000..b4802fc7 --- /dev/null +++ b/src/guidellm/utils/threading.py @@ -0,0 +1,147 @@ +import asyncio +import contextlib +import functools +import time +from collections.abc import Generator, Iterable, Iterator +from multiprocessing.synchronize import Barrier as ProcessingBarrier +from multiprocessing.synchronize import Event as ProcessingEvent +from threading import Barrier as ThreadingBarrier +from threading import BrokenBarrierError, Thread +from threading import Event as ThreadingEvent +from typing import Any, Callable, Literal, Optional, Union + +__all__ = ["synchronous_to_exitable_async"] + + +def _start_barrier_monitor_thread( + barrier: Optional[Union[ThreadingBarrier, ProcessingBarrier]], + barrier_event: ThreadingEvent, +): + if barrier is None: + return + + def _watch() -> None: + try: + barrier.wait() + except BrokenBarrierError: + pass + finally: + barrier_event.set() + + Thread(target=_watch, daemon=True).start() + + +def _check_event_set( + events: list[tuple[str, Union[ThreadingEvent, ProcessingEvent]]], +) -> Optional[str]: + for name, event in events: + if event.is_set(): + return name + return None + + +def _run_worker( + events_list: list[tuple[str, Union[ThreadingEvent, ProcessingEvent]]], + exit_barrier: Optional[Union[ThreadingBarrier, ProcessingBarrier]], + synchronous: Optional[Union[Iterator, Iterable, Generator, Callable]], + poll_interval: float, + args: tuple, + kwargs: dict, +) -> tuple[str, Any]: + finish_reason: str = "completed" + last_val: Any = None + + try: + barrier_event = list(filter(lambda x: x[0] == "barrier", events_list))[0][1] + _start_barrier_monitor_thread(exit_barrier, barrier_event) + + if isinstance(synchronous, Iterable): + synchronous = iter(synchronous) + + while True: + if (check_event := _check_event_set(events_list)) is not None: + finish_reason = check_event + break + + if isinstance(synchronous, (Iterator, Generator)): + try: + last_val = next(synchronous) + except StopIteration: + break + elif isinstance(synchronous, Callable): + last_val = synchronous(*args, **kwargs) + break + + time.sleep(poll_interval) + + if ( + finish_reason == "completed" + and (check_event := _check_event_set(events_list)) is not None + ): + # Final check for any exit signals + finish_reason = check_event + except Exception as err: # noqa: BLE001 + finish_reason = "internal_error" + last_val = err + finally: + if exit_barrier is not None: + with contextlib.suppress(BrokenBarrierError, RuntimeError): + exit_barrier.abort() + + return finish_reason, last_val + + +async def synchronous_to_exitable_async( + synchronous: Optional[Union[Iterator, Iterable, Generator, Callable]], + exit_events: Optional[dict[str, Union[ThreadingEvent, ProcessingEvent]]] = None, + exit_barrier: Optional[Union[ThreadingBarrier, ProcessingBarrier]] = None, + poll_interval: float = 0.1, + *args, + **kwargs, +) -> tuple[Union[Literal["completed", "canceled", "barrier"], str], Any]: + """ + Run a sync callable or iterable inside an async context with exit controls. + Supports cooperative termination via exit events and an optional barrier. + + :param synchronous: Callable (invoked once) or iterable/iterator (next()). If + None, only watch exit events (poll mode). + :param exit_events: Optional mapping of name -> Event objects to signal exit. + 'canceled', 'barrier', and 'internal_error' are reserved keywords. + :param exit_barrier: Optional barrier to coordinate shutdown; when it trips or is + aborted, the worker exits with reason "barrier". On exit, this function aborts + the barrier to release any waiters. + :param poll_interval: Sleep duration (seconds) used only in poll mode. + :param args: Positional arguments passed to the callable (if provided). + :param kwargs: Keyword arguments passed to the callable (if provided). + :return: (exit_reason, last_item). exit_reason is "completed", "canceled", + "barrier", or a key from exit_events. last_item is the last yielded value for + an iterator or the return value for a callable. + :raises asyncio.CancelledError: If the async task is canceled. + """ + events_map = exit_events or {} + + canceled_event = ThreadingEvent() + barrier_event = ThreadingEvent() + events_list = [ + ("canceled", canceled_event), + ("barrier", barrier_event), + *list(events_map.items()), + ] + worker = functools.partial( + _run_worker, + events_list, + exit_barrier, + synchronous, + poll_interval, + args, + kwargs, + ) + + try: + return await asyncio.to_thread(worker) + except asyncio.CancelledError: + if exit_barrier is not None: + with contextlib.suppress(BrokenBarrierError, RuntimeError): + exit_barrier.abort() + canceled_event.set() + raise diff --git a/tests/integration/scheduler/test_worker_group.py b/tests/integration/scheduler/test_worker_group.py new file mode 100644 index 00000000..4c39f36d --- /dev/null +++ b/tests/integration/scheduler/test_worker_group.py @@ -0,0 +1,185 @@ +""" +Integration tests for WorkerProcessGroup. + +Tests the complete lifecycle of the worker group with real multiprocessing +worker processes and a mock backend. Validates end-to-end functionality +across different scheduling strategies and constraints. +""" + +from __future__ import annotations + +import asyncio +import random +import time +from collections import defaultdict +from functools import wraps +from typing import Any + +import pytest + +from guidellm.scheduler import ( + AsyncConstantStrategy, + AsyncPoissonStrategy, + BackendInterface, + ConcurrentStrategy, + MaxDurationConstraintInitializer, + MaxErrorRateConstraintInitializer, + MaxErrorsConstraintInitializer, + MaxGlobalErrorRateConstraintInitializer, + MaxNumberConstraintInitializer, + MeasuredRequestTimings, + SynchronousStrategy, + ThroughputStrategy, + WorkerProcessGroup, +) +from guidellm.scheduler.constraints import ConstraintInitializer +from guidellm.scheduler.strategy import SchedulingStrategy + + +def async_timeout(delay): + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +class MockRequestTimings(MeasuredRequestTimings): + """Mock timing implementation for integration testing.""" + + +class MockBackend(BackendInterface): + """Mock backend for integration testing with predictable responses.""" + + def __init__( + self, + processes_limit_value: int | None = None, + requests_limit_value: int | None = None, + error_rate: float = 0.2, + response_delay: float = 0.0, + ): + self._processes_limit = processes_limit_value + self._requests_limit = requests_limit_value + self._error_rate = error_rate + self._response_delay = response_delay + + @property + def processes_limit(self) -> int | None: + return self._processes_limit + + @property + def requests_limit(self) -> int | None: + return self._requests_limit + + def info(self) -> dict[str, Any]: + return {"type": "mock_integration", "delay": self._response_delay} + + async def process_startup(self): + pass + + async def validate(self): + pass + + async def process_shutdown(self): + pass + + async def resolve(self, request, request_info, request_history): + """Return predictable response based on input request.""" + # Simulate processing time + await asyncio.sleep(self._response_delay) + + if ( + self._error_rate + and self._error_rate > 0 + and random.random() < self._error_rate + ): + raise RuntimeError("Mock error for testing") + + yield f"response_for_{request}" + + +class TestWorkerGroup: + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(5) + @pytest.mark.parametrize( + "strategy", + [ + SynchronousStrategy(), + ConcurrentStrategy(streams=10), + ThroughputStrategy(max_concurrency=20), + AsyncConstantStrategy(rate=1000.0), + AsyncPoissonStrategy(rate=1000.0), + ], + ) + @pytest.mark.parametrize( + "constraints_inits", + [ + {"max_num": MaxNumberConstraintInitializer(max_num=100)}, + {"max_duration": MaxDurationConstraintInitializer(max_duration=0.5)}, + {"max_errors": MaxErrorsConstraintInitializer(max_errors=20)}, + {"max_error_rate": MaxErrorRateConstraintInitializer(max_error_rate=0.1)}, + { + "max_global_error_rate": MaxGlobalErrorRateConstraintInitializer( + max_error_rate=0.1 + ) + }, + ], + ) + async def test_lifecycle( + self, + strategy: SchedulingStrategy, + constraints_inits: dict[str, ConstraintInitializer], + ): + """Test comprehensive lifecycle with different strategies and constraints.""" + # Setup + backend = MockBackend(response_delay=0.01, processes_limit_value=1) + requests = [f"request_{ind}" for ind in range(1000)] + group = WorkerProcessGroup( + backend=backend, + requests=requests, + strategy=strategy, + constraints={ + key: init.create_constraint() for key, init in constraints_inits.items() + }, + infinite_requests=False, + ) + + try: + # Create processes + await group.create_processes() + assert group.processes is not None + assert len(group.processes) > 0 + assert group.mp_context is not None + + # Start processing + start_time = time.time() + 0.1 + await group.start(start_time) + actual_start = time.time() + assert actual_start == pytest.approx(start_time) + + # Validate scheduler state + assert group.scheduler_state is not None + assert group.scheduler_state.start_time == start_time + assert group.scheduler_state.num_processes == len(group.processes) + + # Collect all request updates + received_updates = defaultdict(list) + received_responses = [] + + async for ( + response, + request, + request_info, + _state, + ) in group.request_updates(): + received_updates[request].append(request_info.status) + if response is not None: + received_responses.append(response) + finally: + # Clean shutdown + exceptions = await group.shutdown() + assert len(exceptions) == 0, f"Shutdown errors: {exceptions}" diff --git a/tests/unit/backend/test_backend.py b/tests/unit/backend/test_backend.py index 1115d509..c462112f 100644 --- a/tests/unit/backend/test_backend.py +++ b/tests/unit/backend/test_backend.py @@ -1,136 +1,289 @@ -import time +""" +Unit tests for the Backend base class and registry functionality. + +### WRITTEN BY AI ### +""" + +from typing import Any +from unittest.mock import Mock, patch import pytest -from guidellm.backend import ( - Backend, - ResponseSummary, - StreamingTextResponse, +from guidellm.backend.backend import Backend +from guidellm.backend.objects import ( + GenerationRequest, + GenerationRequestTimings, ) +from guidellm.scheduler import ScheduledRequestInfo + + +class TestBackend: + """Test cases for Backend base class.""" + + @pytest.mark.smoke + def test_backend_default_properties(self): + """Test Backend default property implementations. + + ### WRITTEN BY AI ### + """ + + class TestBackend(Backend): + def info(self) -> dict[str, Any]: + return {"test": "info"} + + async def process_startup(self): + pass + + async def process_shutdown(self): + pass + + async def validate(self): + pass + + async def resolve(self, request, request_info, history=None): + yield request, request_info + + async def default_model(self) -> str: + return "test-model" + + backend = TestBackend("openai_http") + assert backend.processes_limit is None + assert backend.requests_limit is None + assert backend.type_ == "openai_http" + + @pytest.mark.sanity + def test_backend_initialization(self): + """Test Backend initialization with type. + + ### WRITTEN BY AI ### + """ + + class TestBackend(Backend): + def info(self) -> dict[str, Any]: + return {"type": self.type_} + + async def process_startup(self): + pass + + async def process_shutdown(self): + pass + + async def validate(self): + pass + + async def resolve(self, request, request_info, history=None): + yield request, request_info + + async def default_model(self) -> str: + return "test-model" + + backend = TestBackend("openai_http") + assert backend.type_ == "openai_http" + assert backend.info() == {"type": "openai_http"} + + @pytest.mark.sanity + def test_backend_registry_mixin(self): + """Test that Backend inherits from RegistryMixin. + + ### WRITTEN BY AI ### + """ + from guidellm.utils.registry import RegistryMixin + + assert issubclass(Backend, RegistryMixin) + assert hasattr(Backend, "register") + assert hasattr(Backend, "get_registered_object") + assert hasattr(Backend, "create") + + @pytest.mark.sanity + def test_backend_create_method(self): + """Test Backend.create class method. + + ### WRITTEN BY AI ### + """ + # Mock a registered backend + mock_backend_class = Mock() + mock_backend_instance = Mock() + mock_backend_class.return_value = mock_backend_instance + + with patch.object( + Backend, "get_registered_object", return_value=mock_backend_class + ): + result = Backend.create("openai_http", test_arg="value") + + Backend.get_registered_object.assert_called_once_with("openai_http") + mock_backend_class.assert_called_once_with(test_arg="value") + assert result == mock_backend_instance + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_backend_interface_compatibility(self): + """Test that Backend is compatible with BackendInterface. + + ### WRITTEN BY AI ### + """ + from guidellm.scheduler import BackendInterface as SchedulerBackendInterface + + assert issubclass(Backend, SchedulerBackendInterface) + + # Test that Backend uses the correct generic types + class TestBackend(Backend): + def info(self) -> dict[str, Any]: + return {} + + async def process_startup(self): + pass + + async def process_shutdown(self): + pass + + async def validate(self): + pass + + async def resolve(self, request, request_info, history=None): + # Verify types match the interface + assert isinstance(request, GenerationRequest) + assert isinstance(request_info, ScheduledRequestInfo) + yield request, request_info + + async def default_model(self) -> str: + return "test-model" + + backend = TestBackend("openai_http") + + # Create test request and info + request = GenerationRequest(content="test") + request_info = ScheduledRequestInfo( + request_id="test-id", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + + # Test resolve method + async for response, info in backend.resolve(request, request_info): + assert response == request + assert info == request_info + + @pytest.mark.regression + def test_backend_register_process(self): + """Test that Backend docstring examples are valid. + + ### WRITTEN BY AI ### + """ + + # Test that the pattern shown in docstring works + class MyBackend(Backend): + def __init__(self, api_key: str): + super().__init__("mock_backend") # type: ignore [arg-type] + self.api_key = api_key + + def info(self) -> dict[str, Any]: + return {"api_key": "***"} + + async def process_startup(self): + self.client = Mock() # Simulate API client + + async def process_shutdown(self): + self.client = None # type: ignore[assignment] + + async def validate(self): + pass + + async def resolve(self, request, request_info, history=None): + yield request, request_info + + async def default_model(self) -> str: + return "my-model" + + # Register the backend + Backend.register("my_backend")(MyBackend) + + # Create instance + backend = Backend.create("my_backend", api_key="secret") + assert isinstance(backend, MyBackend) + assert backend.api_key == "secret" + assert backend.type_ == "mock_backend" + + +class TestBackendRegistry: + """Test cases for Backend registry functionality.""" + + @pytest.mark.smoke + def test_openai_backend_registered(self): + """Test that OpenAI HTTP backend is registered. + + ### WRITTEN BY AI ### + """ + from guidellm.backend.openai import OpenAIHTTPBackend + + # OpenAI backend should be registered + backend = Backend.create("openai_http", target="http://test") + assert isinstance(backend, OpenAIHTTPBackend) + assert backend.type_ == "openai_http" + + @pytest.mark.smoke + def test_backend_create_invalid_type(self): + """Test Backend.create with invalid type. + + ### WRITTEN BY AI ### + """ + with pytest.raises(ValueError): + Backend.create("invalid_type") + + @pytest.mark.sanity + def test_backend_registry_functionality(self): + """Test that backend registry functions work. + + ### WRITTEN BY AI ### + """ + from guidellm.backend.openai import OpenAIHTTPBackend + + # Test that we can get registered backends + openai_class = Backend.get_registered_object("openai_http") + assert openai_class == OpenAIHTTPBackend + + # Test creating with kwargs + backend = Backend.create( + "openai_http", target="http://localhost:8000", model="gpt-4" + ) + assert backend.target == "http://localhost:8000" + assert backend.model == "gpt-4" + + @pytest.mark.regression + def test_backend_registration_decorator(self): + """Test that backend registration decorator works. + + ### WRITTEN BY AI ### + """ + + # Create a test backend class + @Backend.register("test_backend") + class TestBackend(Backend): + def __init__(self, test_param="default"): + super().__init__("test_backend") # type: ignore + self._test_param = test_param + + def info(self): + return {"test_param": self._test_param} + + async def process_startup(self): + pass + + async def process_shutdown(self): + pass + + async def validate(self): + pass + + async def resolve(self, request, request_info, history=None): + yield request, request_info + async def default_model(self): + return "test-model" -@pytest.mark.smoke -def test_backend_registry(): - assert Backend._registry["mock"] is not None # type: ignore - - backend_instance = Backend.create("mock") # type: ignore - assert backend_instance is not None - - with pytest.raises(ValueError): - Backend.register("mock")("backend") # type: ignore - - with pytest.raises(ValueError): - Backend.create("invalid_type") # type: ignore - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_backend_text_completions(mock_backend): - index = 0 - prompt = "Test Prompt" - request_id = "test-request-id" - prompt_token_count = 3 - output_token_count = 10 - final_resp = None - - async for response in mock_backend.text_completions( - prompt=prompt, - request_id=request_id, - prompt_token_count=prompt_token_count, - output_token_count=output_token_count, - ): - assert isinstance(response, (StreamingTextResponse, ResponseSummary)) - - if index == 0: - assert isinstance(response, StreamingTextResponse) - assert response.type_ == "start" - assert response.iter_count == 0 - assert response.delta == "" - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == request_id - elif not isinstance(response, ResponseSummary): - assert response.type_ == "iter" - assert response.iter_count == index - assert len(response.delta) > 0 - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == request_id - else: - assert not final_resp - final_resp = response - assert isinstance(response, ResponseSummary) - assert len(response.value) > 0 - assert response.iterations > 0 - assert response.start_time > 0 - assert response.end_time == pytest.approx(time.time(), abs=0.01) - assert response.request_prompt_tokens == prompt_token_count - assert response.request_output_tokens == output_token_count - assert response.response_prompt_tokens == 3 - assert response.response_output_tokens == 10 - assert response.request_id == request_id - - index += 1 - - assert final_resp - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_backend_chat_completions(mock_backend): - index = 0 - prompt = "Test Prompt" - request_id = "test-request-id" - prompt_token_count = 3 - output_token_count = 10 - final_resp = None - - async for response in mock_backend.chat_completions( - content=prompt, - request_id=request_id, - prompt_token_count=prompt_token_count, - output_token_count=output_token_count, - ): - assert isinstance(response, (StreamingTextResponse, ResponseSummary)) - - if index == 0: - assert isinstance(response, StreamingTextResponse) - assert response.type_ == "start" - assert response.iter_count == 0 - assert response.delta == "" - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == request_id - elif not isinstance(response, ResponseSummary): - assert response.type_ == "iter" - assert response.iter_count == index - assert len(response.delta) > 0 - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == request_id - else: - assert not final_resp - final_resp = response - assert isinstance(response, ResponseSummary) - assert len(response.value) > 0 - assert response.iterations > 0 - assert response.start_time > 0 - assert response.end_time == pytest.approx(time.time(), abs=0.01) - assert response.request_prompt_tokens == prompt_token_count - assert response.request_output_tokens == output_token_count - assert response.response_prompt_tokens == 3 - assert response.response_output_tokens == 10 - assert response.request_id == request_id - - index += 1 - - assert final_resp - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_backend_models(mock_backend): - models = await mock_backend.available_models() - assert models == ["mock-model"] - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_backend_validate(mock_backend): - await mock_backend.validate() + # Test that it's registered and can be created + backend = Backend.create("test_backend", test_param="custom") + assert isinstance(backend, TestBackend) + assert backend.info() == {"test_param": "custom"} diff --git a/tests/unit/backend/test_interface.py b/tests/unit/backend/test_interface.py new file mode 100644 index 00000000..bd1c3af4 --- /dev/null +++ b/tests/unit/backend/test_interface.py @@ -0,0 +1,88 @@ +""" +Unit tests for the BackendInterface abstract class. + +### WRITTEN BY AI ### +""" + +from typing import Any, Optional + +import pytest + +from guidellm.backend.interface import BackendInterface + + +class TestBackendInterface: + """Test cases for BackendInterface abstract class.""" + + @pytest.mark.sanity + def test_backend_interface_properties_are_abstract(self): + """Test that required properties are abstract. + + ### WRITTEN BY AI ### + """ + + # Create a partial implementation to verify abstract nature + class PartialBackend(BackendInterface): + # Missing required properties/methods + pass + + with pytest.raises(TypeError): + PartialBackend() + + @pytest.mark.sanity + def test_minimal_concrete_implementation(self): + """Test that a minimal concrete implementation can be created. + + ### WRITTEN BY AI ### + """ + + class MinimalBackend(BackendInterface): + @property + def processes_limit(self) -> Optional[int]: + return None + + @property + def requests_limit(self) -> Optional[int]: + return None + + def info(self) -> dict[str, Any]: + return {} + + async def process_startup(self) -> None: + pass + + async def validate(self) -> None: + pass + + async def process_shutdown(self) -> None: + pass + + async def resolve(self, request, request_info, history=None): + yield request, request_info + + async def default_model(self) -> str: + return "my-model" + + # Should be able to instantiate + backend = MinimalBackend() + assert backend is not None + assert isinstance(backend, BackendInterface) + + @pytest.mark.regression + def test_backend_interface_method_signatures(self): + """Test that BackendInterface methods have correct signatures. + + ### WRITTEN BY AI ### + """ + import inspect + + # Check resolve method signature + resolve_sig = inspect.signature(BackendInterface.resolve) + params = list(resolve_sig.parameters.keys()) + + expected_params = ["self", "request", "request_info", "history"] + assert params == expected_params + + # Check that history has default value + history_param = resolve_sig.parameters["history"] + assert history_param.default is None diff --git a/tests/unit/backend/test_mock_backend.py b/tests/unit/backend/test_mock_backend.py new file mode 100644 index 00000000..16098a81 --- /dev/null +++ b/tests/unit/backend/test_mock_backend.py @@ -0,0 +1,208 @@ +""" +Unit tests for the MockBackend implementation. + +### WRITTEN BY AI ### +""" + +import pytest + +from guidellm.backend import Backend +from guidellm.backend.objects import GenerationRequest, GenerationRequestTimings +from guidellm.scheduler import ScheduledRequestInfo +from tests.unit.mock_backend import MockBackend + + +class TestMockBackend: + """Test cases for MockBackend.""" + + @pytest.mark.smoke + def test_mock_backend_creation(self): + """Test MockBackend can be created. + + ### WRITTEN BY AI ### + """ + backend = MockBackend() + assert backend.type_ == "mock" + assert backend.model == "mock-model" + assert backend.target == "mock-target" + + @pytest.mark.smoke + def test_mock_backend_registration(self): + """Test MockBackend is properly registered. + + ### WRITTEN BY AI ### + """ + backend = Backend.create("mock") + assert isinstance(backend, MockBackend) + assert backend.type_ == "mock" + + @pytest.mark.smoke + def test_mock_backend_info(self): + """Test MockBackend info method. + + ### WRITTEN BY AI ### + """ + backend = MockBackend(model="test-model", target="test-target") + info = backend.info() + + assert info["type"] == "mock" + assert info["model"] == "test-model" + assert info["target"] == "test-target" + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_mock_backend_lifecycle(self): + """Test MockBackend process lifecycle. + + ### WRITTEN BY AI ### + """ + backend = MockBackend() + + # Test startup + await backend.process_startup() + assert backend._in_process is True + + # Test validation + await backend.validate() # Should not raise + + # Test default model + model = await backend.default_model() + assert model == "mock-model" + + # Test shutdown + await backend.process_shutdown() + assert backend._in_process is False + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_mock_backend_validate_not_started(self): + """Test validation fails when backend not started. + + ### WRITTEN BY AI ### + """ + backend = MockBackend() + + with pytest.raises(RuntimeError, match="Backend not started up"): + await backend.validate() + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_mock_backend_resolve(self): + """Test MockBackend resolve method. + + ### WRITTEN BY AI ### + """ + backend = MockBackend(iter_delay=0.001) # Small delay for testing + await backend.process_startup() + + try: + request = GenerationRequest( + request_id="test-id", + content="Test prompt", + constraints={"output_tokens": 3}, + ) + request_info = ScheduledRequestInfo( + request_id="test-id", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + + responses = [] + async for response, info in backend.resolve(request, request_info): + responses.append((response, info)) + + # Should get multiple responses (one per token + final) + assert len(responses) >= 2 + + # Check final response + final_response = responses[-1][0] + assert final_response.request_id == "test-id" + assert final_response.iterations > 0 + assert len(final_response.value) > 0 + assert final_response.delta is None # Final response has no delta + + # Check timing information + final_info = responses[-1][1] + assert final_info.request_timings.request_start is not None + assert final_info.request_timings.request_end is not None + assert final_info.request_timings.first_iteration is not None + assert final_info.request_timings.last_iteration is not None + + finally: + await backend.process_shutdown() + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_mock_backend_resolve_not_started(self): + """Test resolve fails when backend not started. + + ### WRITTEN BY AI ### + """ + backend = MockBackend() + + request = GenerationRequest(content="test") + request_info = ScheduledRequestInfo( + request_id="test", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + + with pytest.raises(RuntimeError, match="Backend not started up"): + async for _ in backend.resolve(request, request_info): + pass + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_mock_backend_resolve_with_history(self): + """Test resolve method raises error with conversation history. + + ### WRITTEN BY AI ### + """ + backend = MockBackend() + await backend.process_startup() + + try: + request = GenerationRequest(content="test") + request_info = ScheduledRequestInfo( + request_id="test", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + history = [(request, None)] # Mock history + + with pytest.raises( + NotImplementedError, match="Multi-turn requests not supported" + ): + async for _ in backend.resolve(request, request_info, history): + pass + finally: + await backend.process_shutdown() + + @pytest.mark.regression + def test_mock_backend_token_generation(self): + """Test token generation methods. + + ### WRITTEN BY AI ### + """ + # Test with specific token count + tokens = MockBackend._get_tokens(5) + assert len(tokens) == 5 + assert tokens[-1] == "." # Should end with period + + # Test with None (random count) + tokens_random = MockBackend._get_tokens(None) + assert len(tokens_random) >= 8 + assert len(tokens_random) <= 512 + + # Test prompt token estimation + estimated = MockBackend._estimate_prompt_tokens("hello world test") + assert estimated == 3 # Three words diff --git a/tests/unit/backend/test_objects.py b/tests/unit/backend/test_objects.py new file mode 100644 index 00000000..682aaf4c --- /dev/null +++ b/tests/unit/backend/test_objects.py @@ -0,0 +1,351 @@ +""" +Unit tests for GenerationRequest, GenerationResponse, GenerationRequestTimings. +""" + +import uuid + +import pytest + +from guidellm.backend.objects import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, +) + + +class TestGenerationRequest: + """Test cases for GenerationRequest model.""" + + @pytest.mark.smoke + def test_generation_request_creation(self): + """Test basic GenerationRequest creation. + + ### WRITTEN BY AI ### + """ + request = GenerationRequest(content="test content") + + assert request.content == "test content" + assert request.request_type == "text_completions" # default + assert isinstance(request.request_id, str) + assert request.params == {} + assert request.stats == {} + assert request.constraints == {} + + @pytest.mark.smoke + def test_generation_request_with_all_fields(self): + """Test GenerationRequest creation with all fields. + + ### WRITTEN BY AI ### + """ + request_id = "test-123" + content = ["message1", "message2"] + params = {"temperature": 0.7, "max_tokens": 100} + stats = {"prompt_tokens": 50} + constraints = {"output_tokens": 100} + + request = GenerationRequest( + request_id=request_id, + request_type="chat_completions", + content=content, + params=params, + stats=stats, + constraints=constraints, + ) + + assert request.request_id == request_id + assert request.request_type == "chat_completions" + assert request.content == content + assert request.params == params + assert request.stats == stats + assert request.constraints == constraints + + @pytest.mark.sanity + def test_generation_request_auto_id_generation(self): + """Test that request_id is auto-generated if not provided. + + ### WRITTEN BY AI ### + """ + request1 = GenerationRequest(content="test1") + request2 = GenerationRequest(content="test2") + + assert request1.request_id != request2.request_id + assert len(request1.request_id) > 0 + assert len(request2.request_id) > 0 + + # Should be valid UUIDs + uuid.UUID(request1.request_id) + uuid.UUID(request2.request_id) + + @pytest.mark.sanity + def test_generation_request_type_validation(self): + """Test request_type field validation. + + ### WRITTEN BY AI ### + """ + # Valid types + request1 = GenerationRequest(content="test", request_type="text_completions") + request2 = GenerationRequest(content="test", request_type="chat_completions") + + assert request1.request_type == "text_completions" + assert request2.request_type == "chat_completions" + + @pytest.mark.regression + def test_generation_request_content_types(self): + """Test GenerationRequest with different content types. + + ### WRITTEN BY AI ### + """ + # String content + request1 = GenerationRequest(content="string content") + assert request1.content == "string content" + + # List content + request2 = GenerationRequest(content=["item1", "item2"]) + assert request2.content == ["item1", "item2"] + + # Dict content + dict_content = {"role": "user", "content": "test"} + request3 = GenerationRequest(content=dict_content) + assert request3.content == dict_content + + +class TestGenerationResponse: + """Test cases for GenerationResponse model.""" + + @pytest.mark.smoke + def test_generation_response_creation(self): + """Test basic GenerationResponse creation. + + ### WRITTEN BY AI ### + """ + request_id = "test-123" + request_args = {"model": "gpt-3.5-turbo"} + + response = GenerationResponse(request_id=request_id, request_args=request_args) + + assert response.request_id == request_id + assert response.request_args == request_args + assert response.value is None + assert response.delta is None + assert response.iterations == 0 + assert response.request_prompt_tokens is None + assert response.request_output_tokens is None + assert response.response_prompt_tokens is None + assert response.response_output_tokens is None + + @pytest.mark.smoke + def test_generation_response_with_all_fields(self): + """Test GenerationResponse creation with all fields. + + ### WRITTEN BY AI ### + """ + response = GenerationResponse( + request_id="test-123", + request_args={"model": "gpt-4"}, + value="Generated text", + delta="new text", + iterations=5, + request_prompt_tokens=50, + request_output_tokens=100, + response_prompt_tokens=55, + response_output_tokens=95, + ) + + assert response.request_id == "test-123" + assert response.request_args == {"model": "gpt-4"} + assert response.value == "Generated text" + assert response.delta == "new text" + assert response.iterations == 5 + assert response.request_prompt_tokens == 50 + assert response.request_output_tokens == 100 + assert response.response_prompt_tokens == 55 + assert response.response_output_tokens == 95 + + @pytest.mark.sanity + def test_generation_response_prompt_tokens_property(self): + """Test prompt_tokens property logic. + + ### WRITTEN BY AI ### + """ + # When both are available, prefers response_prompt_tokens + response1 = GenerationResponse( + request_id="test", + request_args={}, + request_prompt_tokens=50, + response_prompt_tokens=55, + ) + assert response1.prompt_tokens == 55 + + # When only request_prompt_tokens is available + response2 = GenerationResponse( + request_id="test", request_args={}, request_prompt_tokens=50 + ) + assert response2.prompt_tokens == 50 + + # When only response_prompt_tokens is available + response3 = GenerationResponse( + request_id="test", request_args={}, response_prompt_tokens=55 + ) + assert response3.prompt_tokens == 55 + + # When neither is available + response4 = GenerationResponse(request_id="test", request_args={}) + assert response4.prompt_tokens is None + + @pytest.mark.sanity + def test_generation_response_output_tokens_property(self): + """Test output_tokens property logic. + + ### WRITTEN BY AI ### + """ + # When both are available, prefers response_output_tokens + response1 = GenerationResponse( + request_id="test", + request_args={}, + request_output_tokens=100, + response_output_tokens=95, + ) + assert response1.output_tokens == 95 + + # When only request_output_tokens is available + response2 = GenerationResponse( + request_id="test", request_args={}, request_output_tokens=100 + ) + assert response2.output_tokens == 100 + + # When only response_output_tokens is available + response3 = GenerationResponse( + request_id="test", request_args={}, response_output_tokens=95 + ) + assert response3.output_tokens == 95 + + # When neither is available + response4 = GenerationResponse(request_id="test", request_args={}) + assert response4.output_tokens is None + + @pytest.mark.sanity + def test_generation_response_total_tokens_property(self): + """Test total_tokens property calculation. + + ### WRITTEN BY AI ### + """ + # When both prompt and output tokens are available + response1 = GenerationResponse( + request_id="test", + request_args={}, + response_prompt_tokens=50, + response_output_tokens=100, + ) + assert response1.total_tokens == 150 + + # When one is missing + response2 = GenerationResponse( + request_id="test", request_args={}, response_prompt_tokens=50 + ) + assert response2.total_tokens is None + + # When both are missing + response3 = GenerationResponse(request_id="test", request_args={}) + assert response3.total_tokens is None + + @pytest.mark.regression + def test_generation_response_preferred_token_methods(self): + """Test preferred_*_tokens methods. + + ### WRITTEN BY AI ### + """ + response = GenerationResponse( + request_id="test", + request_args={}, + request_prompt_tokens=50, + request_output_tokens=100, + response_prompt_tokens=55, + response_output_tokens=95, + ) + + # Test preferred_prompt_tokens + assert response.preferred_prompt_tokens("request") == 50 + assert response.preferred_prompt_tokens("response") == 55 + + # Test preferred_output_tokens + assert response.preferred_output_tokens("request") == 100 + assert response.preferred_output_tokens("response") == 95 + + @pytest.mark.regression + def test_generation_response_preferred_tokens_fallback(self): + """Test preferred_*_tokens methods with fallback logic. + + ### WRITTEN BY AI ### + """ + # Only response tokens available + response1 = GenerationResponse( + request_id="test", + request_args={}, + response_prompt_tokens=55, + response_output_tokens=95, + ) + + assert response1.preferred_prompt_tokens("request") == 55 # Falls back + assert response1.preferred_output_tokens("request") == 95 # Falls back + + # Only request tokens available + response2 = GenerationResponse( + request_id="test", + request_args={}, + request_prompt_tokens=50, + request_output_tokens=100, + ) + + assert response2.preferred_prompt_tokens("response") == 50 # Falls back + assert response2.preferred_output_tokens("response") == 100 # Falls back + + +class TestGenerationRequestTimings: + """Test cases for GenerationRequestTimings model.""" + + @pytest.mark.smoke + def test_generation_request_timings_creation(self): + """Test basic GenerationRequestTimings creation. + + ### WRITTEN BY AI ### + """ + timings = GenerationRequestTimings() + + assert timings.first_iteration is None + assert timings.last_iteration is None + + @pytest.mark.smoke + def test_generation_request_timings_with_fields(self): + """Test GenerationRequestTimings creation with fields. + + ### WRITTEN BY AI ### + """ + first_time = 1234567890.0 + last_time = 1234567895.0 + + timings = GenerationRequestTimings( + first_iteration=first_time, last_iteration=last_time + ) + + assert timings.first_iteration == first_time + assert timings.last_iteration == last_time + + @pytest.mark.regression + def test_generation_request_timings_fields_optional(self): + """Test that all timing fields are optional. + + ### WRITTEN BY AI ### + """ + # Should be able to create with no fields + timings1 = GenerationRequestTimings() + assert timings1.first_iteration is None + assert timings1.last_iteration is None + + # Should be able to create with only one field + timings2 = GenerationRequestTimings(first_iteration=123.0) + assert timings2.first_iteration == 123.0 + assert timings2.last_iteration is None + + timings3 = GenerationRequestTimings(last_iteration=456.0) + assert timings3.first_iteration is None + assert timings3.last_iteration == 456.0 diff --git a/tests/unit/backend/test_openai_backend.py b/tests/unit/backend/test_openai_backend.py index 0a4c2c38..54a560d6 100644 --- a/tests/unit/backend/test_openai_backend.py +++ b/tests/unit/backend/test_openai_backend.py @@ -1,207 +1,1151 @@ -import time +""" +Unit tests for OpenAIHTTPBackend implementation. +### WRITTEN BY AI ### +""" + +import base64 +from pathlib import Path +from unittest.mock import AsyncMock, Mock, patch + +import httpx import pytest +from PIL import Image + +from guidellm.backend.objects import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, +) +from guidellm.backend.openai import OpenAIHTTPBackend, UsageStats +from guidellm.scheduler import ScheduledRequestInfo + + +class TestOpenAIHTTPBackend: + """Test cases for OpenAIHTTPBackend.""" + + @pytest.mark.smoke + def test_openai_backend_initialization_minimal(self): + """Test minimal OpenAIHTTPBackend initialization. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://localhost:8000") + + assert backend.target == "http://localhost:8000" + assert backend.model is None + assert backend.timeout == 60.0 + assert backend.http2 is True + assert backend.follow_redirects is True + assert backend.verify is False + assert backend.stream_response is True + assert backend._in_process is False + assert backend._async_client is None + + @pytest.mark.smoke + def test_openai_backend_initialization_full(self): + """Test full OpenAIHTTPBackend initialization. + + ### WRITTEN BY AI ### + """ + extra_query = {"param": "value"} + extra_body = {"setting": "test"} + remove_from_body = ["unwanted"] + headers = {"Custom-Header": "value"} + + backend = OpenAIHTTPBackend( + target="https://localhost:8000/v1", + model="test-model", + api_key="test-key", + organization="test-org", + project="test-project", + timeout=120.0, + http2=False, + follow_redirects=False, + max_output_tokens=1000, + stream_response=False, + extra_query=extra_query, + extra_body=extra_body, + remove_from_body=remove_from_body, + headers=headers, + verify=True, + ) + + assert backend.target == "https://localhost:8000" + assert backend.model == "test-model" + assert backend.timeout == 120.0 + assert backend.http2 is False + assert backend.follow_redirects is False + assert backend.verify is True + assert backend.max_output_tokens == 1000 + assert backend.stream_response is False + assert backend.extra_query == extra_query + assert backend.extra_body == extra_body + assert backend.remove_from_body == remove_from_body + + @pytest.mark.sanity + def test_openai_backend_target_normalization(self): + """Test target URL normalization. + + ### WRITTEN BY AI ### + """ + # Remove trailing slashes and /v1 + backend1 = OpenAIHTTPBackend(target="http://localhost:8000/") + assert backend1.target == "http://localhost:8000" + + backend2 = OpenAIHTTPBackend(target="http://localhost:8000/v1") + assert backend2.target == "http://localhost:8000" + + backend3 = OpenAIHTTPBackend(target="http://localhost:8000/v1/") + assert backend3.target == "http://localhost:8000" + + @pytest.mark.sanity + def test_openai_backend_header_building(self): + """Test header building logic. + + ### WRITTEN BY AI ### + """ + # Test with API key + backend1 = OpenAIHTTPBackend(target="http://test", api_key="test-key") + assert "Authorization" in backend1.headers + assert backend1.headers["Authorization"] == "Bearer test-key" + + # Test with Bearer prefix already + backend2 = OpenAIHTTPBackend(target="http://test", api_key="Bearer test-key") + assert backend2.headers["Authorization"] == "Bearer test-key" + + # Test with organization and project + backend3 = OpenAIHTTPBackend( + target="http://test", organization="test-org", project="test-project" + ) + assert backend3.headers["OpenAI-Organization"] == "test-org" + assert backend3.headers["OpenAI-Project"] == "test-project" + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_openai_backend_info(self): + """Test info method. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend( + target="http://test", model="test-model", timeout=30.0 + ) + + info = backend.info() + + assert info["target"] == "http://test" + assert info["model"] == "test-model" + assert info["timeout"] == 30.0 + assert info["health_path"] == "/health" + assert info["models_path"] == "/v1/models" + assert info["text_completions_path"] == "/v1/completions" + assert info["chat_completions_path"] == "/v1/chat/completions" + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_openai_backend_process_startup(self): + """Test process startup. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + assert not backend._in_process + assert backend._async_client is None + + await backend.process_startup() + + assert backend._in_process + assert backend._async_client is not None + assert isinstance(backend._async_client, httpx.AsyncClient) + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_openai_backend_process_startup_already_started(self): + """Test process startup when already started. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + with pytest.raises(RuntimeError, match="Backend already started up"): + await backend.process_startup() + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_openai_backend_process_shutdown(self): + """Test process shutdown. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + assert backend._in_process + assert backend._async_client is not None + + await backend.process_shutdown() + + assert not backend._in_process + assert backend._async_client is None + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_openai_backend_process_shutdown_not_started(self): + """Test process shutdown when not started. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + with pytest.raises(RuntimeError, match="Backend not started up"): + await backend.process_shutdown() + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_openai_backend_check_in_process(self): + """Test _check_in_process method. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + with pytest.raises(RuntimeError, match="Backend not started up"): + backend._check_in_process() + + await backend.process_startup() + backend._check_in_process() # Should not raise + + await backend.process_shutdown() + with pytest.raises(RuntimeError, match="Backend not started up"): + backend._check_in_process() + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_openai_backend_available_models(self): + """Test available_models method. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + mock_response = Mock() + mock_response.json.return_value = { + "data": [{"id": "test-model1"}, {"id": "test-model2"}] + } + mock_response.raise_for_status = Mock() + + with patch.object(backend._async_client, "get", return_value=mock_response): + models = await backend.available_models() + + assert models == ["test-model1", "test-model2"] + backend._async_client.get.assert_called_once() + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_openai_backend_default_model(self): + """Test default_model method. + + ### WRITTEN BY AI ### + """ + # Test when model is already set + backend1 = OpenAIHTTPBackend(target="http://test", model="test-model") + result1 = await backend1.default_model() + assert result1 == "test-model" + + # Test when not in process + backend2 = OpenAIHTTPBackend(target="http://test") + result2 = await backend2.default_model() + assert result2 is None + + # Test when in process but no model set + backend3 = OpenAIHTTPBackend(target="http://test") + await backend3.process_startup() + + with patch.object(backend3, "available_models", return_value=["test-model2"]): + result3 = await backend3.default_model() + assert result3 == "test-model2" + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_openai_backend_validate_with_model(self): + """Test validate method when model is set. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test", model="test-model") + await backend.process_startup() + + mock_response = Mock() + mock_response.raise_for_status = Mock() + + with patch.object(backend._async_client, "get", return_value=mock_response): + await backend.validate() # Should not raise + + backend._async_client.get.assert_called_once_with( + "http://test/health", headers={"Content-Type": "application/json"} + ) + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_openai_backend_validate_without_model(self): + """Test validate method when no model is set. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + with patch.object(backend, "available_models", return_value=["test-model"]): + await backend.validate() + assert backend.model == "test-model" + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_openai_backend_validate_fallback_to_text_completions(self): + """Test validate method fallback to text completions. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + # Mock health and models endpoints to fail + def mock_get(*args, **kwargs): + raise httpx.HTTPStatusError("Error", request=Mock(), response=Mock()) + + # Mock text_completions to succeed + async def mock_text_completions(*args, **kwargs): + yield "test", UsageStats() + + with ( + patch.object(backend._async_client, "get", side_effect=mock_get), + patch.object( + backend, "text_completions", side_effect=mock_text_completions + ), + ): + await backend.validate() # Should not raise + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_openai_backend_validate_failure(self): + """Test validate method when all validation methods fail. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + def mock_fail(*args, **kwargs): + raise httpx.HTTPStatusError("Error", request=Mock(), response=Mock()) + + def mock_http_error(*args, **kwargs): + raise httpx.HTTPStatusError("Error", request=Mock(), response=Mock()) + + with ( + patch.object(backend._async_client, "get", side_effect=mock_http_error), + patch.object(backend, "text_completions", side_effect=mock_http_error), + pytest.raises(RuntimeError, match="Backend validation failed"), + ): + await backend.validate() + + @pytest.mark.sanity + def test_openai_backend_get_headers(self): + """Test _get_headers method. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend( + target="http://test", api_key="test-key", headers={"Custom": "value"} + ) + + headers = backend._get_headers() + + expected = { + "Content-Type": "application/json", + "Authorization": "Bearer test-key", + "Custom": "value", + } + assert headers == expected + + @pytest.mark.sanity + def test_openai_backend_get_params(self): + """Test _get_params method. + + ### WRITTEN BY AI ### + """ + extra_query = { + "general": "value", + "text_completions": {"specific": "text"}, + "chat_completions": {"specific": "chat"}, + } + + backend = OpenAIHTTPBackend(target="http://test", extra_query=extra_query) + + # Test endpoint-specific params + text_params = backend._get_params("text_completions") + assert text_params == {"specific": "text"} + + # Test fallback to general params + other_params = backend._get_params("other") + assert other_params == extra_query + + @pytest.mark.regression + def test_openai_backend_get_chat_messages_string(self): + """Test _get_chat_messages with string content. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + messages = backend._get_chat_messages("Hello world") + + expected = [{"role": "user", "content": "Hello world"}] + assert messages == expected + + @pytest.mark.regression + def test_openai_backend_get_chat_messages_list(self): + """Test _get_chat_messages with list content. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + content = [ + "Hello", + {"type": "text", "text": "world"}, + {"role": "assistant", "content": "existing message"}, + ] + + messages = backend._get_chat_messages(content) + + expected = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": "world"}, + {"role": "assistant", "content": "existing message"}, + ], + } + ] + assert messages == expected + + @pytest.mark.regression + def test_openai_backend_get_chat_messages_invalid(self): + """Test _get_chat_messages with invalid content. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + with pytest.raises(ValueError, match="Unsupported content type"): + backend._get_chat_messages(123) + + with pytest.raises(ValueError, match="Unsupported content item type"): + backend._get_chat_messages([123]) + + @pytest.mark.regression + def test_openai_backend_get_chat_message_media_item_image(self): + """Test _get_chat_message_media_item with PIL Image. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + # Create a mock PIL Image + mock_image = Mock(spec=Image.Image) + mock_image.tobytes.return_value = b"fake_image_data" + + result = backend._get_chat_message_media_item(mock_image) + + expected_data = base64.b64encode(b"fake_image_data").decode("utf-8") + expected = { + "type": "image", + "image": {"url": f"data:image/jpeg;base64,{expected_data}"}, + } + assert result == expected + + @pytest.mark.regression + def test_openai_backend_get_chat_message_media_item_path(self): + """Test _get_chat_message_media_item with file paths. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + # Test unsupported file type + unsupported_path = Path("test.txt") + with pytest.raises(ValueError, match="Unsupported file type: .txt"): + backend._get_chat_message_media_item(unsupported_path) + + @pytest.mark.regression + def test_openai_backend_get_body(self): + """Test _get_body method. + + ### WRITTEN BY AI ### + """ + extra_body = {"general": "value", "text_completions": {"temperature": 0.5}} + + backend = OpenAIHTTPBackend( + target="http://test", + model="test-model", + max_output_tokens=1000, + extra_body=extra_body, + ) + + request_kwargs = {"temperature": 0.7} + + body = backend._get_body( + endpoint_type="text_completions", + request_kwargs=request_kwargs, + max_output_tokens=500, + prompt="test", + ) + + # Check that max_tokens settings are applied + assert body["temperature"] == 0.7 # request_kwargs override extra_body + assert body["model"] == "test-model" + assert body["max_tokens"] == 500 + assert body["max_completion_tokens"] == 500 + assert body["ignore_eos"] is True + assert body["prompt"] == "test" + # stop: None is filtered out by the None filter + assert "stop" not in body + + @pytest.mark.regression + def test_openai_backend_get_completions_text_content(self): + """Test _get_completions_text_content method. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + # Test with text field + data1 = {"choices": [{"text": "generated text"}]} + result1 = backend._get_completions_text_content(data1) + assert result1 == "generated text" + + # Test with delta content field + data2 = {"choices": [{"delta": {"content": "delta text"}}]} + result2 = backend._get_completions_text_content(data2) + assert result2 == "delta text" + + # Test with no choices + data3: dict[str, list] = {"choices": []} + result3 = backend._get_completions_text_content(data3) + assert result3 is None + + # Test with no choices key + data4: dict[str, str] = {} + result4 = backend._get_completions_text_content(data4) + assert result4 is None + + @pytest.mark.regression + def test_openai_backend_get_completions_usage_stats(self): + """Test _get_completions_usage_stats method. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + # Test with usage data + data1 = {"usage": {"prompt_tokens": 50, "completion_tokens": 100}} + result1 = backend._get_completions_usage_stats(data1) + assert isinstance(result1, UsageStats) + assert result1.prompt_tokens == 50 + assert result1.output_tokens == 100 + + # Test with no usage data + data2: dict[str, str] = {} + result2 = backend._get_completions_usage_stats(data2) + assert result2 is None + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_openai_backend_resolve_not_implemented_history(self): + """Test resolve method raises error for conversation history. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + request = GenerationRequest(content="test") + request_info = ScheduledRequestInfo( + request_id="test-id", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + history = [(request, GenerationResponse(request_id="test", request_args={}))] + + with pytest.raises(NotImplementedError, match="Multi-turn requests"): + async for _ in backend.resolve(request, request_info, history): + pass + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_openai_backend_resolve_text_completions(self): + """Test resolve method for text completions. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + request = GenerationRequest( + content="test prompt", + request_type="text_completions", + params={"temperature": 0.7}, + constraints={"output_tokens": 100}, + ) + request_info = ScheduledRequestInfo( + request_id="test-id", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + + # Mock text_completions method + async def mock_text_completions(*args, **kwargs): + yield None, None # Start signal + yield "Hello", None # First token + yield " world", UsageStats(prompt_tokens=10, output_tokens=2) # Final + + with patch.object( + backend, "text_completions", side_effect=mock_text_completions + ): + responses = [] + async for response, info in backend.resolve(request, request_info): + responses.append((response, info)) + + assert len(responses) >= 2 + final_response = responses[-1][0] + assert final_response.value == "Hello world" + assert final_response.request_id == request.request_id + assert final_response.iterations == 2 + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_openai_backend_resolve_chat_completions(self): + """Test resolve method for chat completions. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + request = GenerationRequest( + content="test message", + request_type="chat_completions", + params={"temperature": 0.5}, + ) + request_info = ScheduledRequestInfo( + request_id="test-id", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + + # Mock chat_completions method + async def mock_chat_completions(*args, **kwargs): + yield None, None # Start signal + yield "Response", UsageStats(prompt_tokens=5, output_tokens=1) + + with patch.object( + backend, "chat_completions", side_effect=mock_chat_completions + ): + responses = [] + async for response, info in backend.resolve(request, request_info): + responses.append((response, info)) + + final_response = responses[-1][0] + assert final_response.value == "Response" + assert final_response.request_id == request.request_id + + +class TestOpenAICompletions: + """Test cases for completion methods.""" + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_text_completions_not_in_process(self): + """Test text_completions when backend not started. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + with pytest.raises(RuntimeError, match="Backend not started up"): + async for _ in backend.text_completions("test", "req-id"): + pass + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_text_completions_basic(self): + """Test basic text_completions functionality. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "choices": [{"text": "Generated text"}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5}, + } + + with patch.object( + backend._async_client, "post", return_value=mock_response + ): + results = [] + async for result in backend.text_completions( + prompt="test prompt", request_id="req-123", stream_response=False + ): + results.append(result) + + assert len(results) == 2 + assert results[0] == (None, None) # Initial yield + assert results[1][0] == "Generated text" + assert isinstance(results[1][1], UsageStats) + assert results[1][1].prompt_tokens == 10 + assert results[1][1].output_tokens == 5 + finally: + await backend.process_shutdown() + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_chat_completions_not_in_process(self): + """Test chat_completions when backend not started. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + with pytest.raises(RuntimeError, match="Backend not started up"): + async for _ in backend.chat_completions("test"): + pass + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_chat_completions_basic(self): + """Test basic chat_completions functionality. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "choices": [{"delta": {"content": "Chat response"}}], + "usage": {"prompt_tokens": 8, "completion_tokens": 3}, + } + + with patch.object( + backend._async_client, "post", return_value=mock_response + ): + results = [] + async for result in backend.chat_completions( + content="Hello", request_id="req-456", stream_response=False + ): + results.append(result) + + assert len(results) == 2 + assert results[0] == (None, None) + assert results[1][0] == "Chat response" + assert isinstance(results[1][1], UsageStats) + assert results[1][1].prompt_tokens == 8 + assert results[1][1].output_tokens == 3 + finally: + await backend.process_shutdown() + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_text_completions_with_parameters(self): + """Test text_completions with additional parameters. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "choices": [{"text": "response"}], + "usage": {"prompt_tokens": 5, "completion_tokens": 1}, + } + + with patch.object( + backend._async_client, "post", return_value=mock_response + ) as mock_post: + async for _ in backend.text_completions( + prompt="test", + request_id="req-123", + output_token_count=50, + temperature=0.7, + stream_response=False, + ): + pass + + # Check that the request body contains expected parameters + call_args = mock_post.call_args + body = call_args[1]["json"] + assert body["max_tokens"] == 50 + assert body["temperature"] == 0.7 + assert body["model"] == "gpt-4" + finally: + await backend.process_shutdown() + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_chat_completions_content_formatting(self): + """Test chat_completions content formatting. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "choices": [{"delta": {"content": "response"}}] + } + + with patch.object( + backend._async_client, "post", return_value=mock_response + ) as mock_post: + async for _ in backend.chat_completions( + content="Hello world", stream_response=False + ): + pass + + call_args = mock_post.call_args + body = call_args[1]["json"] + expected_messages = [{"role": "user", "content": "Hello world"}] + assert body["messages"] == expected_messages + finally: + await backend.process_shutdown() + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_openai_backend_validate_no_models_available(self): + """Test validate method when no models are available. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + try: + # Mock endpoints to fail, then available_models to return empty list + def mock_get_fail(*args, **kwargs): + raise httpx.HTTPStatusError("Error", request=Mock(), response=Mock()) + + with ( + patch.object(backend._async_client, "get", side_effect=mock_get_fail), + patch.object(backend, "available_models", return_value=[]), + patch.object(backend, "text_completions", side_effect=mock_get_fail), + pytest.raises( + RuntimeError, + match="No model available and could not set a default model", + ), + ): + await backend.validate() + finally: + await backend.process_shutdown() + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_text_completions_streaming(self): + """Test text_completions with streaming enabled.""" + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + # Mock streaming response + mock_stream = Mock() + mock_stream.raise_for_status = Mock() + + async def mock_aiter_lines(): + lines = [ + 'data: {"choices":[{"text":"Hello"}], "usage":{"prompt_tokens":5,"completion_tokens":1}}', # noqa: E501 + 'data: {"choices":[{"text":" world"}], "usage":{"prompt_tokens":5,"completion_tokens":2}}', # noqa: E501 + 'data: {"choices":[{"text":"!"}], "usage":{"prompt_tokens":5,"completion_tokens":3}}', # noqa: E501 + "data: [DONE]", + ] + for line in lines: + yield line + + mock_stream.aiter_lines = mock_aiter_lines + + mock_client_stream = AsyncMock() + mock_client_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_client_stream.__aexit__ = AsyncMock(return_value=None) + + with patch.object( + backend._async_client, "stream", return_value=mock_client_stream + ): + results = [] + async for result in backend.text_completions( + prompt="test prompt", request_id="req-123", stream_response=True + ): + results.append(result) + + # Should get initial None, then tokens, then final with usage + assert len(results) >= 3 + assert results[0] == (None, None) # Initial yield + assert all( + isinstance(result[0], str) for result in results[1:] + ) # Has text content + assert all( + isinstance(result[1], UsageStats) for result in results[1:] + ) # Has usage stats + assert all( + result[1].output_tokens == i for i, result in enumerate(results[1:], 1) + ) + finally: + await backend.process_shutdown() + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_chat_completions_streaming(self): + """Test chat_completions with streaming enabled. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + # Mock streaming response + mock_stream = Mock() + mock_stream.raise_for_status = Mock() + + async def mock_aiter_lines(): + lines = [ + 'data: {"choices":[{"delta":{"content":"Hi"}}]}', + 'data: {"choices":[{"delta":{"content":" there"}}]}', + 'data: {"choices":[{"delta":{"content":"!"}}]}', + 'data: {"usage":{"prompt_tokens":3,"completion_tokens":3}}', + "data: [DONE]", + ] + for line in lines: + yield line + + mock_stream.aiter_lines = mock_aiter_lines + + mock_client_stream = AsyncMock() + mock_client_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_client_stream.__aexit__ = AsyncMock(return_value=None) + + with patch.object( + backend._async_client, "stream", return_value=mock_client_stream + ): + results = [] + async for result in backend.chat_completions( + content="Hello", request_id="req-456", stream_response=True + ): + results.append(result) + + # Should get initial None, then deltas, then final with usage + assert len(results) >= 3 + assert results[0] == (None, None) # Initial yield + assert any(result[0] for result in results if result[0]) # Has content + assert any(result[1] for result in results if result[1]) # Has usage stats + finally: + await backend.process_shutdown() + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_streaming_response_edge_cases(self): + """Test streaming response edge cases for line processing. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + # Mock streaming response with edge cases + mock_stream = Mock() + mock_stream.raise_for_status = Mock() + + async def mock_aiter_lines(): + lines = [ + "", # Empty line + " ", # Whitespace only + "not data line", # Line without data prefix + 'data: {"choices":[{"text":"Hello"}]}', # Valid data + "data: [DONE]", # End marker + ] + for line in lines: + yield line + + mock_stream.aiter_lines = mock_aiter_lines + + mock_client_stream = AsyncMock() + mock_client_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_client_stream.__aexit__ = AsyncMock(return_value=None) + + with patch.object( + backend._async_client, "stream", return_value=mock_client_stream + ): + results = [] + async for result in backend.text_completions( + prompt="test", request_id="req-123", stream_response=True + ): + results.append(result) + + # Should get initial None and the valid response + assert len(results) == 2 + assert results[0] == (None, None) + assert results[1][0] == "Hello" + finally: + await backend.process_shutdown() + + @pytest.mark.sanity + def test_openai_backend_get_chat_message_media_item_jpeg_file(self): + """Test _get_chat_message_media_item with JPEG file path. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + # Create a mock Path object for JPEG file + mock_jpeg_path = Mock(spec=Path) + mock_jpeg_path.suffix.lower.return_value = ".jpg" + + # Mock Image.open to return a mock image + mock_image = Mock(spec=Image.Image) + mock_image.tobytes.return_value = b"fake_jpeg_data" + + with patch("guidellm.backend.openai.Image.open", return_value=mock_image): + result = backend._get_chat_message_media_item(mock_jpeg_path) + + expected_data = base64.b64encode(b"fake_jpeg_data").decode("utf-8") + expected = { + "type": "image", + "image": {"url": f"data:image/jpeg;base64,{expected_data}"}, + } + assert result == expected + + @pytest.mark.sanity + def test_openai_backend_get_chat_message_media_item_wav_file(self): + """Test _get_chat_message_media_item with WAV file path. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + # Create a mock Path object for WAV file + mock_wav_path = Mock(spec=Path) + mock_wav_path.suffix.lower.return_value = ".wav" + mock_wav_path.read_bytes.return_value = b"fake_wav_data" + + result = backend._get_chat_message_media_item(mock_wav_path) + + expected_data = base64.b64encode(b"fake_wav_data").decode("utf-8") + expected = { + "type": "input_audio", + "input_audio": {"data": expected_data, "format": "wav"}, + } + assert result == expected + + @pytest.mark.sanity + def test_openai_backend_get_chat_messages_with_pil_image(self): + """Test _get_chat_messages with PIL Image in content list. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + + # Create a mock PIL Image + mock_image = Mock(spec=Image.Image) + mock_image.tobytes.return_value = b"fake_image_bytes" + + content = ["Hello", mock_image, "world"] + + result = backend._get_chat_messages(content) + + # Should have one user message with mixed content + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 3 + + # Check text items + assert result[0]["content"][0] == {"type": "text", "text": "Hello"} + assert result[0]["content"][2] == {"type": "text", "text": "world"} + + # Check image item + image_item = result[0]["content"][1] + assert image_item["type"] == "image" + assert "data:image/jpeg;base64," in image_item["image"]["url"] + + @pytest.mark.regression + @pytest.mark.asyncio + async def test_resolve_timing_edge_cases(self): + """Test resolve method timing edge cases. + + ### WRITTEN BY AI ### + """ + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + try: + request = GenerationRequest( + content="test prompt", + request_type="text_completions", + constraints={"output_tokens": 50}, + ) + request_info = ScheduledRequestInfo( + request_id="test-id", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + + # Mock text_completions to test timing edge cases + async def mock_text_completions(*args, **kwargs): + yield None, None # Initial yield - tests line 343 + yield "token1", None # First token + yield "token2", UsageStats(prompt_tokens=10, output_tokens=2) # Final + + with patch.object( + backend, "text_completions", side_effect=mock_text_completions + ): + responses = [] + async for response, info in backend.resolve(request, request_info): + responses.append((response, info)) + + # Check that timing was properly set + final_response, final_info = responses[-1] + assert final_info.request_timings.request_start is not None + assert final_info.request_timings.first_iteration is not None + assert final_info.request_timings.last_iteration is not None + assert final_info.request_timings.request_end is not None + assert final_response.delta is None # Tests line 362 -from guidellm.backend import OpenAIHTTPBackend, ResponseSummary, StreamingTextResponse -from guidellm.config import settings - - -@pytest.mark.smoke -def test_openai_http_backend_default_initialization(): - backend = OpenAIHTTPBackend() - assert backend.target == settings.openai.base_url - assert backend.model is None - assert backend.headers.get("Authorization") == settings.openai.bearer_token - assert backend.organization == settings.openai.organization - assert backend.project == settings.openai.project - assert backend.timeout == settings.request_timeout - assert backend.http2 is True - assert backend.follow_redirects is True - assert backend.max_output_tokens == settings.openai.max_output_tokens - assert backend.extra_query is None - - -@pytest.mark.smoke -def test_openai_http_backend_intialization(): - backend = OpenAIHTTPBackend( - target="http://test-target", - model="test-model", - api_key="test-key", - organization="test-org", - project="test-proj", - timeout=10, - http2=False, - follow_redirects=False, - max_output_tokens=100, - extra_query={"foo": "bar"}, - ) - assert backend.target == "http://test-target" - assert backend.model == "test-model" - assert backend.headers.get("Authorization") == "Bearer test-key" - assert backend.organization == "test-org" - assert backend.project == "test-proj" - assert backend.timeout == 10 - assert backend.http2 is False - assert backend.follow_redirects is False - assert backend.max_output_tokens == 100 - assert backend.extra_query == {"foo": "bar"} - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_openai_http_backend_available_models(httpx_openai_mock): - backend = OpenAIHTTPBackend(target="http://target.mock") - models = await backend.available_models() - assert models == ["mock-model"] - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_openai_http_backend_validate(httpx_openai_mock): - backend = OpenAIHTTPBackend(target="http://target.mock", model="mock-model") - await backend.validate() - - backend = OpenAIHTTPBackend(target="http://target.mock") - await backend.validate() - assert backend.model == "mock-model" - - backend = OpenAIHTTPBackend(target="http://target.mock", model="invalid-model") - with pytest.raises(ValueError): - await backend.validate() - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_openai_http_backend_text_completions(httpx_openai_mock): - backend = OpenAIHTTPBackend(target="http://target.mock", model="mock-model") - - index = 0 - final_resp = None - async for response in backend.text_completions("Test Prompt", request_id="test-id"): - assert isinstance(response, (StreamingTextResponse, ResponseSummary)) - - if index == 0: - assert isinstance(response, StreamingTextResponse) - assert response.type_ == "start" - assert response.iter_count == 0 - assert response.delta == "" - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == "test-id" - elif not isinstance(response, ResponseSummary): - assert response.type_ == "iter" - assert response.iter_count == index - assert len(response.delta) > 0 - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == "test-id" - else: - assert not final_resp - final_resp = response - assert isinstance(response, ResponseSummary) - assert len(response.value) > 0 - assert response.request_args is not None - assert response.iterations > 0 - assert response.start_time > 0 - assert response.end_time == pytest.approx(time.time(), abs=0.01) - assert response.request_prompt_tokens is None - assert response.request_output_tokens is None - assert response.response_prompt_tokens == 3 - assert response.response_output_tokens > 0 # type: ignore - assert response.request_id == "test-id" - - index += 1 - assert final_resp - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_openai_http_backend_text_completions_counts(httpx_openai_mock): - backend = OpenAIHTTPBackend( - target="http://target.mock", - model="mock-model", - max_output_tokens=100, - ) - final_resp = None - - async for response in backend.text_completions( - "Test Prompt", request_id="test-id", prompt_token_count=3, output_token_count=10 - ): - final_resp = response - - assert final_resp - assert isinstance(final_resp, ResponseSummary) - assert len(final_resp.value) > 0 - assert final_resp.request_args is not None - assert final_resp.request_prompt_tokens == 3 - assert final_resp.request_output_tokens == 10 - assert final_resp.response_prompt_tokens == 3 - assert final_resp.response_output_tokens == 10 - assert final_resp.request_id == "test-id" - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_openai_http_backend_chat_completions(httpx_openai_mock): - backend = OpenAIHTTPBackend(target="http://target.mock", model="mock-model") - - index = 0 - final_resp = None - async for response in backend.chat_completions("Test Prompt", request_id="test-id"): - assert isinstance(response, (StreamingTextResponse, ResponseSummary)) - - if index == 0: - assert isinstance(response, StreamingTextResponse) - assert response.type_ == "start" - assert response.iter_count == 0 - assert response.delta == "" - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == "test-id" - elif not isinstance(response, ResponseSummary): - assert response.type_ == "iter" - assert response.iter_count == index - assert len(response.delta) > 0 - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == "test-id" - else: - assert not final_resp - final_resp = response - assert isinstance(response, ResponseSummary) - assert len(response.value) > 0 - assert response.request_args is not None - assert response.iterations > 0 - assert response.start_time > 0 - assert response.end_time == pytest.approx(time.time(), abs=0.01) - assert response.request_prompt_tokens is None - assert response.request_output_tokens is None - assert response.response_prompt_tokens == 3 - assert response.response_output_tokens > 0 # type: ignore - assert response.request_id == "test-id" - - index += 1 - - assert final_resp - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_openai_http_backend_chat_completions_counts(httpx_openai_mock): - backend = OpenAIHTTPBackend( - target="http://target.mock", - model="mock-model", - max_output_tokens=100, - ) - final_resp = None - - async for response in backend.chat_completions( - "Test Prompt", request_id="test-id", prompt_token_count=3, output_token_count=10 - ): - final_resp = response - - assert final_resp - assert isinstance(final_resp, ResponseSummary) - assert len(final_resp.value) > 0 - assert final_resp.request_args is not None - assert final_resp.request_prompt_tokens == 3 - assert final_resp.request_output_tokens == 10 - assert final_resp.response_prompt_tokens == 3 - assert final_resp.response_output_tokens == 10 - assert final_resp.request_id == "test-id" + finally: + await backend.process_shutdown() diff --git a/tests/unit/backend/test_openai_backend_custom_configs.py b/tests/unit/backend/test_openai_backend_custom_configs.py deleted file mode 100644 index 7f6706ad..00000000 --- a/tests/unit/backend/test_openai_backend_custom_configs.py +++ /dev/null @@ -1,88 +0,0 @@ -import pytest - -from guidellm.backend import OpenAIHTTPBackend -from guidellm.config import settings - - -@pytest.mark.smoke -def test_openai_http_backend_default_initialization(): - backend = OpenAIHTTPBackend() - assert backend.verify is True - - -@pytest.mark.smoke -def test_openai_http_backend_custom_ssl_verification(): - backend = OpenAIHTTPBackend(verify=False) - assert backend.verify is False - - -@pytest.mark.smoke -def test_openai_http_backend_custom_headers_override(): - # Set a default api_key, which would normally create an Authorization header - settings.openai.api_key = "default-api-key" - - # Set custom headers that override the default Authorization and add a new header - openshift_token = "Bearer sha256~xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" - override_headers = { - "Authorization": openshift_token, - "Custom-Header": "Custom-Value", - } - - # Initialize the backend - backend = OpenAIHTTPBackend(headers=override_headers) - - # Check that the override headers are used - assert backend.headers["Authorization"] == openshift_token - assert backend.headers["Custom-Header"] == "Custom-Value" - assert len(backend.headers) == 2 - - # Reset the settings - settings.openai.api_key = None - settings.openai.headers = None - - -@pytest.mark.smoke -def test_openai_http_backend_kwarg_headers_override_settings(): - # Set headers via settings (simulating environment variables) - settings.openai.headers = {"Authorization": "Bearer settings-token"} - - # Set different headers via kwargs (simulating --backend-args) - override_headers = { - "Authorization": "Bearer kwargs-token", - "Custom-Header": "Custom-Value", - } - - # Initialize the backend with kwargs - backend = OpenAIHTTPBackend(headers=override_headers) - - # Check that the kwargs headers took precedence - assert backend.headers["Authorization"] == "Bearer kwargs-token" - assert backend.headers["Custom-Header"] == "Custom-Value" - assert len(backend.headers) == 2 - - # Reset the settings - settings.openai.headers = None - - -@pytest.mark.smoke -def test_openai_http_backend_remove_header_with_none(): - # Set a default api_key, which would normally create an Authorization header - settings.openai.api_key = "default-api-key" - - # Set a custom header and explicitly set Authorization to None to remove it - override_headers = { - "Authorization": None, - "Custom-Header": "Custom-Value", - } - - # Initialize the backend - backend = OpenAIHTTPBackend(headers=override_headers) - - # Check that the Authorization header is removed and the custom header is present - assert "Authorization" not in backend.headers - assert backend.headers["Custom-Header"] == "Custom-Value" - assert len(backend.headers) == 1 - - # Reset the settings - settings.openai.api_key = None - settings.openai.headers = None diff --git a/tests/unit/backend/test_response.py b/tests/unit/backend/test_response.py deleted file mode 100644 index b3dc99c9..00000000 --- a/tests/unit/backend/test_response.py +++ /dev/null @@ -1,192 +0,0 @@ -from typing import get_args - -import pytest - -from guidellm.backend import ( - RequestArgs, - ResponseSummary, - StreamingResponseType, - StreamingTextResponse, -) - - -@pytest.mark.smoke -def test_streaming_response_types(): - valid_types = get_args(StreamingResponseType) - assert valid_types == ("start", "iter") - - -@pytest.mark.smoke -def test_streaming_text_response_default_initilization(): - response = StreamingTextResponse( - type_="start", - value="", - start_time=0.0, - first_iter_time=None, - iter_count=0, - delta="", - time=0.0, - ) - assert response.request_id is None - - -@pytest.mark.smoke -def test_streaming_text_response_initialization(): - response = StreamingTextResponse( - type_="start", - value="Hello, world!", - start_time=0.0, - first_iter_time=0.0, - iter_count=1, - delta="Hello, world!", - time=1.0, - request_id="123", - ) - assert response.type_ == "start" - assert response.value == "Hello, world!" - assert response.start_time == 0.0 - assert response.first_iter_time == 0.0 - assert response.iter_count == 1 - assert response.delta == "Hello, world!" - assert response.time == 1.0 - assert response.request_id == "123" - - -@pytest.mark.smoke -def test_streaming_text_response_marshalling(): - response = StreamingTextResponse( - type_="start", - value="Hello, world!", - start_time=0.0, - first_iter_time=0.0, - iter_count=0, - delta="Hello, world!", - time=1.0, - request_id="123", - ) - serialized = response.model_dump() - deserialized = StreamingTextResponse.model_validate(serialized) - - for key, value in vars(response).items(): - assert getattr(deserialized, key) == value - - -@pytest.mark.smoke -def test_request_args_default_initialization(): - args = RequestArgs( - target="http://example.com", - headers={}, - params={}, - payload={}, - ) - assert args.timeout is None - assert args.http2 is None - assert args.follow_redirects is None - - -@pytest.mark.smoke -def test_request_args_initialization(): - args = RequestArgs( - target="http://example.com", - headers={ - "Authorization": "Bearer token", - }, - params={}, - payload={ - "query": "Hello, world!", - }, - timeout=10.0, - http2=True, - follow_redirects=True, - ) - assert args.target == "http://example.com" - assert args.headers == {"Authorization": "Bearer token"} - assert args.payload == {"query": "Hello, world!"} - assert args.timeout == 10.0 - assert args.http2 is True - assert args.follow_redirects is True - - -@pytest.mark.smoke -def test_response_args_marshalling(): - args = RequestArgs( - target="http://example.com", - headers={"Authorization": "Bearer token"}, - params={}, - payload={"query": "Hello, world!"}, - timeout=10.0, - http2=True, - ) - serialized = args.model_dump() - deserialized = RequestArgs.model_validate(serialized) - - for key, value in vars(args).items(): - assert getattr(deserialized, key) == value - - -@pytest.mark.smoke -def test_response_summary_default_initialization(): - summary = ResponseSummary( - value="Hello, world!", - request_args=RequestArgs( - target="http://example.com", - headers={}, - params={}, - payload={}, - ), - start_time=0.0, - end_time=0.0, - first_iter_time=None, - last_iter_time=None, - ) - assert summary.value == "Hello, world!" - assert summary.request_args.target == "http://example.com" - assert summary.request_args.headers == {} - assert summary.request_args.payload == {} - assert summary.start_time == 0.0 - assert summary.end_time == 0.0 - assert summary.first_iter_time is None - assert summary.last_iter_time is None - assert summary.iterations == 0 - assert summary.request_prompt_tokens is None - assert summary.request_output_tokens is None - assert summary.response_prompt_tokens is None - assert summary.response_output_tokens is None - assert summary.request_id is None - - -@pytest.mark.smoke -def test_response_summary_initialization(): - summary = ResponseSummary( - value="Hello, world!", - request_args=RequestArgs( - target="http://example.com", - headers={}, - params={}, - payload={}, - ), - start_time=1.0, - end_time=2.0, - iterations=3, - first_iter_time=1.0, - last_iter_time=2.0, - request_prompt_tokens=5, - request_output_tokens=10, - response_prompt_tokens=5, - response_output_tokens=10, - request_id="123", - ) - assert summary.value == "Hello, world!" - assert summary.request_args.target == "http://example.com" - assert summary.request_args.headers == {} - assert summary.request_args.payload == {} - assert summary.start_time == 1.0 - assert summary.end_time == 2.0 - assert summary.iterations == 3 - assert summary.first_iter_time == 1.0 - assert summary.last_iter_time == 2.0 - assert summary.request_prompt_tokens == 5 - assert summary.request_output_tokens == 10 - assert summary.response_prompt_tokens == 5 - assert summary.response_output_tokens == 10 - assert summary.request_id == "123" diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a0457b6f..00d4eec1 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,195 +1,195 @@ -import json -from collections.abc import AsyncIterable -from typing import Any, Literal, Optional -from unittest.mock import MagicMock, patch - -import httpx -import pytest -import respx - -from guidellm.backend import ResponseSummary, StreamingTextResponse - -from .mock_backend import MockBackend - - -@pytest.fixture -def mock_auto_tokenizer(): - with patch("transformers.AutoTokenizer.from_pretrained") as mock_from_pretrained: - - def _fake_tokenize(text: str) -> list[int]: - tokens = text.split() - return [0] * len(tokens) - - mock_tokenizer = MagicMock() - mock_tokenizer.tokenize = MagicMock(side_effect=_fake_tokenize) - mock_from_pretrained.return_value = mock_tokenizer - yield mock_tokenizer - - -@pytest.fixture -def mock_backend(request): - params = request.param if hasattr(request, "param") else {} - kwargs = {} - - for key in ("model", "target", "iter_delay"): - if key in params: - kwargs[key] = params[key] - - return MockBackend(**kwargs) - - -class MockCompletionsIter(AsyncIterable): - def __init__( - self, - type_: Literal["text", "chat"], - prompt: str, - output_token_count: Optional[int], - target: Optional[str] = None, - model: Optional[str] = None, - iter_delay: Optional[float] = None, - ): - self._type = type_ - self._backend = MockBackend( - model=model, - target=target, - iter_delay=iter_delay, - ) - self._prompt = prompt - self._output_token_count = output_token_count - - async def __aiter__(self): - async for token_iter in ( - self._backend.text_completions( - prompt=self._prompt, output_token_count=self._output_token_count - ) - if self._type == "text" - else self._backend.chat_completions( - content=self._prompt, output_token_count=self._output_token_count - ) - ): - if ( - isinstance(token_iter, StreamingTextResponse) - and token_iter.type_ == "start" - ): - continue - - data: dict[str, Any] - - if isinstance(token_iter, StreamingTextResponse): - if self._type == "text": - data = { - "choices": [ - { - "index": token_iter.iter_count, - "text": token_iter.delta, - } - ] - } - elif self._type == "chat": - data = { - "choices": [ - { - "index": token_iter.iter_count, - "delta": {"content": token_iter.delta}, - } - ] - } - else: - raise ValueError("Invalid type for mock completions") - elif isinstance(token_iter, ResponseSummary): - data = { - "usage": { - "prompt_tokens": ( - len(self._prompt.split()) + self._prompt.count(" ") - ), - "completion_tokens": token_iter.response_output_tokens, - } - } - else: - raise ValueError("Invalid token_iter type") - - yield f"data: {json.dumps(data)}\n".encode() - - yield b"data: [DONE]\n" - - -@pytest.fixture -def httpx_openai_mock(request): - params = request.param if hasattr(request, "param") else {} - model = params.get("model", "mock-model") - target = params.get("target", "http://target.mock") - iter_delay = params.get("iter_delay", None) - - with respx.mock(assert_all_mocked=True, assert_all_called=False) as mock_router: - - async def _mock_completions_response(request) -> AsyncIterable[str]: - headers = request.headers - payload = json.loads(request.content) - - assert headers["Content-Type"] == "application/json" - assert payload["model"] == model - assert payload["stream"] is True - assert payload["stream_options"] == {"include_usage": True} - assert payload["prompt"] is not None - assert len(payload["prompt"]) > 0 - assert payload["max_completion_tokens"] > 0 - assert payload["max_tokens"] > 0 - - return httpx.Response( # type: ignore - 200, - stream=MockCompletionsIter( # type: ignore - type_="text", - prompt=payload["prompt"], - output_token_count=( - payload["max_completion_tokens"] - if payload.get("ignore_eos", False) - else None - ), - target=target, - model=model, - iter_delay=iter_delay, - ), - ) - - async def _mock_chat_completions_response(request): - headers = request.headers - payload = json.loads(request.content) - - assert headers["Content-Type"] == "application/json" - assert payload["model"] == model - assert payload["stream"] is True - assert payload["stream_options"] == {"include_usage": True} - assert payload["messages"] is not None - assert len(payload["messages"]) > 0 - assert payload["max_completion_tokens"] > 0 - assert payload["max_tokens"] > 0 - - return httpx.Response( # type: ignore - 200, - stream=MockCompletionsIter( # type: ignore - type_="chat", - prompt=payload["messages"][0]["content"], - output_token_count=( - payload["max_completion_tokens"] - if payload.get("ignore_eos", False) - else None - ), - target=target, - model=model, - iter_delay=iter_delay, - ), - ) - - mock_router.route(method="GET", path="/v1/models").mock( - return_value=httpx.Response( - 200, json={"data": [{"id": model} if model else {"id": "mock-model"}]} - ) - ) - mock_router.route(method="POST", path="/v1/completions").mock( - side_effect=_mock_completions_response # type: ignore - ) - mock_router.route(method="POST", path="/v1/chat/completions").mock( - side_effect=_mock_chat_completions_response - ) - - yield mock_router +# import json +# from collections.abc import AsyncIterable +# from typing import Any, Literal, Optional +# from unittest.mock import MagicMock, patch + +# import httpx +# import pytest +# import respx + +# from guidellm.backend import ResponseSummary, StreamingTextResponse + +# from .mock_backend import MockBackend + + +# @pytest.fixture +# def mock_auto_tokenizer(): +# with patch("transformers.AutoTokenizer.from_pretrained") as mock_from_pretrained: + +# def _fake_tokenize(text: str) -> list[int]: +# tokens = text.split() +# return [0] * len(tokens) + +# mock_tokenizer = MagicMock() +# mock_tokenizer.tokenize = MagicMock(side_effect=_fake_tokenize) +# mock_from_pretrained.return_value = mock_tokenizer +# yield mock_tokenizer + + +# @pytest.fixture +# def mock_backend(request): +# params = request.param if hasattr(request, "param") else {} +# kwargs = {} + +# for key in ("model", "target", "iter_delay"): +# if key in params: +# kwargs[key] = params[key] + +# return MockBackend(**kwargs) + + +# class MockCompletionsIter(AsyncIterable): +# def __init__( +# self, +# type_: Literal["text", "chat"], +# prompt: str, +# output_token_count: Optional[int], +# target: Optional[str] = None, +# model: Optional[str] = None, +# iter_delay: Optional[float] = None, +# ): +# self._type = type_ +# self._backend = MockBackend( +# model=model, +# target=target, +# iter_delay=iter_delay, +# ) +# self._prompt = prompt +# self._output_token_count = output_token_count + +# async def __aiter__(self): +# async for token_iter in ( +# self._backend.text_completions( +# prompt=self._prompt, output_token_count=self._output_token_count +# ) +# if self._type == "text" +# else self._backend.chat_completions( +# content=self._prompt, output_token_count=self._output_token_count +# ) +# ): +# if ( +# isinstance(token_iter, StreamingTextResponse) +# and token_iter.type_ == "start" +# ): +# continue + +# data: dict[str, Any] + +# if isinstance(token_iter, StreamingTextResponse): +# if self._type == "text": +# data = { +# "choices": [ +# { +# "index": token_iter.iter_count, +# "text": token_iter.delta, +# } +# ] +# } +# elif self._type == "chat": +# data = { +# "choices": [ +# { +# "index": token_iter.iter_count, +# "delta": {"content": token_iter.delta}, +# } +# ] +# } +# else: +# raise ValueError("Invalid type for mock completions") +# elif isinstance(token_iter, ResponseSummary): +# data = { +# "usage": { +# "prompt_tokens": ( +# len(self._prompt.split()) + self._prompt.count(" ") +# ), +# "completion_tokens": token_iter.response_output_tokens, +# } +# } +# else: +# raise ValueError("Invalid token_iter type") + +# yield f"data: {json.dumps(data)}\n".encode() + +# yield b"data: [DONE]\n" + + +# @pytest.fixture +# def httpx_openai_mock(request): +# params = request.param if hasattr(request, "param") else {} +# model = params.get("model", "mock-model") +# target = params.get("target", "http://target.mock") +# iter_delay = params.get("iter_delay", None) + +# with respx.mock(assert_all_mocked=True, assert_all_called=False) as mock_router: + +# async def _mock_completions_response(request) -> AsyncIterable[str]: +# headers = request.headers +# payload = json.loads(request.content) + +# assert headers["Content-Type"] == "application/json" +# assert payload["model"] == model +# assert payload["stream"] is True +# assert payload["stream_options"] == {"include_usage": True} +# assert payload["prompt"] is not None +# assert len(payload["prompt"]) > 0 +# assert payload["max_completion_tokens"] > 0 +# assert payload["max_tokens"] > 0 + +# return httpx.Response( # type: ignore +# 200, +# stream=MockCompletionsIter( # type: ignore +# type_="text", +# prompt=payload["prompt"], +# output_token_count=( +# payload["max_completion_tokens"] +# if payload.get("ignore_eos", False) +# else None +# ), +# target=target, +# model=model, +# iter_delay=iter_delay, +# ), +# ) + +# async def _mock_chat_completions_response(request): +# headers = request.headers +# payload = json.loads(request.content) + +# assert headers["Content-Type"] == "application/json" +# assert payload["model"] == model +# assert payload["stream"] is True +# assert payload["stream_options"] == {"include_usage": True} +# assert payload["messages"] is not None +# assert len(payload["messages"]) > 0 +# assert payload["max_completion_tokens"] > 0 +# assert payload["max_tokens"] > 0 + +# return httpx.Response( # type: ignore +# 200, +# stream=MockCompletionsIter( # type: ignore +# type_="chat", +# prompt=payload["messages"][0]["content"], +# output_token_count=( +# payload["max_completion_tokens"] +# if payload.get("ignore_eos", False) +# else None +# ), +# target=target, +# model=model, +# iter_delay=iter_delay, +# ), +# ) + +# mock_router.route(method="GET", path="/v1/models").mock( +# return_value=httpx.Response( +# 200, json={"data": [{"id": model} if model else {"id": "mock-model"}]} +# ) +# ) +# mock_router.route(method="POST", path="/v1/completions").mock( +# side_effect=_mock_completions_response # type: ignore +# ) +# mock_router.route(method="POST", path="/v1/chat/completions").mock( +# side_effect=_mock_chat_completions_response +# ) + +# yield mock_router diff --git a/tests/unit/mock_backend.py b/tests/unit/mock_backend.py index 27bfe382..4e1476d3 100644 --- a/tests/unit/mock_backend.py +++ b/tests/unit/mock_backend.py @@ -1,172 +1,186 @@ +""" +Mock backend implementation for testing purposes. +""" + import asyncio import random import time -from collections.abc import AsyncGenerator -from pathlib import Path -from typing import Any, Optional, Union - -from lorem.text import TextLorem # type: ignore -from PIL import Image - -from guidellm.backend import ( - Backend, - RequestArgs, - ResponseSummary, - StreamingTextResponse, +from collections.abc import AsyncIterator +from typing import Any, Optional + +from lorem.text import TextLorem + +from guidellm.backend.backend import Backend +from guidellm.backend.objects import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, ) +from guidellm.scheduler import ScheduledRequestInfo -@Backend.register("mock") # type: ignore +@Backend.register("mock") class MockBackend(Backend): + """ + Mock backend for testing that simulates text generation. + + Provides predictable responses with configurable delays and token counts + for testing the backend interface without requiring an actual LLM service. + """ + def __init__( self, - model: Optional[str] = "mock-model", - target: Optional[str] = "mock-target", + target: str = "mock-target", + model: str = "mock-model", iter_delay: Optional[float] = None, ): - super().__init__(type_="mock") # type: ignore + """ + Initialize mock backend. + + :param model: Model name to simulate. + :param target: Target URL to simulate. + :param iter_delay: Delay between iterations in seconds. + """ + super().__init__(type_="mock") # type: ignore [reportCallIssue] self._model = model self._target = target self._iter_delay = iter_delay + self._in_process = False @property def target(self) -> str: - return self._target # type: ignore + """Target URL for the mock backend.""" + return self._target @property def model(self) -> Optional[str]: + """Model name for the mock backend.""" return self._model - @property def info(self) -> dict[str, Any]: - return {} - - async def reset(self) -> None: - pass - - async def prepare_multiprocessing(self): - pass - - async def check_setup(self): - pass - - async def available_models(self) -> list[str]: - return [self.model] # type: ignore + """ + Return mock backend configuration information. + """ + return { + "type": "mock", + "model": self._model, + "target": self._target, + "iter_delay": self._iter_delay, + } + + async def process_startup(self) -> None: + """ + Initialize the mock backend process. + """ + self._in_process = True + + async def process_shutdown(self) -> None: + """ + Shutdown the mock backend process. + """ + self._in_process = False + + async def validate(self) -> None: + """ + Validate the mock backend configuration. + """ + if not self._in_process: + raise RuntimeError("Backend not started up for process") + + async def default_model(self) -> Optional[str]: + """ + Return the default model for the mock backend. + """ + return self._model - async def text_completions( # type: ignore + async def resolve( self, - prompt: Union[str, list[str]], - request_id: Optional[str] = None, - prompt_token_count: Optional[int] = None, - output_token_count: Optional[int] = None, - **kwargs, - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - if not isinstance(prompt, str) or not prompt: - raise ValueError("Prompt must be a non-empty string") - - async for response in self._text_prompt_response_generator( - prompt, - request_id, - prompt_token_count, - output_token_count, - ): - yield response - - async def chat_completions( # type: ignore - self, - content: Union[ - str, - list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], - Any, - ], - request_id: Optional[str] = None, - prompt_token_count: Optional[int] = None, - output_token_count: Optional[int] = None, - raw_content: bool = False, - **kwargs, - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - if not isinstance(content, str) or not content: - raise ValueError("Content must be a non-empty string") - - async for response in self._text_prompt_response_generator( - content, - request_id, - prompt_token_count, - output_token_count, - ): - yield response - - async def _text_prompt_response_generator( - self, - prompt: str, - request_id: Optional[str], - prompt_token_count: Optional[int], - output_token_count: Optional[int], - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - tokens = self._get_tokens(output_token_count) - start_time = time.time() - - yield StreamingTextResponse( - type_="start", + request: GenerationRequest, + request_info: ScheduledRequestInfo[GenerationRequestTimings], + history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, + ) -> AsyncIterator[ + tuple[GenerationResponse, ScheduledRequestInfo[GenerationRequestTimings]] + ]: + """ + Process a generation request and yield progressive responses. + + ### WRITTEN BY AI ### + """ + if not self._in_process: + raise RuntimeError("Backend not started up for process") + + if history is not None: + raise NotImplementedError( + "Multi-turn requests not supported in mock backend" + ) + + # Extract token counts from request + prompt_tokens = request.stats.get("prompt_tokens") + output_tokens = request.constraints.get("output_tokens") + + # Generate mock tokens + tokens = self._get_tokens(output_tokens) + + # Initialize response + response = GenerationResponse( + request_id=request.request_id, + request_args={ + "request_type": request.request_type, + "output_token_count": output_tokens, + **request.params, + }, value="", - start_time=start_time, - first_iter_time=None, - iter_count=0, - delta="", - time=start_time, - request_id=request_id, + request_prompt_tokens=prompt_tokens, + request_output_tokens=output_tokens, ) - first_iter_time = None - last_iter_time = None + # Initialize timings + request_info.request_timings = GenerationRequestTimings() + request_info.request_timings.request_start = time.time() + # Generate response iteratively for index, token in enumerate(tokens): if self._iter_delay: await asyncio.sleep(self._iter_delay) - if first_iter_time is None: - first_iter_time = time.time() - - yield StreamingTextResponse( - type_="iter", - value="".join(tokens[: index + 1]), - start_time=start_time, - first_iter_time=first_iter_time, - iter_count=index + 1, - delta=token, - time=time.time(), - request_id=request_id, - ) + if request_info.request_timings.first_iteration is None: + request_info.request_timings.first_iteration = time.time() - last_iter_time = time.time() - - yield ResponseSummary( - value="".join(tokens), - request_args=RequestArgs( - target=self.target, - headers={}, - params={}, - payload={"prompt": prompt, "output_token_count": output_token_count}, - ), - iterations=len(tokens), - start_time=start_time, - end_time=time.time(), - first_iter_time=first_iter_time, - last_iter_time=last_iter_time, - request_prompt_tokens=prompt_token_count, - request_output_tokens=output_token_count, - response_prompt_tokens=len(prompt.split()) + prompt.count(" "), - response_output_tokens=len(tokens), - request_id=request_id, + response.value += token # type: ignore [reportOperatorIssue] + response.delta = token + response.iterations = index + 1 + request_info.request_timings.last_iteration = time.time() + + yield response, request_info + + # Final response with usage stats + request_info.request_timings.request_end = time.time() + response.response_prompt_tokens = prompt_tokens or self._estimate_prompt_tokens( + str(request.content) ) + response.response_output_tokens = len(tokens) + response.delta = None + + yield response, request_info + + @staticmethod + def _estimate_prompt_tokens(content: str) -> int: + """ + Estimate prompt tokens from content. + """ + # Simple word-based token estimation + return len(str(content).split()) @staticmethod def _get_tokens(token_count: Optional[int] = None) -> list[str]: + """ + Generate mock tokens for response. + """ if token_count is None: token_count = random.randint(8, 512) words = TextLorem(srange=(token_count, token_count)).sentence().split() - tokens = [] # type: ignore + tokens = [] for word in words: if len(tokens) == token_count - 1: diff --git a/tests/unit/mock_benchmark.py b/tests/unit/mock_benchmark.py index 81364fa1..511bacbf 100644 --- a/tests/unit/mock_benchmark.py +++ b/tests/unit/mock_benchmark.py @@ -1,9 +1,9 @@ from guidellm.benchmark import ( BenchmarkArgs, - BenchmarkRunStats, + BenchmarkSchedulerStats, GenerativeBenchmark, + GenerativeRequestStats, GenerativeTextErrorStats, - GenerativeTextResponseStats, SynchronousProfile, ) from guidellm.objects import StatusBreakdown @@ -21,7 +21,7 @@ def mock_generative_benchmark() -> GenerativeBenchmark: return GenerativeBenchmark.from_stats( run_id="fa4a92c1-9a1d-4c83-b237-83fcc7971bd3", successful=[ - GenerativeTextResponseStats( + GenerativeRequestStats( request_id="181a63e2-dc26-4268-9cfc-2ed9279aae63", request_type="text_completions", scheduler_info=SchedulerRequestInfo( @@ -48,7 +48,7 @@ def mock_generative_benchmark() -> GenerativeBenchmark: first_token_time=1744728125.2473357, last_token_time=1744728126.699908, ), - GenerativeTextResponseStats( + GenerativeRequestStats( request_id="8a7846d5-7624-420d-a269-831e568a848f", request_type="text_completions", scheduler_info=SchedulerRequestInfo( @@ -75,7 +75,7 @@ def mock_generative_benchmark() -> GenerativeBenchmark: first_token_time=1744728126.7526379, last_token_time=1744728128.1956792, ), - GenerativeTextResponseStats( + GenerativeRequestStats( request_id="4cde0e6c-4531-4e59-aac1-07bc8b6e4139", request_type="text_completions", scheduler_info=SchedulerRequestInfo( @@ -102,7 +102,7 @@ def mock_generative_benchmark() -> GenerativeBenchmark: first_token_time=1744728128.2481627, last_token_time=1744728129.6914039, ), - GenerativeTextResponseStats( + GenerativeRequestStats( request_id="a95b96be-05d4-4130-b0dd-9528c01c9909", request_type="text_completions", scheduler_info=SchedulerRequestInfo( @@ -129,7 +129,7 @@ def mock_generative_benchmark() -> GenerativeBenchmark: first_token_time=1744728129.7438853, last_token_time=1744728131.187019, ), - GenerativeTextResponseStats( + GenerativeRequestStats( request_id="714b751c-bbfe-4b2a-a0af-7c1bf2c224ae", request_type="text_completions", scheduler_info=SchedulerRequestInfo( @@ -156,7 +156,7 @@ def mock_generative_benchmark() -> GenerativeBenchmark: first_token_time=1744728131.2394557, last_token_time=1744728132.6828275, ), - GenerativeTextResponseStats( + GenerativeRequestStats( request_id="ef73ae8a-4c8f-4c88-b303-cfff152ce378", request_type="text_completions", scheduler_info=SchedulerRequestInfo( @@ -226,7 +226,7 @@ def mock_generative_benchmark() -> GenerativeBenchmark: cooldown_number=None, cooldown_duration=None, ), - run_stats=BenchmarkRunStats( + run_stats=BenchmarkSchedulerStats( start_time=1744728125.0772898, end_time=1744728135.8407037, requests_made=StatusBreakdown( diff --git a/tests/unit/scheduler/__init__.py b/tests/unit/scheduler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/scheduler/test_constraints.py b/tests/unit/scheduler/test_constraints.py new file mode 100644 index 00000000..487aa574 --- /dev/null +++ b/tests/unit/scheduler/test_constraints.py @@ -0,0 +1,1568 @@ +import inspect +import random +import time +from typing import Protocol + +import pytest +from pydantic import ValidationError + +from guidellm.scheduler import ( + Constraint, + ConstraintInitializer, + ConstraintsInitializerFactory, + MaxDurationConstraint, + MaxDurationConstraintInitializer, + MaxErrorRateConstraint, + MaxErrorRateConstraintInitializer, + MaxErrorsConstraint, + MaxErrorsConstraintInitializer, + MaxGlobalErrorRateConstraint, + MaxGlobalErrorRateConstraintInitializer, + MaxNumberConstraint, + MaxNumberConstraintInitializer, + ScheduledRequestInfo, + SchedulerState, + SchedulerUpdateAction, +) + + +class TestConstraint: + """Test the Constraint protocol.""" + + @pytest.mark.smoke + def test_is_protocol(self): + """Test that Constraint is a protocol and runtime checkable.""" + assert issubclass(Constraint, Protocol) + assert hasattr(Constraint, "_is_protocol") + assert Constraint._is_protocol is True + assert hasattr(Constraint, "_is_runtime_protocol") + assert Constraint._is_runtime_protocol is True + + @pytest.mark.smoke + def test_protocol_method_signature(self): + """Test that the Constraint protocol has the correct method signature.""" + call_method = Constraint.__call__ + sig = inspect.signature(call_method) + + expected_params = ["self", "state", "request"] + assert list(sig.parameters.keys()) == expected_params + + params = sig.parameters + assert "state" in params + assert "request" in params + + @pytest.mark.smoke + def test_runtime_is_constraint(self): + """Test that Constraint can be checked at runtime using isinstance.""" + + class ValidConstraint: + def __call__( + self, + state: SchedulerState, + request: ScheduledRequestInfo, + ) -> SchedulerUpdateAction: + return SchedulerUpdateAction() + + valid_instance = ValidConstraint() + assert isinstance(valid_instance, Constraint) + + class InvalidConstraint: + pass + + invalid_instance = InvalidConstraint() + assert not isinstance(invalid_instance, Constraint) + + @pytest.mark.smoke + def test_runtime_is_not_intializer(self): + """ + Test that a class not implementing the ConstraintInitializer + protocol is not recognized as such. + """ + + class ValidConstraint: + def __call__( + self, + state: SchedulerState, + request: ScheduledRequestInfo, + ) -> SchedulerUpdateAction: + return SchedulerUpdateAction() + + not_initializer_instance = ValidConstraint() + assert not isinstance(not_initializer_instance, ConstraintInitializer) + + +class TestConstraintInitializer: + """Test the ConstraintInitializer protocol.""" + + @pytest.mark.smoke + def test_is_protocol(self): + """Test that ConstraintInitializer is a protocol and runtime checkable.""" + assert issubclass(ConstraintInitializer, Protocol) + assert hasattr(ConstraintInitializer, "_is_protocol") + assert ConstraintInitializer._is_protocol is True + assert hasattr(ConstraintInitializer, "_is_runtime_protocol") + assert ConstraintInitializer._is_runtime_protocol is True + + @pytest.mark.smoke + def test_protocol_method_signature(self): + """Test that ConstraintInitializer protocol has correct method signature.""" + create_constraint_method = ConstraintInitializer.create_constraint + sig = inspect.signature(create_constraint_method) + + expected_params = ["self", "kwargs"] + assert list(sig.parameters.keys()) == expected_params + kwargs_param = sig.parameters["kwargs"] + assert kwargs_param.kind == kwargs_param.VAR_KEYWORD + + @pytest.mark.smoke + def test_runtime_is_initializer(self): + """Test that ConstraintInitializer can be checked at runtime.""" + + class ValidInitializer: + def create_constraint(self, **kwargs) -> Constraint: + class SimpleConstraint: + def __call__( + self, + state: SchedulerState, + request: ScheduledRequestInfo, + ) -> SchedulerUpdateAction: + return SchedulerUpdateAction() + + return SimpleConstraint() + + valid_instance = ValidInitializer() + assert isinstance(valid_instance, ConstraintInitializer) + + @pytest.mark.smoke + def test_runtime_is_not_constraint(self): + """ + Test that a class not implementing the Constraint protocol + is not recognized as such. + """ + + class ValidInitializer: + def create_constraint(self, **kwargs) -> Constraint: + class SimpleConstraint: + def __call__( + self, + state: SchedulerState, + request: ScheduledRequestInfo, + ) -> SchedulerUpdateAction: + return SchedulerUpdateAction() + + return SimpleConstraint() + + not_constraint_instance = ValidInitializer() + assert not isinstance(not_constraint_instance, Constraint) + + +class TestMaxNumberConstraint: + """Test the MaxNumberConstraint implementation.""" + + @pytest.fixture(params=[{"max_num": 100}, {"max_num": 50.5}, {"max_num": 1}]) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxNumberConstraint(**constructor_args) + + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxNumberConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.sanity + def test_is_not_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxNumberConstraint does not satisfy + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert not isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that MaxNumberConstraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxNumberConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxNumberConstraint() + with pytest.raises(ValidationError): + MaxNumberConstraint(max_num=-1) + with pytest.raises(ValidationError): + MaxNumberConstraint(max_num=0) + with pytest.raises(ValidationError): + MaxNumberConstraint(max_num="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions and progress""" + instance, constructor_args = valid_instances + start_time = time.time() + + for num_requests in range(0, int(constructor_args["max_num"]) * 2 + 1, 1): + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=num_requests, + processed_requests=num_requests // 2, + ) + request = ScheduledRequestInfo( + request_id=f"test-{num_requests}", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + created_exceeded = num_requests >= constructor_args["max_num"] + processed_exceeded = num_requests // 2 >= constructor_args["max_num"] + expected_queuing = "stop" if created_exceeded else "continue" + expected_processing = "stop_local" if processed_exceeded else "continue" + assert action.request_queuing == expected_queuing + assert action.request_processing == expected_processing + assert isinstance(action.metadata, dict) + assert action.metadata == { + "max_number": constructor_args["max_num"], + "create_exceeded": created_exceeded, + "processed_exceeded": processed_exceeded, + "created_requests": state.created_requests, + "processed_requests": state.processed_requests, + } + assert isinstance(action.progress, dict) + processed_requests = num_requests // 2 + remaining_fraction = max( + 0.0, 1.0 - processed_requests / constructor_args["max_num"] + ) + remaining_requests = max( + 0.0, constructor_args["max_num"] - processed_requests + ) + assert action.progress["remaining_fraction"] == pytest.approx( + remaining_fraction + ) + assert action.progress["remaining_requests"] == remaining_requests + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxNumberConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxNumberConstraint.model_validate(data) + assert reconstructed.max_num == instance.max_num + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestMaxNumberConstraintInitializer: + """Test the MaxNumberConstraintInitializer implementation.""" + + @pytest.fixture(params=[{"max_num": 100}, {"max_num": 50.5}, {"max_num": 1}]) + def valid_instances(self, request): + """Provide valid instances of MaxNumberConstraintInitializer.""" + params = request.param + instance = MaxNumberConstraintInitializer(**params) + return instance, params + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self): + """Test that MaxNumberConstraintInitializer satisfies the protocol.""" + initializer = MaxNumberConstraintInitializer(max_num=100) + assert isinstance(initializer, ConstraintInitializer) + + @pytest.mark.smoke + def test_is_not_constraint_protocol(self): + """ + Test that MaxNumberConstraintInitializer does not satisfy + the constraint protocol. + """ + initializer = MaxNumberConstraintInitializer(max_num=100) + assert not isinstance(initializer, Constraint) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that the initializer can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that the initializer rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxNumberConstraintInitializer() + with pytest.raises(ValidationError): + MaxNumberConstraintInitializer(max_num=-1) + with pytest.raises(ValidationError): + MaxNumberConstraintInitializer(max_num=0) + with pytest.raises(ValidationError): + MaxNumberConstraintInitializer(max_num="invalid") + + def test_constraint_initialization_functionality(self, valid_instances): + """Test that the constraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + constraint = instance.create_constraint() + assert isinstance(constraint, MaxNumberConstraint) + assert constraint.max_num == constructor_args["max_num"] + + def test_marshalling(self, valid_instances): + """ + Test that MaxNumberConstraintInitializer can be + serialized and deserialized. + """ + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxNumberConstraintInitializer.model_validate(data) + assert reconstructed.max_num == instance.max_num + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestMaxDurationConstraint: + """Test the MaxDurationConstraint implementation.""" + + @pytest.fixture( + params=[{"max_duration": 2.0}, {"max_duration": 1}, {"max_duration": 0.5}] + ) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxDurationConstraint(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxDurationConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.sanity + def test_is_not_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxDurationConstraint does not satisfy + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert not isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that MaxDurationConstraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxDurationConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxDurationConstraint() + with pytest.raises(ValidationError): + MaxDurationConstraint(max_duration=-1) + with pytest.raises(ValidationError): + MaxDurationConstraint(max_duration=0) + with pytest.raises(ValidationError): + MaxDurationConstraint(max_duration="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions and progress through a time loop""" + instance, constructor_args = valid_instances + start_time = time.time() + + max_duration = constructor_args["max_duration"] + sleep_interval = max_duration * 0.05 + target_duration = max_duration * 1.5 + + elapsed = 0.0 + step = 0 + + while elapsed <= target_duration: + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=step + 1, + processed_requests=step, + ) + request = ScheduledRequestInfo( + request_id=f"test-{step}", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + + duration_exceeded = elapsed >= max_duration + + if not duration_exceeded: + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + else: + assert action.request_queuing == "stop" + assert action.request_processing == "stop_local" + assert isinstance(action.metadata, dict) + assert action.metadata["max_duration"] == max_duration + assert action.metadata["elapsed_time"] == pytest.approx(elapsed, abs=0.01) + assert action.metadata["duration_exceeded"] == duration_exceeded + assert action.metadata["start_time"] == start_time + assert isinstance(action.progress, dict) + expected_remaining_fraction = max(0.0, 1.0 - elapsed / max_duration) + expected_remaining_duration = max(0.0, max_duration - elapsed) + assert action.progress["remaining_fraction"] == pytest.approx( + expected_remaining_fraction, abs=0.1 + ) + assert action.progress["remaining_duration"] == pytest.approx( + expected_remaining_duration, abs=0.1 + ) + time.sleep(sleep_interval) + elapsed = time.time() - start_time + step += 1 + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxDurationConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxDurationConstraint.model_validate(data) + assert reconstructed.max_duration == instance.max_duration + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestMaxDurationConstraintInitializer: + """Test the MaxDurationConstraintInitializer implementation.""" + + @pytest.fixture( + params=[{"max_duration": 30.0}, {"max_duration": 60}, {"max_duration": 0.5}] + ) + def valid_instances(self, request): + """Provide valid instances of MaxDurationConstraintInitializer.""" + params = request.param + instance = MaxDurationConstraintInitializer(**params) + return instance, params + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self): + """Test that MaxDurationConstraintInitializer satisfies the protocol.""" + initializer = MaxDurationConstraintInitializer(max_duration=30.0) + assert isinstance(initializer, ConstraintInitializer) + + @pytest.mark.smoke + def test_is_not_constraint_protocol(self): + """ + Test that MaxDurationConstraintInitializer does not satisfy + the constraint protocol. + """ + initializer = MaxDurationConstraintInitializer(max_duration=30.0) + assert not isinstance(initializer, Constraint) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that the initializer can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that the initializer rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxDurationConstraintInitializer() + with pytest.raises(ValidationError): + MaxDurationConstraintInitializer(max_duration=0) + with pytest.raises(ValidationError): + MaxDurationConstraintInitializer(max_duration=-1) + with pytest.raises(ValidationError): + MaxDurationConstraintInitializer(max_duration="invalid") + + def test_constraint_initialization_functionality(self, valid_instances): + """Test that the constraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + constraint = instance.create_constraint() + assert isinstance(constraint, MaxDurationConstraint) + assert constraint.max_duration == constructor_args["max_duration"] + + def test_marshalling(self, valid_instances): + """ + Test that MaxDurationConstraintInitializer can be + serialized and deserialized. + """ + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxDurationConstraintInitializer.model_validate(data) + assert reconstructed.max_duration == instance.max_duration + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestMaxErrorsConstraint: + """Test the MaxErrorsConstraint implementation.""" + + @pytest.fixture(params=[{"max_errors": 10}, {"max_errors": 5.5}, {"max_errors": 1}]) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxErrorsConstraint(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxErrorsConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.sanity + def test_is_not_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxErrorsConstraint does not satisfy + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert not isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that MaxErrorsConstraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxErrorsConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxErrorsConstraint() + with pytest.raises(ValidationError): + MaxErrorsConstraint(max_errors=-1) + with pytest.raises(ValidationError): + MaxErrorsConstraint(max_errors=0) + with pytest.raises(ValidationError): + MaxErrorsConstraint(max_errors="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions""" + instance, constructor_args = valid_instances + start_time = time.time() + + for num_errors in range(int(constructor_args["max_errors"] * 2)): + created_requests = (num_errors + 1) * 2 + processed_requests = num_errors + 1 + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=created_requests, + processed_requests=processed_requests, + errored_requests=num_errors, + ) + request = ScheduledRequestInfo( + request_id=f"test-{num_errors}", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + errors_exceeded = num_errors >= constructor_args["max_errors"] + if not errors_exceeded: + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + else: + assert action.request_queuing == "stop" + assert action.request_processing == "stop_all" + + assert isinstance(action.metadata, dict) + assert action.metadata == { + "max_errors": constructor_args["max_errors"], + "errors_exceeded": errors_exceeded, + "current_errors": num_errors, + } + assert action.progress == {} + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxErrorsConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxErrorsConstraint.model_validate(data) + assert reconstructed.max_errors == instance.max_errors + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestMaxErrorsConstraintInitializer: + """Test the MaxErrorsConstraintInitializer implementation.""" + + @pytest.fixture(params=[{"max_errors": 10}, {"max_errors": 5.5}, {"max_errors": 1}]) + def valid_instances(self, request): + """Provide valid instances of MaxErrorsConstraintInitializer.""" + params = request.param + instance = MaxErrorsConstraintInitializer(**params) + return instance, params + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self): + """Test that MaxErrorsConstraintInitializer satisfies the protocol.""" + initializer = MaxErrorsConstraintInitializer(max_errors=10) + assert isinstance(initializer, ConstraintInitializer) + + @pytest.mark.smoke + def test_is_not_constraint_protocol(self): + """ + Test that MaxErrorsConstraintInitializer does not satisfy + the constraint protocol. + """ + initializer = MaxErrorsConstraintInitializer(max_errors=10) + assert not isinstance(initializer, Constraint) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that the initializer can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that the initializer rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxErrorsConstraintInitializer() + with pytest.raises(ValidationError): + MaxErrorsConstraintInitializer(max_errors=-1) + with pytest.raises(ValidationError): + MaxErrorsConstraintInitializer(max_errors=0) + with pytest.raises(ValidationError): + MaxErrorsConstraintInitializer(max_errors="invalid") + + def test_constraint_initialization_functionality(self, valid_instances): + """Test that the constraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + constraint = instance.create_constraint() + assert isinstance(constraint, MaxErrorsConstraint) + assert constraint.max_errors == constructor_args["max_errors"] + + def test_marshalling(self, valid_instances): + """ + Test that MaxErrorsConstraintInitializer can be + serialized and deserialized. + """ + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxErrorsConstraintInitializer.model_validate(data) + assert reconstructed.max_errors == instance.max_errors + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestMaxErrorRateConstraint: + """Test the MaxErrorRateConstraint implementation.""" + + @pytest.fixture( + params=[ + {"max_error_rate": 0.1, "window_size": 40}, + {"max_error_rate": 0.5, "window_size": 50}, + {"max_error_rate": 0.05, "window_size": 55}, + ] + ) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxErrorRateConstraint(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxErrorRateConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.sanity + def test_is_not_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxErrorRateConstraint does not satisfy + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert not isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that MaxErrorRateConstraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxErrorRateConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxErrorRateConstraint() + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate=0) + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate=-1) + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate=1.5) + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate=0.5, window_size=0) + with pytest.raises(ValidationError): + MaxErrorRateConstraint(max_error_rate="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions with sliding window behavior""" + instance, constructor_args = valid_instances + start_time = time.time() + + max_error_rate = constructor_args["max_error_rate"] + window_size = constructor_args["window_size"] + safety_factor = 1.5 + total_errors = 0 + error_window = [] + + for request_num in range(window_size * 2): + error_probability = max_error_rate * safety_factor + + if random.random() < error_probability: + total_errors += 1 + status = "errored" + error_window.append(1) + else: + status = "completed" + error_window.append(0) + error_window = ( + error_window[-window_size:] + if len(error_window) > window_size + else error_window + ) + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=request_num + 1, + processed_requests=request_num + 1, + ) + request = ScheduledRequestInfo( + request_id=f"test-{request_num}", + status=status, + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + error_count = sum(instance.error_window) + processed_requests = state.processed_requests + exceeded_min_processed = processed_requests >= window_size + current_error_rate = ( + error_count / float(min(processed_requests, window_size)) + if processed_requests > 0 + else 0.0 + ) + exceeded_error_rate = current_error_rate >= max_error_rate + should_stop = exceeded_min_processed and exceeded_error_rate + expected_queuing = "stop" if should_stop else "continue" + expected_processing = "stop_all" if should_stop else "continue" + + assert action.request_queuing == expected_queuing + assert action.request_processing == expected_processing + assert isinstance(action.metadata, dict) + assert action.metadata["max_error_rate"] == max_error_rate + assert action.metadata["window_size"] == window_size + assert action.metadata["error_count"] == error_count + assert action.metadata["current_error_rate"] == current_error_rate + assert action.metadata["exceeded_error_rate"] == exceeded_error_rate + assert action.progress == {} + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxErrorRateConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxErrorRateConstraint.model_validate(data) + assert reconstructed.max_error_rate == instance.max_error_rate + assert reconstructed.window_size == instance.window_size + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestMaxErrorRateConstraintInitializer: + """Test the MaxErrorRateConstraintInitializer implementation.""" + + @pytest.fixture( + params=[ + {"max_error_rate": 0.1, "window_size": 10}, + {"max_error_rate": 0.5, "window_size": 20}, + {"max_error_rate": 0.05, "window_size": 5}, + ] + ) + def valid_instances(self, request): + """Provide valid instances of MaxErrorRateConstraintInitializer.""" + params = request.param + instance = MaxErrorRateConstraintInitializer(**params) + return instance, params + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self): + """Test that MaxErrorRateConstraintInitializer satisfies the protocol.""" + initializer = MaxErrorRateConstraintInitializer( + max_error_rate=0.1, window_size=10 + ) + assert isinstance(initializer, ConstraintInitializer) + + @pytest.mark.smoke + def test_is_not_constraint_protocol(self): + """ + Test that MaxErrorRateConstraintInitializer does not satisfy + the constraint protocol. + """ + initializer = MaxErrorRateConstraintInitializer( + max_error_rate=0.1, window_size=10 + ) + assert not isinstance(initializer, Constraint) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that the initializer can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that the initializer rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxErrorRateConstraintInitializer() + with pytest.raises(ValidationError): + MaxErrorRateConstraintInitializer(max_error_rate=0) + with pytest.raises(ValidationError): + MaxErrorRateConstraintInitializer(max_error_rate=-1) + with pytest.raises(ValidationError): + MaxErrorRateConstraintInitializer(max_error_rate=1.5) + with pytest.raises(ValidationError): + MaxErrorRateConstraintInitializer(max_error_rate=0.5, window_size=0) + + def test_constraint_initialization_functionality(self, valid_instances): + """Test that the constraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + constraint = instance.create_constraint() + assert isinstance(constraint, MaxErrorRateConstraint) + assert constraint.max_error_rate == constructor_args["max_error_rate"] + assert constraint.window_size == constructor_args["window_size"] + + def test_marshalling(self, valid_instances): + """ + Test that MaxErrorRateConstraintInitializer can be + serialized and deserialized. + """ + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxErrorRateConstraintInitializer.model_validate(data) + assert reconstructed.max_error_rate == instance.max_error_rate + assert reconstructed.window_size == instance.window_size + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestMaxGlobalErrorRateConstraint: + """Test the MaxGlobalErrorRateConstraint implementation.""" + + @pytest.fixture( + params=[ + {"max_error_rate": 0.1, "min_processed": 50}, + {"max_error_rate": 0.2, "min_processed": 100}, + {"max_error_rate": 0.05, "min_processed": 31}, + ] + ) + def valid_instances(self, request): + constructor_args = request.param + instance = MaxGlobalErrorRateConstraint(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that MaxGlobalErrorRateConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.sanity + def test_is_not_constraint_initializer_protocol(self, valid_instances): + """ + Test that MaxGlobalErrorRateConstraint does not satisfy + the ConstraintInitializer protocol. + """ + constraint, _ = valid_instances + assert not isinstance(constraint, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """ + Test that MaxGlobalErrorRateConstraint can be initialized + with valid parameters. + """ + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that MaxGlobalErrorRateConstraint rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint() + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate=0) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate=-1) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate=1.5) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate=0.5, min_processed=30) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraint(max_error_rate="invalid") + + @pytest.mark.smoke + def test_constraint_functionality(self, valid_instances): + """Test constraint returns correct actions based on global error rate""" + instance, constructor_args = valid_instances + start_time = time.time() + + max_error_rate = constructor_args["max_error_rate"] + min_processed = constructor_args["min_processed"] + safety_factor = 1.5 + total_requests = min_processed * 2 + total_errors = 0 + + for request_num in range(total_requests): + error_probability = max_error_rate * safety_factor + + if random.random() < error_probability: + total_errors += 1 + status = "errored" + else: + status = "completed" + + processed_requests = request_num + 1 + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=processed_requests + 10, + processed_requests=processed_requests, + errored_requests=total_errors, + ) + request = ScheduledRequestInfo( + request_id=f"test-{request_num}", + status=status, + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = instance(state, request) + assert isinstance(action, SchedulerUpdateAction) + + exceeded_min_processed = processed_requests >= min_processed + error_rate = ( + total_errors / float(processed_requests) + if processed_requests > 0 + else 0.0 + ) + exceeded_error_rate = error_rate >= max_error_rate + should_stop = exceeded_min_processed and exceeded_error_rate + + expected_queuing = "stop" if should_stop else "continue" + expected_processing = "stop_all" if should_stop else "continue" + + assert action.request_queuing == expected_queuing + assert action.request_processing == expected_processing + + assert isinstance(action.metadata, dict) + assert action.metadata == { + "max_error_rate": max_error_rate, + "min_processed": min_processed, + "processed_requests": processed_requests, + "errored_requests": total_errors, + "error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + } + + # Error constraints don't provide progress information + assert action.progress == {} + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that MaxGlobalErrorRateConstraint can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxGlobalErrorRateConstraint.model_validate(data) + assert reconstructed.max_error_rate == instance.max_error_rate + assert reconstructed.min_processed == instance.min_processed + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestMaxGlobalErrorRateConstraintInitializer: + """Test the MaxGlobalErrorRateConstraintInitializer implementation.""" + + @pytest.fixture( + params=[ + {"max_error_rate": 0.1, "min_processed": 50}, + {"max_error_rate": 0.2, "min_processed": 100}, + {"max_error_rate": 0.05, "min_processed": 31}, + ] + ) + def valid_instances(self, request): + """Provide valid instances of MaxGlobalErrorRateConstraintInitializer.""" + params = request.param + instance = MaxGlobalErrorRateConstraintInitializer(**params) + return instance, params + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self): + """Test that MaxGlobalErrorRateConstraintInitializer satisfies the protocol.""" + initializer = MaxGlobalErrorRateConstraintInitializer( + max_error_rate=0.1, min_processed=50 + ) + assert isinstance(initializer, ConstraintInitializer) + + @pytest.mark.smoke + def test_is_not_constraint_protocol(self): + """ + Test that MaxGlobalErrorRateConstraintInitializer does not satisfy + the constraint protocol. + """ + initializer = MaxGlobalErrorRateConstraintInitializer( + max_error_rate=0.1, min_processed=50 + ) + assert not isinstance(initializer, Constraint) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that the initializer can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that the initializer rejects invalid parameters.""" + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraintInitializer() + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraintInitializer(max_error_rate=0) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraintInitializer(max_error_rate=-1) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraintInitializer(max_error_rate=1.5) + with pytest.raises(ValidationError): + MaxGlobalErrorRateConstraintInitializer( + max_error_rate=0.5, min_processed=30 + ) + + def test_constraint_initialization_functionality(self, valid_instances): + """Test that the constraint can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + constraint = instance.create_constraint() + assert isinstance(constraint, MaxGlobalErrorRateConstraint) + assert constraint.max_error_rate == constructor_args["max_error_rate"] + assert constraint.min_processed == constructor_args["min_processed"] + + def test_marshalling(self, valid_instances): + """ + Test that MaxGlobalErrorRateConstraintInitializer can be + serialized and deserialized. + """ + instance, constructor_args = valid_instances + + data = instance.model_dump() + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = MaxGlobalErrorRateConstraintInitializer.model_validate(data) + assert reconstructed.max_error_rate == instance.max_error_rate + assert reconstructed.min_processed == instance.min_processed + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestConstraintsInitializerFactory: + """Test the ConstraintsInitializerFactory implementation.""" + + EXPECTED_REGISTERED_KEYS = { + "max_number": MaxNumberConstraintInitializer, + "max_duration": MaxDurationConstraintInitializer, + "max_errors": MaxErrorsConstraintInitializer, + "max_error_rate": MaxErrorRateConstraintInitializer, + "max_global_error_rate": MaxGlobalErrorRateConstraintInitializer, + } + + @pytest.mark.smoke + def test_registered_constraint_keys(self): + """Test that all expected constraint keys are registered and no others.""" + registered_keys = set(ConstraintsInitializerFactory.registered_objects().keys()) + expected_keys = set(self.EXPECTED_REGISTERED_KEYS.keys()) + + assert registered_keys == expected_keys, ( + f"Registered keys {registered_keys} do not match " + f"expected keys {expected_keys}" + ) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("key", "expected_class"), + [ + ("max_number", MaxNumberConstraintInitializer), + ("max_duration", MaxDurationConstraintInitializer), + ("max_errors", MaxErrorsConstraintInitializer), + ("max_error_rate", MaxErrorRateConstraintInitializer), + ("max_global_error_rate", MaxGlobalErrorRateConstraintInitializer), + ], + ) + def test_registered_constraint_classes(self, key, expected_class): + """Test that each registered key maps to the expected initializer class.""" + assert ConstraintsInitializerFactory.is_registered(key) + registered_class = ConstraintsInitializerFactory.get_registered_object(key) + assert registered_class == expected_class + + @pytest.mark.sanity + def test_unregistered_key_fails(self): + """Test that unregistered keys raise ValueError.""" + unregistered_key = "nonexistent_constraint" + assert not ConstraintsInitializerFactory.is_registered(unregistered_key) + + with pytest.raises( + ValueError, match=f"Unknown constraint initializer key: {unregistered_key}" + ): + ConstraintsInitializerFactory.create(unregistered_key) + + with pytest.raises( + ValueError, match=f"Unknown constraint initializer key: {unregistered_key}" + ): + ConstraintsInitializerFactory.create_constraint(unregistered_key) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("key", "init_args", "expected_constraint_class"), + [ + ("max_number", {"max_num": 100}, MaxNumberConstraint), + ("max_duration", {"max_duration": 30.0}, MaxDurationConstraint), + ("max_errors", {"max_errors": 5}, MaxErrorsConstraint), + ( + "max_error_rate", + {"max_error_rate": 0.1, "window_size": 50}, + MaxErrorRateConstraint, + ), + ( + "max_global_error_rate", + {"max_error_rate": 0.05, "min_processed": 100}, + MaxGlobalErrorRateConstraint, + ), + ], + ) + def test_create_initializer(self, key, init_args, expected_constraint_class): + """Test that create method returns properly configured initializers.""" + initializer = ConstraintsInitializerFactory.create(key, **init_args) + + assert isinstance(initializer, ConstraintInitializer) + assert isinstance(initializer, self.EXPECTED_REGISTERED_KEYS[key]) + + for attr_name, attr_value in init_args.items(): + assert hasattr(initializer, attr_name) + assert getattr(initializer, attr_name) == attr_value + + constraint = initializer.create_constraint() + assert isinstance(constraint, Constraint) + assert isinstance(constraint, expected_constraint_class) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("key", "init_args", "expected_constraint_class"), + [ + ("max_number", {"max_num": 100}, MaxNumberConstraint), + ("max_duration", {"max_duration": 30.0}, MaxDurationConstraint), + ("max_errors", {"max_errors": 5}, MaxErrorsConstraint), + ( + "max_error_rate", + {"max_error_rate": 0.1, "window_size": 50}, + MaxErrorRateConstraint, + ), + ( + "max_global_error_rate", + {"max_error_rate": 0.05, "min_processed": 100}, + MaxGlobalErrorRateConstraint, + ), + ], + ) + def test_create_constraint_direct(self, key, init_args, expected_constraint_class): + """Test that create_constraint method returns configured constraints.""" + constraint = ConstraintsInitializerFactory.create_constraint(key, **init_args) + + assert isinstance(constraint, Constraint) + assert isinstance(constraint, expected_constraint_class) + + for attr_name, attr_value in init_args.items(): + assert hasattr(constraint, attr_name) + assert getattr(constraint, attr_name) == attr_value + + @pytest.mark.smoke + def test_resolve_with_constraint_instances(self): + """Test resolve method with pre-instantiated Constraint objects.""" + max_num_constraint = MaxNumberConstraint(max_num=50) + max_duration_constraint = MaxDurationConstraint(max_duration=60.0) + + initializers = { + "max_number": max_num_constraint, + "max_duration": max_duration_constraint, + } + + resolved = ConstraintsInitializerFactory.resolve(initializers) + + assert len(resolved) == 2 + assert resolved["max_number"] is max_num_constraint + assert resolved["max_duration"] is max_duration_constraint + assert all(isinstance(c, Constraint) for c in resolved.values()) + + @pytest.mark.smoke + def test_resolve_with_initializer_instances(self): + """Test resolve method with pre-instantiated ConstraintInitializer objects.""" + max_num_initializer = MaxNumberConstraintInitializer(max_num=75) + max_errors_initializer = MaxErrorsConstraintInitializer(max_errors=10) + + initializers = { + "max_number": max_num_initializer, + "max_errors": max_errors_initializer, + } + + resolved = ConstraintsInitializerFactory.resolve(initializers) + + assert len(resolved) == 2 + assert isinstance(resolved["max_number"], MaxNumberConstraint) + assert isinstance(resolved["max_errors"], MaxErrorsConstraint) + assert resolved["max_number"].max_num == 75 + assert resolved["max_errors"].max_errors == 10 + assert all(isinstance(c, Constraint) for c in resolved.values()) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("input_spec", "expected_values"), + [ + ( + { + "max_number": {"max_num": 200}, + "max_duration": {"max_duration": 45.0}, + "max_errors": {"max_errors": 3}, + }, + { + "max_number": ("max_num", 200), + "max_duration": ("max_duration", 45.0), + "max_errors": ("max_errors", 3), + }, + ), + ( + { + "max_error_rate": {"max_error_rate": 0.15, "window_size": 100}, + "max_global_error_rate": { + "max_error_rate": 0.08, + "min_processed": 50, + }, + }, + { + "max_error_rate": ("max_error_rate", 0.15), + "max_global_error_rate": ("max_error_rate", 0.08), + }, + ), + ], + ) + def test_resolve_with_dict_configs(self, input_spec, expected_values): + """Test resolve method with dictionary configurations.""" + resolved = ConstraintsInitializerFactory.resolve(input_spec) + + assert len(resolved) == len(input_spec) + assert all(isinstance(c, Constraint) for c in resolved.values()) + + for key, (attr_name, attr_value) in expected_values.items(): + assert key in resolved + constraint = resolved[key] + assert hasattr(constraint, attr_name) + assert getattr(constraint, attr_name) == attr_value + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("input_spec", "expected_values"), + [ + ( + {"max_number": 150}, + {"max_number": ("max_num", 150)}, + ), + ( + {"max_duration": 90.0}, + {"max_duration": ("max_duration", 90.0)}, + ), + ( + {"max_errors": 8}, + {"max_errors": ("max_errors", 8)}, + ), + ( + {"max_error_rate": 0.15}, + {"max_error_rate": ("max_error_rate", 0.15)}, + ), + ( + {"max_global_error_rate": 0.05}, + {"max_global_error_rate": ("max_error_rate", 0.05)}, + ), + ], + ) + def test_resolve_with_simple_values(self, input_spec, expected_values): + """Test that resolve method now supports simple scalar values.""" + resolved = ConstraintsInitializerFactory.resolve(input_spec) + + assert len(resolved) == len(input_spec) + assert all(isinstance(c, Constraint) for c in resolved.values()) + + for key, (attr_name, attr_value) in expected_values.items(): + assert key in resolved + constraint = resolved[key] + assert hasattr(constraint, attr_name) + assert getattr(constraint, attr_name) == attr_value + + @pytest.mark.smoke + def test_resolve_mixed_types(self): + """Test resolve method with mixed constraint types including simple values.""" + max_num_constraint = MaxNumberConstraint(max_num=25) + max_duration_initializer = MaxDurationConstraintInitializer(max_duration=120.0) + + mixed_spec = { + "max_number": max_num_constraint, + "max_duration": max_duration_initializer, + "max_errors": {"max_errors": 15}, + "max_error_rate": 0.08, + "max_global_error_rate": {"max_error_rate": 0.12}, + } + + resolved = ConstraintsInitializerFactory.resolve(mixed_spec) + + assert len(resolved) == 5 + assert all(isinstance(c, Constraint) for c in resolved.values()) + assert resolved["max_number"] is max_num_constraint + assert isinstance(resolved["max_duration"], MaxDurationConstraint) + assert isinstance(resolved["max_errors"], MaxErrorsConstraint) + assert isinstance(resolved["max_error_rate"], MaxErrorRateConstraint) + assert isinstance( + resolved["max_global_error_rate"], MaxGlobalErrorRateConstraint + ) + + assert resolved["max_error_rate"].max_error_rate == 0.08 + + @pytest.mark.sanity + def test_resolve_constraints_method_bug_fixed(self): + """Test resolve_constraints method now works correctly after bug fix. + + Note: Previously resolve_constraints had a bug where the parameter name + 'constraints' was shadowed by a local variable, causing it to + always return an empty dictionary. This bug has been fixed. + """ + max_num_constraint = MaxNumberConstraint(max_num=80) + + constraints_spec = { + "max_number": max_num_constraint, + "max_duration": {"max_duration": 300.0}, + } + + resolved = ConstraintsInitializerFactory.resolve_constraints(constraints_spec) + + assert len(resolved) == 2 + assert resolved["max_number"] is max_num_constraint + assert isinstance(resolved["max_duration"], MaxDurationConstraint) + assert resolved["max_duration"].max_duration == 300.0 + + @pytest.mark.sanity + def test_resolve_with_invalid_key(self): + """Test that resolve raises ValueError for unregistered keys.""" + invalid_spec = { + "max_number": {"max_num": 100}, + "invalid_constraint": {"some_param": 42}, + } + + with pytest.raises( + ValueError, match="Unknown constraint initializer key: invalid_constraint" + ): + ConstraintsInitializerFactory.resolve(invalid_spec) + + @pytest.mark.sanity + def test_resolve_constraints_with_invalid_key_now_raises(self): + """Test that resolve_constraints now properly validates keys after bug fix. + + Note: Previously due to the variable shadowing bug in resolve_constraints, + it didn't actually process the input and therefore didn't validate keys, + always returning an empty dictionary. Now it properly validates. + """ + invalid_spec = { + "max_duration": {"max_duration": 60.0}, + "nonexistent_key": {"param": "value"}, + } + + with pytest.raises(ValueError, match="Unknown constraint initializer key"): + ConstraintsInitializerFactory.resolve_constraints(invalid_spec) + + @pytest.mark.smoke + def test_functional_constraint_creation(self): + """Test that created constraints are functionally correct.""" + constraint = ConstraintsInitializerFactory.create_constraint( + "max_number", max_num=10 + ) + start_time = time.time() + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=5, + processed_requests=5, + ) + request = ScheduledRequestInfo( + request_id="test-request", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + + state_exceeded = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=15, + processed_requests=15, + ) + action_exceeded = constraint(state_exceeded, request) + assert action_exceeded.request_queuing == "stop" + assert action_exceeded.request_processing == "stop_local" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("initializer_class", "key", "value", "expected_attr"), + [ + (MaxNumberConstraintInitializer, "max_number", 100, "max_num"), + (MaxDurationConstraintInitializer, "max_duration", 45.0, "max_duration"), + (MaxErrorsConstraintInitializer, "max_errors", 5, "max_errors"), + ( + MaxErrorRateConstraintInitializer, + "max_error_rate", + 0.1, + "max_error_rate", + ), + ( + MaxGlobalErrorRateConstraintInitializer, + "max_global_error_rate", + 0.05, + "max_error_rate", + ), + ], + ) + def test_from_simple_value_class_method( + self, initializer_class, key, value, expected_attr + ): + """Test that each initializer class properly handles from_simple_value.""" + initializer = initializer_class.from_simple_value(value) + assert hasattr(initializer, expected_attr) + assert getattr(initializer, expected_attr) == value + + constraint = initializer.create_constraint() + assert hasattr(constraint, expected_attr) + assert getattr(constraint, expected_attr) == value + + factory_result = ConstraintsInitializerFactory.resolve({key: value}) + assert key in factory_result + factory_constraint = factory_result[key] + assert hasattr(factory_constraint, expected_attr) + assert getattr(factory_constraint, expected_attr) == value diff --git a/tests/unit/scheduler/test_environment.py b/tests/unit/scheduler/test_environment.py new file mode 100644 index 00000000..2ff377ab --- /dev/null +++ b/tests/unit/scheduler/test_environment.py @@ -0,0 +1,276 @@ +import inspect +import time +from abc import ABC +from typing import Generic +from unittest.mock import patch + +import pytest + +from guidellm.scheduler import ( + Environment, + MaxNumberConstraint, + NonDistributedEnvironment, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerState, + SynchronousStrategy, +) + + +class TestEnvironment: + @pytest.mark.smoke + def test_is_abstract_base_class(self): + """Test that Environment is an abstract base class.""" + assert issubclass(Environment, ABC) + assert inspect.isabstract(Environment) + + @pytest.mark.smoke + def test_abstract_methods_defined(self): + """Test that the required abstract methods are defined.""" + abstract_methods = Environment.__abstractmethods__ + expected_methods = { + "sync_run_params", + "sync_run_start", + "update_run_iteration", + "sync_run_error", + "sync_run_end", + } + assert abstract_methods == expected_methods + + @pytest.mark.smoke + def test_generic_type_parameters(self): + """Test that Environment is generic with correct type parameters.""" + assert issubclass(Environment, Generic) + # Environment should be Generic[RequestT, ResponseT] + orig_bases = getattr(Environment, "__orig_bases__", ()) + assert len(orig_bases) > 0 + generic_base = next( + ( + base + for base in orig_bases + if hasattr(base, "__origin__") and base.__origin__ is Generic + ), + None, + ) + assert generic_base is not None + type_args = getattr(generic_base, "__args__", ()) + assert RequestT in type_args + assert ResponseT in type_args + + @pytest.mark.smoke + def test_invalid_implementation(self): + """Test that invalid implementations raise TypeError.""" + + class InvalidImplementation(Environment): + pass + + with pytest.raises(TypeError): + InvalidImplementation() + + @pytest.mark.smoke + def test_implementation_construction(self): + """Test that concrete implementations can be constructed.""" + + class TestEnvironment(Environment): + async def sync_run_params(self, requests, strategy, constraints): + return requests, strategy, constraints + + async def sync_run_start(self): + return 0.0 + + async def update_run_iteration(self, response, request, request_info): + pass + + async def sync_run_error(self, err): + pass + + async def sync_run_end(self): + yield + + env = TestEnvironment() + assert isinstance(env, Environment) + + @pytest.mark.smoke + def test_method_signatures(self): + """Test that method signatures match expected interface.""" + params_sig = inspect.signature(Environment.sync_run_params) + assert len(params_sig.parameters) == 4 + param_names = list(params_sig.parameters.keys()) + assert param_names == ["self", "requests", "strategy", "constraints"] + + start_sig = inspect.signature(Environment.sync_run_start) + assert len(start_sig.parameters) == 1 + assert "self" in start_sig.parameters + + update_sig = inspect.signature(Environment.update_run_iteration) + assert len(update_sig.parameters) == 5 + param_names = list(update_sig.parameters.keys()) + assert param_names == ["self", "response", "request", "request_info", "state"] + + error_sig = inspect.signature(Environment.sync_run_error) + assert len(error_sig.parameters) == 2 + param_names = list(error_sig.parameters.keys()) + assert param_names == ["self", "err"] + + end_sig = inspect.signature(Environment.sync_run_end) + assert len(end_sig.parameters) == 1 + assert "self" in end_sig.parameters + + +class TestNonDistributedEnvironment: + @pytest.mark.smoke + def test_initialization(self): + """Test basic initialization of NonDistributedEnvironment.""" + env = NonDistributedEnvironment() + assert env.run_errors == [] + assert isinstance(env, Environment) + + @pytest.mark.sanity + def test_invalid_initialization(self): + """Test that initialization doesn't accept invalid arguments.""" + with pytest.raises(TypeError): + NonDistributedEnvironment("invalid_arg") + + @pytest.mark.smoke + def test_inheritance_and_typing(self): + """Test inheritance and type relationships.""" + env = NonDistributedEnvironment() + + # Should inherit from Environment + assert isinstance(env, Environment) + assert issubclass(NonDistributedEnvironment, Environment) + + # Should implement all required methods + required_methods = [ + "sync_run_params", + "sync_run_start", + "update_run_iteration", + "sync_run_error", + "sync_run_end", + ] + + for method_name in required_methods: + assert hasattr(env, method_name) + assert callable(getattr(env, method_name)) + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("requests", "strategy", "constraints", "error_to_inject"), + [ + ( + ["request1", "request2"], + SynchronousStrategy(), + {"max_requests": MaxNumberConstraint(max_num=10)}, + None, + ), + ( + [], + SynchronousStrategy(), + {}, + None, + ), + ( + ["single_request"], + SynchronousStrategy(), + {"max_requests": MaxNumberConstraint(max_num=1)}, + RuntimeError("Test error"), + ), + ( + range(5), + SynchronousStrategy(), + {"max_requests": MaxNumberConstraint(max_num=5)}, + ValueError("Connection failed"), + ), + ], + ids=[ + "normal_execution", + "empty_requests", + "with_error", + "multiple_requests_with_error", + ], + ) + async def test_lifecycle(self, requests, strategy, constraints, error_to_inject): + """Test the complete lifecycle of environment methods.""" + env = NonDistributedEnvironment() + + ( + returned_requests, + returned_strategy, + returned_constraints, + ) = await env.sync_run_params(requests, strategy, constraints) + assert returned_requests is requests + assert returned_strategy is strategy + assert returned_constraints is constraints + + with ( + patch("time.time", return_value=1000.0), + patch("guidellm.scheduler.environment.settings") as mock_settings, + ): + mock_settings.scheduler_start_delay_non_distributed = 2.5 + start_time = await env.sync_run_start() + assert start_time == 1002.5 + + mock_response = "mock_response" + mock_request = "mock_request" + mock_request_info = ScheduledRequestInfo( + request_id="test-123", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=time.time(), + ) + mock_state = SchedulerState( + node_id=0, + num_processes=1, + start_time=time.time(), + ) + + await env.update_run_iteration( + mock_response, mock_request, mock_request_info, mock_state + ) + await env.update_run_iteration( + None, mock_request, mock_request_info, mock_state + ) + await env.update_run_iteration( + mock_response, None, mock_request_info, mock_state + ) + + if error_to_inject: + await env.sync_run_error(error_to_inject) + assert error_to_inject in env.run_errors + + if error_to_inject: + with pytest.raises(type(error_to_inject)) as exc_info: + async for _ in env.sync_run_end(): + pass + assert str(exc_info.value) == str(error_to_inject) + else: + results = [] + async for result in env.sync_run_end(): + results.append(result) + assert results == [] + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_sync_run_start_uses_config(self): + """Test that sync_run_start uses configuration value.""" + env = NonDistributedEnvironment() + + with ( + patch("time.time", return_value=500.0), + patch("guidellm.scheduler.environment.settings") as mock_settings, + ): + # Test different delay values + mock_settings.scheduler_start_delay_non_distributed = 0.0 + start_time = await env.sync_run_start() + assert start_time == 500.0 + + mock_settings.scheduler_start_delay_non_distributed = 1.5 + start_time = await env.sync_run_start() + assert start_time == 501.5 + + mock_settings.scheduler_start_delay_non_distributed = 10.0 + start_time = await env.sync_run_start() + assert start_time == 510.0 diff --git a/tests/unit/scheduler/test_objects.py b/tests/unit/scheduler/test_objects.py new file mode 100644 index 00000000..a0e0bb73 --- /dev/null +++ b/tests/unit/scheduler/test_objects.py @@ -0,0 +1,1204 @@ +import inspect +import typing +from collections.abc import AsyncIterator +from typing import Any, Optional, TypeVar + +import pytest +from pydantic import ValidationError + +from guidellm.scheduler import ( + BackendInterface, + BackendT, + MeasuredRequestTimings, + MeasuredRequestTimingsT, + RequestSchedulerTimings, + RequestT, + ResponseT, + ScheduledRequestInfo, + SchedulerState, + SchedulerUpdateAction, + SchedulerUpdateActionProgress, +) + + +def test_request_t(): + """Validate that RequestT is a TypeVar usable for generics and isn't bound.""" + assert isinstance(RequestT, TypeVar) + assert RequestT.__name__ == "RequestT" + assert RequestT.__bound__ is None + assert RequestT.__constraints__ == () + + +def test_response_t(): + """Validate that ResponseT is a TypeVar usable for generics and isn't bound.""" + assert isinstance(ResponseT, TypeVar) + assert ResponseT.__name__ == "ResponseT" + assert ResponseT.__bound__ is None + assert ResponseT.__constraints__ == () + + +def test_request_timings_t(): + """Validate MeasuredRequestTimingsT is a TypeVar bound to MeasuredRequestTimings.""" + assert isinstance(MeasuredRequestTimingsT, TypeVar) + assert MeasuredRequestTimingsT.__name__ == "MeasuredRequestTimingsT" + assert MeasuredRequestTimingsT.__bound__ == MeasuredRequestTimings + assert MeasuredRequestTimingsT.__constraints__ == () + + +def test_backend_t(): + """Validate that BackendT is a TypeVar bound to BackendInterface.""" + assert isinstance(BackendT, TypeVar) + assert BackendT.__name__ == "BackendT" + assert BackendT.__bound__.__name__ == "BackendInterface" + assert BackendT.__constraints__ == () + + +class TestBackendInterface: + """Test the BackendInterface abstract base class.""" + + @pytest.mark.smoke + def test_is_abstract_base_class(self): + """Test that BackendInterface is an ABC and cannot be instantiated directly.""" + from abc import ABC + + assert issubclass(BackendInterface, ABC) + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + BackendInterface() + + @pytest.mark.smoke + def test_abstract_methods_defined(self): + """Test that all expected abstract methods are defined.""" + expected_methods = { + "info", + "process_startup", + "validate", + "process_shutdown", + "resolve", + } + expected_properties = { + "processes_limit", + "requests_limit", + } + + for method_name in expected_methods: + assert hasattr(BackendInterface, method_name) + method = getattr(BackendInterface, method_name) + assert inspect.isfunction(method) or inspect.ismethod(method) + + for prop_name in expected_properties: + assert hasattr(BackendInterface, prop_name) + prop = getattr(BackendInterface, prop_name) + assert hasattr(prop, "__get__") + + @pytest.mark.smoke + def test_generic_type_parameters(self): + """Test that BackendInterface has the correct generic type parameters.""" + orig_bases = BackendInterface.__orig_bases__ + abc_base = None + generic_base = None + + for base in orig_bases: + if hasattr(base, "__origin__"): + if base.__origin__ is typing.Generic: + generic_base = base + elif base.__name__ == "ABC": + abc_base = base + + assert abc_base is not None, "Should inherit from ABC" + assert generic_base is not None, "Should inherit from Generic" + + if hasattr(generic_base, "__args__"): + type_params = generic_base.__args__ + assert len(type_params) == 3, "Should have 3 type parameters" + param_names = [param.__name__ for param in type_params] + expected_names = ["RequestT", "MeasuredRequestTimingsT", "ResponseT"] + assert param_names == expected_names + + @pytest.mark.sanity + def test_invalid_implementation(self): + """Test that a concrete implementation must implement all abstract methods.""" + + class PartialBackend(BackendInterface): + @property + def processes_limit(self): + return 1 + + @property + def requests_limit(self): + return 10 + + def info(self): + return {} + + async def process_startup(self): + pass + + # Missing: validate, process_shutdown, resolve + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + PartialBackend() + + @pytest.mark.smoke + def test_implementation_construction(self): + """Test that a complete concrete implementation can be instantiated.""" + + class ConcreteBackend(BackendInterface[str, MeasuredRequestTimings, str]): + @property + def processes_limit(self) -> Optional[int]: + return 4 + + @property + def requests_limit(self) -> Optional[int]: + return 100 + + def info(self) -> dict[str, Any]: + return {"model": "test", "version": "1.0"} + + async def process_startup(self) -> None: + pass + + async def validate(self) -> None: + pass + + async def process_shutdown(self) -> None: + pass + + async def resolve( + self, + request: str, + request_info: ScheduledRequestInfo[MeasuredRequestTimings], + history: Optional[list[tuple[str, str]]] = None, + ) -> AsyncIterator[ + tuple[str, ScheduledRequestInfo[MeasuredRequestTimings]] + ]: + yield f"Response to: {request}", request_info + + backend = ConcreteBackend() + assert isinstance(backend, BackendInterface) + assert isinstance(backend, ConcreteBackend) + assert backend.processes_limit == 4 + assert backend.requests_limit == 100 + info = backend.info() + assert info == {"model": "test", "version": "1.0"} + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_implementation_async_methods(self): + """Test that async methods work correctly in concrete implementation.""" + + class AsyncBackend(BackendInterface[dict, MeasuredRequestTimings, dict]): + def __init__(self): + self.startup_called = False + self.validate_called = False + self.shutdown_called = False + + @property + def processes_limit(self) -> Optional[int]: + return None # Unlimited + + @property + def requests_limit(self) -> Optional[int]: + return None # Unlimited + + def info(self) -> dict[str, Any]: + return {"backend": "async_test"} + + async def process_startup(self) -> None: + self.startup_called = True + + async def validate(self) -> None: + self.validate_called = True + + async def process_shutdown(self) -> None: + self.shutdown_called = True + + async def resolve( + self, + request: dict, + request_info: ScheduledRequestInfo[MeasuredRequestTimings], + history: Optional[list[tuple[dict, dict]]] = None, + ) -> AsyncIterator[ + tuple[dict, ScheduledRequestInfo[MeasuredRequestTimings]] + ]: + response = {"result": request.get("input", ""), "status": "success"} + yield response, request_info + + backend = AsyncBackend() + await backend.process_startup() + assert backend.startup_called + + await backend.validate() + assert backend.validate_called + + await backend.process_shutdown() + assert backend.shutdown_called + + request = {"input": "test_request"} + request_info = ScheduledRequestInfo( + request_id="test-123", + status="queued", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=1000.0, + ) + results = [] + async for response, updated_info in backend.resolve(request, request_info): + results.append((response, updated_info)) + + assert len(results) == 1 + response, updated_info = results[0] + assert response == {"result": "test_request", "status": "success"} + assert updated_info == request_info + + @pytest.mark.smoke + def test_method_signatures(self): + """Test that abstract methods have the expected signatures.""" + info_sig = inspect.signature(BackendInterface.info) + assert len(info_sig.parameters) == 1 + assert list(info_sig.parameters.keys()) == ["self"] + + startup_sig = inspect.signature(BackendInterface.process_startup) + assert len(startup_sig.parameters) == 1 # Only self + assert list(startup_sig.parameters.keys()) == ["self"] + + validate_sig = inspect.signature(BackendInterface.validate) + assert len(validate_sig.parameters) == 1 # Only self + assert list(validate_sig.parameters.keys()) == ["self"] + + shutdown_sig = inspect.signature(BackendInterface.process_shutdown) + assert len(shutdown_sig.parameters) == 1 # Only self + assert list(shutdown_sig.parameters.keys()) == ["self"] + + resolve_sig = inspect.signature(BackendInterface.resolve) + expected_params = ["self", "request", "request_info", "history"] + assert list(resolve_sig.parameters.keys()) == expected_params + + history_param = resolve_sig.parameters["history"] + assert history_param.default is None + + +class TestRequestSchedulerTimings: + CHECK_KEYS = [ + "targeted_start", + "queued", + "dequeued", + "resolve_start", + "resolve_end", + "finalized", + ] + + @pytest.fixture( + params=[ + # Default empty configuration + {}, + # All None values explicitly set + { + "targeted_start": None, + "queued": None, + "dequeued": None, + "resolve_start": None, + "resolve_end": None, + "finalized": None, + }, + # Complete timing sequence + { + "targeted_start": 1000.0, + "queued": 200.0, + "dequeued": 800.0, + "resolve_start": 1000.5, + "resolve_end": 1100.0, + "finalized": 1100.5, + }, + # Partial timing data + { + "queued": 200.0, + "resolve_start": 1000.5, + "resolve_end": 1100.0, + }, + # Edge case: zero timestamps + { + "targeted_start": 0.0, + "queued": 0.0, + "dequeued": 0.0, + "resolve_start": 0.0, + "resolve_end": 0.0, + "finalized": 0.0, + }, + ], + ids=[ + "default_empty", + "all_none_explicit", + "complete_sequence", + "partial_data", + "zero_timestamps", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of RequestSchedulerTimings. + + Returns: + tuple: (instance, constructor_args) where instance is the constructed + RequestSchedulerTimings and constructor_args are the kwargs used. + """ + constructor_args = request.param + instance = RequestSchedulerTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, RequestSchedulerTimings) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args + for field, expected_value in constructor_args.items(): + assert getattr(instance, field) == expected_value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("targeted_start", "invalid_string"), + ("queued", "invalid_string"), + ("dequeued", [1, 2, 3]), + ("resolve_start", {"key": "value"}), + ("resolve_end", [1, 2, 3]), + ("finalized", object()), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + RequestSchedulerTimings(**kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = RequestSchedulerTimings.model_validate(data) + assert isinstance(reconstructed, RequestSchedulerTimings) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + assert getattr(reconstructed, field) == getattr(instance, field) + + # Validate that the reconstructed instance matches original constructor args + for field, expected_value in constructor_args.items(): + assert getattr(reconstructed, field) == expected_value + + +class TestRequestTimings: + CHECK_KEYS = [ + "request_start", + "request_end", + ] + + @pytest.fixture( + params=[ + # Default empty configuration + {}, + # All None values explicitly set + { + "request_start": None, + "request_end": None, + }, + # Complete timing sequence + { + "request_start": 1000.0, + "request_end": 1100.0, + }, + # Partial timing data + { + "request_start": 1000.0, + }, + # Edge case: zero timestamps + { + "request_start": 0.0, + "request_end": 0.0, + }, + ], + ids=[ + "default_empty", + "all_none_explicit", + "complete_sequence", + "partial_data", + "zero_timestamps", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of RequestTimings. + + Returns: + tuple: (instance, constructor_args) where instance is the constructed + RequestTimings and constructor_args are the kwargs used. + """ + constructor_args = request.param + instance = MeasuredRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, MeasuredRequestTimings) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args + for field, expected_value in constructor_args.items(): + assert getattr(instance, field) == expected_value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("request_start", "invalid_string"), + ("request_end", [1, 2, 3]), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + MeasuredRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = MeasuredRequestTimings.model_validate(data) + assert isinstance(reconstructed, MeasuredRequestTimings) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + assert getattr(reconstructed, field) == getattr(instance, field) + + # Validate that the reconstructed instance matches original constructor args + for field, expected_value in constructor_args.items(): + assert getattr(reconstructed, field) == expected_value + + +class TestScheduledRequestInfo: + CHECK_KEYS = [ + "request_id", + "status", + "error", + "scheduler_node_id", + "scheduler_process_id", + "scheduler_start_time", + "scheduler_timings", + "request_timings", + ] + + @pytest.fixture( + params=[ + # Minimal required configuration + { + "request_id": "test-req-123", + "status": "queued", + "scheduler_node_id": 1, + "scheduler_process_id": 0, + "scheduler_start_time": 1000.0, + }, + # Complete configuration with all fields + { + "request_id": "test-req-456", + "status": "completed", + "error": None, + "scheduler_node_id": 2, + "scheduler_process_id": 1, + "scheduler_start_time": 2000.0, + "scheduler_timings": { + "targeted_start": 1900.0, + "queued": 1950.0, + "dequeued": 2000.0, + "resolve_start": 2050.0, + "resolve_end": 2100.0, + "finalized": 2150.0, + }, + "request_timings": { + "request_start": 2060.0, + "request_end": 2110.0, + }, + }, + # Error state configuration + { + "request_id": "test-req-error", + "status": "errored", + "error": "Connection timeout", + "scheduler_node_id": 0, + "scheduler_process_id": 0, + "scheduler_start_time": 3000.0, + }, + # Different status values + { + "request_id": "test-req-pending", + "status": "pending", + "scheduler_node_id": 1, + "scheduler_process_id": 2, + "scheduler_start_time": 4000.0, + }, + { + "request_id": "test-req-in-progress", + "status": "in_progress", + "scheduler_node_id": 2, + "scheduler_process_id": 1, + "scheduler_start_time": 5000.0, + }, + ], + ids=[ + "minimal_required", + "complete_configuration", + "error_state", + "pending_status", + "in_progress_status", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of ScheduledRequestInfo. + + Returns: + tuple: (instance, constructor_args) where instance is the constructed + ScheduledRequestInfo and constructor_args are the kwargs used. + """ + constructor_args = request.param.copy() + + # Handle nested objects + if "scheduler_timings" in constructor_args: + constructor_args["scheduler_timings"] = RequestSchedulerTimings( + **constructor_args["scheduler_timings"] + ) + if "request_timings" in constructor_args: + constructor_args["request_timings"] = MeasuredRequestTimings( + **constructor_args["request_timings"] + ) + + instance = ScheduledRequestInfo(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ScheduledRequestInfo) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args + for field, expected_value in constructor_args.items(): + if field in ["scheduler_timings", "request_timings"]: + actual_value = getattr(instance, field) + if expected_value is None: + assert actual_value is None or ( + field == "scheduler_timings" + and isinstance(actual_value, RequestSchedulerTimings) + ) + else: + assert isinstance(actual_value, type(expected_value)) + else: + assert getattr(instance, field) == expected_value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("request_id", None), # Required field + ("request_id", 123), # Wrong type + ("status", "invalid_status"), # Invalid literal + ("scheduler_node_id", "not_an_int"), + ("scheduler_process_id", -1.5), + ("scheduler_start_time", "not_a_float"), + ("error", 123), # Should be string or None + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + # Start with valid base config + base_kwargs = { + "request_id": "test-req", + "status": "queued", + "scheduler_node_id": 1, + "scheduler_process_id": 0, + "scheduler_start_time": 1000.0, + } + base_kwargs[field] = value + with pytest.raises(ValidationError): + ScheduledRequestInfo(**base_kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = ScheduledRequestInfo.model_validate(data) + assert isinstance(reconstructed, ScheduledRequestInfo) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + original_value = getattr(instance, field) + reconstructed_value = getattr(reconstructed, field) + + if field in ["scheduler_timings", "request_timings"]: + if original_value is not None and reconstructed_value is not None: + assert ( + original_value.model_dump() == reconstructed_value.model_dump() + ) + else: + assert original_value is None or isinstance( + original_value, + (RequestSchedulerTimings, MeasuredRequestTimings), + ) + assert reconstructed_value is None or isinstance( + reconstructed_value, + (RequestSchedulerTimings, MeasuredRequestTimings), + ) + else: + assert original_value == reconstructed_value + + @pytest.mark.smoke + def test_started_at_property(self): + """Test the started_at property logic.""" + # Test with request_timings.request_start (should take precedence) + instance = ScheduledRequestInfo( + request_id="test-req", + status="completed", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + scheduler_timings=RequestSchedulerTimings(resolve_start=2000.0), + request_timings=MeasuredRequestTimings(request_start=2100.0), + ) + assert instance.started_at == 2100.0 + + # Test with only scheduler_timings.resolve_start + instance = ScheduledRequestInfo( + request_id="test-req", + status="completed", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + scheduler_timings=RequestSchedulerTimings(resolve_start=2000.0), + ) + assert instance.started_at == 2000.0 + + # Test with no timing info + instance = ScheduledRequestInfo( + request_id="test-req", + status="queued", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + ) + assert instance.started_at is None + + @pytest.mark.smoke + def test_completed_at_property(self): + """Test the completed_at property logic.""" + # Test with request_timings.request_end (should take precedence) + instance = ScheduledRequestInfo( + request_id="test-req", + status="completed", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + scheduler_timings=RequestSchedulerTimings(resolve_end=2000.0), + request_timings=MeasuredRequestTimings(request_end=2100.0), + ) + assert instance.completed_at == 2100.0 + + # Test with only scheduler_timings.resolve_end + instance = ScheduledRequestInfo( + request_id="test-req", + status="completed", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + scheduler_timings=RequestSchedulerTimings(resolve_end=2000.0), + ) + assert instance.completed_at == 2000.0 + + # Test with no timing info + instance = ScheduledRequestInfo( + request_id="test-req", + status="queued", + scheduler_node_id=1, + scheduler_process_id=0, + scheduler_start_time=1000.0, + ) + assert instance.completed_at is None + + +class TestSchedulerState: + CHECK_KEYS = [ + "node_id", + "num_processes", + "start_time", + "end_time", + "end_queuing_time", + "end_queuing_constraints", + "end_processing_time", + "end_processing_constraints", + "scheduler_constraints", + "remaining_fraction", + "remaining_requests", + "remaining_duration", + "created_requests", + "queued_requests", + "pending_requests", + "processing_requests", + "processed_requests", + "successful_requests", + "errored_requests", + "cancelled_requests", + ] + + @pytest.fixture( + params=[ + # Minimal required configuration + { + "node_id": 0, + "num_processes": 1, + "start_time": 1000.0, + }, + # Complete configuration with all fields + { + "node_id": 1, + "num_processes": 4, + "start_time": 2000.0, + "end_time": 3000.0, + "end_queuing_time": 2500.0, + "end_queuing_constraints": { + "time_limit": SchedulerUpdateAction( + request_queuing="stop", metadata={"max_duration": 1500} + ) + }, + "end_processing_time": 2800.0, + "end_processing_constraints": { + "request_limit": SchedulerUpdateAction( + request_processing="stop_all", metadata={"max_requests": 1000} + ) + }, + "scheduler_constraints": { + "rate_limit": SchedulerUpdateAction(metadata={"max_rps": 100}) + }, + "remaining_fraction": 0.25, + "remaining_requests": 50, + "remaining_duration": 300.0, + "created_requests": 200, + "queued_requests": 180, + "pending_requests": 20, + "processing_requests": 10, + "processed_requests": 150, + "successful_requests": 140, + "errored_requests": 8, + "cancelled_requests": 2, + }, + # Partial configuration with some stats + { + "node_id": 2, + "num_processes": 2, + "start_time": 4000.0, + "created_requests": 50, + "processed_requests": 30, + "successful_requests": 28, + "errored_requests": 2, + }, + # Edge case: zero values + { + "node_id": 0, + "num_processes": 1, + "start_time": 0.0, + "created_requests": 0, + "processed_requests": 0, + "successful_requests": 0, + }, + ], + ids=[ + "minimal_required", + "complete_configuration", + "partial_stats", + "zero_values", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of SchedulerState. + + Returns: + tuple: (instance, constructor_args) where instance is the constructed + SchedulerState and constructor_args are the kwargs used. + """ + constructor_args = request.param + instance = SchedulerState(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, SchedulerState) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args + for field, expected_value in constructor_args.items(): + assert getattr(instance, field) == expected_value + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("field", "value"), + [ + ("node_id", "not_an_int"), + ("start_time", "not_a_float"), + ("end_time", [1, 2, 3]), + ("remaining_fraction", "not_a_float"), + ("created_requests", "not_an_int"), + ("end_queuing_constraints", "not_a_dict"), + ("scheduler_constraints", ["not", "a", "dict"]), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + # Start with valid base config + base_kwargs = { + "node_id": 0, + "num_processes": 1, + "start_time": 1000.0, + } + base_kwargs[field] = value + with pytest.raises(ValidationError): + SchedulerState(**base_kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = SchedulerState.model_validate(data) + assert isinstance(reconstructed, SchedulerState) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + assert getattr(reconstructed, field) == getattr(instance, field) + + # Validate that the reconstructed instance matches original constructor args + for field, expected_value in constructor_args.items(): + assert getattr(reconstructed, field) == expected_value + + +class TestSchedulerUpdateAction: + CHECK_KEYS = [ + "request_queuing", + "request_processing", + "metadata", + "progress", + ] + + @pytest.fixture( + params=[ + # Default configuration + {}, + # All explicit default values + { + "request_queuing": "continue", + "request_processing": "continue", + "metadata": {}, + "progress": {}, + }, + # Stop queuing configuration + { + "request_queuing": "stop", + "request_processing": "continue", + "metadata": {"reason": "rate_limit_exceeded"}, + }, + # Stop local processing configuration + { + "request_queuing": "continue", + "request_processing": "stop_local", + "metadata": {"node_id": 1, "reason": "resource_exhausted"}, + }, + # Stop all processing configuration + { + "request_queuing": "stop", + "request_processing": "stop_all", + "metadata": { + "emergency_stop": True, + "reason": "critical_error", + "error_details": {"code": 500, "message": "Internal server error"}, + }, + }, + # Complex metadata configuration + { + "request_queuing": "continue", + "request_processing": "continue", + "metadata": { + "stats": {"processed": 100, "pending": 50}, + "constraints": {"max_rps": 10, "max_concurrent": 20}, + "config": {"batch_size": 32, "timeout": 30.0}, + }, + }, + # Progress with remaining_fraction only + { + "request_queuing": "continue", + "request_processing": "continue", + "progress": {"remaining_fraction": 0.75}, + }, + # Progress with remaining_requests only + { + "request_queuing": "continue", + "request_processing": "continue", + "progress": {"remaining_requests": 250.0}, + }, + # Progress with remaining_duration only + { + "request_queuing": "continue", + "request_processing": "continue", + "progress": {"remaining_duration": 120.5}, + }, + # Complete progress configuration + { + "request_queuing": "stop", + "request_processing": "stop_all", + "metadata": {"shutdown_reason": "completion"}, + "progress": { + "remaining_fraction": 0.0, + "remaining_requests": 0.0, + "remaining_duration": 0.0, + }, + }, + # Partial progress configuration + { + "request_queuing": "continue", + "request_processing": "continue", + "metadata": {"checkpoint": "mid_benchmark"}, + "progress": { + "remaining_fraction": 0.45, + "remaining_duration": 180.0, + }, + }, + ], + ids=[ + "default_empty", + "explicit_defaults", + "stop_queuing", + "stop_local_processing", + "stop_all_processing", + "complex_metadata", + "progress_fraction_only", + "progress_requests_only", + "progress_duration_only", + "complete_progress", + "partial_progress", + ], + ) + def valid_instances(self, request): + """Creates various valid configurations of SchedulerUpdateAction. + + Returns: + tuple: (instance, constructor_args) where instance is the constructed + SchedulerUpdateAction and constructor_args are the kwargs used. + """ + constructor_args = request.param + instance = SchedulerUpdateAction(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, SchedulerUpdateAction) + for key in self.CHECK_KEYS: + assert hasattr(instance, key) + + # Validate that the instance attributes match the constructor args or defaults + for field in self.CHECK_KEYS: + if field in constructor_args: + assert getattr(instance, field) == constructor_args[field] + elif field in ["request_queuing", "request_processing"]: + assert getattr(instance, field) == "continue" + elif field in ["metadata", "progress"]: + assert getattr(instance, field) == {} + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("field", "value"), + [ + ("request_queuing", "invalid_action"), + ("request_queuing", 123), + ("request_processing", "invalid_action"), + ("request_processing", ["stop"]), + ("metadata", "not_a_dict"), + ("metadata", [{"key": "value"}]), + ("progress", "not_a_dict"), + ("progress", [{"remaining_fraction": 0.5}]), + ("progress", {"remaining_fraction": "not_a_float"}), + ("progress", {"remaining_requests": "not_a_float"}), + ("progress", {"remaining_duration": "not_a_float"}), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + SchedulerUpdateAction(**kwargs) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + # Test model_dump + data = instance.model_dump() + assert isinstance(data, dict) + assert all(key in data for key in self.CHECK_KEYS) + + # Test model_validate + reconstructed = SchedulerUpdateAction.model_validate(data) + assert isinstance(reconstructed, SchedulerUpdateAction) + + # Validate that all fields match between original and reconstructed instances + for field in self.CHECK_KEYS: + assert getattr(reconstructed, field) == getattr(instance, field) + + # Validate that the reconstructed instance matches expected values + for field in self.CHECK_KEYS: + if field in constructor_args: + assert getattr(reconstructed, field) == constructor_args[field] + elif field in ["request_queuing", "request_processing"]: + assert getattr(reconstructed, field) == "continue" + elif field in ["metadata", "progress"]: + assert getattr(reconstructed, field) == {} + + @pytest.mark.smoke + def test_progress_field_behavior(self): + """Test the progress field specific behavior and validation.""" + # Test empty progress (default) + instance = SchedulerUpdateAction() + assert instance.progress == {} + assert isinstance(instance.progress, dict) + + # Test progress with all valid fields + progress_data = { + "remaining_fraction": 0.75, + "remaining_requests": 100.0, + "remaining_duration": 30.5, + } + instance = SchedulerUpdateAction(progress=progress_data) + assert instance.progress == progress_data + + # Test progress with partial fields (TypedDict allows partial) + partial_progress = {"remaining_fraction": 0.25} + instance = SchedulerUpdateAction(progress=partial_progress) + assert instance.progress == partial_progress + + # Test progress with zero values + zero_progress = { + "remaining_fraction": 0.0, + "remaining_requests": 0.0, + "remaining_duration": 0.0, + } + instance = SchedulerUpdateAction(progress=zero_progress) + assert instance.progress == zero_progress + + # Test that progress field persists through marshalling + data = instance.model_dump() + assert "progress" in data + assert data["progress"] == zero_progress + + reconstructed = SchedulerUpdateAction.model_validate(data) + assert reconstructed.progress == zero_progress + + @pytest.mark.smoke + @pytest.mark.parametrize( + "progress_value", + [ + {"remaining_fraction": 0.0}, + {"remaining_fraction": 1.0}, + {"remaining_requests": 0.0}, + {"remaining_requests": 1000.0}, + {"remaining_duration": 0.0}, + {"remaining_duration": 3600.0}, + {"remaining_fraction": 0.5, "remaining_requests": 50.0}, + {"remaining_requests": 25.0, "remaining_duration": 120.0}, + {"remaining_fraction": 0.33, "remaining_duration": 45.0}, + ], + ) + def test_progress_valid_combinations(self, progress_value): + """Test various valid combinations of progress field values.""" + instance = SchedulerUpdateAction(progress=progress_value) + assert instance.progress == progress_value + + # Verify marshalling works correctly + data = instance.model_dump() + reconstructed = SchedulerUpdateAction.model_validate(data) + assert reconstructed.progress == progress_value + + @pytest.mark.smoke + def test_scheduler_update_action_progress_typeddict(self): + """Test the SchedulerUpdateActionProgress TypedDict behavior.""" + # Test that SchedulerUpdateActionProgress is a proper TypedDict + # Verify it's a TypedDict (has the special attributes) + assert hasattr(SchedulerUpdateActionProgress, "__annotations__") + assert hasattr(SchedulerUpdateActionProgress, "__total__") + assert hasattr(SchedulerUpdateActionProgress, "__required_keys__") + assert hasattr(SchedulerUpdateActionProgress, "__optional_keys__") + + # Check that all keys are optional (total=False) + expected_keys = { + "remaining_fraction", + "remaining_requests", + "remaining_duration", + } + actual_keys = set(SchedulerUpdateActionProgress.__annotations__.keys()) + assert actual_keys == expected_keys + assert SchedulerUpdateActionProgress.__total__ is False + assert SchedulerUpdateActionProgress.__required_keys__ == frozenset() + assert SchedulerUpdateActionProgress.__optional_keys__ == expected_keys + + # Test that type annotations are correct + annotations = SchedulerUpdateActionProgress.__annotations__ + assert "remaining_fraction" in annotations + assert "remaining_requests" in annotations + assert "remaining_duration" in annotations + + # Test creation of valid TypedDict instances + valid_progress_1: SchedulerUpdateActionProgress = {} + valid_progress_2: SchedulerUpdateActionProgress = {"remaining_fraction": 0.5} + valid_progress_3: SchedulerUpdateActionProgress = { + "remaining_fraction": 0.25, + "remaining_requests": 100.0, + "remaining_duration": 60.0, + } + + # All should be valid dict instances + assert isinstance(valid_progress_1, dict) + assert isinstance(valid_progress_2, dict) + assert isinstance(valid_progress_3, dict) diff --git a/tests/unit/scheduler/test_scheduler.py b/tests/unit/scheduler/test_scheduler.py new file mode 100644 index 00000000..49c153ee --- /dev/null +++ b/tests/unit/scheduler/test_scheduler.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import asyncio +import random +import uuid +from collections import defaultdict +from typing import Any + +import pytest +from pydantic import BaseModel, Field + +from guidellm.scheduler import ( + BackendInterface, + ConstraintInitializer, + Environment, + MaxNumberConstraintInitializer, + NonDistributedEnvironment, + ScheduledRequestInfo, + Scheduler, + SchedulerState, + SchedulingStrategy, + SynchronousStrategy, +) + + +class MockRequest(BaseModel): + payload: str + id_: str = Field(default_factory=lambda: str(uuid.uuid4())) + + +class MockBackend(BackendInterface): + """Mock backend for integration testing with predictable responses.""" + + def __init__( + self, + processes_limit_value: int | None = None, + requests_limit_value: int | None = None, + error_rate: float = 0.2, + response_delay: float = 0.0, + ): + self._processes_limit = processes_limit_value + self._requests_limit = requests_limit_value + self._error_rate = error_rate + self._response_delay = response_delay + + @property + def processes_limit(self) -> int | None: + return self._processes_limit + + @property + def requests_limit(self) -> int | None: + return self._requests_limit + + def info(self) -> dict[str, Any]: + return {"type": "mock_integration", "delay": self._response_delay} + + async def process_startup(self): + pass + + async def validate(self): + pass + + async def process_shutdown(self): + pass + + async def resolve(self, request: MockRequest, request_info, request_history): + """Return predictable response based on input request.""" + # Simulate processing time + await asyncio.sleep(self._response_delay) + + if ( + self._error_rate + and self._error_rate > 0 + and random.random() < self._error_rate + ): + raise RuntimeError(f"mock_error_for_{request.payload}") + + yield f"response_for_{request.payload}" + + +@pytest.mark.smoke +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("strategy", "env", "constraint_inits"), + [ + ( + SynchronousStrategy(), + NonDistributedEnvironment(), + {"max_number": MaxNumberConstraintInitializer(max_num=100)}, + ), + ], +) +async def test_scheduler_run( + strategy: SchedulingStrategy, + env: Environment, + constraint_inits: dict[str, ConstraintInitializer], +): + scheduler = Scheduler() + constraints = { + key: init.create_constraint() for key, init in constraint_inits.items() + } + received_updates = defaultdict(list) + received_responses = [] + last_state = None + num_requests = 50 + + async for resp, req, info, state in scheduler.run( + requests=[ + MockRequest(payload=f"req_{ind}") for ind in range(num_requests) + ], # less than total requests sent to test new request iter + backend=MockBackend(), + strategy=strategy, + env=env, + **constraints, + ): + assert req is not None + assert isinstance(req, MockRequest) + assert isinstance(info, ScheduledRequestInfo) + assert info.status != "cancelled" + assert isinstance(state, SchedulerState) + if info.status == "completed": + assert resp == f"response_for_{req.payload}" + received_responses.append(resp) + elif info.status == "errored": + assert resp is None + assert info.error is not None + assert info.error == f"mock_error_for_{req.payload}" + received_responses.append(info.error) + + if len(received_updates[req.payload]) < 3: + received_updates[req.payload].append(info.status) + last_state = state + + assert len(received_updates) == num_requests + assert len(received_responses) == constraints["max_number"].max_num + assert last_state.created_requests == constraints["max_number"].max_num + assert last_state.queued_requests == 0 + assert last_state.processing_requests == 0 + assert last_state.processed_requests == constraints["max_number"].max_num + assert last_state.cancelled_requests == 0 + assert ( + last_state.successful_requests + last_state.errored_requests + ) == constraints["max_number"].max_num + + def _request_indices(): + while True: + yield from range(num_requests) + + for index, req, statuses, resp in zip( + _request_indices(), + received_updates.keys(), + received_updates.values(), + received_responses, + ): + assert req == f"req_{index}" + assert resp in (f"response_for_{req}", f"mock_error_for_{req}") + assert statuses in ( + ["queued", "in_progress", "completed"], + ["queued", "in_progress", "errored"], + ) diff --git a/tests/unit/scheduler/test_strategy.py b/tests/unit/scheduler/test_strategy.py new file mode 100644 index 00000000..6057b731 --- /dev/null +++ b/tests/unit/scheduler/test_strategy.py @@ -0,0 +1,923 @@ +import inspect +import math +import statistics +import time +from abc import ABC + +import pytest +from pydantic import ValidationError + +from guidellm.scheduler import ( + AsyncConstantStrategy, + AsyncPoissonStrategy, + ConcurrentStrategy, + ConstantRateRequestTimings, + LastCompletionRequestTimings, + NoDelayRequestTimings, + PoissonRateRequestTimings, + ScheduledRequestInfo, + ScheduledRequestTimings, + SchedulingStrategy, + SynchronousStrategy, + ThroughputStrategy, +) +from guidellm.scheduler.strategy import ( + _exponential_decay_fraction, + _exponential_decay_tau, +) + + +class TestExponentialDecayHelpers: + @pytest.mark.smoke + @pytest.mark.parametrize( + ("max_progress", "convergence", "expected_range"), + [ + (1.0, 0.99, (0.21, 0.22)), + (5.0, 0.99, (1.08, 1.09)), + (10.0, 0.95, (3.33, 3.35)), + ], + ) + def test_exponential_decay_tau(self, max_progress, convergence, expected_range): + """Test exponential decay tau calculation.""" + tau = _exponential_decay_tau(max_progress, convergence) + assert expected_range[0] <= tau <= expected_range[1] + expected_tau = max_progress / (-math.log(1 - convergence)) + assert tau == pytest.approx(expected_tau, rel=1e-10) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("progress", "tau", "expected_min", "expected_max"), + [ + (0.0, 1.0, 0.0, 0.0), # No progress = 0 + (1.0, 1.0, 0.6, 0.7), # 1 tau ≈ 63.2% + (2.0, 1.0, 0.85, 0.87), # 2 tau ≈ 86.5% + (3.0, 1.0, 0.95, 0.96), # 3 tau ≈ 95.0% + ], + ) + def test_exponential_decay_fraction( + self, progress, tau, expected_min, expected_max + ): + """Test exponential decay fraction calculation.""" + fraction = _exponential_decay_fraction(progress, tau) + assert expected_min <= fraction <= expected_max + expected_fraction = 1 - math.exp(-progress / tau) + assert fraction == pytest.approx(expected_fraction, rel=1e-10) + + @pytest.mark.smoke + def test_exponential_decay_fraction_boundary_conditions(self): + """Test boundary conditions for exponential decay fraction.""" + assert _exponential_decay_fraction(0.0, 1.0) == 0.0 + assert _exponential_decay_fraction(0.0, 10.0) == 0.0 + large_progress = 100.0 + fraction = _exponential_decay_fraction(large_progress, 1.0) + assert fraction > 0.99999 + + +class TestScheduledRequestTimings: + @pytest.mark.smoke + def test_is_abstract_base_class(self): + """Test that ScheduledRequestTimings is an abstract base class.""" + assert issubclass(ScheduledRequestTimings, ABC) + assert inspect.isabstract(ScheduledRequestTimings) + + @pytest.mark.smoke + def test_abstract_methods_defined(self): + """Test that the required abstract methods are defined.""" + abstract_methods = ScheduledRequestTimings.__abstractmethods__ + expected_methods = {"next_offset", "request_completed"} + assert abstract_methods == expected_methods + + # Validate method signatures + next_offset_method = ScheduledRequestTimings.next_offset + assert callable(next_offset_method) + request_completed_method = ScheduledRequestTimings.request_completed + assert callable(request_completed_method) + + # Check signature parameters using inspect + next_offset_sig = inspect.signature(next_offset_method) + assert len(next_offset_sig.parameters) == 1 + assert str(next_offset_sig.return_annotation) == "float" + request_completed_sig = inspect.signature(request_completed_method) + assert len(request_completed_sig.parameters) == 2 + params = list(request_completed_sig.parameters.values()) + param_annotation = params[1].annotation + assert param_annotation in {ScheduledRequestInfo, "ScheduledRequestInfo"} + + @pytest.mark.sanity + def test_invalid_implementation(self): + """Test that invalid implementations raise TypeError.""" + + class InvalidImplementation(ScheduledRequestTimings): + pass # Missing required abstract methods + + with pytest.raises(TypeError): + InvalidImplementation() + + @pytest.mark.smoke + def test_child_implementation(self): + """Test that concrete implementations can be constructed.""" + + # Test with a proper concrete implementation in this test scope + class TestRequestTimings(ScheduledRequestTimings): + offset: float = 0.0 + + def next_offset(self) -> float: + self.offset += 1.0 + return self.offset + + def request_completed(self, request_info: ScheduledRequestInfo): + pass + + timing = TestRequestTimings() + assert isinstance(timing, ScheduledRequestTimings) + + assert timing.next_offset() == 1.0 + assert timing.next_offset() == 2.0 + + mock_request = ScheduledRequestInfo( + request_id="test", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=time.time(), + ) + timing.request_completed(mock_request) + + +class TestLastCompletionRequestTimings: + @pytest.fixture( + params=[ + {}, + {"offset": 10.0}, + {"startup_requests": 5, "startup_requests_delay": 0.5}, + { + "offset": 0.0, + "startup_requests": 0, + "startup_requests_delay": 0.0, + }, + { + "offset": 2.5, + "startup_requests": 3, + "startup_requests_delay": 1.0, + }, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of LastCompletionRequestTimings.""" + constructor_args = request.param + instance = LastCompletionRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization( + self, valid_instances: tuple[LastCompletionRequestTimings, dict] + ): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, LastCompletionRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("startup_requests", -1), + ("startup_requests_delay", -0.5), + ("offset", "invalid"), + ("startup_requests", 1.5), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + LastCompletionRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_lifecycle( + self, valid_instances: tuple[LastCompletionRequestTimings, dict] + ): + """Test the complete lifecycle of next_offset and request_completed calls.""" + instance, constructor_args = valid_instances + initial_offset = instance.offset + startup_requests = constructor_args.get("startup_requests", 0) + startup_delay = constructor_args.get("startup_requests_delay", 0.0) + request_times = [] + + for index in range(max(5, startup_requests + 2)): + offset = instance.next_offset() + assert isinstance(offset, (int, float)) + + if index < startup_requests: + expected_offset = initial_offset + (index + 1) * startup_delay + assert offset == pytest.approx(expected_offset, abs=1e-5) + + completion_time = time.time() + offset + request_times.append(completion_time) + + mock_request = ScheduledRequestInfo( + request_id=f"test-{index}", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=time.time(), + ) + mock_request.scheduler_timings.resolve_end = completion_time + instance.request_completed(mock_request) + + @pytest.mark.smoke + def test_marshalling( + self, valid_instances: tuple[LastCompletionRequestTimings, dict] + ): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = LastCompletionRequestTimings.model_validate(data) + assert isinstance(reconstructed, LastCompletionRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestNoDelayRequestTimings: + @pytest.fixture( + params=[ + {}, + {"offset": 0.2}, + {"startup_duration": 0.3, "startup_target_requests": 5}, + { + "offset": 0.15, + "startup_duration": 0.2, + "startup_target_requests": 20, + "startup_convergence": 0.9, + }, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of NoDelayRequestTimings.""" + constructor_args = request.param + instance = NoDelayRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[NoDelayRequestTimings, dict]): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, NoDelayRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("offset", -1.0), + ("startup_duration", -1.0), + ("startup_target_requests", 0), + ("startup_target_requests", -1), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + NoDelayRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_lifecycle(self, valid_instances: tuple[NoDelayRequestTimings, dict]): + """Test the complete lifecycle of timing methods.""" + instance, constructor_args = valid_instances + startup_duration = constructor_args.get("startup_duration", 0.0) + base_offset = constructor_args.get("offset", 0.0) + start_time = time.time() + min_time = base_offset + startup_duration + 0.2 + end_time = start_time + min_time + last_offset = -1 * math.inf + + while (current_time := time.time()) < end_time: + offset = instance.next_offset() + + if startup_duration > 0 and (current_time - start_time) <= startup_duration: + assert offset < base_offset + startup_duration + assert offset > last_offset + elif startup_duration > 0: + assert offset == base_offset + startup_duration + else: + assert offset == base_offset + + last_offset = offset + time.sleep(0.025) + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[NoDelayRequestTimings, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = NoDelayRequestTimings.model_validate(data) + assert isinstance(reconstructed, NoDelayRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestConstantRateRequestTimings: + @pytest.fixture( + params=[ + {"rate": 1.0}, + {"rate": 5.0, "offset": 2.0}, + {"rate": 10.5, "offset": 1.0}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of ConstantRateRequestTimings.""" + constructor_args = request.param + instance = ConstantRateRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization( + self, valid_instances: tuple[ConstantRateRequestTimings, dict] + ): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ConstantRateRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("rate", 0), + ("rate", -1.0), + ("offset", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization scenarios.""" + kwargs = {"rate": 1.0} + kwargs[field] = value + with pytest.raises(ValidationError): + ConstantRateRequestTimings(**kwargs) + + @pytest.mark.smoke + def test_constant_rate_behavior( + self, valid_instances: tuple[ConstantRateRequestTimings, dict] + ): + """Test that requests are scheduled at constant intervals.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + expected_interval = 1.0 / rate + base_offset = constructor_args.get("offset", 0.0) + num_requests = int(5 * rate) # simulate 5 seconds + + for ind in range(num_requests): + offset = instance.next_offset() + assert offset >= base_offset + assert offset == pytest.approx( + base_offset + ind * expected_interval, rel=1e-2 + ) + + @pytest.mark.smoke + def test_marshalling( + self, valid_instances: tuple[ConstantRateRequestTimings, dict] + ): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = ConstantRateRequestTimings.model_validate(data) + assert isinstance(reconstructed, ConstantRateRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestPoissonRateRequestTimings: + @pytest.fixture( + params=[ + {"rate": 1.0}, + { + "rate": 5.0, + "random_seed": 123, + "offset": 1.0, + }, + { + "rate": 0.5, + }, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of PoissonRateRequestTimings.""" + constructor_args = request.param + instance = PoissonRateRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization( + self, valid_instances: tuple[PoissonRateRequestTimings, dict] + ): + """Test initialization with valid configurations.""" + instance, constructor_args = valid_instances + assert isinstance(instance, PoissonRateRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.smoke + def test_lifecycle(self, valid_instances: tuple[PoissonRateRequestTimings, dict]): + """Test that Poisson timing produces variable intervals.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + base_offset = constructor_args.get("offset", 0.0) + num_requests = 200 + last_offset = 0.0 + intervals = [] + + for index in range(num_requests): + offset = instance.next_offset() + + if index == 0: + assert offset == base_offset + else: + assert offset > last_offset + + intervals.append(offset - last_offset) + last_offset = offset + + expected_mean_interval = 1.0 / rate + actual_mean_interval = statistics.mean(intervals) + tolerance = 0.2 * expected_mean_interval + assert abs(actual_mean_interval - expected_mean_interval) < tolerance + + @pytest.mark.smoke + def test_marshalling(self, valid_instances: tuple[PoissonRateRequestTimings, dict]): + """Test marshalling to/from pydantic dict formats.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert isinstance(data, dict) + + for key, value in constructor_args.items(): + assert data[key] == value + + reconstructed = PoissonRateRequestTimings.model_validate(data) + assert isinstance(reconstructed, PoissonRateRequestTimings) + + for key, value in constructor_args.items(): + assert getattr(reconstructed, key) == value + + +class TestSchedulingStrategy: + @pytest.mark.smoke + def test_base(self): + """Test that base methods are defined in SchedulingStrategy.""" + # Validate inheritance and interface compliance + assert issubclass(SchedulingStrategy, object) + + # Validate expected methods exist + expected_methods = { + "processes_limit", + "requests_limit", + "create_request_timings", + } + strategy_methods = set(dir(SchedulingStrategy)) + for method in expected_methods: + assert method in strategy_methods + + # validate expected properties + processes_limit_prop = SchedulingStrategy.processes_limit + assert isinstance(processes_limit_prop, property) + requests_limit_prop = SchedulingStrategy.requests_limit + assert isinstance(requests_limit_prop, property) + create_request_timings_method = SchedulingStrategy.create_request_timings + assert callable(create_request_timings_method) + + # Validate method signature + sig = inspect.signature(create_request_timings_method) + params = list(sig.parameters.keys()) + expected_params = [ + "self", + "local_rank", + "local_world_size", + "local_max_concurrency", + ] + assert params == expected_params + + @pytest.mark.sanity + def test_invalid_implementation(self): + """Test that invalid implementations raise NotImplementedError.""" + + class InvalidStrategy(SchedulingStrategy): + pass + + strategy = InvalidStrategy(type_="strategy") + with pytest.raises(NotImplementedError): + strategy.create_request_timings(0, 1, 1) + + @pytest.mark.smoke + def test_concrete_implementation(self): + """Test that concrete implementations can be constructed.""" + + class TestStrategy(SchedulingStrategy): + type_: str = "strategy" + + def create_request_timings( + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, + ): + return LastCompletionRequestTimings() + + strategy = TestStrategy() + assert isinstance(strategy, SchedulingStrategy) + timing = strategy.create_request_timings(0, 1, 1) + assert isinstance(timing, ScheduledRequestTimings) + + +class TestSynchronousStrategy: + @pytest.mark.smoke + def test_initialization(self): + """Test initialization of SynchronousStrategy.""" + strategy = SynchronousStrategy() + assert strategy.type_ == "synchronous" + + @pytest.mark.smoke + def test_limits(self): + """Test that SynchronousStrategy enforces proper limits.""" + strategy = SynchronousStrategy() + assert strategy.processes_limit == 1 + assert strategy.requests_limit == 1 + + @pytest.mark.smoke + def test_create_timings_valid(self): + """Test creating timings with valid parameters.""" + strategy = SynchronousStrategy() + timing = strategy.create_request_timings(0, 1, 1) + assert isinstance(timing, LastCompletionRequestTimings) + + @pytest.mark.sanity + def test_create_timings_invalid(self): + """Test that invalid parameters raise ValueError.""" + strategy = SynchronousStrategy() + + with pytest.raises(ValueError): + strategy.create_request_timings(1, 1, 1) # rank != 0 + + with pytest.raises(ValueError): + strategy.create_request_timings(0, 2, 1) # world_size > 1 + + @pytest.mark.smoke + def test_string_representation(self): + """Test __str__ method for SynchronousStrategy.""" + strategy = SynchronousStrategy() + result = str(strategy) + assert result == "synchronous" + + +class TestConcurrentStrategy: + @pytest.fixture( + params=[ + {"streams": 1}, + {"streams": 4}, + {"streams": 8, "startup_duration": 2.0}, + {"streams": 2, "startup_duration": 0.0}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of ConcurrentStrategy.""" + constructor_args = request.param + instance = ConcurrentStrategy(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[ConcurrentStrategy, dict]): + """Test initialization of ConcurrentStrategy.""" + instance, constructor_args = valid_instances + assert instance.type_ == "concurrent" + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("streams", 0), + ("streams", -1), + ("startup_duration", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization.""" + kwargs = {"streams": 2} + kwargs[field] = value + with pytest.raises(ValidationError): + ConcurrentStrategy(**kwargs) + + @pytest.mark.smoke + def test_limits(self, valid_instances: tuple[ConcurrentStrategy, dict]): + """Test that ConcurrentStrategy returns correct limits.""" + instance, constructor_args = valid_instances + streams = constructor_args["streams"] + assert instance.processes_limit == streams + assert instance.requests_limit == streams + + @pytest.mark.smoke + def test_create_timings(self, valid_instances: tuple[ConcurrentStrategy, dict]): + """Test creating timings.""" + instance, constructor_args = valid_instances + streams = constructor_args["streams"] + startup_duration = constructor_args.get("startup_duration", 0.0) + + # Test with different rank and world_size combinations + for local_rank in range(min(streams, 2)): + for local_world_size in range(1, min(streams + 1, 3)): + if local_rank < local_world_size: + timing = instance.create_request_timings( + local_rank, local_world_size, streams + ) + assert isinstance(timing, LastCompletionRequestTimings) + + # Verify startup behavior + if startup_duration > 0: + # Check that timing has proper startup configuration + expected_delay_per_stream = startup_duration / streams + streams_per_worker = streams // local_world_size + expected_offset = ( + local_rank * streams_per_worker * expected_delay_per_stream + ) + assert timing.offset == pytest.approx(expected_offset, abs=1e-5) + + @pytest.mark.sanity + def test_create_timings_invalid( + self, valid_instances: tuple[ConcurrentStrategy, dict] + ): + """Test invalid inputs for create request timings.""" + instance, constructor_args = valid_instances + streams = constructor_args["streams"] + + # Test various invalid configurations + invalid_configs = [ + (streams, 1, 1), # rank >= streams + (0, streams + 1, 1), # world_size > streams + ] + + for local_rank, local_world_size, local_max_concurrency in invalid_configs: + if local_rank >= streams or local_world_size > streams: + with pytest.raises(ValueError): + instance.create_request_timings( + local_rank, local_world_size, local_max_concurrency + ) + + @pytest.mark.smoke + def test_string_representation( + self, valid_instances: tuple[ConcurrentStrategy, dict] + ): + """Test __str__ method for ConcurrentStrategy.""" + instance, constructor_args = valid_instances + streams = constructor_args["streams"] + result = str(instance) + assert result == f"concurrent@{streams}" + + +class TestThroughputStrategy: + @pytest.fixture( + params=[ + {}, + {"max_concurrency": 10}, + {"startup_duration": 5.0}, + {"max_concurrency": 5, "startup_duration": 2.0}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of ThroughputStrategy.""" + constructor_args = request.param + instance = ThroughputStrategy(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[ThroughputStrategy, dict]): + """Test initialization of ThroughputStrategy.""" + instance, constructor_args = valid_instances + assert instance.type_ == "throughput" + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("max_concurrency", 0), + ("max_concurrency", -1), + ("startup_duration", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + ThroughputStrategy(**kwargs) + + @pytest.mark.smoke + def test_limits(self, valid_instances: tuple[ThroughputStrategy, dict]): + """Test that ThroughputStrategy returns correct limits.""" + instance, constructor_args = valid_instances + max_concurrency = constructor_args.get("max_concurrency") + assert instance.processes_limit == max_concurrency + assert instance.requests_limit == max_concurrency + + @pytest.mark.smoke + def test_create_timings(self, valid_instances: tuple[ThroughputStrategy, dict]): + """Test creating timings.""" + instance, constructor_args = valid_instances + startup_duration = constructor_args.get("startup_duration", 0.0) + + # Test with different configurations + for local_rank in range(3): + for local_world_size in range(1, 4): + for local_max_concurrency in range(1, 6): + timing = instance.create_request_timings( + local_rank, local_world_size, local_max_concurrency + ) + assert isinstance(timing, NoDelayRequestTimings) + + # Verify startup configuration + if startup_duration > 0: + assert timing.startup_duration == startup_duration + assert timing.startup_target_requests == local_max_concurrency + expected_offset = ( + 0.05 * startup_duration * (local_rank / local_world_size) + ) + assert timing.offset == pytest.approx(expected_offset, abs=1e-5) + else: + assert timing.startup_duration == 0.0 + assert timing.offset == 0.0 + + @pytest.mark.smoke + def test_string_representation( + self, valid_instances: tuple[ThroughputStrategy, dict] + ): + """Test __str__ method for ThroughputStrategy.""" + instance, _ = valid_instances + result = str(instance) + assert result == "throughput" + + +class TestAsyncConstantStrategy: + @pytest.fixture( + params=[ + {"rate": 1.0}, + {"rate": 5.0}, + {"rate": 10.3, "max_concurrency": 8}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of AsyncConstantStrategy.""" + constructor_args = request.param + instance = AsyncConstantStrategy(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[AsyncConstantStrategy, dict]): + """Test initialization of AsyncConstantStrategy.""" + instance, constructor_args = valid_instances + assert instance.type_ == "constant" + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("rate", 0), + ("rate", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization.""" + kwargs = {"rate": 1.0} + kwargs[field] = value + with pytest.raises(ValidationError): + AsyncConstantStrategy(**kwargs) + + @pytest.mark.smoke + def test_create_timings(self, valid_instances: tuple[AsyncConstantStrategy, dict]): + """Test creating timings.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + + # Test with different worker configurations + for local_world_size in range(1, 5): + timing = instance.create_request_timings(0, local_world_size, 1) + assert isinstance(timing, ConstantRateRequestTimings) + + # Rate should be distributed across workers + expected_worker_rate = rate / local_world_size + assert timing.rate == pytest.approx(expected_worker_rate, abs=1e-5) + + @pytest.mark.smoke + def test_string_representation( + self, valid_instances: tuple[AsyncConstantStrategy, dict] + ): + """Test __str__ method for AsyncConstantStrategy.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + result = str(instance) + assert result == f"constant@{rate:.2f}" + + +class TestAsyncPoissonStrategy: + @pytest.fixture( + params=[ + {"rate": 1.0}, + {"rate": 5.0, "random_seed": 123}, + {"rate": 10.3, "random_seed": 456, "max_concurrency": 8}, + ] + ) + def valid_instances(self, request): + """Creates various valid configurations of AsyncPoissonStrategy.""" + constructor_args = request.param + instance = AsyncPoissonStrategy(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization(self, valid_instances: tuple[AsyncPoissonStrategy, dict]): + """Test initialization of AsyncPoissonStrategy.""" + instance, constructor_args = valid_instances + assert instance.type_ == "poisson" + + for key, value in constructor_args.items(): + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("rate", 0), + ("rate", -1.0), + ], + ) + def test_invalid_initialization(self, field, value): + """Test invalid initialization.""" + kwargs = {"rate": 1.0, "random_seed": 42} + kwargs[field] = value + with pytest.raises(ValidationError): + AsyncPoissonStrategy(**kwargs) + + @pytest.mark.smoke + def test_create_timings(self, valid_instances: tuple[AsyncPoissonStrategy, dict]): + """Test creating timings.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + base_seed = constructor_args.get("random_seed", 42) + + # Test with different worker configurations + for local_rank in range(3): + for local_world_size in range(1, 4): + timing = instance.create_request_timings( + local_rank, local_world_size, 1 + ) + assert isinstance(timing, PoissonRateRequestTimings) + + # Rate should be distributed across workers + expected_worker_rate = rate / local_world_size + assert timing.rate == pytest.approx(expected_worker_rate, abs=1e-5) + + # Each worker should have a unique seed + expected_seed = base_seed + local_rank + assert timing.random_seed == expected_seed + + @pytest.mark.smoke + def test_string_representation( + self, valid_instances: tuple[AsyncPoissonStrategy, dict] + ): + """Test __str__ method for AsyncPoissonStrategy.""" + instance, constructor_args = valid_instances + rate = constructor_args["rate"] + result = str(instance) + assert result == f"poisson@{rate:.2f}" diff --git a/tests/unit/scheduler/test_worker.py b/tests/unit/scheduler/test_worker.py new file mode 100644 index 00000000..48effbd2 --- /dev/null +++ b/tests/unit/scheduler/test_worker.py @@ -0,0 +1,711 @@ +import asyncio +import contextlib +import inspect +import math +import threading +import time +from collections import defaultdict +from multiprocessing import Barrier, Event, Queue +from multiprocessing.synchronize import Barrier as ProcessingBarrier +from multiprocessing.synchronize import Event as ProcessingEvent +from queue import Empty +from typing import Any, Callable, Generic, Literal, Optional +from unittest.mock import AsyncMock, patch + +import pytest + +from guidellm.scheduler import ( + BackendInterface, + LastCompletionRequestTimings, + MeasuredRequestTimings, + ScheduledRequestInfo, + ScheduledRequestTimings, + WorkerProcess, +) +from guidellm.scheduler.strategy import ( + ConstantRateRequestTimings, + NoDelayRequestTimings, + PoissonRateRequestTimings, +) +from guidellm.utils import MsgpackEncoding, random + + +class MockRequestTimings(MeasuredRequestTimings): + """Mock timing implementation for testing.""" + + +class MockBackend(BackendInterface): + """Mock backend for testing worker functionality.""" + + def __init__( + self, + delay: float = 0.01, + should_fail: bool = False, + request_error_rate: float = 0.0, + ): + self.delay = delay + self.should_fail = should_fail + self.request_error_rate = request_error_rate + self.process_startup_called = False + self.validate_called = False + self.process_shutdown_called = False + self.resolve_called = False + + @property + def processes_limit(self) -> Optional[int]: + return None + + @property + def requests_limit(self) -> Optional[int]: + return None + + def info(self) -> dict[str, Any]: + return {"type": "mock", "delay": self.delay} + + async def process_startup(self): + await asyncio.sleep(self.delay) + self.process_startup_called = True + + async def validate(self): + await asyncio.sleep(self.delay) + self.validate_called = True + if self.should_fail: + raise RuntimeError("Mock validation failed") + + async def process_shutdown(self): + await asyncio.sleep(0.1) + self.process_shutdown_called = True + + async def resolve(self, request, request_info, request_history): + self.resolve_called = True + await asyncio.sleep(self.delay) + if self.should_fail: + raise RuntimeError("Mock resolve failed") + if self.request_error_rate > 0.0 and random.random() < self.request_error_rate: + raise RuntimeError("Mock resolve failed") + yield f"response_for_{request}" + + +class TestWorkerProcess: + """Test suite for WorkerProcess class.""" + + @pytest.fixture + def worker_process(self): + """Create a WorkerProcess instance for testing.""" + backend = MockBackend() + request_timings = LastCompletionRequestTimings() + + return WorkerProcess( + local_rank=0, + local_world_size=2, + async_limit=5, + startup_barrier=Barrier(2), + shutdown_event=Event(), + error_event=Event(), + requests_queue=Queue(), + updates_queue=Queue(), + backend=backend, + request_timings=request_timings, + poll_intervals=0.01, + ) + + @pytest.mark.smoke + def test_class_signatures(self, worker_process: WorkerProcess): + """Test inheritance and type relationships.""" + # Class + assert isinstance(worker_process, Generic) + assert issubclass(WorkerProcess, Generic) + + # Generics + orig_bases = getattr(WorkerProcess, "__orig_bases__", ()) + assert len(orig_bases) > 0 + generic_base = next( + ( + base + for base in orig_bases + if hasattr(base, "__origin__") and base.__origin__ is Generic + ), + None, + ) + assert generic_base is not None + type_args = getattr(generic_base, "__args__", ()) + assert len(type_args) == 3 # RequestT, MeasuredRequestTimingsT, ResponseT + + # Function signatures + run_sig = inspect.signature(WorkerProcess.run) + assert len(run_sig.parameters) == 1 + assert "self" in run_sig.parameters + + run_async_sig = inspect.signature(WorkerProcess.run_async) + assert len(run_async_sig.parameters) == 1 + assert "self" in run_async_sig.parameters + + stop_processing_sig = inspect.signature(WorkerProcess.run_async_stop_processing) + assert len(stop_processing_sig.parameters) == 1 + assert "self" in stop_processing_sig.parameters + + requests_processing_sig = inspect.signature( + WorkerProcess.run_async_requests_processing + ) + assert len(requests_processing_sig.parameters) == 1 + assert "self" in requests_processing_sig.parameters + + @pytest.mark.smoke + def test_initialization(self, worker_process: WorkerProcess): + """Test basic initialization of WorkerProcess.""" + # worker info + assert worker_process.local_rank == 0 + assert worker_process.local_world_size == 2 + assert worker_process.async_limit == 5 + + # process synchronization + assert isinstance(worker_process.startup_barrier, ProcessingBarrier) + assert isinstance(worker_process.shutdown_event, ProcessingEvent) + assert isinstance(worker_process.error_event, ProcessingEvent) + assert hasattr(worker_process.requests_queue, "put") + assert hasattr(worker_process.requests_queue, "get") + assert hasattr(worker_process.updates_queue, "put") + assert hasattr(worker_process.updates_queue, "get") + + # local synchronization + assert worker_process.pending_requests_queue is None + assert worker_process.pending_updates_queue is None + + # request processing + assert isinstance(worker_process.backend, MockBackend) + assert worker_process.poll_intervals == 0.01 + assert isinstance(worker_process.request_timings, LastCompletionRequestTimings) + + @pytest.mark.sanity + def test_invalid_initialization(self): + """Test that invalid initialization raises appropriate errors.""" + # Test with missing required parameters + with pytest.raises(TypeError): + WorkerProcess() + + # Create a complete set of valid parameters + backend = MockBackend() + request_timings = LastCompletionRequestTimings() + barrier = Barrier(2) + shutdown_event = Event() + error_event = Event() + requests_queue = Queue() + updates_queue = Queue() + + # Test missing each required parameter one by one + required_params = [ + "local_rank", + "local_world_size", + "async_limit", + "startup_barrier", + "shutdown_event", + "error_event", + "requests_queue", + "updates_queue", + "backend", + "request_timings", + ] + + for param_to_remove in required_params: + kwargs = { + "local_rank": 0, + "local_world_size": 2, + "async_limit": 5, + "startup_barrier": barrier, + "shutdown_event": shutdown_event, + "error_event": error_event, + "requests_queue": requests_queue, + "updates_queue": updates_queue, + "backend": backend, + "request_timings": request_timings, + "poll_intervals": 0.01, + } + + del kwargs[param_to_remove] + + with pytest.raises(TypeError): + WorkerProcess(**kwargs) + + @pytest.mark.smoke + @patch("asyncio.run") + def test_run(self, mock_asyncio_run, worker_process: WorkerProcess): + """ + Test that run method functions as expected (calls run_async, handles errors) + """ + # Test successful execution + with patch.object( + worker_process, "run_async", new_callable=AsyncMock + ) as mock_run_async: + worker_process.run() + mock_asyncio_run.assert_called_once() + mock_run_async.assert_called_once() + + mock_asyncio_run.reset_mock() + + # Test exception during execution + test_exception = RuntimeError("Test error in run_async") + with patch.object( + worker_process, "run_async", new_callable=AsyncMock + ) as mock_run_async: + mock_asyncio_run.side_effect = test_exception + + with pytest.raises( + RuntimeError, match="Worker process 0 encountered an error" + ): + worker_process.run() + + assert worker_process.error_event.is_set() + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("stop_action", "req_action"), + [ + ("complete_short", "complete_short"), + ("complete_long", "error"), + ("error", "complete_long"), + ("error", "error"), + ("complete_long", "cancel"), + ("cancel", "complete_long"), + ("cancel", "cancel"), + ], + ) + async def test_run_async( # noqa: C901 + self, + worker_process: WorkerProcess, + stop_action: Literal["complete_short", "complete_long", "error", "cancel"], + req_action: Literal["complete_short", "complete_long", "error", "cancel"], + ): + def make_task(action: str, state: dict): + loops = {"error": 1, "cancel": 2, "complete_short": 3, "complete_long": 50}[ + action + ] + + async def _run(self): + state.update(called=True, iterations=0) + try: + for _ in range(loops): + await asyncio.sleep(0.01) + state["iterations"] += 1 + if action == "error": + state["errored"] = True + raise RuntimeError(state["error_message"]) + if action == "cancel": + state["cancelled"] = True + raise asyncio.CancelledError(state["cancel_message"]) + if action == "complete_short": + state["completed_short"] = True + if action == "complete_long": + state["completed_long"] = True + except asyncio.CancelledError: + state["cancelled"] = True + raise + + return _run, loops + + def init_state(prefix): + return { + "called": False, + "iterations": 0, + "completed_short": False, + "completed_long": False, + "errored": False, + "cancelled": False, + "error_message": f"{prefix} processing error", + "cancel_message": f"{prefix} processing cancelled", + } + + stop_state, req_state = init_state("Stop"), init_state("Requests") + stop_fn, stop_loops = make_task(stop_action, stop_state) + req_fn, req_loops = make_task(req_action, req_state) + + expected_exc = RuntimeError if "error" in {stop_action, req_action} else None + with ( + patch.object( + type(worker_process), "run_async_stop_processing", new=stop_fn + ), + patch.object( + type(worker_process), "run_async_requests_processing", new=req_fn + ), + ): + if expected_exc: + with pytest.raises(expected_exc): + await worker_process.run_async() + else: + await worker_process.run_async() + + assert stop_state["called"] + assert req_state["called"] + + # build unified expected outcome table + def is_long(a): + return a == "complete_long" + + def is_short(a): + return a in {"complete_short", "error", "cancel"} + + expectations = { + "stop": { + "errored": stop_action == "error", + "cancelled": stop_action == "cancel" + or (is_short(req_action) and is_long(stop_action)) + or (req_action == "error" and is_long(stop_action)), + }, + "req": { + "errored": req_action == "error", + "cancelled": req_action == "cancel" + or (is_short(stop_action) and is_long(req_action)) + or (stop_action == "error" and is_long(req_action)), + }, + } + + # assert final state matches expectations + for label, (state, action) in { + "stop": (stop_state, stop_action), + "req": (req_state, req_action), + }.items(): + if expectations[label]["errored"]: + assert state["errored"] + if expectations[label]["cancelled"]: + assert state["cancelled"] + if action.startswith("complete_") and not expectations[label]["cancelled"]: + key = ( + "completed_short" + if action == "complete_short" + else "completed_long" + ) + assert state[key] + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "stop_action", + ["error_event", "shutdown_event", "cancel_event"], + ) + async def test_run_async_stop_processing( + self, worker_process: WorkerProcess, stop_action + ): + # ensure initial state + assert not worker_process.error_event.is_set() + assert not worker_process.shutdown_event.is_set() + + action = stop_action + early_check_delay = 0.01 + trigger_delay = 0.05 + + task = asyncio.create_task(worker_process.run_async_stop_processing()) + time_start = time.time() + await asyncio.sleep(early_check_delay) + assert not task.done(), "Task finished before any stop signal was triggered" + + async def trigger(): + await asyncio.sleep(trigger_delay - early_check_delay) + if action == "error_event": + worker_process.error_event.set() + elif action == "shutdown_event": + worker_process.shutdown_event.set() + elif action == "cancel_event": + task.cancel() + + trigger_task = asyncio.create_task(trigger()) + + if action == "error_event": + with pytest.raises(RuntimeError): + await asyncio.wait_for(task, timeout=1.0) + elif action in {"shutdown_event", "cancel_event"}: + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(task, timeout=1.0) + else: + raise ValueError(f"Unknown stop action: {action}") + + await asyncio.gather(trigger_task, return_exceptions=True) + + # validate correct ending states + elapsed = time.time() - time_start + assert elapsed >= trigger_delay - 0.01, ( + "Task completed too early: " + f"elapsed={elapsed:.3f}s < trigger={trigger_delay:.3f}s" + ) + if action == "error_event": + assert worker_process.error_event.is_set() + assert not worker_process.shutdown_event.is_set() + elif action == "shutdown_event": + assert worker_process.shutdown_event.is_set() + assert not worker_process.error_event.is_set() + elif action == "cancel_event": + assert not worker_process.error_event.is_set() + assert not worker_process.shutdown_event.is_set() + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("request_timings_const", "async_limit"), + [ + (lambda: LastCompletionRequestTimings(), 1), + (lambda: PoissonRateRequestTimings(rate=10000), 2), + (lambda: ConstantRateRequestTimings(rate=10000), 3), + (lambda: NoDelayRequestTimings(), 4), + ], + ) + async def test_run_async_requests_processing( # noqa: C901 + self, + request_timings_const: Callable[[], ScheduledRequestTimings], + async_limit: int, + ): + startup_barrier = Barrier(2) + requests_queue = Queue() + updates_queue = Queue() + backend = MockBackend(delay=0.001) + worker_process = WorkerProcess( + local_rank=0, + local_world_size=1, + async_limit=async_limit, + startup_barrier=startup_barrier, + shutdown_event=Event(), + error_event=Event(), + requests_queue=requests_queue, + updates_queue=updates_queue, + backend=backend, + request_timings=request_timings_const(), + poll_intervals=0.01, + ) + + def _trip_barrier_later(): + time.sleep(0.02) + with contextlib.suppress(RuntimeError): + # barrier may be aborted (suppressed) during cancellation + worker_process.startup_barrier.wait(timeout=1.0) + + threading.Thread(target=_trip_barrier_later, daemon=True).start() + + run_task = asyncio.create_task(worker_process.run_async_requests_processing()) + await asyncio.sleep(0.05) # small delay to allow start up first + + # validate start up + assert worker_process.backend.process_startup_called + assert worker_process.backend.validate_called + assert worker_process.pending_requests_queue is not None + assert worker_process.pending_updates_queue is not None + assert worker_process.startup_completed + + # ensure full processing of requests + for index in range(20): + requests_queue.put( + MsgpackEncoding.encode( + ( + f"req-{index}", + ScheduledRequestInfo[MeasuredRequestTimings]( + request_id=f"req-{index}", + status="queued", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=time.time(), + ), + ) + ) + ) + + updates = [] + num_failures = 0 + max_wait_time = 5.0 + start_time = time.time() + while time.time() - start_time < max_wait_time: + try: + update_message = updates_queue.get_nowait() + updates.append(MsgpackEncoding.decode(update_message)) + num_failures = 0 + except Empty: + num_failures += 1 + if len(updates) >= 40: # We got all expected updates + break + await asyncio.sleep(0.05) + + # validate updates are correct for each request + assert len(updates) == 40 + per_request = defaultdict(dict) + for update in updates: + response, request, info = update + if info.status == "in_progress": + per_request[info.request_id]["start"] = (response, request, info) + per_request[info.request_id]["targeted_start"] = ( + info.scheduler_timings.targeted_start + ) + per_request[info.request_id]["resolve_start"] = ( + info.scheduler_timings.resolve_start + ) + elif info.status == "completed": + per_request[info.request_id]["complete"] = (response, request, info) + per_request[info.request_id]["resolve_end"] = ( + info.scheduler_timings.resolve_end + ) + assert len(per_request) == 20 + assert all( + "start" in parts and "complete" in parts for parts in per_request.values() + ) + + # validate request times match expected + last_targeted_start = -1 * math.inf + for index in range(20): + targeted_start = per_request[f"req-{index}"]["targeted_start"] + resolve_start = per_request[f"req-{index}"]["resolve_start"] + resolve_end = per_request[f"req-{index}"]["resolve_end"] + assert targeted_start >= last_targeted_start + assert targeted_start < resolve_start + assert resolve_start == pytest.approx(targeted_start) + assert resolve_end == pytest.approx(resolve_start + backend.delay) + + # Validate concurrency limits are respected + events = [] + for req_id in per_request: + events.append((per_request[req_id]["resolve_start"], 1)) + events.append((per_request[req_id]["resolve_end"], -1)) + events.sort() + max_concurrent = concurrent = 0 + for _, delta in events: + concurrent += delta + max_concurrent = max(max_concurrent, concurrent) + assert max_concurrent <= async_limit + + # validate cancellation + backend.delay = 10 + # max concurrent for backend + 2 queued for backend + num_cancel_tasks = (async_limit + 2) * 2 + for index in range(20, 20 + num_cancel_tasks): + requests_queue.put( + MsgpackEncoding.encode( + ( + f"req-{index}", + ScheduledRequestInfo[MeasuredRequestTimings]( + request_id=f"req-{index}", + status="queued", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=time.time(), + ), + ) + ) + ) + await asyncio.sleep(0.5) + run_task.cancel() + await asyncio.gather(run_task, return_exceptions=True) + assert worker_process.backend.process_shutdown_called + assert worker_process.pending_requests_queue is None + assert worker_process.pending_updates_queue is None + + # validate canceled tasks + updates = [] + num_failures = 0 + while True: + try: + update_message = updates_queue.get_nowait() + updates.append(MsgpackEncoding.decode(update_message)) + except Empty: + num_failures += 1 + if num_failures > 3: + break + await asyncio.sleep(0.1) + # Ensure we get all updates we expected (async_limit for pending + 2 for queued) + assert len(updates) >= 2 * (async_limit + 2) + # Ensure we didn't process all requests on the queue and shutdown early + assert len(updates) < 2 * 2 * (async_limit + 2) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("request_timings_const", "async_limit", "request_error_rate"), + [ + (lambda: LastCompletionRequestTimings(), 1, 0.1), + (lambda: PoissonRateRequestTimings(rate=10000), 2, 0.2), + (lambda: ConstantRateRequestTimings(rate=10000), 3, 0.3), + (lambda: NoDelayRequestTimings(), 4, 0.4), + ], + ) + def test_run_lifecycle( + self, + request_timings_const: Callable[[], ScheduledRequestTimings], + async_limit: int, + request_error_rate: float, + ): + backend = MockBackend( + delay=0.01, + request_error_rate=request_error_rate, + ) + startup_barrier = Barrier(2) + shutdown_event = Event() + requests_queue = Queue() + updates_queue = Queue() + backend = MockBackend(delay=0.001) + worker_process = WorkerProcess( + local_rank=0, + local_world_size=1, + async_limit=async_limit, + startup_barrier=startup_barrier, + shutdown_event=shutdown_event, + error_event=Event(), + requests_queue=requests_queue, + updates_queue=updates_queue, + backend=backend, + request_timings=request_timings_const(), + poll_intervals=0.01, + ) + + def _background_thread(): + time.sleep(0.1) # delay for startup + startup_barrier.wait() + + for index in range(20): + requests_queue.put( + MsgpackEncoding.encode( + ( + f"req-{index}", + ScheduledRequestInfo[MeasuredRequestTimings]( + request_id=f"req-{index}", + status="queued", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=time.time(), + ), + ) + ) + ) + + time.sleep(0.5) # delay for processing + shutdown_event.set() + + threading.Thread(target=_background_thread).start() + worker_process.run() + + updates = [] + max_attempts = 50 + attempts = 0 + while attempts < max_attempts: + try: + update_message = updates_queue.get_nowait() + updates.append(MsgpackEncoding.decode(update_message)) + except Empty: + attempts += 1 + if len(updates) >= 40: # We got all expected updates + break + time.sleep(0.05) + + # Validate updates + assert len(updates) == 40 + per_request = defaultdict(dict) + for update in updates: + response, request, info = update + if info.status == "in_progress": + per_request[info.request_id]["start"] = (response, request, info) + per_request[info.request_id]["targeted_start"] = ( + info.scheduler_timings.targeted_start + ) + per_request[info.request_id]["resolve_start"] = ( + info.scheduler_timings.resolve_start + ) + elif info.status == "completed": + per_request[info.request_id]["complete"] = (response, request, info) + per_request[info.request_id]["resolve_end"] = ( + info.scheduler_timings.resolve_end + ) + assert len(per_request) == 20 + assert all( + "start" in parts and "complete" in parts for parts in per_request.values() + ) diff --git a/tests/unit/scheduler/test_worker_group.py b/tests/unit/scheduler/test_worker_group.py new file mode 100644 index 00000000..173a214b --- /dev/null +++ b/tests/unit/scheduler/test_worker_group.py @@ -0,0 +1,672 @@ +import asyncio +import inspect +import math +import os +import queue +import threading +import time +from collections import defaultdict +from multiprocessing import get_context +from queue import Empty +from typing import Any, Generic, Optional + +import culsans +import pytest + +from guidellm.scheduler import ( + AsyncConstantStrategy, + AsyncPoissonStrategy, + BackendInterface, + ConcurrentStrategy, + MaxNumberConstraint, + MeasuredRequestTimings, + ScheduledRequestInfo, + SchedulerState, + SynchronousStrategy, + ThroughputStrategy, + WorkerProcessGroup, + worker_group, +) +from guidellm.utils import MsgpackEncoding + + +class MockWorker: + """Picklable mock worker used to validate create_processes logic.""" + + @classmethod + def __class_getitem__(cls, item): + return cls + + def __init__( + self, + local_rank, + local_world_size, + async_limit, + startup_barrier, + shutdown_event, + error_event, + requests_queue, + updates_queue, + backend, + request_timings, + poll_intervals, + ): + self.local_rank = local_rank + self.local_world_size = local_world_size + self.async_limit = async_limit + self.startup_barrier = startup_barrier + self.shutdown_event = shutdown_event + self.error_event = error_event + self.requests_queue = requests_queue + self.updates_queue = updates_queue + self.backend = backend + self.request_timings = request_timings + self.poll_intervals = poll_intervals + + def run(self): + try: + # Access parameters to ensure they're usable and wait for barrier + shutdown_is_set = self.shutdown_event.is_set() + error_is_set = self.error_event.is_set() + backend_info = self.backend.info() + + self.startup_barrier.wait() + + # Publish diagnostics back to parent for assertions + payload = ( + "diag", + self.local_rank, + { + "child_pid": os.getpid(), + "local_rank": self.local_rank, + "local_world_size": self.local_world_size, + "async_limit": self.async_limit, + "backend_info": backend_info, + "shutdown_is_set": shutdown_is_set, + "error_is_set": error_is_set, + "passed_barrier": True, + "request_timings_type": type(self.request_timings).__name__, + }, + ) + self.updates_queue.put(payload) + except Exception as err: # noqa: BLE001 + try: + self.error_event.set() + self.updates_queue.put(("error", self.local_rank, repr(err))) + finally: + raise + + +class MockWorkerProcessor(MockWorker): + def run(self): + self.startup_barrier.wait() + + while not self.shutdown_event.is_set() and not self.error_event.is_set(): + try: + request_msg = self.requests_queue.get(timeout=0.1) + except queue.Empty: + continue + + request, request_info = MsgpackEncoding.decode(request_msg) + request_info.status = "in_progress" + self.updates_queue.put( + MsgpackEncoding.encode((None, request, request_info)) + ) + time.sleep(0.01) + request_info.status = "completed" + response = f"response_for_{request}" + self.updates_queue.put( + MsgpackEncoding.encode((response, request, request_info)) + ) + + +class MockRequestTimings(MeasuredRequestTimings): + """Mock timing implementation for testing.""" + + +class MockBackend(BackendInterface): + """Mock backend for testing worker group functionality.""" + + def __init__( + self, + processes_limit_value: Optional[int] = None, + requests_limit_value: Optional[int] = None, + ): + self._processes_limit = processes_limit_value + self._requests_limit = requests_limit_value + + @property + def processes_limit(self) -> Optional[int]: + return self._processes_limit + + @property + def requests_limit(self) -> Optional[int]: + return self._requests_limit + + def info(self) -> dict[str, Any]: + return {"type": "mock"} + + async def process_startup(self): + pass + + async def validate(self): + pass + + async def process_shutdown(self): + pass + + async def resolve(self, request, request_info, request_history): + yield f"response_for_{request}" + + +class TestWorkerProcessGroup: + """Test suite for WorkerProcessGroup class.""" + + @pytest.fixture + def worker_process_group(self): + """Create a WorkerProcessGroup instance for testing.""" + backend = MockBackend() + requests = ["request1", "request2", "request3"] + strategy = SynchronousStrategy() + constraints = {"max_requests": MaxNumberConstraint(max_num=10)} + + return WorkerProcessGroup( + backend=backend, + requests=requests, + strategy=strategy, + constraints=constraints, + ) + + @pytest.mark.smoke + def test_class_signatures(self, worker_process_group: WorkerProcessGroup): + """Test inheritance and type relationships.""" + # Class + assert isinstance(worker_process_group, Generic) + assert issubclass(WorkerProcessGroup, Generic) + + # Generics + orig_bases = getattr(WorkerProcessGroup, "__orig_bases__", ()) + assert len(orig_bases) > 0 + generic_base = next( + ( + base + for base in orig_bases + if hasattr(base, "__origin__") and base.__origin__ is Generic + ), + None, + ) + assert generic_base is not None + type_args = getattr(generic_base, "__args__", ()) + assert len(type_args) == 3 + + # Function signatures + create_processes_sig = inspect.signature(WorkerProcessGroup.create_processes) + assert len(create_processes_sig.parameters) == 1 + assert "self" in create_processes_sig.parameters + + start_sig = inspect.signature(WorkerProcessGroup.start) + assert len(start_sig.parameters) == 2 + assert "self" in start_sig.parameters + assert "start_time" in start_sig.parameters + + request_updates_sig = inspect.signature(WorkerProcessGroup.request_updates) + assert len(request_updates_sig.parameters) == 1 + assert "self" in request_updates_sig.parameters + + shutdown_sig = inspect.signature(WorkerProcessGroup.shutdown) + assert len(shutdown_sig.parameters) == 1 + assert "self" in shutdown_sig.parameters + + @pytest.mark.smoke + def test_initialization(self, worker_process_group: WorkerProcessGroup): + """Test basic initialization of WorkerProcessGroup.""" + # Core attributes + assert isinstance(worker_process_group.backend, MockBackend) + expected_requests = ["request1", "request2", "request3"] + assert list(worker_process_group.requests) == expected_requests + assert isinstance(worker_process_group.strategy, SynchronousStrategy) + assert isinstance(worker_process_group.constraints, dict) + assert "max_requests" in worker_process_group.constraints + constraint = worker_process_group.constraints["max_requests"] + assert isinstance(constraint, MaxNumberConstraint) + + # Multiprocessing attributes (should be None initially) + assert worker_process_group.mp_context is None + assert worker_process_group.processes is None + + # Synchronization primitives (should be None initially) + assert worker_process_group.startup_barrier is None + assert worker_process_group.shutdown_event is None + assert worker_process_group.error_event is None + + # Queues (should be None initially) + assert worker_process_group.requests_queue is None + assert worker_process_group.updates_queue is None + assert worker_process_group.pending_updates_queue is None + assert worker_process_group.pending_updates_complete is None + + # Scheduler state and tasks (should be None initially) + assert worker_process_group.state_update_lock is None + assert worker_process_group.scheduler_state is None + assert worker_process_group.populate_requests_task is None + assert worker_process_group.populate_updates_task is None + + @pytest.mark.sanity + def test_invalid_initialization(self): + """Test that invalid initialization raises appropriate errors.""" + # Test with missing required parameters + with pytest.raises(TypeError): + WorkerProcessGroup() + + # Create a complete set of valid parameters + backend = MockBackend() + requests = ["request1", "request2"] + strategy = SynchronousStrategy() + constraints = {"max_requests": MaxNumberConstraint(max_num=10)} + + # Test missing each required parameter one by one + required_params = [ + "backend", + "requests", + "strategy", + "constraints", + ] + + for param_to_remove in required_params: + kwargs = { + "backend": backend, + "requests": requests, + "strategy": strategy, + "constraints": constraints, + } + + del kwargs[param_to_remove] + + with pytest.raises(TypeError): + WorkerProcessGroup(**kwargs) + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("strategy", "expected_num_procs", "expected_max_conc"), + [ + (SynchronousStrategy(), 1, 1), + (ConcurrentStrategy(streams=3), 3, 3), + (ThroughputStrategy(max_concurrency=6), 3, 6), + (AsyncConstantStrategy(rate=100.0), 3, 12), + (AsyncPoissonStrategy(rate=100.0), 3, 12), + ], + ) + async def test_create_processes( + self, + monkeypatch, + strategy, + expected_num_procs, + expected_max_conc, + ): + # Patch required mock settings + monkeypatch.setattr( + worker_group.settings, "max_worker_processes", 3, raising=False + ) + monkeypatch.setattr(worker_group.settings, "max_concurrency", 12, raising=False) + monkeypatch.setattr( + worker_group.settings, "scheduler_poll_interval", 0.01, raising=False + ) + monkeypatch.setattr(worker_group, "WorkerProcess", MockWorker, raising=True) + + # Setup group to test + backend = MockBackend() + requests = [f"r{i}" for i in range(10)] + constraints = {"max_requests": MaxNumberConstraint(max_num=100)} + group = WorkerProcessGroup( + backend=backend, + requests=requests, + strategy=strategy, + constraints=constraints, + ) + + # Run within a reasonable time limit + try: + await asyncio.wait_for(group.create_processes(), timeout=5.0) + except asyncio.TimeoutError: + pytest.fail("create_processes() timed out after 5 seconds") + + # Check expected attributes are created + assert group.mp_context is not None + assert hasattr(group.mp_context, "Barrier") + assert hasattr(group.mp_context, "Event") + assert hasattr(group.mp_context, "Queue") + assert group.processes is not None + assert len(group.processes) == expected_num_procs + + # Validate processes ran correctly + diags: dict[int, dict] = {} + for _ in range(expected_num_procs): + kind, rank, payload = group.updates_queue.get(timeout=3) + if kind == "error": + pytest.fail(f"Worker {rank} reported error: {payload}") + assert kind == "diag" + diags[rank] = payload + + # Verify returned processes state + main_pid = os.getpid() + assert len(diags) == expected_num_procs + for rank, payload in diags.items(): + assert payload["local_rank"] == rank + assert payload["local_world_size"] == expected_num_procs + assert payload["passed_barrier"] is True + assert payload["shutdown_is_set"] is False + assert payload["error_is_set"] is False + assert isinstance(payload["backend_info"], dict) + assert payload["child_pid"] != main_pid + per_proc = math.ceil(expected_max_conc / expected_num_procs) + expected_last = expected_max_conc - per_proc * (expected_num_procs - 1) + for rank, payload in diags.items(): + exp_limit = per_proc if rank < expected_num_procs - 1 else expected_last + assert payload["async_limit"] == exp_limit + + exceptions = await group.shutdown() + assert len(exceptions) == 0, f"Shutdown encountered exceptions: {exceptions}" + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_start(self, monkeypatch): + # Patch required mock settings + monkeypatch.setattr( + worker_group.settings, "max_worker_processes", 1, raising=False + ) + monkeypatch.setattr(worker_group.settings, "max_concurrency", 1, raising=False) + monkeypatch.setattr( + worker_group.settings, "scheduler_poll_interval", 0.01, raising=False + ) + monkeypatch.setattr(worker_group, "WorkerProcess", MockWorker, raising=True) + + # Setup group and mimic create_processes + backend = MockBackend() + requests = [f"r{i}" for i in range(5)] # to few requests, test new iter logic + group = WorkerProcessGroup( + backend=backend, + requests=requests, + strategy=SynchronousStrategy(), + constraints={"max_num": MaxNumberConstraint(max_num=10)}, + ) + group.mp_context = get_context("fork") + group.startup_barrier = group.mp_context.Barrier(2) + group.shutdown_event = group.mp_context.Event() + group.error_event = group.mp_context.Event() + group.requests_queue = group.mp_context.Queue() + group.updates_queue = group.mp_context.Queue() + group.pending_updates_queue = culsans.Queue() + group.pending_updates_complete = threading.Event() + group.processes = [None] + + # Validate function runs and returns at start_time + start_time = time.time() + 0.2 + await asyncio.wait_for(group.start(start_time), timeout=3.0) + end_time = time.time() + assert end_time == pytest.approx(start_time, abs=0.01) + + # Validate instance state + assert group.state_update_lock is not None + assert hasattr(group.state_update_lock, "acquire") + assert group.scheduler_state is not None + assert group.scheduler_state.num_processes == 1 + assert group.scheduler_state.start_time == start_time + assert isinstance(group.populate_requests_task, asyncio.Task) + assert isinstance(group.populate_updates_task, asyncio.Task) + + # Pull the queued requests + await asyncio.sleep(0.1) + sent_requests = [] + while True: + await asyncio.sleep(0) + try: + req = group.requests_queue.get(timeout=1.0) + sent_requests.append(req) + except Empty: + break + assert len(sent_requests) == 10 + + # Enqueue lifecycle updates + for req in requests + requests: + group.updates_queue.put( + MsgpackEncoding.encode( + ( + None, + req, + ScheduledRequestInfo[MockRequestTimings]( + request_id=str(req), + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ), + ) + ) + ) + group.updates_queue.put( + MsgpackEncoding.encode( + ( + None, + req, + ScheduledRequestInfo[MockRequestTimings]( + request_id=str(req), + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ), + ) + ) + ) + await asyncio.sleep(0) + + # Drain 3 updates per request (queued, started, completed) + await asyncio.sleep(0.1) + updates = [] + for _ in range(3 * 10): + try: + update = await asyncio.wait_for( + group.pending_updates_queue.async_get(), timeout=1.0 + ) + updates.append(update) + except asyncio.TimeoutError: + break + assert len(updates) == 3 * 10 + + # Ensure tasks finish + if not group.populate_requests_task.done(): + await asyncio.wait_for(group.populate_requests_task, timeout=1.0) + if not group.populate_updates_task.done(): + await asyncio.wait_for(group.populate_updates_task, timeout=1.0) + + # Clean up resources + group.processes = None + exceptions = await group.shutdown() + assert len(exceptions) == 0, f"Shutdown encountered exceptions: {exceptions}" + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_start_cancel_requests_handling(self, monkeypatch): + """Test the start() method's async tasks handle shutdown correctly""" + # Patch required mock settings + monkeypatch.setattr( + worker_group.settings, "max_worker_processes", 1, raising=False + ) + monkeypatch.setattr(worker_group.settings, "max_concurrency", 1, raising=False) + monkeypatch.setattr( + worker_group.settings, "scheduler_poll_interval", 0.01, raising=False + ) + monkeypatch.setattr(worker_group, "WorkerProcess", MockWorker, raising=True) + + # Setup group and mimic create_processes + backend = MockBackend() + requests = [f"req_{i}" for i in range(10)] + group = WorkerProcessGroup( + backend=backend, + requests=requests, + strategy=SynchronousStrategy(), + constraints={}, + ) + group.mp_context = get_context("fork") + group.startup_barrier = group.mp_context.Barrier(2) + group.shutdown_event = group.mp_context.Event() + group.error_event = group.mp_context.Event() + group.requests_queue = group.mp_context.Queue(maxsize=1) # Ensure saturated + group.updates_queue = group.mp_context.Queue() + group.pending_updates_queue = culsans.Queue() + group.pending_updates_complete = threading.Event() + group.processes = [None] + + # Validate function runs and returns at start_time + start_time = time.time() + 0.1 + await asyncio.wait_for(group.start(start_time), timeout=3.0) + end_time = time.time() + assert end_time == pytest.approx(start_time, abs=0.01) + + # Verify tasks are running + assert isinstance(group.populate_requests_task, asyncio.Task) + assert isinstance(group.populate_updates_task, asyncio.Task) + assert not group.populate_requests_task.done() + assert not group.populate_updates_task.done() + + def _process_request(): + req, req_info = MsgpackEncoding.decode( + group.requests_queue.get(timeout=1.0) + ) + group.updates_queue.put( + MsgpackEncoding.encode( + ( + None, + req, + ScheduledRequestInfo[MockRequestTimings]( + request_id=str(req), + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ), + ) + ) + ) + group.updates_queue.put( + MsgpackEncoding.encode( + ( + None, + req, + ScheduledRequestInfo[MockRequestTimings]( + request_id=str(req), + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ), + ) + ) + ) + + # Pull a few requests and push updates to ensure saturation + for _ in range(3): + await asyncio.sleep(0) + _process_request() + + # Check that we've received all expected updates so far + updates_by_request = defaultdict(list) + while True: + try: + resp, req, req_info, state = await asyncio.wait_for( + group.pending_updates_queue.async_get(), + timeout=0.1, + ) + updates_by_request[req].append(req_info.status) + except asyncio.TimeoutError: + break + for index, (_, statuses) in enumerate(updates_by_request.items()): + if index < 3: + assert statuses == ["queued", "in_progress", "completed"] + else: + assert statuses == ["queued"] + + # Test that shutdown event stops the tasks + group.shutdown_event.set() + await asyncio.sleep(0.1) # allow propagation + assert group.pending_requests_complete.is_set() + assert group.populate_requests_task.done() + await asyncio.sleep(0.1) # allow processing + assert group.pending_updates_complete.is_set() + assert group.populate_updates_task.done() + + # Check all expected pending updates and statuses processed + while True: + try: + resp, req, req_info, state = await asyncio.wait_for( + group.pending_updates_queue.async_get(), timeout=0.1 + ) + updates_by_request[req].append(req_info.status) + except asyncio.TimeoutError: + break + + for index, (_, statuses) in enumerate(updates_by_request.items()): + if index < 3: + assert statuses == ["queued", "in_progress", "completed"] + else: + assert statuses == ["queued", "in_progress", "cancelled"] + + # Clean up resources + group.processes = None + exceptions = await group.shutdown() + assert len(exceptions) == 0, f"Shutdown encountered exceptions: {exceptions}" + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_request_updates(self, monkeypatch): + """Test the request_updates async iterator functionality.""" + # Configure settings for controlled testing + monkeypatch.setattr( + worker_group.settings, "max_worker_processes", 1, raising=False + ) + monkeypatch.setattr(worker_group.settings, "max_concurrency", 1, raising=False) + monkeypatch.setattr( + worker_group.settings, "scheduler_poll_interval", 0.01, raising=False + ) + monkeypatch.setattr( + worker_group, "WorkerProcess", MockWorkerProcessor, raising=True + ) + + # Setup group + backend = MockBackend() + requests = [f"req_{index}" for index in range(20)] + group = WorkerProcessGroup( + backend=backend, + requests=requests, + strategy=SynchronousStrategy(), + constraints={"max_num": MaxNumberConstraint(max_num=10)}, + ) + + # Mimic create_processes to set required state + await group.create_processes() + await group.start(time.time() + 0.05) + + # Collect all updates from request_updates iterator + received_updates = defaultdict(list) + received_responses = [] + count = 0 + async for resp, req, req_info, state in group.request_updates(): + assert isinstance(req_info, ScheduledRequestInfo) + assert isinstance(state, SchedulerState) + received_updates[req].append(req_info.status) + if resp is not None: + received_responses.append(resp) + count += 1 + + # Check we have all expected updates (10 requests) + assert len(received_updates) == 10 + for index, (req, statuses, resp) in enumerate( + zip(received_updates.keys(), received_updates.values(), received_responses) + ): + assert req == f"req_{index}" + assert resp == f"response_for_req_{index}" + assert statuses == ["queued", "in_progress", "completed"] + + # Cleanup + await group.shutdown() diff --git a/tests/unit/utils/test_encoding.py b/tests/unit/utils/test_encoding.py new file mode 100644 index 00000000..404a8671 --- /dev/null +++ b/tests/unit/utils/test_encoding.py @@ -0,0 +1,222 @@ +from typing import Any, Generic, TypeVar + +import pytest +from pydantic import BaseModel, Field + +from guidellm.utils.encoding import MsgpackEncoding + + +class SimpleModel(BaseModel): + name: str + value: int + + +class NestedModel(BaseModel): + simple: SimpleModel + items: list[str] + metadata: dict[str, Any] + + +T = TypeVar("T") + + +class GenericModel(BaseModel, Generic[T]): + data: T + count: int + + +class ComplexModel(BaseModel): + id: str = Field(description="Unique identifier") + nested: NestedModel + numbers: list[int] + mapping: dict[str, SimpleModel] + + +class TestMsgpackEncoding: + @pytest.mark.smoke + @pytest.mark.parametrize( + "primitive_data", + [ + # Basic primitives + 42, + 3.14, + True, + False, + None, + "hello world", + "", + [], + [1, 2, 3], + {}, + {"key": "value"}, + # Nested collections + [1, [2, 3], {"nested": True}], + {"outer": {"inner": [1, 2, 3]}}, + # Mixed types + [1, "string", 3.14, True, None], + {"int": 42, "str": "hello", "float": 3.14, "bool": True, "null": None}, + ], + ) + def test_encode_decode_primitives(self, primitive_data): + """Test encoding and decoding of Python primitives and collections.""" + encoded = MsgpackEncoding.encode(primitive_data) + assert isinstance(encoded, bytes) + + decoded = MsgpackEncoding.decode(encoded) + assert decoded == primitive_data + assert isinstance(decoded, type(primitive_data)) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("tuple_data", "expected_list"), + [ + ((), []), + ((1, 2, 3), [1, 2, 3]), + ((1, (2, 3), {"tuple_dict": True}), [1, [2, 3], {"tuple_dict": True}]), + ], + ) + def test_encode_decode_tuples(self, tuple_data, expected_list): + encoded = MsgpackEncoding.encode(tuple_data) + assert isinstance(encoded, bytes) + + decoded = MsgpackEncoding.decode(encoded) + assert decoded == expected_list + assert isinstance(decoded, list) + + @pytest.mark.smoke + @pytest.mark.parametrize( + "model_data", + [ + SimpleModel(name="test", value=42), + NestedModel( + simple=SimpleModel(name="nested", value=100), + items=["a", "b", "c"], + metadata={"key": "value", "number": 123}, + ), + ComplexModel( + id="test-123", + nested=NestedModel( + simple=SimpleModel(name="complex", value=999), + items=["x", "y"], + metadata={"complex": True}, + ), + numbers=[1, 2, 3, 4, 5], + mapping={ + "first": SimpleModel(name="first", value=1), + "second": SimpleModel(name="second", value=2), + }, + ), + ], + ) + def test_encode_decode_pydantic_models(self, model_data): + """Test encoding and decoding of Pydantic models.""" + encoded = MsgpackEncoding.encode(model_data) + assert isinstance(encoded, bytes) + + decoded = MsgpackEncoding.decode(encoded) + assert decoded == model_data + assert isinstance(decoded, type(model_data)) + assert decoded.model_dump() == model_data.model_dump() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("generic_model", "expected_type"), + [ + (GenericModel[str](data="hello", count=1), str), + (GenericModel[int](data=42, count=2), int), + (GenericModel[list[str]](data=["a", "b"], count=3), list), + ], + ) + def test_encode_decode_generic_models(self, generic_model, expected_type): + """Test encoding and decoding of generic Pydantic models.""" + encoded = MsgpackEncoding.encode(generic_model) + assert isinstance(encoded, bytes) + + decoded = MsgpackEncoding.decode(encoded) + assert decoded == generic_model + assert decoded.data == generic_model.data + assert decoded.count == generic_model.count + assert isinstance(decoded.data, expected_type) + + @pytest.mark.smoke + @pytest.mark.parametrize( + "mixed_data", + [ + [SimpleModel(name="item1", value=1), SimpleModel(name="item2", value=2)], + {"model": SimpleModel(name="dict_value", value=42), "primitive": "string"}, + { + "models": [ + SimpleModel(name="item1", value=1), + SimpleModel(name="item2", value=2), + ], + "data": {"nested": {"deep": SimpleModel(name="deep", value=999)}}, + }, + [ + { + "id": "test", + "model": NestedModel( + simple=SimpleModel(name="nested_in_list", value=456), + items=["nested", "list"], + metadata={"in_list": True}, + ), + "primitives": [1, 2, 3], + } + ], + ], + ) + def test_encode_decode_mixed_collections(self, mixed_data): + encoded = MsgpackEncoding.encode(mixed_data) + assert isinstance(encoded, bytes) + + decoded = MsgpackEncoding.decode(encoded) + assert decoded == mixed_data + assert isinstance(decoded, type(mixed_data)) + + @pytest.mark.smoke + def test_round_trip_consistency(self): + original_data = { + "simple": SimpleModel(name="test", value=42), + "nested": NestedModel( + simple=SimpleModel(name="nested", value=100), + items=["a", "b", "c"], + metadata={"key": "value"}, + ), + "primitives": [1, 2, 3, "string", True, None], + "list_data": [1, 2, SimpleModel(name="list", value=999)], + } + + current_data = original_data + for _ in range(3): + encoded = MsgpackEncoding.encode(current_data) + current_data = MsgpackEncoding.decode(encoded) + + assert current_data == original_data + + @pytest.mark.smoke + def test_empty_collections(self): + test_cases = [[], {}] + + for empty_collection in test_cases: + encoded = MsgpackEncoding.encode(empty_collection) + decoded = MsgpackEncoding.decode(encoded) + assert decoded == empty_collection + assert isinstance(decoded, type(empty_collection)) + + @pytest.mark.smoke + def test_pydantic_constants(self): + """Test that the Pydantic-related constants are properly defined.""" + assert MsgpackEncoding.PYDANTIC_TAG == "__pydantic__" + assert MsgpackEncoding.PYDANTIC_DATA == "data" + assert MsgpackEncoding.PYDANTIC_ARGS == "args" + + @pytest.mark.sanity + def test_encode_invalid_data(self): + """Test encoding behavior with edge cases.""" + + class CustomClass: + def __init__(self, value): + self.value = value + + custom_obj = CustomClass(42) + primitive = MsgpackEncoding.to_primitive(custom_obj) + assert primitive is custom_obj diff --git a/tests/unit/utils/test_threading.py b/tests/unit/utils/test_threading.py new file mode 100644 index 00000000..887bf82c --- /dev/null +++ b/tests/unit/utils/test_threading.py @@ -0,0 +1,141 @@ +import asyncio +import threading +from collections.abc import Iterator + +import pytest + +from guidellm.utils.threading import synchronous_to_exitable_async + + +def _infinite_counter() -> Iterator[int]: + i = 0 + while True: + i += 1 + yield i + + +@pytest.mark.smoke +@pytest.mark.asyncio +async def test_callable_completed_returns_value(): + async def run(): + def add(a: int, b: int) -> int: + return a + b + + reason, value = await synchronous_to_exitable_async(add, None, None, 0.01, 2, 3) + return reason, value + + reason, value = await run() + assert reason == "completed" + assert value == 5 + + +@pytest.mark.smoke +@pytest.mark.asyncio +async def test_iterable_completed_returns_last_item(): + items = ["a", "b", "c"] + reason, value = await synchronous_to_exitable_async(items, None, None, 0.005) + assert reason == "completed" + assert value == "c" + + +@pytest.mark.smoke +@pytest.mark.asyncio +async def test_iterator_exits_on_custom_event(): + stop_event = threading.Event() + + async def trigger_event(): + await asyncio.sleep(0.02) + stop_event.set() + + task = asyncio.create_task( + synchronous_to_exitable_async( + _infinite_counter(), + exit_events={"stop": stop_event}, + exit_barrier=None, + poll_interval=0.005, + ) + ) + trigger = asyncio.create_task(trigger_event()) + reason, value = await task + await trigger + + assert reason == "stop" + assert isinstance(value, int) + + +@pytest.mark.smoke +@pytest.mark.asyncio +async def test_barrier_triggers_exit(): + barrier = threading.Barrier(2) + + waiter = threading.Thread(target=barrier.wait, daemon=True) + waiter.start() + + reason, _ = await synchronous_to_exitable_async( + _infinite_counter(), + exit_events=None, + exit_barrier=barrier, + poll_interval=0.005, + ) + + assert reason == "barrier" + + +@pytest.mark.sanity +@pytest.mark.asyncio +async def test_cancellation_sets_canceled_and_aborts_barrier(): + barrier = threading.Barrier(2) + + async def runner(): + return await synchronous_to_exitable_async( + _infinite_counter(), + exit_events=None, + exit_barrier=barrier, + poll_interval=0.01, + ) + + task = asyncio.create_task(runner()) + await asyncio.sleep(0.02) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + for _ in range(50): + if barrier.broken: + break + await asyncio.sleep(0.01) + assert barrier.broken is True + + +@pytest.mark.smoke +@pytest.mark.asyncio +async def test_callable_internal_error_propagates_in_tuple(): + def boom(): + raise ValueError("boom!") + + reason, err = await synchronous_to_exitable_async(boom, None, None, 0.001) + assert reason == "internal_error" + assert isinstance(err, ValueError) + assert str(err) == "boom!" + + +@pytest.mark.smoke +@pytest.mark.asyncio +async def test_poll_mode_only_exits_on_custom_event(): + stop_event = threading.Event() + + async def trigger(): + await asyncio.sleep(0.02) + stop_event.set() + + trigger_task = asyncio.create_task(trigger()) + reason, last = await synchronous_to_exitable_async( + None, + exit_events={"stop": stop_event}, + exit_barrier=None, + poll_interval=0.005, + ) + await trigger_task + + assert reason == "stop" + assert last is None