diff --git a/src/guidellm/backend/__init__.py b/src/guidellm/backend/__init__.py deleted file mode 100644 index 315a28f0..00000000 --- a/src/guidellm/backend/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from .backend import ( - Backend, - BackendType, -) -from .openai import CHAT_COMPLETIONS_PATH, TEXT_COMPLETIONS_PATH, OpenAIHTTPBackend -from .response import ( - RequestArgs, - ResponseSummary, - StreamingResponseType, - StreamingTextResponse, -) - -__all__ = [ - "CHAT_COMPLETIONS_PATH", - "TEXT_COMPLETIONS_PATH", - "Backend", - "BackendType", - "OpenAIHTTPBackend", - "RequestArgs", - "ResponseSummary", - "StreamingResponseType", - "StreamingTextResponse", -] diff --git a/src/guidellm/backend/backend.py b/src/guidellm/backend/backend.py deleted file mode 100644 index ceffdc77..00000000 --- a/src/guidellm/backend/backend.py +++ /dev/null @@ -1,259 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator -from pathlib import Path -from typing import Any, Literal, Optional, Union - -from loguru import logger -from PIL import Image - -from guidellm.backend.response import ResponseSummary, StreamingTextResponse -from guidellm.settings import settings - -__all__ = [ - "Backend", - "BackendType", -] - - -BackendType = Literal["openai_http"] - - -class Backend(ABC): - """ - Abstract base class for generative AI backends. - - This class provides a common interface for creating and interacting with different - generative AI backends. Subclasses should implement the abstract methods to - define specific backend behavior. - - :cvar _registry: A registration dictionary that maps BackendType to backend classes. - :param type_: The type of the backend. - """ - - _registry: dict[BackendType, "type[Backend]"] = {} - - @classmethod - def register(cls, backend_type: BackendType): - """ - A decorator to register a backend class in the backend registry. - - :param backend_type: The type of backend to register. - :type backend_type: BackendType - :return: The decorated backend class. - :rtype: Type[Backend] - """ - if backend_type in cls._registry: - raise ValueError(f"Backend type already registered: {backend_type}") - - if not issubclass(cls, Backend): - raise TypeError("Only subclasses of Backend can be registered") - - def inner_wrapper(wrapped_class: type["Backend"]): - cls._registry[backend_type] = wrapped_class - logger.info("Registered backend type: {}", backend_type) - return wrapped_class - - return inner_wrapper - - @classmethod - def create(cls, type_: BackendType, **kwargs) -> "Backend": - """ - Factory method to create a backend instance based on the backend type. - - :param type_: The type of backend to create. - :type type_: BackendType - :param kwargs: Additional arguments for backend initialization. - :return: An instance of a subclass of Backend. - :rtype: Backend - :raises ValueError: If the backend type is not registered. - """ - - logger.info("Creating backend of type {}", type_) - - if type_ not in cls._registry: - err = ValueError(f"Unsupported backend type: {type_}") - logger.error("{}", err) - raise err - - return Backend._registry[type_](**kwargs) - - def __init__(self, type_: BackendType): - self._type = type_ - - @property - def type_(self) -> BackendType: - """ - :return: The type of the backend. - """ - return self._type - - @property - @abstractmethod - def target(self) -> str: - """ - :return: The target location for the backend. - """ - ... - - @property - @abstractmethod - def model(self) -> Optional[str]: - """ - :return: The model used for the backend requests. - """ - ... - - @property - @abstractmethod - def info(self) -> dict[str, Any]: - """ - :return: The information about the backend. - """ - ... - - @abstractmethod - async def reset(self) -> None: - """ - Reset the connection object. This is useful for backends that - reuse connections or have state that needs to be cleared. - """ - ... - - async def validate(self): - """ - Handle final setup and validate the backend is ready for use. - If not successful, raises the appropriate exception. - """ - logger.info("{} validating backend {}", self.__class__.__name__, self.type_) - await self.check_setup() - models = await self.available_models() - if not models: - raise ValueError("No models available for the backend") - - # Use the preferred route defined in the global settings when performing the - # validation request. This avoids calling an unavailable endpoint (ie - # /v1/completions) when the deployment only supports the chat completions - # endpoint. - if settings.preferred_route == "chat_completions": - async for _ in self.chat_completions( # type: ignore[attr-defined] - content="Test connection", output_token_count=1 - ): - pass - else: - async for _ in self.text_completions( # type: ignore[attr-defined] - prompt="Test connection", output_token_count=1 - ): - pass - - await self.reset() - - @abstractmethod - async def check_setup(self): - """ - Check the setup for the backend. - If unsuccessful, raises the appropriate exception. - - :raises ValueError: If the setup check fails. - """ - ... - - @abstractmethod - async def prepare_multiprocessing(self): - """ - Prepare the backend for use in a multiprocessing environment. - This is useful for backends that have instance state that can not - be shared across processes and should be cleared out and re-initialized - for each new process. - """ - ... - - @abstractmethod - async def available_models(self) -> list[str]: - """ - Get the list of available models for the backend. - - :return: The list of available models. - :rtype: List[str] - """ - ... - - @abstractmethod - async def text_completions( - self, - prompt: Union[str, list[str]], - request_id: Optional[str] = None, - prompt_token_count: Optional[int] = None, - output_token_count: Optional[int] = None, - **kwargs, - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - """ - Generate text only completions for the given prompt. - Does not support multiple modalities, complicated chat interfaces, - or chat templates. Specifically, it requests with only the prompt. - - :param prompt: The prompt (or list of prompts) to generate a completion for. - If a list is supplied, these are concatenated and run through the model - for a single prompt. - :param request_id: The unique identifier for the request, if any. - Added to logging statements and the response for tracking purposes. - :param prompt_token_count: The number of tokens measured in the prompt, if any. - Returned in the response stats for later analysis, if applicable. - :param output_token_count: If supplied, the number of tokens to enforce - generation of for the output for this request. - :param kwargs: Additional keyword arguments to pass with the request. - :return: An async generator that yields a StreamingTextResponse for start, - a StreamingTextResponse for each received iteration, - and a ResponseSummary for the final response. - """ - ... - - @abstractmethod - async def chat_completions( - self, - content: Union[ - str, - list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], - Any, - ], - request_id: Optional[str] = None, - prompt_token_count: Optional[int] = None, - output_token_count: Optional[int] = None, - raw_content: bool = False, - **kwargs, - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - """ - Generate chat completions for the given content. - Supports multiple modalities, complicated chat interfaces, and chat templates. - Specifically, it requests with the content, which can be any combination of - text, images, and audio provided the target model supports it, - and returns the output text. Additionally, any chat templates - for the model are applied within the backend. - - :param content: The content (or list of content) to generate a completion for. - This supports any combination of text, images, and audio (model dependent). - Supported text only request examples: - content="Sample prompt", content=["Sample prompt", "Second prompt"], - content=[{"type": "text", "value": "Sample prompt"}. - Supported text and image request examples: - content=["Describe the image", PIL.Image.open("image.jpg")], - content=["Describe the image", Path("image.jpg")], - content=["Describe the image", {"type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}]. - Supported text and audio request examples: - content=["Transcribe the audio", Path("audio.wav")], - content=["Transcribe the audio", {"type": "input_audio", - "input_audio": {"data": f"{base64_bytes}", "format": "wav}]. - Additionally, if raw_content=True then the content is passed directly to the - backend without any processing. - :param request_id: The unique identifier for the request, if any. - Added to logging statements and the response for tracking purposes. - :param prompt_token_count: The number of tokens measured in the prompt, if any. - Returned in the response stats for later analysis, if applicable. - :param output_token_count: If supplied, the number of tokens to enforce - generation of for the output for this request. - :param kwargs: Additional keyword arguments to pass with the request. - :return: An async generator that yields a StreamingTextResponse for start, - a StreamingTextResponse for each received iteration, - and a ResponseSummary for the final response. - """ - ... diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py deleted file mode 100644 index e1fcdf89..00000000 --- a/src/guidellm/backend/openai.py +++ /dev/null @@ -1,705 +0,0 @@ -import base64 -import copy -import json -import time -from collections.abc import AsyncGenerator -from pathlib import Path -from typing import Any, Literal, Optional, Union - -import httpx -from loguru import logger -from PIL import Image - -from guidellm.backend.backend import Backend -from guidellm.backend.response import ( - RequestArgs, - ResponseSummary, - StreamingTextResponse, -) -from guidellm.settings import settings - -__all__ = [ - "CHAT_COMPLETIONS", - "CHAT_COMPLETIONS_PATH", - "MODELS", - "TEXT_COMPLETIONS", - "TEXT_COMPLETIONS_PATH", - "OpenAIHTTPBackend", -] - - -TEXT_COMPLETIONS_PATH = "/v1/completions" -CHAT_COMPLETIONS_PATH = "/v1/chat/completions" - -EndpointType = Literal["chat_completions", "models", "text_completions"] -CHAT_COMPLETIONS: EndpointType = "chat_completions" -MODELS: EndpointType = "models" -TEXT_COMPLETIONS: EndpointType = "text_completions" - - -@Backend.register("openai_http") -class OpenAIHTTPBackend(Backend): - """ - A HTTP-based backend implementation for requests to an OpenAI compatible server. - For example, a vLLM server instance or requests to OpenAI's API. - - :param target: The target URL string for the OpenAI server. ex: http://0.0.0.0:8000 - :param model: The model to use for all requests on the target server. - If none is provided, the first available model will be used. - :param api_key: The API key to use for requests to the OpenAI server. - If provided, adds an Authorization header with the value - "Authorization: Bearer {api_key}". - If not provided, no Authorization header is added. - :param organization: The organization to use for requests to the OpenAI server. - For example, if set to "org_123", adds an OpenAI-Organization header with the - value "OpenAI-Organization: org_123". - If not provided, no OpenAI-Organization header is added. - :param project: The project to use for requests to the OpenAI server. - For example, if set to "project_123", adds an OpenAI-Project header with the - value "OpenAI-Project: project_123". - If not provided, no OpenAI-Project header is added. - :param timeout: The timeout to use for requests to the OpenAI server. - If not provided, the default timeout provided from settings is used. - :param http2: If True, uses HTTP/2 for requests to the OpenAI server. - Defaults to True. - :param follow_redirects: If True, the HTTP client will follow redirect responses. - If not provided, the default value from settings is used. - :param max_output_tokens: The maximum number of tokens to request for completions. - If not provided, the default maximum tokens provided from settings is used. - :param extra_query: Query parameters to include in requests to the OpenAI server. - If "chat_completions", "models", or "text_completions" are included as keys, - the values of these keys will be used as the parameters for the respective - endpoint. - If not provided, no extra query parameters are added. - :param extra_body: Body parameters to include in requests to the OpenAI server. - If "chat_completions", "models", or "text_completions" are included as keys, - the values of these keys will be included in the body for the respective - endpoint. - If not provided, no extra body parameters are added. - :param remove_from_body: Parameters that should be removed from the body of each - request. - If not provided, no parameters are removed from the body. - """ - - def __init__( - self, - target: Optional[str] = None, - model: Optional[str] = None, - api_key: Optional[str] = None, - organization: Optional[str] = None, - project: Optional[str] = None, - timeout: Optional[float] = None, - http2: Optional[bool] = True, - follow_redirects: Optional[bool] = None, - max_output_tokens: Optional[int] = None, - extra_query: Optional[dict] = None, - extra_body: Optional[dict] = None, - remove_from_body: Optional[list[str]] = None, - headers: Optional[dict] = None, - verify: Optional[bool] = None, - ): - super().__init__(type_="openai_http") - self._target = target or settings.openai.base_url - - if not self._target: - raise ValueError("Target URL must be provided for OpenAI HTTP backend.") - - if self._target.endswith("/v1") or self._target.endswith("/v1/"): - # backwards compatability, strip v1 off - self._target = self._target[:-3] - - if self._target.endswith("/"): - self._target = self._target[:-1] - - self._model = model - - # Start with default headers based on other params - default_headers: dict[str, str] = {} - api_key = api_key or settings.openai.api_key - bearer_token = settings.openai.bearer_token - if api_key: - default_headers["Authorization"] = f"Bearer {api_key}" - elif bearer_token: - default_headers["Authorization"] = bearer_token - - self.organization = organization or settings.openai.organization - if self.organization: - default_headers["OpenAI-Organization"] = self.organization - - self.project = project or settings.openai.project - if self.project: - default_headers["OpenAI-Project"] = self.project - - # User-provided headers from kwargs or settings override defaults - merged_headers = default_headers.copy() - merged_headers.update(settings.openai.headers or {}) - if headers: - merged_headers.update(headers) - - # Remove headers with None values for backward compatibility and convenience - self.headers = {k: v for k, v in merged_headers.items() if v is not None} - - self.timeout = timeout if timeout is not None else settings.request_timeout - self.http2 = http2 if http2 is not None else settings.request_http2 - self.follow_redirects = ( - follow_redirects - if follow_redirects is not None - else settings.request_follow_redirects - ) - self.verify = verify if verify is not None else settings.openai.verify - self.max_output_tokens = ( - max_output_tokens - if max_output_tokens is not None - else settings.openai.max_output_tokens - ) - self.extra_query = extra_query - self.extra_body = extra_body - self.remove_from_body = remove_from_body - self._async_client: Optional[httpx.AsyncClient] = None - - @property - def target(self) -> str: - """ - :return: The target URL string for the OpenAI server. - """ - return self._target - - @property - def model(self) -> Optional[str]: - """ - :return: The model to use for all requests on the target server. - If validate hasn't been called yet and no model was passed in, - this will be None until validate is called to set the default. - """ - return self._model - - @property - def info(self) -> dict[str, Any]: - """ - :return: The information about the backend. - """ - return { - "max_output_tokens": self.max_output_tokens, - "timeout": self.timeout, - "http2": self.http2, - "follow_redirects": self.follow_redirects, - "headers": self.headers, - "text_completions_path": TEXT_COMPLETIONS_PATH, - "chat_completions_path": CHAT_COMPLETIONS_PATH, - } - - async def reset(self) -> None: - """ - Reset the connection object. This is useful for backends that - reuse connections or have state that needs to be cleared. - For this backend, it closes the async client if it exists. - """ - if self._async_client is not None: - await self._async_client.aclose() - - async def check_setup(self): - """ - Check if the backend is setup correctly and can be used for requests. - Specifically, if a model is not provided, it grabs the first available model. - If no models are available, raises a ValueError. - If a model is provided and not available, raises a ValueError. - - :raises ValueError: If no models or the provided model is not available. - """ - models = await self.available_models() - if not models: - raise ValueError(f"No models available for target: {self.target}") - - if not self.model: - self._model = models[0] - elif self.model not in models: - raise ValueError( - f"Model {self.model} not found in available models:" - f"{models} for target: {self.target}" - ) - - async def prepare_multiprocessing(self): - """ - Prepare the backend for use in a multiprocessing environment. - Clears out the sync and async clients to ensure they are re-initialized - for each process. - """ - if self._async_client is not None: - await self._async_client.aclose() - self._async_client = None - - async def available_models(self) -> list[str]: - """ - Get the available models for the target server using the OpenAI models endpoint: - /v1/models - """ - target = f"{self.target}/v1/models" - headers = self._headers() - params = self._params(MODELS) - response = await self._get_async_client().get( - target, headers=headers, params=params - ) - response.raise_for_status() - - models = [] - - for item in response.json()["data"]: - models.append(item["id"]) - - return models - - async def text_completions( # type: ignore[override] - self, - prompt: Union[str, list[str]], - request_id: Optional[str] = None, - prompt_token_count: Optional[int] = None, - output_token_count: Optional[int] = None, - **kwargs, - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - """ - Generate text completions for the given prompt using the OpenAI - completions endpoint: /v1/completions. - - :param prompt: The prompt (or list of prompts) to generate a completion for. - If a list is supplied, these are concatenated and run through the model - for a single prompt. - :param request_id: The unique identifier for the request, if any. - Added to logging statements and the response for tracking purposes. - :param prompt_token_count: The number of tokens measured in the prompt, if any. - Returned in the response stats for later analysis, if applicable. - :param output_token_count: If supplied, the number of tokens to enforce - generation of for the output for this request. - :param kwargs: Additional keyword arguments to pass with the request. - :return: An async generator that yields a StreamingTextResponse for start, - a StreamingTextResponse for each received iteration, - and a ResponseSummary for the final response. - """ - logger.debug("{} invocation with args: {}", self.__class__.__name__, locals()) - - if isinstance(prompt, list): - raise ValueError( - "List prompts (batching) is currently not supported for " - f"text_completions OpenAI pathways. Received: {prompt}" - ) - - headers = self._headers() - params = self._params(TEXT_COMPLETIONS) - payload = self._completions_payload( - endpoint_type=TEXT_COMPLETIONS, - orig_kwargs=kwargs, - max_output_tokens=output_token_count, - prompt=prompt, - ) - - try: - async for resp in self._iterative_completions_request( - type_="text_completions", - request_id=request_id, - request_prompt_tokens=prompt_token_count, - request_output_tokens=output_token_count, - headers=headers, - params=params, - payload=payload, - ): - yield resp - except Exception as ex: - logger.error( - "{} request with headers: {} and params: {} and payload: {} failed: {}", - self.__class__.__name__, - headers, - params, - payload, - ex, - ) - raise ex - - async def chat_completions( # type: ignore[override] - self, - content: Union[ - str, - list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], - Any, - ], - request_id: Optional[str] = None, - prompt_token_count: Optional[int] = None, - output_token_count: Optional[int] = None, - raw_content: bool = False, - **kwargs, - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - """ - Generate chat completions for the given content using the OpenAI - chat completions endpoint: /v1/chat/completions. - - :param content: The content (or list of content) to generate a completion for. - This supports any combination of text, images, and audio (model dependent). - Supported text only request examples: - content="Sample prompt", content=["Sample prompt", "Second prompt"], - content=[{"type": "text", "value": "Sample prompt"}. - Supported text and image request examples: - content=["Describe the image", PIL.Image.open("image.jpg")], - content=["Describe the image", Path("image.jpg")], - content=["Describe the image", {"type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}]. - Supported text and audio request examples: - content=["Transcribe the audio", Path("audio.wav")], - content=["Transcribe the audio", {"type": "input_audio", - "input_audio": {"data": f"{base64_bytes}", "format": "wav}]. - Additionally, if raw_content=True then the content is passed directly to the - backend without any processing. - :param request_id: The unique identifier for the request, if any. - Added to logging statements and the response for tracking purposes. - :param prompt_token_count: The number of tokens measured in the prompt, if any. - Returned in the response stats for later analysis, if applicable. - :param output_token_count: If supplied, the number of tokens to enforce - generation of for the output for this request. - :param kwargs: Additional keyword arguments to pass with the request. - :return: An async generator that yields a StreamingTextResponse for start, - a StreamingTextResponse for each received iteration, - and a ResponseSummary for the final response. - """ - logger.debug("{} invocation with args: {}", self.__class__.__name__, locals()) - headers = self._headers() - params = self._params(CHAT_COMPLETIONS) - messages = ( - content if raw_content else self._create_chat_messages(content=content) - ) - payload = self._completions_payload( - endpoint_type=CHAT_COMPLETIONS, - orig_kwargs=kwargs, - max_output_tokens=output_token_count, - messages=messages, - ) - - try: - async for resp in self._iterative_completions_request( - type_="chat_completions", - request_id=request_id, - request_prompt_tokens=prompt_token_count, - request_output_tokens=output_token_count, - headers=headers, - params=params, - payload=payload, - ): - yield resp - except Exception as ex: - logger.error( - "{} request with headers: {} and params: {} and payload: {} failed: {}", - self.__class__.__name__, - headers, - params, - payload, - ex, - ) - raise ex - - def _get_async_client(self) -> httpx.AsyncClient: - """ - Get the async HTTP client for making requests. - If the client has not been created yet, it will create one. - - :return: The async HTTP client. - """ - if self._async_client is None or self._async_client.is_closed: - client = httpx.AsyncClient( - http2=self.http2, - timeout=self.timeout, - follow_redirects=self.follow_redirects, - verify=self.verify, - ) - self._async_client = client - else: - client = self._async_client - - return client - - def _headers(self) -> dict[str, str]: - headers = { - "Content-Type": "application/json", - } - headers.update(self.headers) - return headers - - def _params(self, endpoint_type: EndpointType) -> dict[str, str]: - if self.extra_query is None: - return {} - - if ( - CHAT_COMPLETIONS in self.extra_query - or MODELS in self.extra_query - or TEXT_COMPLETIONS in self.extra_query - ): - return self.extra_query.get(endpoint_type, {}) - - return self.extra_query - - def _extra_body(self, endpoint_type: EndpointType) -> dict[str, Any]: - if self.extra_body is None: - return {} - - if ( - CHAT_COMPLETIONS in self.extra_body - or MODELS in self.extra_body - or TEXT_COMPLETIONS in self.extra_body - ): - return copy.deepcopy(self.extra_body.get(endpoint_type, {})) - - return copy.deepcopy(self.extra_body) - - def _completions_payload( - self, - endpoint_type: EndpointType, - orig_kwargs: Optional[dict], - max_output_tokens: Optional[int], - **kwargs, - ) -> dict: - payload = self._extra_body(endpoint_type) - payload.update(orig_kwargs or {}) - payload.update(kwargs) - payload["model"] = self.model - payload["stream"] = True - payload["stream_options"] = { - "include_usage": True, - } - - if max_output_tokens or self.max_output_tokens: - logger.debug( - "{} adding payload args for setting output_token_count: {}", - self.__class__.__name__, - max_output_tokens or self.max_output_tokens, - ) - payload["max_tokens"] = max_output_tokens or self.max_output_tokens - payload["max_completion_tokens"] = payload["max_tokens"] - - if max_output_tokens: - # only set stop and ignore_eos if max_output_tokens set at request level - # otherwise the instance value is just the max to enforce we stay below - payload["stop"] = None - payload["ignore_eos"] = True - - if self.remove_from_body: - for key in self.remove_from_body: - payload.pop(key, None) - - return payload - - @staticmethod - def _create_chat_messages( - content: Union[ - str, - list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], - Any, - ], - ) -> list[dict]: - if isinstance(content, str): - return [ - { - "role": "user", - "content": content, - } - ] - - if isinstance(content, list): - resolved_content = [] - - for item in content: - if isinstance(item, dict): - resolved_content.append(item) - elif isinstance(item, str): - resolved_content.append({"type": "text", "text": item}) - elif isinstance(item, Image.Image) or ( - isinstance(item, Path) and item.suffix.lower() in [".jpg", ".jpeg"] - ): - image = item if isinstance(item, Image.Image) else Image.open(item) - encoded = base64.b64encode(image.tobytes()).decode("utf-8") - resolved_content.append( - { - "type": "image", - "image": { - "url": f"data:image/jpeg;base64,{encoded}", - }, - } - ) - elif isinstance(item, Path) and item.suffix.lower() in [".wav"]: - encoded = base64.b64encode(item.read_bytes()).decode("utf-8") - resolved_content.append( - { - "type": "input_audio", - "input_audio": { - "data": f"{encoded}", - "format": "wav", - }, - } - ) - else: - raise ValueError( - f"Unsupported content item type: {item} in list: {content}" - ) - - return [ - { - "role": "user", - "content": resolved_content, - } - ] - - raise ValueError(f"Unsupported content type: {content}") - - async def _iterative_completions_request( - self, - type_: Literal["text_completions", "chat_completions"], - request_id: Optional[str], - request_prompt_tokens: Optional[int], - request_output_tokens: Optional[int], - headers: dict[str, str], - params: dict[str, str], - payload: dict[str, Any], - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - if type_ == "text_completions": - target = f"{self.target}{TEXT_COMPLETIONS_PATH}" - elif type_ == "chat_completions": - target = f"{self.target}{CHAT_COMPLETIONS_PATH}" - else: - raise ValueError(f"Unsupported type: {type_}") - - logger.info( - "{} making request: {} to target: {} using http2: {} following " - "redirects: {} for timeout: {} with headers: {} and params: {} and ", - "payload: {}", - self.__class__.__name__, - request_id, - target, - self.http2, - self.follow_redirects, - self.timeout, - headers, - params, - payload, - ) - - response_value = "" - response_prompt_count: Optional[int] = None - response_output_count: Optional[int] = None - iter_count = 0 - start_time = time.time() - iter_time = start_time - first_iter_time: Optional[float] = None - last_iter_time: Optional[float] = None - - yield StreamingTextResponse( - type_="start", - value="", - start_time=start_time, - first_iter_time=None, - iter_count=iter_count, - delta="", - time=start_time, - request_id=request_id, - ) - - # reset start time after yielding start response to ensure accurate timing - start_time = time.time() - - async with self._get_async_client().stream( - "POST", target, headers=headers, params=params, json=payload - ) as stream: - stream.raise_for_status() - - async for line in stream.aiter_lines(): - iter_time = time.time() - logger.debug( - "{} request: {} recieved iter response line: {}", - self.__class__.__name__, - request_id, - line, - ) - - if not line or not line.strip().startswith("data:"): - continue - - if line.strip() == "data: [DONE]": - break - - data = json.loads(line.strip()[len("data: ") :]) - if delta := self._extract_completions_delta_content(type_, data): - if first_iter_time is None: - first_iter_time = iter_time - last_iter_time = iter_time - - iter_count += 1 - response_value += delta - - yield StreamingTextResponse( - type_="iter", - value=response_value, - iter_count=iter_count, - start_time=start_time, - first_iter_time=first_iter_time, - delta=delta, - time=iter_time, - request_id=request_id, - ) - - if usage := self._extract_completions_usage(data): - response_prompt_count = usage["prompt"] - response_output_count = usage["output"] - - logger.info( - "{} request: {} with headers: {} and params: {} and payload: {} completed" - "with: {}", - self.__class__.__name__, - request_id, - headers, - params, - payload, - response_value, - ) - - yield ResponseSummary( - value=response_value, - request_args=RequestArgs( - target=target, - headers=headers, - params=params, - payload=payload, - timeout=self.timeout, - http2=self.http2, - follow_redirects=self.follow_redirects, - ), - start_time=start_time, - end_time=iter_time, - first_iter_time=first_iter_time, - last_iter_time=last_iter_time, - iterations=iter_count, - request_prompt_tokens=request_prompt_tokens, - request_output_tokens=request_output_tokens, - response_prompt_tokens=response_prompt_count, - response_output_tokens=response_output_count, - request_id=request_id, - ) - - @staticmethod - def _extract_completions_delta_content( - type_: Literal["text_completions", "chat_completions"], data: dict - ) -> Optional[str]: - if "choices" not in data or not data["choices"]: - return None - - if type_ == "text_completions": - return data["choices"][0]["text"] - - if type_ == "chat_completions": - return data.get("choices", [{}])[0].get("delta", {}).get("content") - - raise ValueError(f"Unsupported type: {type_}") - - @staticmethod - def _extract_completions_usage( - data: dict, - ) -> Optional[dict[Literal["prompt", "output"], int]]: - if "usage" not in data or not data["usage"]: - return None - - return { - "prompt": data["usage"]["prompt_tokens"], - "output": data["usage"]["completion_tokens"], - } diff --git a/src/guidellm/backend/response.py b/src/guidellm/backend/response.py deleted file mode 100644 index f2272a73..00000000 --- a/src/guidellm/backend/response.py +++ /dev/null @@ -1,136 +0,0 @@ -from typing import Any, Literal, Optional - -from pydantic import computed_field - -from guidellm.objects.pydantic import StandardBaseModel -from guidellm.settings import settings - -__all__ = [ - "RequestArgs", - "ResponseSummary", - "StreamingResponseType", - "StreamingTextResponse", -] - - -StreamingResponseType = Literal["start", "iter"] - - -class StreamingTextResponse(StandardBaseModel): - """ - A model representing the response content for a streaming text request. - - :param type_: The type of the response; either 'start' or 'iter'. - :param value: The value of the response up to this iteration. - :param start_time: The time.time() the request started. - :param iter_count: The iteration count for the response. For 'start' this is 0 - and for the first 'iter' it is 1. - :param delta: The text delta added to the response for this stream iteration. - :param time: If 'start', the time.time() the request started. - If 'iter', the time.time() the iteration was received. - :param request_id: The unique identifier for the request, if any. - """ - - type_: StreamingResponseType - value: str - start_time: float - first_iter_time: Optional[float] - iter_count: int - delta: str - time: float - request_id: Optional[str] = None - - -class RequestArgs(StandardBaseModel): - """ - A model representing the arguments for a request to a backend. - Biases towards an HTTP request, but can be used for other types of backends. - - :param target: The target URL or function for the request. - :param headers: The headers, if any, included in the request such as authorization. - :param params: The query parameters, if any, included in the request. - :param payload: The payload / arguments for the request including the prompt / - content and other configurations. - :param timeout: The timeout for the request in seconds, if any. - :param http2: Whether HTTP/2 was used for the request, if applicable. - :param follow_redirects: Whether the request should follow redirect responses. - """ - - target: str - headers: dict[str, str] - params: dict[str, str] - payload: dict[str, Any] - timeout: Optional[float] = None - http2: Optional[bool] = None - follow_redirects: Optional[bool] = None - - -class ResponseSummary(StandardBaseModel): - """ - A model representing a summary of a backend request. - Always returned as the final iteration of a streaming request. - - :param value: The final value returned from the request. - :param request_args: The arguments used to make the request. - :param iterations: The number of iterations in the request. - :param start_time: The time the request started. - :param end_time: The time the request ended. - :param first_iter_time: The time the first iteration was received. - :param last_iter_time: The time the last iteration was received. - :param request_prompt_tokens: The number of tokens measured in the prompt - for the request, if any. - :param request_output_tokens: The number of tokens enforced for the output - for the request, if any. - :param response_prompt_tokens: The number of tokens measured in the prompt - for the response, if any. - :param response_output_tokens: The number of tokens measured in the output - for the response, if any. - :param request_id: The unique identifier for the request, if any. - :param error: The error message, if any, returned from making the request. - """ - - value: str - request_args: RequestArgs - iterations: int = 0 - start_time: float - end_time: float - first_iter_time: Optional[float] - last_iter_time: Optional[float] - request_prompt_tokens: Optional[int] = None - request_output_tokens: Optional[int] = None - response_prompt_tokens: Optional[int] = None - response_output_tokens: Optional[int] = None - request_id: Optional[str] = None - error: Optional[str] = None - - @computed_field # type: ignore[misc] - @property - def prompt_tokens(self) -> Optional[int]: - """ - The number of tokens measured in the prompt based on preferences - for trusting the input or response. - - :return: The number of tokens in the prompt, if any. - """ - if settings.preferred_prompt_tokens_source == "request": - return self.request_prompt_tokens or self.response_prompt_tokens - - return self.response_prompt_tokens or self.request_prompt_tokens - - @computed_field # type: ignore[misc] - @property - def output_tokens(self) -> Optional[int]: - """ - The number of tokens measured in the output based on preferences - for trusting the input or response. - - :return: The number of tokens in the output, if any. - """ - if self.error is not None: - # error occurred, can't trust request tokens were all generated - return self.response_prompt_tokens - - if settings.preferred_output_tokens_source == "request": - return self.request_output_tokens or self.response_output_tokens - - return self.response_output_tokens or self.request_output_tokens diff --git a/src/guidellm/backends/__init__.py b/src/guidellm/backends/__init__.py new file mode 100644 index 00000000..064722ac --- /dev/null +++ b/src/guidellm/backends/__init__.py @@ -0,0 +1,26 @@ +""" +Backend infrastructure for GuideLLM language model interactions. + +Provides abstract base classes, implemented backends, request/response objects, +and timing utilities for standardized communication with LLM providers. +""" + +from .backend import ( + Backend, + BackendType, +) +from .objects import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, +) +from .openai import OpenAIHTTPBackend + +__all__ = [ + "Backend", + "BackendType", + "GenerationRequest", + "GenerationRequestTimings", + "GenerationResponse", + "OpenAIHTTPBackend", +] diff --git a/src/guidellm/backends/backend.py b/src/guidellm/backends/backend.py new file mode 100644 index 00000000..8f91d5e7 --- /dev/null +++ b/src/guidellm/backends/backend.py @@ -0,0 +1,119 @@ +""" +Backend interface and registry for generative AI model interactions. + +Provides the abstract base class for implementing backends that communicate with +generative AI models. Backends handle the lifecycle of generation requests. + +Classes: + Backend: Abstract base class for generative AI backends with registry support. + +Type Aliases: + BackendType: Literal type defining supported backend implementations. +""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import Literal + +from guidellm.backends.objects import ( + GenerationRequest, + GenerationResponse, +) +from guidellm.scheduler import BackendInterface +from guidellm.utils import RegistryMixin + +__all__ = [ + "Backend", + "BackendType", +] + + +BackendType = Literal["openai_http"] + + +class Backend( + RegistryMixin["type[Backend]"], + BackendInterface[GenerationRequest, GenerationResponse], +): + """ + Base class for generative AI backends with registry and lifecycle. + + Provides a standard interface for backends that communicate with generative AI + models. Combines the registry pattern for automatic discovery with a defined + lifecycle for process-based distributed execution. + + Backend lifecycle phases: + 1. Creation and configuration + 2. Process startup - Initialize resources in worker process + 3. Validation - Verify backend readiness + 4. Request resolution - Process generation requests + 5. Process shutdown - Clean up resources + + Backend state (excluding process_startup resources) must be pickleable for + distributed execution across process boundaries. + + Example: + :: + @Backend.register("my_backend") + class MyBackend(Backend): + def __init__(self, api_key: str): + super().__init__("my_backend") + self.api_key = api_key + + async def process_startup(self): + self.client = MyAPIClient(self.api_key) + + backend = Backend.create("my_backend", api_key="secret") + """ + + @classmethod + def create(cls, type_: BackendType, **kwargs) -> Backend: + """ + Create a backend instance based on the backend type. + + :param type_: The type of backend to create. + :param kwargs: Additional arguments for backend initialization. + :return: An instance of a subclass of Backend. + :raises ValueError: If the backend type is not registered. + """ + + backend = cls.get_registered_object(type_) + + if backend is None: + raise ValueError( + f"Backend type '{type_}' is not registered. " + f"Available types: {list(cls.registry.keys()) if cls.registry else []}" + ) + + return backend(**kwargs) + + def __init__(self, type_: BackendType): + """ + Initialize a backend instance. + + :param type_: The backend type identifier. + """ + self.type_ = type_ + + @property + def processes_limit(self) -> int | None: + """ + :return: Maximum number of worker processes supported. None if unlimited. + """ + return None + + @property + def requests_limit(self) -> int | None: + """ + :return: Maximum number of concurrent requests supported globally. + None if unlimited. + """ + return None + + @abstractmethod + async def default_model(self) -> str | None: + """ + :return: The default model name or identifier for generation requests. + """ + ... diff --git a/src/guidellm/backends/objects.py b/src/guidellm/backends/objects.py new file mode 100644 index 00000000..05280940 --- /dev/null +++ b/src/guidellm/backends/objects.py @@ -0,0 +1,156 @@ +""" +Backend object models for request and response handling. + +Provides standardized models for generation requests, responses, and timing +information to ensure consistent data handling across different backend +implementations. +""" + +import uuid +from typing import Any, Literal, Optional + +from pydantic import Field + +from guidellm.scheduler import ( + MeasuredRequestTimings, + SchedulerMessagingPydanticRegistry, +) +from guidellm.utils import StandardBaseModel + +__all__ = [ + "GenerationRequest", + "GenerationRequestTimings", + "GenerationResponse", +] + + +@SchedulerMessagingPydanticRegistry.register() +class GenerationRequest(StandardBaseModel): + """Request model for backend generation operations.""" + + request_id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique identifier for the request.", + ) + request_type: Literal["text_completions", "chat_completions"] = Field( + default="text_completions", + description=( + "Type of request. 'text_completions' uses backend.text_completions(), " + "'chat_completions' uses backend.chat_completions()." + ), + ) + content: Any = Field( + description=( + "Request content. For text_completions: string or list of strings. " + "For chat_completions: string, list of messages, or raw content " + "(set raw_content=True in params)." + ) + ) + params: dict[str, Any] = Field( + default_factory=dict, + description=( + "Additional parameters passed to backend methods. " + "Common: max_tokens, temperature, stream." + ), + ) + stats: dict[Literal["prompt_tokens"], int] = Field( + default_factory=dict, + description="Request statistics including prompt token count.", + ) + constraints: dict[Literal["output_tokens"], int] = Field( + default_factory=dict, + description="Request constraints such as maximum output tokens.", + ) + + +@SchedulerMessagingPydanticRegistry.register() +class GenerationResponse(StandardBaseModel): + """Response model for backend generation operations.""" + + request_id: str = Field( + description="Unique identifier matching the original GenerationRequest." + ) + request_args: dict[str, Any] = Field( + description="Arguments passed to the backend for this request." + ) + value: Optional[str] = Field( + default=None, + description="Complete generated text content. None for streaming responses.", + ) + delta: Optional[str] = Field( + default=None, description="Incremental text content for streaming responses." + ) + iterations: int = Field( + default=0, description="Number of generation iterations completed." + ) + request_prompt_tokens: Optional[int] = Field( + default=None, description="Token count from the original request prompt." + ) + request_output_tokens: Optional[int] = Field( + default=None, + description="Expected output token count from the original request.", + ) + response_prompt_tokens: Optional[int] = Field( + default=None, description="Actual prompt token count reported by the backend." + ) + response_output_tokens: Optional[int] = Field( + default=None, description="Actual output token count reported by the backend." + ) + + @property + def prompt_tokens(self) -> Optional[int]: + """ + :return: The number of prompt tokens used in the request + (response_prompt_tokens if available, otherwise request_prompt_tokens). + """ + return self.response_prompt_tokens or self.request_prompt_tokens + + @property + def output_tokens(self) -> Optional[int]: + """ + :return: The number of output tokens generated in the response + (response_output_tokens if available, otherwise request_output_tokens). + """ + return self.response_output_tokens or self.request_output_tokens + + @property + def total_tokens(self) -> Optional[int]: + """ + :return: The total number of tokens used in the request and response. + Sum of prompt_tokens and output_tokens. + """ + if self.prompt_tokens is None or self.output_tokens is None: + return None + return self.prompt_tokens + self.output_tokens + + def preferred_prompt_tokens( + self, preferred_source: Literal["request", "response"] + ) -> Optional[int]: + if preferred_source == "request": + return self.request_prompt_tokens or self.response_prompt_tokens + else: + return self.response_prompt_tokens or self.request_prompt_tokens + + def preferred_output_tokens( + self, preferred_source: Literal["request", "response"] + ) -> Optional[int]: + if preferred_source == "request": + return self.request_output_tokens or self.response_output_tokens + else: + return self.response_output_tokens or self.request_output_tokens + + +@SchedulerMessagingPydanticRegistry.register() +@MeasuredRequestTimings.register("generation_request_timings") +class GenerationRequestTimings(MeasuredRequestTimings): + """Timing model for tracking generation request lifecycle events.""" + + timings_type: Literal["generation_request_timings"] = "generation_request_timings" + first_iteration: Optional[float] = Field( + default=None, + description="Unix timestamp when the first generation iteration began.", + ) + last_iteration: Optional[float] = Field( + default=None, + description="Unix timestamp when the last generation iteration completed.", + ) diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py new file mode 100644 index 00000000..ce83076f --- /dev/null +++ b/src/guidellm/backends/openai.py @@ -0,0 +1,649 @@ +""" +OpenAI HTTP backend implementation for GuideLLM. + +Provides HTTP-based backend for OpenAI-compatible servers including OpenAI API, +vLLM servers, and other compatible inference engines. Supports text and chat +completions with streaming, authentication, and multimodal capabilities. + +Classes: + UsageStats: Token usage statistics for generation requests. + OpenAIHTTPBackend: HTTP backend for OpenAI-compatible API servers. +""" + +import base64 +import contextlib +import copy +import json +import time +from collections.abc import AsyncIterator +from pathlib import Path +from typing import Any, ClassVar, Optional, Union + +import httpx +from PIL import Image +from pydantic import dataclasses + +from guidellm.backends.backend import Backend +from guidellm.backends.objects import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, +) +from guidellm.scheduler import ScheduledRequestInfo + +__all__ = ["OpenAIHTTPBackend", "UsageStats"] + + +@dataclasses.dataclass +class UsageStats: + """Token usage statistics for generation requests.""" + + prompt_tokens: Optional[int] = None + output_tokens: Optional[int] = None + + +@Backend.register("openai_http") +class OpenAIHTTPBackend(Backend): + """ + HTTP backend for OpenAI-compatible servers. + + Supports OpenAI API, vLLM servers, and other compatible endpoints with + text/chat completions, streaming, authentication, and multimodal inputs. + Handles request formatting, response parsing, error handling, and token + usage tracking with flexible parameter customization. + + Example: + :: + backend = OpenAIHTTPBackend( + target="http://localhost:8000", + model="gpt-3.5-turbo", + api_key="your-api-key" + ) + + await backend.process_startup() + async for response, request_info in backend.resolve(request, info): + process_response(response) + await backend.process_shutdown() + """ + + HEALTH_PATH: ClassVar[str] = "/health" + MODELS_PATH: ClassVar[str] = "/v1/models" + TEXT_COMPLETIONS_PATH: ClassVar[str] = "/v1/completions" + CHAT_COMPLETIONS_PATH: ClassVar[str] = "/v1/chat/completions" + + MODELS_KEY: ClassVar[str] = "models" + TEXT_COMPLETIONS_KEY: ClassVar[str] = "text_completions" + CHAT_COMPLETIONS_KEY: ClassVar[str] = "chat_completions" + + def __init__( + self, + target: str, + model: Optional[str] = None, + api_key: Optional[str] = None, + organization: Optional[str] = None, + project: Optional[str] = None, + timeout: float = 60.0, + http2: bool = True, + follow_redirects: bool = True, + max_output_tokens: Optional[int] = None, + stream_response: bool = True, + extra_query: Optional[dict] = None, + extra_body: Optional[dict] = None, + remove_from_body: Optional[list[str]] = None, + headers: Optional[dict] = None, + verify: bool = False, + ): + """ + Initialize OpenAI HTTP backend. + + :param target: Target URL for the OpenAI server (e.g., "http://localhost:8000"). + :param model: Model to use for requests. If None, uses first available model. + :param api_key: API key for authentication. Adds Authorization header + if provided. + :param organization: Organization ID. Adds OpenAI-Organization header + if provided. + :param project: Project ID. Adds OpenAI-Project header if provided. + :param timeout: Request timeout in seconds. Defaults to 60 seconds. + :param http2: Whether to use HTTP/2. Defaults to True. + :param follow_redirects: Whether to follow redirects. Default True. + :param max_output_tokens: Maximum tokens for completions. If None, none is set. + :param stream_response: Whether to stream responses by default. Can be + overridden per request. Defaults to True. + :param extra_query: Additional query parameters. Both general and + endpoint-specific with type keys supported. + :param extra_body: Additional body parameters. Both general and + endpoint-specific with type keys supported. + :param remove_from_body: Parameter names to remove from request bodies. + :param headers: Additional HTTP headers. + :param verify: Whether to verify SSL certificates. Default False. + """ + super().__init__(type_="openai_http") + + # Request Values + self.target = target.rstrip("/").removesuffix("/v1") + self.model = model + self.headers = self._build_headers(api_key, organization, project, headers) + + # Store configuration + self.timeout = timeout + self.http2 = http2 + self.follow_redirects = follow_redirects + self.verify = verify + self.max_output_tokens = max_output_tokens + self.stream_response = stream_response + self.extra_query = extra_query or {} + self.extra_body = extra_body or {} + self.remove_from_body = remove_from_body or [] + + # Runtime state + self._in_process = False + self._async_client: Optional[httpx.AsyncClient] = None + + @property + def info(self) -> dict[str, Any]: + """ + :return: Dictionary containing backend configuration details. + """ + return { + "target": self.target, + "model": self.model, + "headers": self.headers, + "timeout": self.timeout, + "http2": self.http2, + "follow_redirects": self.follow_redirects, + "verify": self.verify, + "max_output_tokens": self.max_output_tokens, + "stream_response": self.stream_response, + "extra_query": self.extra_query, + "extra_body": self.extra_body, + "remove_from_body": self.remove_from_body, + "health_path": self.HEALTH_PATH, + "models_path": self.MODELS_PATH, + "text_completions_path": self.TEXT_COMPLETIONS_PATH, + "chat_completions_path": self.CHAT_COMPLETIONS_PATH, + } + + async def process_startup(self): + """ + Initialize HTTP client and backend resources. + + :raises RuntimeError: If backend is already initialized. + :raises httpx.Exception: If HTTP client cannot be created. + """ + if self._in_process: + raise RuntimeError("Backend already started up for process.") + + self._async_client = httpx.AsyncClient( + http2=self.http2, + timeout=self.timeout, + follow_redirects=self.follow_redirects, + verify=self.verify, + ) + self._in_process = True + + async def process_shutdown(self): + """ + Clean up HTTP client and backend resources. + + :raises RuntimeError: If backend was not properly initialized. + :raises httpx.Exception: If HTTP client cannot be closed. + """ + if not self._in_process: + raise RuntimeError("Backend not started up for process.") + + await self._async_client.aclose() # type: ignore [union-attr] + self._async_client = None + self._in_process = False + + async def validate(self): + """ + Validate backend configuration and connectivity. + + Validate backend configuration and connectivity through test requests, + and auto-selects first available model if none is configured. + + :raises RuntimeError: If backend cannot connect or validate configuration. + """ + self._check_in_process() + + if self.model: + with contextlib.suppress(httpx.TimeoutException, httpx.HTTPStatusError): + # Model is set, use /health endpoint as first check + target = f"{self.target}{self.HEALTH_PATH}" + headers = self._get_headers() + response = await self._async_client.get(target, headers=headers) # type: ignore [union-attr] + response.raise_for_status() + + return + + with contextlib.suppress(httpx.TimeoutException, httpx.HTTPStatusError): + # Check if models endpoint is available next + models = await self.available_models() + if models and not self.model: + self.model = models[0] + elif not self.model: + raise RuntimeError( + "No model available and could not set a default model " + "from the server's available models." + ) + + return + + with contextlib.suppress(httpx.TimeoutException, httpx.HTTPStatusError): + # Last check, fall back on dummy request to text completions + async for _, __ in self.text_completions( + prompt="Validate backend", + request_id="validate", + output_token_count=1, + stream_response=False, + ): + pass + + return + + raise RuntimeError( + "Backend validation failed. Could not connect to the server or " + "validate the backend configuration." + ) + + async def available_models(self) -> list[str]: + """ + Get available models from the target server. + + :return: List of model identifiers. + :raises HTTPError: If models endpoint returns an error. + :raises RuntimeError: If backend is not initialized. + """ + self._check_in_process() + + target = f"{self.target}{self.MODELS_PATH}" + headers = self._get_headers() + params = self._get_params(self.MODELS_KEY) + response = await self._async_client.get(target, headers=headers, params=params) # type: ignore [union-attr] + response.raise_for_status() + + return [item["id"] for item in response.json()["data"]] + + async def default_model(self) -> Optional[str]: + """ + Get the default model for this backend. + + :return: Model name or None if no model is available. + """ + if self.model or not self._in_process: + return self.model + + models = await self.available_models() + return models[0] if models else None + + async def resolve( + self, + request: GenerationRequest, + request_info: ScheduledRequestInfo, + history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, + ) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]: + """ + Process a generation request and yield progressive responses. + + Handles request formatting, timing tracking, API communication, and + response parsing with streaming support. + + :param request: Generation request with content and parameters. + :param request_info: Request tracking info updated with timing metadata. + :param history: Conversation history. Currently not supported. + :raises NotImplementedError: If history is provided. + :yields: Tuples of (response, updated_request_info) as generation progresses. + """ + self._check_in_process() + if history is not None: + raise NotImplementedError( + "Multi-turn requests with conversation history are not yet supported" + ) + + response = GenerationResponse( + request_id=request.request_id, + request_args={ + "request_type": request.request_type, + "output_token_count": request.constraints.get("output_tokens"), + **request.params, + }, + value="", + request_prompt_tokens=request.stats.get("prompt_tokens"), + request_output_tokens=request.constraints.get("output_tokens"), + ) + request_info.request_timings = GenerationRequestTimings() + request_info.request_timings.request_start = time.time() + + completion_method = ( + self.text_completions + if request.request_type == "text_completions" + else self.chat_completions + ) + completion_kwargs = ( + { + "prompt": request.content, + "request_id": request.request_id, + "output_token_count": request.constraints.get("output_tokens"), + "stream_response": request.params.get("stream", self.stream_response), + **request.params, + } + if request.request_type == "text_completions" + else { + "content": request.content, + "request_id": request.request_id, + "output_token_count": request.constraints.get("output_tokens"), + "stream_response": request.params.get("stream", self.stream_response), + **request.params, + } + ) + + async for delta, usage_stats in completion_method(**completion_kwargs): + if request_info.request_timings.request_start is None: + request_info.request_timings.request_start = time.time() + + if delta is not None: + if request_info.request_timings.first_iteration is None: + request_info.request_timings.first_iteration = time.time() + response.value += delta # type: ignore [operator] + response.delta = delta + request_info.request_timings.last_iteration = time.time() + response.iterations += 1 + + if usage_stats is not None: + request_info.request_timings.request_end = time.time() + response.response_output_tokens = usage_stats.output_tokens + response.response_prompt_tokens = usage_stats.prompt_tokens + + yield response, request_info + + if request_info.request_timings.request_end is None: + request_info.request_timings.request_end = time.time() + response.delta = None + yield response, request_info + + async def text_completions( + self, + prompt: Union[str, list[str]], + request_id: Optional[str], # noqa: ARG002 + output_token_count: Optional[int] = None, + stream_response: bool = True, + **kwargs, + ) -> AsyncIterator[tuple[Optional[str], Optional[UsageStats]]]: + """ + Generate text completions using the /v1/completions endpoint. + + :param prompt: Text prompt(s) for completion. Single string or list. + :param request_id: Request identifier for tracking. + :param output_token_count: Maximum tokens to generate. Overrides default + if specified. + :param stream_response: Whether to stream response progressively. + :param kwargs: Additional request parameters (temperature, top_p, etc.). + :yields: Tuples of (generated_text, usage_stats). First yield is (None, None). + :raises RuntimeError: If backend is not initialized. + :raises HTTPError: If API request fails. + """ + self._check_in_process() + target = f"{self.target}{self.TEXT_COMPLETIONS_PATH}" + headers = self._get_headers() + params = self._get_params(self.TEXT_COMPLETIONS_KEY) + body = self._get_body( + endpoint_type=self.TEXT_COMPLETIONS_KEY, + request_kwargs=kwargs, + max_output_tokens=output_token_count, + prompt=prompt, + ) + yield None, None # Initial yield for async iterator to signal start + + if not stream_response: + response = await self._async_client.post( # type: ignore [union-attr] + target, + headers=headers, + params=params, + json=body, + ) + response.raise_for_status() + data = response.json() + yield ( + self._get_completions_text_content(data), + self._get_completions_usage_stats(data), + ) + return + + body.update({"stream": True, "stream_options": {"include_usage": True}}) + async with self._async_client.stream( # type: ignore [union-attr] + "POST", + target, + headers=headers, + params=params, + json=body, + ) as stream: + stream.raise_for_status() + async for line in stream.aiter_lines(): + if not line or not line.strip().startswith("data:"): + continue + if line.strip() == "data: [DONE]": + break + data = json.loads(line.strip()[len("data: ") :]) + yield ( + self._get_completions_text_content(data), + self._get_completions_usage_stats(data), + ) + + async def chat_completions( + self, + content: Union[ + str, + list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], + Any, + ], + request_id: Optional[str] = None, # noqa: ARG002 + output_token_count: Optional[int] = None, + raw_content: bool = False, + stream_response: bool = True, + **kwargs, + ) -> AsyncIterator[tuple[Optional[str], Optional[UsageStats]]]: + """ + Generate chat completions using the /v1/chat/completions endpoint. + + Supports multimodal inputs including text and images with message formatting. + + :param content: Chat content - string, list of mixed content, or raw content + when raw_content=True. + :param request_id: Request identifier (currently unused). + :param output_token_count: Maximum tokens to generate. Overrides default + if specified. + :param raw_content: If True, passes content directly without formatting. + :param stream_response: Whether to stream response progressively. + :param kwargs: Additional request parameters (temperature, top_p, tools, etc.). + :yields: Tuples of (generated_text, usage_stats). First yield is (None, None). + :raises RuntimeError: If backend is not initialized. + :raises HTTPError: If API request fails. + """ + self._check_in_process() + target = f"{self.target}{self.CHAT_COMPLETIONS_PATH}" + headers = self._get_headers() + params = self._get_params(self.CHAT_COMPLETIONS_KEY) + body = self._get_body( + endpoint_type=self.CHAT_COMPLETIONS_KEY, + request_kwargs=kwargs, + max_output_tokens=output_token_count, + messages=self._get_chat_messages(content) if not raw_content else content, + **kwargs, + ) + yield None, None # Initial yield for async iterator to signal start + + if not stream_response: + response = await self._async_client.post( # type: ignore [union-attr] + target, headers=headers, params=params, json=body + ) + response.raise_for_status() + data = response.json() + yield ( + self._get_completions_text_content(data), + self._get_completions_usage_stats(data), + ) + return + + body.update({"stream": True, "stream_options": {"include_usage": True}}) + async with self._async_client.stream( # type: ignore [union-attr] + "POST", target, headers=headers, params=params, json=body + ) as stream: + stream.raise_for_status() + async for line in stream.aiter_lines(): + if not line or not line.strip().startswith("data:"): + continue + if line.strip() == "data: [DONE]": + break + data = json.loads(line.strip()[len("data: ") :]) + yield ( + self._get_completions_text_content(data), + self._get_completions_usage_stats(data), + ) + + def _build_headers( + self, + api_key: Optional[str], + organization: Optional[str], + project: Optional[str], + user_headers: Optional[dict], + ) -> dict[str, str]: + headers = {} + + if api_key: + headers["Authorization"] = ( + f"Bearer {api_key}" if not api_key.startswith("Bearer") else api_key + ) + if organization: + headers["OpenAI-Organization"] = organization + if project: + headers["OpenAI-Project"] = project + if user_headers: + headers.update(user_headers) + + return {key: val for key, val in headers.items() if val is not None} + + def _check_in_process(self): + if not self._in_process or self._async_client is None: + raise RuntimeError( + "Backend not started up for process, cannot process requests." + ) + + def _get_headers(self) -> dict[str, str]: + return { + "Content-Type": "application/json", + **self.headers, + } + + def _get_params(self, endpoint_type: str) -> dict[str, str]: + if endpoint_type in self.extra_query: + return copy.deepcopy(self.extra_query[endpoint_type]) + return copy.deepcopy(self.extra_query) + + def _get_chat_messages( + self, + content: Union[ + str, + list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], + Any, + ], + ) -> list[dict[str, Any]]: + if isinstance(content, str): + return [{"role": "user", "content": content}] + + if not isinstance(content, list): + raise ValueError(f"Unsupported content type: {type(content)}") + + resolved_content = [] + for item in content: + if isinstance(item, dict): + resolved_content.append(item) + elif isinstance(item, str): + resolved_content.append({"type": "text", "text": item}) + elif isinstance(item, (Image.Image, Path)): + resolved_content.append(self._get_chat_message_media_item(item)) + else: + raise ValueError(f"Unsupported content item type: {type(item)}") + + return [{"role": "user", "content": resolved_content}] + + def _get_chat_message_media_item( + self, item: Union[Path, Image.Image] + ) -> dict[str, Any]: + if isinstance(item, Image.Image): + encoded = base64.b64encode(item.tobytes()).decode("utf-8") + return { + "type": "image", + "image": {"url": f"data:image/jpeg;base64,{encoded}"}, + } + + # Handle file paths + suffix = item.suffix.lower() + if suffix in [".jpg", ".jpeg"]: + image = Image.open(item) + encoded = base64.b64encode(image.tobytes()).decode("utf-8") + return { + "type": "image", + "image": {"url": f"data:image/jpeg;base64,{encoded}"}, + } + elif suffix == ".wav": + encoded = base64.b64encode(item.read_bytes()).decode("utf-8") + return { + "type": "input_audio", + "input_audio": {"data": encoded, "format": "wav"}, + } + else: + raise ValueError(f"Unsupported file type: {suffix}") + + def _get_body( + self, + endpoint_type: str, + request_kwargs: Optional[dict[str, Any]], + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> dict[str, Any]: + # Start with endpoint-specific extra body parameters + extra_body: dict = self.extra_body.get(endpoint_type, self.extra_body) + + body = copy.deepcopy(extra_body) + body.update(request_kwargs or {}) + body.update(kwargs) + body["model"] = self.model + + # Handle token limits + max_tokens = max_output_tokens or self.max_output_tokens + if max_tokens is not None: + body.update( + { + "max_tokens": max_tokens, + "max_completion_tokens": max_tokens, + } + ) + # Set stop conditions only for request-level limits + if max_output_tokens: + body.update({"stop": None, "ignore_eos": True}) + + if self.remove_from_body: + for key in self.remove_from_body: + body.pop(key, None) + + return {key: val for key, val in body.items() if val is not None} + + def _get_completions_text_content(self, data: dict) -> Optional[str]: + if not data.get("choices"): + return None + + choice: dict = data["choices"][0] + return ( + choice.get("text") + or choice.get("delta", {}).get("content") + or choice.get("message", {}).get("content") + ) + + def _get_completions_usage_stats(self, data: dict) -> Optional[UsageStats]: + if not data.get("usage"): + return None + + return UsageStats( + prompt_tokens=data["usage"].get("prompt_tokens"), + output_tokens=data["usage"].get("completion_tokens"), + ) diff --git a/src/guidellm/scheduler/scheduler.py b/src/guidellm/scheduler/scheduler.py index de0660e2..e7d8b2c6 100644 --- a/src/guidellm/scheduler/scheduler.py +++ b/src/guidellm/scheduler/scheduler.py @@ -50,7 +50,7 @@ class Scheduler( Example: :: from guidellm.scheduler import Scheduler - from guidellm.backend import OpenAIBackend + from guidellm.backends import OpenAIBackend from guidellm.scheduler import NonDistributedEnvironment, SynchronousStrategy scheduler = Scheduler() diff --git a/tests/unit/backend/test_backend.py b/tests/unit/backend/test_backend.py index 1115d509..49b65077 100644 --- a/tests/unit/backend/test_backend.py +++ b/tests/unit/backend/test_backend.py @@ -1,136 +1,332 @@ -import time +""" +Unit tests for the Backend base class and registry functionality. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from functools import wraps +from typing import Any +from unittest.mock import Mock, patch import pytest -from guidellm.backend import ( - Backend, - ResponseSummary, - StreamingTextResponse, +from guidellm.backends.backend import Backend, BackendType +from guidellm.backends.objects import ( + GenerationRequest, + GenerationRequestTimings, ) +from guidellm.scheduler import BackendInterface, ScheduledRequestInfo +from guidellm.utils import RegistryMixin + + +def async_timeout(delay): + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +def test_backend_type(): + """Test that BackendType is defined correctly as a Literal type.""" + assert BackendType is not None + # BackendType should be a literal type containing "openai_http" + assert "openai_http" in str(BackendType) + + +class TestBackend: + """Test cases for Backend base class.""" + + @pytest.fixture( + params=[ + {"type_": "openai_http"}, + {"type_": "openai_http"}, # Test multiple instances with same type + ] + ) + def valid_instances(self, request): + """Fixture providing valid Backend instances.""" + constructor_args = request.param + + class TestBackend(Backend): + def info(self) -> dict[str, Any]: + return {"type": self.type_} + + async def process_startup(self): + pass + + async def process_shutdown(self): + pass + + async def validate(self): + pass + + async def resolve( + self, request, request_info, history=None + ) -> AsyncIterator[tuple[Any, Any]]: + yield request, request_info + + async def default_model(self) -> str | None: + return "test-model" + + instance = TestBackend(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test Backend inheritance and type relationships.""" + assert issubclass(Backend, RegistryMixin) + assert issubclass(Backend, BackendInterface) + assert hasattr(Backend, "create") + assert hasattr(Backend, "register") + assert hasattr(Backend, "get_registered_object") + + # Check properties exist + assert hasattr(Backend, "processes_limit") + assert hasattr(Backend, "requests_limit") + + # Check abstract method exists + assert hasattr(Backend, "default_model") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test Backend initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, Backend) + assert instance.type_ == constructor_args["type_"] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("type_", None), + ("type_", 123), + ("type_", ""), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test Backend with invalid field values.""" + + class TestBackend(Backend): + def info(self) -> dict[str, Any]: + return {} + + async def process_startup(self): + pass + + async def process_shutdown(self): + pass + + async def validate(self): + pass + + async def resolve(self, request, request_info, history=None): + yield request, request_info + + async def default_model(self) -> str | None: + return "test-model" + + data = {field: value} + # Backend itself doesn't validate types, but we test that it accepts the value + backend = TestBackend(**data) + assert getattr(backend, field) == value + + @pytest.mark.smoke + def test_default_properties(self, valid_instances): + """Test Backend default property implementations.""" + instance, _ = valid_instances + assert instance.processes_limit is None + assert instance.requests_limit is None + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(5.0) + async def test_default_model_abstract(self): + """Test that default_model is abstract and must be implemented.""" + # Backend itself is abstract and cannot be instantiated + with pytest.raises(TypeError): + Backend("openai_http") # type: ignore + + @pytest.mark.regression + @pytest.mark.asyncio + @async_timeout(5.0) + async def test_interface_compatibility(self, valid_instances): + """Test that Backend is compatible with BackendInterface.""" + instance, _ = valid_instances + + # Test that Backend uses the correct generic types + request = GenerationRequest(content="test") + request_info = ScheduledRequestInfo( + request_id="test-id", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + + # Test resolve method + async for response, info in instance.resolve(request, request_info): + assert response == request + assert info == request_info + break # Only test first iteration + + @pytest.mark.smoke + def test_create_method_valid(self): + """Test Backend.create class method with valid backend.""" + # Mock a registered backend + mock_backend_class = Mock() + mock_backend_instance = Mock() + mock_backend_class.return_value = mock_backend_instance + + with patch.object( + Backend, "get_registered_object", return_value=mock_backend_class + ): + result = Backend.create("openai_http", test_arg="value") + + Backend.get_registered_object.assert_called_once_with("openai_http") + mock_backend_class.assert_called_once_with(test_arg="value") + assert result == mock_backend_instance + + @pytest.mark.sanity + def test_create_method_invalid(self): + """Test Backend.create class method with invalid backend type.""" + with pytest.raises( + ValueError, match="Backend type 'invalid_type' is not registered" + ): + Backend.create("invalid_type") + + @pytest.mark.regression + def test_docstring_example_pattern(self): + """Test that Backend docstring examples work as documented.""" + + # Test the pattern shown in docstring + class MyBackend(Backend): + def __init__(self, api_key: str): + super().__init__("mock_backend") # type: ignore [arg-type] + self.api_key = api_key + + def info(self) -> dict[str, Any]: + return {"api_key": "***"} + + async def process_startup(self): + self.client = Mock() # Simulate API client + + async def process_shutdown(self): + self.client = None # type: ignore[assignment] + + async def validate(self): + pass + + async def resolve(self, request, request_info, history=None): + yield request, request_info + + async def default_model(self) -> str | None: + return "my-model" + + # Register the backend + Backend.register("my_backend")(MyBackend) + + # Create instance + backend = Backend.create("my_backend", api_key="secret") + assert isinstance(backend, MyBackend) + assert backend.api_key == "secret" + assert backend.type_ == "mock_backend" + + +class TestBackendRegistry: + """Test cases for Backend registry functionality.""" + + @pytest.mark.smoke + def test_openai_backend_registered(self): + """Test that OpenAI HTTP backend is registered.""" + from guidellm.backends.openai import OpenAIHTTPBackend + + # OpenAI backend should be registered + backend = Backend.create("openai_http", target="http://test") + assert isinstance(backend, OpenAIHTTPBackend) + assert backend.type_ == "openai_http" + + @pytest.mark.sanity + def test_backend_create_invalid_type(self): + """Test Backend.create with invalid type raises appropriate error.""" + with pytest.raises( + ValueError, match="Backend type 'invalid_type' is not registered" + ): + Backend.create("invalid_type") + + @pytest.mark.smoke + def test_backend_registry_functionality(self): + """Test that backend registry functions work.""" + from guidellm.backends.openai import OpenAIHTTPBackend + + # Test that we can get registered backends + openai_class = Backend.get_registered_object("openai_http") + assert openai_class == OpenAIHTTPBackend + + # Test creating with kwargs + backend = Backend.create( + "openai_http", target="http://localhost:8000", model="gpt-4" + ) + assert backend.target == "http://localhost:8000" + assert backend.model == "gpt-4" + + @pytest.mark.smoke + def test_backend_is_registered(self): + """Test Backend.is_registered method.""" + # Test with a known registered backend + assert Backend.is_registered("openai_http") + + # Test with unknown backend + assert not Backend.is_registered("unknown_backend") + + @pytest.mark.regression + def test_backend_registration_decorator(self): + """Test that backend registration decorator works.""" + + # Create a test backend class + @Backend.register("test_backend") + class TestBackend(Backend): + def __init__(self, test_param="default"): + super().__init__("test_backend") # type: ignore + self._test_param = test_param + + def info(self): + return {"test_param": self._test_param} + + async def process_startup(self): + pass + + async def process_shutdown(self): + pass + + async def validate(self): + pass + + async def resolve(self, request, request_info, history=None): + yield request, request_info + + async def default_model(self): + return "test-model" + + # Test that it's registered and can be created + backend = Backend.create("test_backend", test_param="custom") + assert isinstance(backend, TestBackend) + assert backend.info() == {"test_param": "custom"} + + @pytest.mark.smoke + def test_backend_registered_objects(self): + """Test Backend.registered_objects method returns registered backends.""" + # Should include at least the openai_http backend + registered = Backend.registered_objects() + assert isinstance(registered, tuple) + assert len(registered) > 0 + # Check that openai backend is in the registered objects + from guidellm.backends.openai import OpenAIHTTPBackend -@pytest.mark.smoke -def test_backend_registry(): - assert Backend._registry["mock"] is not None # type: ignore - - backend_instance = Backend.create("mock") # type: ignore - assert backend_instance is not None - - with pytest.raises(ValueError): - Backend.register("mock")("backend") # type: ignore - - with pytest.raises(ValueError): - Backend.create("invalid_type") # type: ignore - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_backend_text_completions(mock_backend): - index = 0 - prompt = "Test Prompt" - request_id = "test-request-id" - prompt_token_count = 3 - output_token_count = 10 - final_resp = None - - async for response in mock_backend.text_completions( - prompt=prompt, - request_id=request_id, - prompt_token_count=prompt_token_count, - output_token_count=output_token_count, - ): - assert isinstance(response, (StreamingTextResponse, ResponseSummary)) - - if index == 0: - assert isinstance(response, StreamingTextResponse) - assert response.type_ == "start" - assert response.iter_count == 0 - assert response.delta == "" - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == request_id - elif not isinstance(response, ResponseSummary): - assert response.type_ == "iter" - assert response.iter_count == index - assert len(response.delta) > 0 - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == request_id - else: - assert not final_resp - final_resp = response - assert isinstance(response, ResponseSummary) - assert len(response.value) > 0 - assert response.iterations > 0 - assert response.start_time > 0 - assert response.end_time == pytest.approx(time.time(), abs=0.01) - assert response.request_prompt_tokens == prompt_token_count - assert response.request_output_tokens == output_token_count - assert response.response_prompt_tokens == 3 - assert response.response_output_tokens == 10 - assert response.request_id == request_id - - index += 1 - - assert final_resp - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_backend_chat_completions(mock_backend): - index = 0 - prompt = "Test Prompt" - request_id = "test-request-id" - prompt_token_count = 3 - output_token_count = 10 - final_resp = None - - async for response in mock_backend.chat_completions( - content=prompt, - request_id=request_id, - prompt_token_count=prompt_token_count, - output_token_count=output_token_count, - ): - assert isinstance(response, (StreamingTextResponse, ResponseSummary)) - - if index == 0: - assert isinstance(response, StreamingTextResponse) - assert response.type_ == "start" - assert response.iter_count == 0 - assert response.delta == "" - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == request_id - elif not isinstance(response, ResponseSummary): - assert response.type_ == "iter" - assert response.iter_count == index - assert len(response.delta) > 0 - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == request_id - else: - assert not final_resp - final_resp = response - assert isinstance(response, ResponseSummary) - assert len(response.value) > 0 - assert response.iterations > 0 - assert response.start_time > 0 - assert response.end_time == pytest.approx(time.time(), abs=0.01) - assert response.request_prompt_tokens == prompt_token_count - assert response.request_output_tokens == output_token_count - assert response.response_prompt_tokens == 3 - assert response.response_output_tokens == 10 - assert response.request_id == request_id - - index += 1 - - assert final_resp - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_backend_models(mock_backend): - models = await mock_backend.available_models() - assert models == ["mock-model"] - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_backend_validate(mock_backend): - await mock_backend.validate() + assert OpenAIHTTPBackend in registered diff --git a/tests/unit/backend/test_objects.py b/tests/unit/backend/test_objects.py new file mode 100644 index 00000000..34a6350c --- /dev/null +++ b/tests/unit/backend/test_objects.py @@ -0,0 +1,467 @@ +""" +Unit tests for GenerationRequest, GenerationResponse, GenerationRequestTimings. +""" + +from __future__ import annotations + +import uuid + +import pytest +from pydantic import ValidationError + +from guidellm.backends.objects import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, +) +from guidellm.scheduler import MeasuredRequestTimings +from guidellm.utils import StandardBaseModel + + +class TestGenerationRequest: + """Test cases for GenerationRequest model.""" + + @pytest.fixture( + params=[ + {"content": "test content"}, + { + "content": ["message1", "message2"], + "request_type": "chat_completions", + "params": {"temperature": 0.7}, + }, + { + "request_id": "custom-id", + "content": {"role": "user", "content": "test"}, + "stats": {"prompt_tokens": 50}, + "constraints": {"output_tokens": 100}, + }, + ] + ) + def valid_instances(self, request): + """Fixture providing valid GenerationRequest instances.""" + constructor_args = request.param + instance = GenerationRequest(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test GenerationRequest inheritance and type relationships.""" + assert issubclass(GenerationRequest, StandardBaseModel) + assert hasattr(GenerationRequest, "model_dump") + assert hasattr(GenerationRequest, "model_validate") + + # Check all expected fields are defined + fields = GenerationRequest.model_fields + expected_fields = [ + "request_id", + "request_type", + "content", + "params", + "stats", + "constraints", + ] + for field in expected_fields: + assert field in fields + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test GenerationRequest initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, GenerationRequest) + assert instance.content == constructor_args["content"] + + # Check defaults + expected_request_type = constructor_args.get("request_type", "text_completions") + assert instance.request_type == expected_request_type + + if "request_id" in constructor_args: + assert instance.request_id == constructor_args["request_id"] + else: + assert isinstance(instance.request_id, str) + # Should be valid UUID + uuid.UUID(instance.request_id) + + @pytest.mark.sanity + def test_invalid_initialization_values(self): + """Test GenerationRequest with invalid field values.""" + # Invalid request_type + with pytest.raises(ValidationError): + GenerationRequest(content="test", request_type="invalid_type") + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test GenerationRequest initialization without required field.""" + with pytest.raises(ValidationError): + GenerationRequest() # Missing required 'content' field + + @pytest.mark.smoke + def test_auto_id_generation(self): + """Test that request_id is auto-generated if not provided.""" + request1 = GenerationRequest(content="test1") + request2 = GenerationRequest(content="test2") + + assert request1.request_id != request2.request_id + assert len(request1.request_id) > 0 + assert len(request2.request_id) > 0 + + # Should be valid UUIDs + uuid.UUID(request1.request_id) + uuid.UUID(request2.request_id) + + @pytest.mark.regression + def test_content_types(self): + """Test GenerationRequest with different content types.""" + # String content + request1 = GenerationRequest(content="string content") + assert request1.content == "string content" + + # List content + request2 = GenerationRequest(content=["item1", "item2"]) + assert request2.content == ["item1", "item2"] + + # Dict content + dict_content = {"role": "user", "content": "test"} + request3 = GenerationRequest(content=dict_content) + assert request3.content == dict_content + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test GenerationRequest serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["content"] == constructor_args["content"] + + # Test reconstruction + reconstructed = GenerationRequest.model_validate(data_dict) + assert reconstructed.content == instance.content + assert reconstructed.request_type == instance.request_type + assert reconstructed.request_id == instance.request_id + + +class TestGenerationResponse: + """Test cases for GenerationResponse model.""" + + @pytest.fixture( + params=[ + { + "request_id": "test-123", + "request_args": {"model": "gpt-3.5-turbo"}, + }, + { + "request_id": "test-456", + "request_args": {"model": "gpt-4"}, + "value": "Generated text", + "delta": "new text", + "iterations": 5, + "request_prompt_tokens": 50, + "request_output_tokens": 100, + "response_prompt_tokens": 55, + "response_output_tokens": 95, + }, + ] + ) + def valid_instances(self, request): + """Fixture providing valid GenerationResponse instances.""" + constructor_args = request.param + instance = GenerationResponse(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test GenerationResponse inheritance and type relationships.""" + assert issubclass(GenerationResponse, StandardBaseModel) + assert hasattr(GenerationResponse, "model_dump") + assert hasattr(GenerationResponse, "model_validate") + + # Check all expected fields and properties are defined + fields = GenerationResponse.model_fields + expected_fields = [ + "request_id", + "request_args", + "value", + "delta", + "iterations", + "request_prompt_tokens", + "request_output_tokens", + "response_prompt_tokens", + "response_output_tokens", + ] + for field in expected_fields: + assert field in fields + + # Check properties exist + assert hasattr(GenerationResponse, "prompt_tokens") + assert hasattr(GenerationResponse, "output_tokens") + assert hasattr(GenerationResponse, "total_tokens") + assert hasattr(GenerationResponse, "preferred_prompt_tokens") + assert hasattr(GenerationResponse, "preferred_output_tokens") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test GenerationResponse initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, GenerationResponse) + assert instance.request_id == constructor_args["request_id"] + assert instance.request_args == constructor_args["request_args"] + + # Check defaults for optional fields + if "value" not in constructor_args: + assert instance.value is None + if "delta" not in constructor_args: + assert instance.delta is None + if "iterations" not in constructor_args: + assert instance.iterations == 0 + + @pytest.mark.sanity + def test_invalid_initialization_values(self): + """Test GenerationResponse with invalid field values.""" + # Invalid iterations type + with pytest.raises(ValidationError): + GenerationResponse(request_id="test", request_args={}, iterations="not_int") + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test GenerationResponse initialization without required fields.""" + with pytest.raises(ValidationError): + GenerationResponse() # Missing required fields + + with pytest.raises(ValidationError): + GenerationResponse(request_id="test") # Missing request_args + + @pytest.mark.smoke + def test_prompt_tokens_property(self): + """Test prompt_tokens property logic.""" + # When both are available, prefers response_prompt_tokens + response1 = GenerationResponse( + request_id="test", + request_args={}, + request_prompt_tokens=50, + response_prompt_tokens=55, + ) + assert response1.prompt_tokens == 55 + + # When only request_prompt_tokens is available + response2 = GenerationResponse( + request_id="test", request_args={}, request_prompt_tokens=50 + ) + assert response2.prompt_tokens == 50 + + # When only response_prompt_tokens is available + response3 = GenerationResponse( + request_id="test", request_args={}, response_prompt_tokens=55 + ) + assert response3.prompt_tokens == 55 + + # When neither is available + response4 = GenerationResponse(request_id="test", request_args={}) + assert response4.prompt_tokens is None + + @pytest.mark.smoke + def test_output_tokens_property(self): + """Test output_tokens property logic.""" + # When both are available, prefers response_output_tokens + response1 = GenerationResponse( + request_id="test", + request_args={}, + request_output_tokens=100, + response_output_tokens=95, + ) + assert response1.output_tokens == 95 + + # When only request_output_tokens is available + response2 = GenerationResponse( + request_id="test", request_args={}, request_output_tokens=100 + ) + assert response2.output_tokens == 100 + + # When only response_output_tokens is available + response3 = GenerationResponse( + request_id="test", request_args={}, response_output_tokens=95 + ) + assert response3.output_tokens == 95 + + # When neither is available + response4 = GenerationResponse(request_id="test", request_args={}) + assert response4.output_tokens is None + + @pytest.mark.smoke + def test_total_tokens_property(self): + """Test total_tokens property calculation.""" + # When both prompt and output tokens are available + response1 = GenerationResponse( + request_id="test", + request_args={}, + response_prompt_tokens=50, + response_output_tokens=100, + ) + assert response1.total_tokens == 150 + + # When one is missing + response2 = GenerationResponse( + request_id="test", request_args={}, response_prompt_tokens=50 + ) + assert response2.total_tokens is None + + # When both are missing + response3 = GenerationResponse(request_id="test", request_args={}) + assert response3.total_tokens is None + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("preferred_source", "expected_prompt", "expected_output"), + [ + ("request", 50, 100), + ("response", 55, 95), + ], + ) + def test_preferred_token_methods( + self, preferred_source, expected_prompt, expected_output + ): + """Test preferred_*_tokens methods.""" + response = GenerationResponse( + request_id="test", + request_args={}, + request_prompt_tokens=50, + request_output_tokens=100, + response_prompt_tokens=55, + response_output_tokens=95, + ) + + assert response.preferred_prompt_tokens(preferred_source) == expected_prompt + assert response.preferred_output_tokens(preferred_source) == expected_output + + @pytest.mark.regression + def test_preferred_tokens_fallback(self): + """Test preferred_*_tokens methods with fallback logic.""" + # Only response tokens available + response1 = GenerationResponse( + request_id="test", + request_args={}, + response_prompt_tokens=55, + response_output_tokens=95, + ) + + assert response1.preferred_prompt_tokens("request") == 55 # Falls back + assert response1.preferred_output_tokens("request") == 95 # Falls back + + # Only request tokens available + response2 = GenerationResponse( + request_id="test", + request_args={}, + request_prompt_tokens=50, + request_output_tokens=100, + ) + + assert response2.preferred_prompt_tokens("response") == 50 # Falls back + assert response2.preferred_output_tokens("response") == 100 # Falls back + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test GenerationResponse serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["request_id"] == constructor_args["request_id"] + assert data_dict["request_args"] == constructor_args["request_args"] + + # Test reconstruction + reconstructed = GenerationResponse.model_validate(data_dict) + assert reconstructed.request_id == instance.request_id + assert reconstructed.request_args == instance.request_args + assert reconstructed.value == instance.value + assert reconstructed.iterations == instance.iterations + + +class TestGenerationRequestTimings: + """Test cases for GenerationRequestTimings model.""" + + @pytest.fixture( + params=[ + {}, + {"first_iteration": 1234567890.0}, + {"last_iteration": 1234567895.0}, + { + "first_iteration": 1234567890.0, + "last_iteration": 1234567895.0, + }, + ] + ) + def valid_instances(self, request): + """Fixture providing valid GenerationRequestTimings instances.""" + constructor_args = request.param + instance = GenerationRequestTimings(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test GenerationRequestTimings inheritance and type relationships.""" + assert issubclass(GenerationRequestTimings, MeasuredRequestTimings) + assert issubclass(GenerationRequestTimings, StandardBaseModel) + assert hasattr(GenerationRequestTimings, "model_dump") + assert hasattr(GenerationRequestTimings, "model_validate") + + # Check inherited fields from MeasuredRequestTimings + fields = GenerationRequestTimings.model_fields + expected_inherited_fields = ["request_start", "request_end"] + for field in expected_inherited_fields: + assert field in fields + + # Check own fields + expected_own_fields = ["first_iteration", "last_iteration"] + for field in expected_own_fields: + assert field in fields + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test GenerationRequestTimings initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, GenerationRequestTimings) + assert isinstance(instance, MeasuredRequestTimings) + + # Check field values + expected_first = constructor_args.get("first_iteration") + expected_last = constructor_args.get("last_iteration") + assert instance.first_iteration == expected_first + assert instance.last_iteration == expected_last + + @pytest.mark.sanity + def test_invalid_initialization_values(self): + """Test GenerationRequestTimings with invalid field values.""" + # Invalid timestamp type + with pytest.raises(ValidationError): + GenerationRequestTimings(first_iteration="not_float") + + with pytest.raises(ValidationError): + GenerationRequestTimings(last_iteration="not_float") + + @pytest.mark.smoke + def test_optional_fields(self): + """Test that all timing fields are optional.""" + # Should be able to create with no fields + timings1 = GenerationRequestTimings() + assert timings1.first_iteration is None + assert timings1.last_iteration is None + + # Should be able to create with only one field + timings2 = GenerationRequestTimings(first_iteration=123.0) + assert timings2.first_iteration == 123.0 + assert timings2.last_iteration is None + + timings3 = GenerationRequestTimings(last_iteration=456.0) + assert timings3.first_iteration is None + assert timings3.last_iteration == 456.0 + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test GenerationRequestTimings serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + + # Test reconstruction + reconstructed = GenerationRequestTimings.model_validate(data_dict) + assert reconstructed.first_iteration == instance.first_iteration + assert reconstructed.last_iteration == instance.last_iteration + assert reconstructed.request_start == instance.request_start + assert reconstructed.request_end == instance.request_end diff --git a/tests/unit/backend/test_openai_backend.py b/tests/unit/backend/test_openai_backend.py index 7123c590..7c7f528d 100644 --- a/tests/unit/backend/test_openai_backend.py +++ b/tests/unit/backend/test_openai_backend.py @@ -1,207 +1,1178 @@ -import time +""" +Unit tests for OpenAIHTTPBackend implementation. +""" +from __future__ import annotations + +import asyncio +import base64 +from functools import wraps +from pathlib import Path +from unittest.mock import AsyncMock, Mock, patch + +import httpx import pytest +from PIL import Image -from guidellm.backend import OpenAIHTTPBackend, ResponseSummary, StreamingTextResponse -from guidellm.settings import settings - - -@pytest.mark.smoke -def test_openai_http_backend_default_initialization(): - backend = OpenAIHTTPBackend() - assert backend.target == settings.openai.base_url - assert backend.model is None - assert backend.headers.get("Authorization") == settings.openai.bearer_token - assert backend.organization == settings.openai.organization - assert backend.project == settings.openai.project - assert backend.timeout == settings.request_timeout - assert backend.http2 is True - assert backend.follow_redirects is True - assert backend.max_output_tokens == settings.openai.max_output_tokens - assert backend.extra_query is None - - -@pytest.mark.smoke -def test_openai_http_backend_intialization(): - backend = OpenAIHTTPBackend( - target="http://test-target", - model="test-model", - api_key="test-key", - organization="test-org", - project="test-proj", - timeout=10, - http2=False, - follow_redirects=False, - max_output_tokens=100, - extra_query={"foo": "bar"}, - ) - assert backend.target == "http://test-target" - assert backend.model == "test-model" - assert backend.headers.get("Authorization") == "Bearer test-key" - assert backend.organization == "test-org" - assert backend.project == "test-proj" - assert backend.timeout == 10 - assert backend.http2 is False - assert backend.follow_redirects is False - assert backend.max_output_tokens == 100 - assert backend.extra_query == {"foo": "bar"} - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_openai_http_backend_available_models(httpx_openai_mock): - backend = OpenAIHTTPBackend(target="http://target.mock") - models = await backend.available_models() - assert models == ["mock-model"] - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_openai_http_backend_validate(httpx_openai_mock): - backend = OpenAIHTTPBackend(target="http://target.mock", model="mock-model") - await backend.validate() - - backend = OpenAIHTTPBackend(target="http://target.mock") - await backend.validate() - assert backend.model == "mock-model" - - backend = OpenAIHTTPBackend(target="http://target.mock", model="invalid-model") - with pytest.raises(ValueError): - await backend.validate() - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_openai_http_backend_text_completions(httpx_openai_mock): - backend = OpenAIHTTPBackend(target="http://target.mock", model="mock-model") - - index = 0 - final_resp = None - async for response in backend.text_completions("Test Prompt", request_id="test-id"): - assert isinstance(response, (StreamingTextResponse, ResponseSummary)) - - if index == 0: - assert isinstance(response, StreamingTextResponse) - assert response.type_ == "start" - assert response.iter_count == 0 - assert response.delta == "" - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == "test-id" - elif not isinstance(response, ResponseSummary): - assert response.type_ == "iter" - assert response.iter_count == index - assert len(response.delta) > 0 - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == "test-id" - else: - assert not final_resp - final_resp = response - assert isinstance(response, ResponseSummary) - assert len(response.value) > 0 - assert response.request_args is not None - assert response.iterations > 0 - assert response.start_time > 0 - assert response.end_time == pytest.approx(time.time(), abs=0.01) - assert response.request_prompt_tokens is None - assert response.request_output_tokens is None - assert response.response_prompt_tokens == 3 - assert response.response_output_tokens > 0 # type: ignore - assert response.request_id == "test-id" - - index += 1 - assert final_resp - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_openai_http_backend_text_completions_counts(httpx_openai_mock): - backend = OpenAIHTTPBackend( - target="http://target.mock", - model="mock-model", - max_output_tokens=100, +from guidellm.backends.backend import Backend +from guidellm.backends.objects import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, +) +from guidellm.backends.openai import OpenAIHTTPBackend, UsageStats +from guidellm.scheduler import ScheduledRequestInfo + + +def async_timeout(delay): + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +def test_usage_stats(): + """Test that UsageStats is defined correctly as a dataclass.""" + stats = UsageStats() + assert stats.prompt_tokens is None + assert stats.output_tokens is None + + stats_with_values = UsageStats(prompt_tokens=10, output_tokens=5) + assert stats_with_values.prompt_tokens == 10 + assert stats_with_values.output_tokens == 5 + + +class TestOpenAIHTTPBackend: + """Test cases for OpenAIHTTPBackend.""" + + @pytest.fixture( + params=[ + {"target": "http://localhost:8000"}, + { + "target": "https://api.openai.com", + "model": "gpt-4", + "api_key": "test-key", + "timeout": 30.0, + "stream_response": False, + }, + { + "target": "http://test-server:8080", + "model": "test-model", + "api_key": "Bearer test-token", + "organization": "test-org", + "project": "test-proj", + "timeout": 120.0, + "http2": False, + "follow_redirects": False, + "max_output_tokens": 500, + "extra_query": {"param": "value"}, + "extra_body": {"setting": "test"}, + "remove_from_body": ["unwanted"], + "headers": {"Custom": "header"}, + "verify": True, + }, + ] ) - final_resp = None - - async for response in backend.text_completions( - "Test Prompt", request_id="test-id", prompt_token_count=3, output_token_count=10 - ): - final_resp = response - - assert final_resp - assert isinstance(final_resp, ResponseSummary) - assert len(final_resp.value) > 0 - assert final_resp.request_args is not None - assert final_resp.request_prompt_tokens == 3 - assert final_resp.request_output_tokens == 10 - assert final_resp.response_prompt_tokens == 3 - assert final_resp.response_output_tokens == 10 - assert final_resp.request_id == "test-id" - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_openai_http_backend_chat_completions(httpx_openai_mock): - backend = OpenAIHTTPBackend(target="http://target.mock", model="mock-model") - - index = 0 - final_resp = None - async for response in backend.chat_completions("Test Prompt", request_id="test-id"): - assert isinstance(response, (StreamingTextResponse, ResponseSummary)) - - if index == 0: - assert isinstance(response, StreamingTextResponse) - assert response.type_ == "start" - assert response.iter_count == 0 - assert response.delta == "" - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == "test-id" - elif not isinstance(response, ResponseSummary): - assert response.type_ == "iter" - assert response.iter_count == index - assert len(response.delta) > 0 - assert response.time == pytest.approx(time.time(), abs=0.01) - assert response.request_id == "test-id" + def valid_instances(self, request): + """Fixture providing valid OpenAIHTTPBackend instances.""" + constructor_args = request.param + instance = OpenAIHTTPBackend(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test OpenAIHTTPBackend inheritance and type relationships.""" + assert issubclass(OpenAIHTTPBackend, Backend) + assert hasattr(OpenAIHTTPBackend, "HEALTH_PATH") + assert OpenAIHTTPBackend.HEALTH_PATH == "/health" + assert hasattr(OpenAIHTTPBackend, "MODELS_PATH") + assert OpenAIHTTPBackend.MODELS_PATH == "/v1/models" + assert hasattr(OpenAIHTTPBackend, "TEXT_COMPLETIONS_PATH") + assert OpenAIHTTPBackend.TEXT_COMPLETIONS_PATH == "/v1/completions" + assert hasattr(OpenAIHTTPBackend, "CHAT_COMPLETIONS_PATH") + assert OpenAIHTTPBackend.CHAT_COMPLETIONS_PATH == "/v1/chat/completions" + assert hasattr(OpenAIHTTPBackend, "MODELS_KEY") + assert OpenAIHTTPBackend.MODELS_KEY == "models" + assert hasattr(OpenAIHTTPBackend, "TEXT_COMPLETIONS_KEY") + assert OpenAIHTTPBackend.TEXT_COMPLETIONS_KEY == "text_completions" + assert hasattr(OpenAIHTTPBackend, "CHAT_COMPLETIONS_KEY") + assert OpenAIHTTPBackend.CHAT_COMPLETIONS_KEY == "chat_completions" + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test OpenAIHTTPBackend initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, OpenAIHTTPBackend) + expected_target = constructor_args["target"].rstrip("/").removesuffix("/v1") + assert instance.target == expected_target + if "model" in constructor_args: + assert instance.model == constructor_args["model"] + if "timeout" in constructor_args: + assert instance.timeout == constructor_args["timeout"] else: - assert not final_resp - final_resp = response - assert isinstance(response, ResponseSummary) - assert len(response.value) > 0 - assert response.request_args is not None - assert response.iterations > 0 - assert response.start_time > 0 - assert response.end_time == pytest.approx(time.time(), abs=0.01) - assert response.request_prompt_tokens is None - assert response.request_output_tokens is None - assert response.response_prompt_tokens == 3 - assert response.response_output_tokens > 0 # type: ignore - assert response.request_id == "test-id" - - index += 1 - - assert final_resp - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_openai_http_backend_chat_completions_counts(httpx_openai_mock): - backend = OpenAIHTTPBackend( - target="http://target.mock", - model="mock-model", - max_output_tokens=100, + assert instance.timeout == 60.0 + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("target", ""), + ("timeout", -1.0), + ("http2", "invalid"), + ("verify", "invalid"), + ], ) - final_resp = None - - async for response in backend.chat_completions( - "Test Prompt", request_id="test-id", prompt_token_count=3, output_token_count=10 - ): - final_resp = response - - assert final_resp - assert isinstance(final_resp, ResponseSummary) - assert len(final_resp.value) > 0 - assert final_resp.request_args is not None - assert final_resp.request_prompt_tokens == 3 - assert final_resp.request_output_tokens == 10 - assert final_resp.response_prompt_tokens == 3 - assert final_resp.response_output_tokens == 10 - assert final_resp.request_id == "test-id" + def test_invalid_initialization_values(self, field, value): + """Test OpenAIHTTPBackend with invalid field values.""" + base_args = {"target": "http://localhost:8000"} + base_args[field] = value + # OpenAI backend doesn't validate types at init, accepts whatever is passed + backend = OpenAIHTTPBackend(**base_args) + assert getattr(backend, field) == value + + @pytest.mark.smoke + def test_factory_registration(self): + """Test that OpenAIHTTPBackend is registered with Backend factory.""" + assert Backend.is_registered("openai_http") + backend = Backend.create("openai_http", target="http://test") + assert isinstance(backend, OpenAIHTTPBackend) + assert backend.type_ == "openai_http" + + @pytest.mark.smoke + def test_initialization_minimal(self): + """Test minimal OpenAIHTTPBackend initialization.""" + backend = OpenAIHTTPBackend(target="http://localhost:8000") + + assert backend.target == "http://localhost:8000" + assert backend.model is None + assert backend.timeout == 60.0 + assert backend.http2 is True + assert backend.follow_redirects is True + assert backend.verify is False + assert backend.stream_response is True + assert backend._in_process is False + assert backend._async_client is None + + @pytest.mark.smoke + def test_initialization_full(self): + """Test full OpenAIHTTPBackend initialization.""" + extra_query = {"param": "value"} + extra_body = {"setting": "test"} + remove_from_body = ["unwanted"] + headers = {"Custom-Header": "value"} + + backend = OpenAIHTTPBackend( + target="https://localhost:8000/v1", + model="test-model", + api_key="test-key", + organization="test-org", + project="test-project", + timeout=120.0, + http2=False, + follow_redirects=False, + max_output_tokens=1000, + stream_response=False, + extra_query=extra_query, + extra_body=extra_body, + remove_from_body=remove_from_body, + headers=headers, + verify=True, + ) + + assert backend.target == "https://localhost:8000" + assert backend.model == "test-model" + assert backend.timeout == 120.0 + assert backend.http2 is False + assert backend.follow_redirects is False + assert backend.verify is True + assert backend.max_output_tokens == 1000 + assert backend.stream_response is False + assert backend.extra_query == extra_query + assert backend.extra_body == extra_body + assert backend.remove_from_body == remove_from_body + + @pytest.mark.sanity + def test_target_normalization(self): + """Test target URL normalization.""" + # Remove trailing slashes and /v1 + backend1 = OpenAIHTTPBackend(target="http://localhost:8000/") + assert backend1.target == "http://localhost:8000" + + backend2 = OpenAIHTTPBackend(target="http://localhost:8000/v1") + assert backend2.target == "http://localhost:8000" + + backend3 = OpenAIHTTPBackend(target="http://localhost:8000/v1/") + assert backend3.target == "http://localhost:8000" + + @pytest.mark.sanity + def test_header_building(self): + """Test header building logic.""" + # Test with API key + backend1 = OpenAIHTTPBackend(target="http://test", api_key="test-key") + assert "Authorization" in backend1.headers + assert backend1.headers["Authorization"] == "Bearer test-key" + + # Test with Bearer prefix already + backend2 = OpenAIHTTPBackend(target="http://test", api_key="Bearer test-key") + assert backend2.headers["Authorization"] == "Bearer test-key" + + # Test with organization and project + backend3 = OpenAIHTTPBackend( + target="http://test", organization="test-org", project="test-project" + ) + assert backend3.headers["OpenAI-Organization"] == "test-org" + assert backend3.headers["OpenAI-Project"] == "test-project" + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(10.0) + @async_timeout(5.0) + async def test_info(self): + """Test info method.""" + backend = OpenAIHTTPBackend( + target="http://test", model="test-model", timeout=30.0 + ) + + info = backend.info() + + assert info["target"] == "http://test" + assert info["model"] == "test-model" + assert info["timeout"] == 30.0 + assert info["health_path"] == "/health" + assert info["models_path"] == "/v1/models" + assert info["text_completions_path"] == "/v1/completions" + assert info["chat_completions_path"] == "/v1/chat/completions" + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(10.0) + @async_timeout(5.0) + async def test_process_startup(self): + """Test process startup.""" + backend = OpenAIHTTPBackend(target="http://test") + + assert not backend._in_process + assert backend._async_client is None + + await backend.process_startup() + + assert backend._in_process + assert backend._async_client is not None + assert isinstance(backend._async_client, httpx.AsyncClient) + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(10.0) + @async_timeout(5.0) + async def test_process_startup_already_started(self): + """Test process startup when already started.""" + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + with pytest.raises(RuntimeError, match="Backend already started up"): + await backend.process_startup() + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(10.0) + @async_timeout(5.0) + async def test_process_shutdown(self): + """Test process shutdown.""" + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + assert backend._in_process + assert backend._async_client is not None + + await backend.process_shutdown() + + assert not backend._in_process + assert backend._async_client is None + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(10.0) + @async_timeout(5.0) + async def test_process_shutdown_not_started(self): + """Test process shutdown when not started.""" + backend = OpenAIHTTPBackend(target="http://test") + + with pytest.raises(RuntimeError, match="Backend not started up"): + await backend.process_shutdown() + + @pytest.mark.sanity + @pytest.mark.asyncio + @async_timeout(10.0) + @async_timeout(5.0) + async def test_check_in_process(self): + """Test _check_in_process method.""" + backend = OpenAIHTTPBackend(target="http://test") + + with pytest.raises(RuntimeError, match="Backend not started up"): + backend._check_in_process() + + await backend.process_startup() + backend._check_in_process() # Should not raise + + await backend.process_shutdown() + with pytest.raises(RuntimeError, match="Backend not started up"): + backend._check_in_process() + + @pytest.mark.sanity + @pytest.mark.asyncio + @async_timeout(10.0) + @async_timeout(5.0) + async def test_available_models(self): + """Test available_models method.""" + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + mock_response = Mock() + mock_response.json.return_value = { + "data": [{"id": "test-model1"}, {"id": "test-model2"}] + } + mock_response.raise_for_status = Mock() + + with patch.object(backend._async_client, "get", return_value=mock_response): + models = await backend.available_models() + + assert models == ["test-model1", "test-model2"] + backend._async_client.get.assert_called_once() + + @pytest.mark.sanity + @pytest.mark.asyncio + @async_timeout(10.0) + @async_timeout(5.0) + async def test_default_model(self): + """Test default_model method.""" + # Test when model is already set + backend1 = OpenAIHTTPBackend(target="http://test", model="test-model") + result1 = await backend1.default_model() + assert result1 == "test-model" + + # Test when not in process + backend2 = OpenAIHTTPBackend(target="http://test") + result2 = await backend2.default_model() + assert result2 is None + + # Test when in process but no model set + backend3 = OpenAIHTTPBackend(target="http://test") + await backend3.process_startup() + + with patch.object(backend3, "available_models", return_value=["test-model2"]): + result3 = await backend3.default_model() + assert result3 == "test-model2" + + @pytest.mark.regression + @pytest.mark.asyncio + @async_timeout(10.0) + @async_timeout(10.0) + async def test_validate_with_model(self): + """Test validate method when model is set.""" + backend = OpenAIHTTPBackend(target="http://test", model="test-model") + await backend.process_startup() + + mock_response = Mock() + mock_response.raise_for_status = Mock() + + with patch.object(backend._async_client, "get", return_value=mock_response): + await backend.validate() # Should not raise + + backend._async_client.get.assert_called_once_with( + "http://test/health", headers={"Content-Type": "application/json"} + ) + + @pytest.mark.regression + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_validate_without_model(self): + """Test validate method when no model is set.""" + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + with patch.object(backend, "available_models", return_value=["test-model"]): + await backend.validate() + assert backend.model == "test-model" + + @pytest.mark.regression + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_validate_fallback_to_text_completions(self): + """Test validate method fallback to text completions.""" + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + # Mock health and models endpoints to fail + def mock_get(*args, **kwargs): + raise httpx.HTTPStatusError("Error", request=Mock(), response=Mock()) + + # Mock text_completions to succeed + async def mock_text_completions(*args, **kwargs): + yield "test", UsageStats() + + with ( + patch.object(backend._async_client, "get", side_effect=mock_get), + patch.object( + backend, "text_completions", side_effect=mock_text_completions + ), + ): + await backend.validate() # Should not raise + + @pytest.mark.regression + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_validate_failure(self): + """Test validate method when all validation methods fail.""" + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + def mock_fail(*args, **kwargs): + raise httpx.HTTPStatusError("Error", request=Mock(), response=Mock()) + + def mock_http_error(*args, **kwargs): + raise httpx.HTTPStatusError("Error", request=Mock(), response=Mock()) + + with ( + patch.object(backend._async_client, "get", side_effect=mock_http_error), + patch.object(backend, "text_completions", side_effect=mock_http_error), + pytest.raises(RuntimeError, match="Backend validation failed"), + ): + await backend.validate() + + @pytest.mark.sanity + def test_get_headers(self): + """Test _get_headers method.""" + backend = OpenAIHTTPBackend( + target="http://test", api_key="test-key", headers={"Custom": "value"} + ) + + headers = backend._get_headers() + + expected = { + "Content-Type": "application/json", + "Authorization": "Bearer test-key", + "Custom": "value", + } + assert headers == expected + + @pytest.mark.sanity + def test_get_params(self): + """Test _get_params method.""" + extra_query = { + "general": "value", + "text_completions": {"specific": "text"}, + "chat_completions": {"specific": "chat"}, + } + + backend = OpenAIHTTPBackend(target="http://test", extra_query=extra_query) + + # Test endpoint-specific params + text_params = backend._get_params("text_completions") + assert text_params == {"specific": "text"} + + # Test fallback to general params + other_params = backend._get_params("other") + assert other_params == extra_query + + @pytest.mark.regression + def test_get_chat_messages_string(self): + """Test _get_chat_messages with string content.""" + backend = OpenAIHTTPBackend(target="http://test") + + messages = backend._get_chat_messages("Hello world") + + expected = [{"role": "user", "content": "Hello world"}] + assert messages == expected + + @pytest.mark.regression + def test_get_chat_messages_list(self): + """Test _get_chat_messages with list content.""" + backend = OpenAIHTTPBackend(target="http://test") + + content = [ + "Hello", + {"type": "text", "text": "world"}, + {"role": "assistant", "content": "existing message"}, + ] + + messages = backend._get_chat_messages(content) + + expected = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": "world"}, + {"role": "assistant", "content": "existing message"}, + ], + } + ] + assert messages == expected + + @pytest.mark.regression + def test_get_chat_messages_invalid(self): + """Test _get_chat_messages with invalid content.""" + backend = OpenAIHTTPBackend(target="http://test") + + with pytest.raises(ValueError, match="Unsupported content type"): + backend._get_chat_messages(123) + + with pytest.raises(ValueError, match="Unsupported content item type"): + backend._get_chat_messages([123]) + + @pytest.mark.regression + def test_get_chat_message_media_item_image(self): + """Test _get_chat_message_media_item with PIL Image.""" + backend = OpenAIHTTPBackend(target="http://test") + + # Create a mock PIL Image + mock_image = Mock(spec=Image.Image) + mock_image.tobytes.return_value = b"fake_image_data" + + result = backend._get_chat_message_media_item(mock_image) + + expected_data = base64.b64encode(b"fake_image_data").decode("utf-8") + expected = { + "type": "image", + "image": {"url": f"data:image/jpeg;base64,{expected_data}"}, + } + assert result == expected + + @pytest.mark.regression + def test_get_chat_message_media_item_path(self): + """Test _get_chat_message_media_item with file paths.""" + backend = OpenAIHTTPBackend(target="http://test") + + # Test unsupported file type + unsupported_path = Path("test.txt") + with pytest.raises(ValueError, match="Unsupported file type: .txt"): + backend._get_chat_message_media_item(unsupported_path) + + @pytest.mark.regression + def test_get_body(self): + """Test _get_body method.""" + extra_body = {"general": "value", "text_completions": {"temperature": 0.5}} + + backend = OpenAIHTTPBackend( + target="http://test", + model="test-model", + max_output_tokens=1000, + extra_body=extra_body, + ) + + request_kwargs = {"temperature": 0.7} + + body = backend._get_body( + endpoint_type="text_completions", + request_kwargs=request_kwargs, + max_output_tokens=500, + prompt="test", + ) + + # Check that max_tokens settings are applied + assert body["temperature"] == 0.7 # request_kwargs override extra_body + assert body["model"] == "test-model" + assert body["max_tokens"] == 500 + assert body["max_completion_tokens"] == 500 + assert body["ignore_eos"] is True + assert body["prompt"] == "test" + # stop: None is filtered out by the None filter + assert "stop" not in body + + @pytest.mark.regression + def test_get_completions_text_content(self): + """Test _get_completions_text_content method.""" + backend = OpenAIHTTPBackend(target="http://test") + + # Test with text field + data1 = {"choices": [{"text": "generated text"}]} + result1 = backend._get_completions_text_content(data1) + assert result1 == "generated text" + + # Test with delta content field + data2 = {"choices": [{"delta": {"content": "delta text"}}]} + result2 = backend._get_completions_text_content(data2) + assert result2 == "delta text" + + # Test with no choices + data3: dict[str, list] = {"choices": []} + result3 = backend._get_completions_text_content(data3) + assert result3 is None + + # Test with no choices key + data4: dict[str, str] = {} + result4 = backend._get_completions_text_content(data4) + assert result4 is None + + @pytest.mark.regression + def test_get_completions_usage_stats(self): + """Test _get_completions_usage_stats method.""" + backend = OpenAIHTTPBackend(target="http://test") + + # Test with usage data + data1 = {"usage": {"prompt_tokens": 50, "completion_tokens": 100}} + result1 = backend._get_completions_usage_stats(data1) + assert isinstance(result1, UsageStats) + assert result1.prompt_tokens == 50 + assert result1.output_tokens == 100 + + # Test with no usage data + data2: dict[str, str] = {} + result2 = backend._get_completions_usage_stats(data2) + assert result2 is None + + @pytest.mark.regression + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_resolve_not_implemented_history(self): + """Test resolve method raises error for conversation history.""" + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + request = GenerationRequest(content="test") + request_info = ScheduledRequestInfo( + request_id="test-id", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + history = [(request, GenerationResponse(request_id="test", request_args={}))] + + with pytest.raises(NotImplementedError, match="Multi-turn requests"): + async for _ in backend.resolve(request, request_info, history): + pass + + @pytest.mark.regression + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_resolve_text_completions(self): + """Test resolve method for text completions.""" + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + request = GenerationRequest( + content="test prompt", + request_type="text_completions", + params={"temperature": 0.7}, + constraints={"output_tokens": 100}, + ) + request_info = ScheduledRequestInfo( + request_id="test-id", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + + # Mock text_completions method + async def mock_text_completions(*args, **kwargs): + yield None, None # Start signal + yield "Hello", None # First token + yield " world", UsageStats(prompt_tokens=10, output_tokens=2) # Final + + with patch.object( + backend, "text_completions", side_effect=mock_text_completions + ): + responses = [] + async for response, info in backend.resolve(request, request_info): + responses.append((response, info)) + + assert len(responses) >= 2 + final_response = responses[-1][0] + assert final_response.value == "Hello world" + assert final_response.request_id == request.request_id + assert final_response.iterations == 2 + + @pytest.mark.regression + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_resolve_chat_completions(self): + """Test resolve method for chat completions.""" + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + request = GenerationRequest( + content="test message", + request_type="chat_completions", + params={"temperature": 0.5}, + ) + request_info = ScheduledRequestInfo( + request_id="test-id", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + + # Mock chat_completions method + async def mock_chat_completions(*args, **kwargs): + yield None, None # Start signal + yield "Response", UsageStats(prompt_tokens=5, output_tokens=1) + + with patch.object( + backend, "chat_completions", side_effect=mock_chat_completions + ): + responses = [] + async for response, info in backend.resolve(request, request_info): + responses.append((response, info)) + + final_response = responses[-1][0] + assert final_response.value == "Response" + assert final_response.request_id == request.request_id + + +class TestOpenAICompletions: + """Test cases for completion methods.""" + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_text_completions_not_in_process(self): + """Test text_completions when backend not started.""" + backend = OpenAIHTTPBackend(target="http://test") + + with pytest.raises(RuntimeError, match="Backend not started up"): + async for _ in backend.text_completions("test", "req-id"): + pass + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_text_completions_basic(self): + """Test basic text_completions functionality.""" + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "choices": [{"text": "Generated text"}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5}, + } + + with patch.object( + backend._async_client, "post", return_value=mock_response + ): + results = [] + async for result in backend.text_completions( + prompt="test prompt", request_id="req-123", stream_response=False + ): + results.append(result) + + assert len(results) == 2 + assert results[0] == (None, None) # Initial yield + assert results[1][0] == "Generated text" + assert isinstance(results[1][1], UsageStats) + assert results[1][1].prompt_tokens == 10 + assert results[1][1].output_tokens == 5 + finally: + await backend.process_shutdown() + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_chat_completions_not_in_process(self): + """Test chat_completions when backend not started.""" + backend = OpenAIHTTPBackend(target="http://test") + + with pytest.raises(RuntimeError, match="Backend not started up"): + async for _ in backend.chat_completions("test"): + pass + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_chat_completions_basic(self): + """Test basic chat_completions functionality.""" + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "choices": [{"delta": {"content": "Chat response"}}], + "usage": {"prompt_tokens": 8, "completion_tokens": 3}, + } + + with patch.object( + backend._async_client, "post", return_value=mock_response + ): + results = [] + async for result in backend.chat_completions( + content="Hello", request_id="req-456", stream_response=False + ): + results.append(result) + + assert len(results) == 2 + assert results[0] == (None, None) + assert results[1][0] == "Chat response" + assert isinstance(results[1][1], UsageStats) + assert results[1][1].prompt_tokens == 8 + assert results[1][1].output_tokens == 3 + finally: + await backend.process_shutdown() + + @pytest.mark.sanity + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_text_completions_with_parameters(self): + """Test text_completions with additional parameters.""" + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "choices": [{"text": "response"}], + "usage": {"prompt_tokens": 5, "completion_tokens": 1}, + } + + with patch.object( + backend._async_client, "post", return_value=mock_response + ) as mock_post: + async for _ in backend.text_completions( + prompt="test", + request_id="req-123", + output_token_count=50, + temperature=0.7, + stream_response=False, + ): + pass + + # Check that the request body contains expected parameters + call_args = mock_post.call_args + body = call_args[1]["json"] + assert body["max_tokens"] == 50 + assert body["temperature"] == 0.7 + assert body["model"] == "gpt-4" + finally: + await backend.process_shutdown() + + @pytest.mark.sanity + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_chat_completions_content_formatting(self): + """Test chat_completions content formatting.""" + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + mock_response = Mock() + mock_response.raise_for_status = Mock() + mock_response.json.return_value = { + "choices": [{"delta": {"content": "response"}}] + } + + with patch.object( + backend._async_client, "post", return_value=mock_response + ) as mock_post: + async for _ in backend.chat_completions( + content="Hello world", stream_response=False + ): + pass + + call_args = mock_post.call_args + body = call_args[1]["json"] + expected_messages = [{"role": "user", "content": "Hello world"}] + assert body["messages"] == expected_messages + finally: + await backend.process_shutdown() + + @pytest.mark.regression + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_validate_no_models_available(self): + """Test validate method when no models are available.""" + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + try: + # Mock endpoints to fail, then available_models to return empty list + def mock_get_fail(*args, **kwargs): + raise httpx.HTTPStatusError("Error", request=Mock(), response=Mock()) + + with ( + patch.object(backend._async_client, "get", side_effect=mock_get_fail), + patch.object(backend, "available_models", return_value=[]), + patch.object(backend, "text_completions", side_effect=mock_get_fail), + pytest.raises( + RuntimeError, + match="No model available and could not set a default model", + ), + ): + await backend.validate() + finally: + await backend.process_shutdown() + + @pytest.mark.sanity + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_text_completions_streaming(self): + """Test text_completions with streaming enabled.""" + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + # Mock streaming response + mock_stream = Mock() + mock_stream.raise_for_status = Mock() + + async def mock_aiter_lines(): + lines = [ + 'data: {"choices":[{"text":"Hello"}], "usage":{"prompt_tokens":5,"completion_tokens":1}}', # noqa: E501 + 'data: {"choices":[{"text":" world"}], "usage":{"prompt_tokens":5,"completion_tokens":2}}', # noqa: E501 + 'data: {"choices":[{"text":"!"}], "usage":{"prompt_tokens":5,"completion_tokens":3}}', # noqa: E501 + "data: [DONE]", + ] + for line in lines: + yield line + + mock_stream.aiter_lines = mock_aiter_lines + + mock_client_stream = AsyncMock() + mock_client_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_client_stream.__aexit__ = AsyncMock(return_value=None) + + with patch.object( + backend._async_client, "stream", return_value=mock_client_stream + ): + results = [] + async for result in backend.text_completions( + prompt="test prompt", request_id="req-123", stream_response=True + ): + results.append(result) + + # Should get initial None, then tokens, then final with usage + assert len(results) >= 3 + assert results[0] == (None, None) # Initial yield + assert all( + isinstance(result[0], str) for result in results[1:] + ) # Has text content + assert all( + isinstance(result[1], UsageStats) for result in results[1:] + ) # Has usage stats + assert all( + result[1].output_tokens == i for i, result in enumerate(results[1:], 1) + ) + finally: + await backend.process_shutdown() + + @pytest.mark.sanity + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_chat_completions_streaming(self): + """Test chat_completions with streaming enabled.""" + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + # Mock streaming response + mock_stream = Mock() + mock_stream.raise_for_status = Mock() + + async def mock_aiter_lines(): + lines = [ + 'data: {"choices":[{"delta":{"content":"Hi"}}]}', + 'data: {"choices":[{"delta":{"content":" there"}}]}', + 'data: {"choices":[{"delta":{"content":"!"}}]}', + 'data: {"usage":{"prompt_tokens":3,"completion_tokens":3}}', + "data: [DONE]", + ] + for line in lines: + yield line + + mock_stream.aiter_lines = mock_aiter_lines + + mock_client_stream = AsyncMock() + mock_client_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_client_stream.__aexit__ = AsyncMock(return_value=None) + + with patch.object( + backend._async_client, "stream", return_value=mock_client_stream + ): + results = [] + async for result in backend.chat_completions( + content="Hello", request_id="req-456", stream_response=True + ): + results.append(result) + + # Should get initial None, then deltas, then final with usage + assert len(results) >= 3 + assert results[0] == (None, None) # Initial yield + assert any(result[0] for result in results if result[0]) # Has content + assert any(result[1] for result in results if result[1]) # Has usage stats + finally: + await backend.process_shutdown() + + @pytest.mark.regression + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_streaming_response_edge_cases(self): + """Test streaming response edge cases for line processing.""" + backend = OpenAIHTTPBackend(target="http://test", model="gpt-4") + await backend.process_startup() + + try: + # Mock streaming response with edge cases + mock_stream = Mock() + mock_stream.raise_for_status = Mock() + + async def mock_aiter_lines(): + lines = [ + "", # Empty line + " ", # Whitespace only + "not data line", # Line without data prefix + 'data: {"choices":[{"text":"Hello"}]}', # Valid data + "data: [DONE]", # End marker + ] + for line in lines: + yield line + + mock_stream.aiter_lines = mock_aiter_lines + + mock_client_stream = AsyncMock() + mock_client_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_client_stream.__aexit__ = AsyncMock(return_value=None) + + with patch.object( + backend._async_client, "stream", return_value=mock_client_stream + ): + results = [] + async for result in backend.text_completions( + prompt="test", request_id="req-123", stream_response=True + ): + results.append(result) + + # Should get initial None and the valid response + assert len(results) == 2 + assert results[0] == (None, None) + assert results[1][0] == "Hello" + finally: + await backend.process_shutdown() + + @pytest.mark.sanity + def test_get_chat_message_media_item_jpeg_file(self): + """Test _get_chat_message_media_item with JPEG file path.""" + backend = OpenAIHTTPBackend(target="http://test") + + # Create a mock Path object for JPEG file + mock_jpeg_path = Mock(spec=Path) + mock_jpeg_path.suffix.lower.return_value = ".jpg" + + # Mock Image.open to return a mock image + mock_image = Mock(spec=Image.Image) + mock_image.tobytes.return_value = b"fake_jpeg_data" + + with patch("guidellm.backend.openai.Image.open", return_value=mock_image): + result = backend._get_chat_message_media_item(mock_jpeg_path) + + expected_data = base64.b64encode(b"fake_jpeg_data").decode("utf-8") + expected = { + "type": "image", + "image": {"url": f"data:image/jpeg;base64,{expected_data}"}, + } + assert result == expected + + @pytest.mark.sanity + def test_get_chat_message_media_item_wav_file(self): + """Test _get_chat_message_media_item with WAV file path.""" + backend = OpenAIHTTPBackend(target="http://test") + + # Create a mock Path object for WAV file + mock_wav_path = Mock(spec=Path) + mock_wav_path.suffix.lower.return_value = ".wav" + mock_wav_path.read_bytes.return_value = b"fake_wav_data" + + result = backend._get_chat_message_media_item(mock_wav_path) + + expected_data = base64.b64encode(b"fake_wav_data").decode("utf-8") + expected = { + "type": "input_audio", + "input_audio": {"data": expected_data, "format": "wav"}, + } + assert result == expected + + @pytest.mark.sanity + def test_get_chat_messages_with_pil_image(self): + """Test _get_chat_messages with PIL Image in content list.""" + backend = OpenAIHTTPBackend(target="http://test") + + # Create a mock PIL Image + mock_image = Mock(spec=Image.Image) + mock_image.tobytes.return_value = b"fake_image_bytes" + + content = ["Hello", mock_image, "world"] + + result = backend._get_chat_messages(content) + + # Should have one user message with mixed content + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 3 + + # Check text items + assert result[0]["content"][0] == {"type": "text", "text": "Hello"} + assert result[0]["content"][2] == {"type": "text", "text": "world"} + + # Check image item + image_item = result[0]["content"][1] + assert image_item["type"] == "image" + assert "data:image/jpeg;base64," in image_item["image"]["url"] + + @pytest.mark.regression + @pytest.mark.asyncio + @async_timeout(10.0) + async def test_resolve_timing_edge_cases(self): + """Test resolve method timing edge cases.""" + backend = OpenAIHTTPBackend(target="http://test") + await backend.process_startup() + + try: + request = GenerationRequest( + content="test prompt", + request_type="text_completions", + constraints={"output_tokens": 50}, + ) + request_info = ScheduledRequestInfo( + request_id="test-id", + status="pending", + scheduler_node_id=1, + scheduler_process_id=1, + scheduler_start_time=123.0, + request_timings=GenerationRequestTimings(), + ) + + # Mock text_completions to test timing edge cases + async def mock_text_completions(*args, **kwargs): + yield None, None # Initial yield - tests line 343 + yield "token1", None # First token + yield "token2", UsageStats(prompt_tokens=10, output_tokens=2) # Final + + with patch.object( + backend, "text_completions", side_effect=mock_text_completions + ): + responses = [] + async for response, info in backend.resolve(request, request_info): + responses.append((response, info)) + + # Check that timing was properly set + final_response, final_info = responses[-1] + assert final_info.request_timings.request_start is not None + assert final_info.request_timings.first_iteration is not None + assert final_info.request_timings.last_iteration is not None + assert final_info.request_timings.request_end is not None + assert final_response.delta is None # Tests line 362 + + finally: + await backend.process_shutdown() diff --git a/tests/unit/backend/test_openai_backend_custom_configs.py b/tests/unit/backend/test_openai_backend_custom_configs.py deleted file mode 100644 index 5855152d..00000000 --- a/tests/unit/backend/test_openai_backend_custom_configs.py +++ /dev/null @@ -1,88 +0,0 @@ -import pytest - -from guidellm.backend import OpenAIHTTPBackend -from guidellm.settings import settings - - -@pytest.mark.smoke -def test_openai_http_backend_default_initialization(): - backend = OpenAIHTTPBackend() - assert backend.verify is True - - -@pytest.mark.smoke -def test_openai_http_backend_custom_ssl_verification(): - backend = OpenAIHTTPBackend(verify=False) - assert backend.verify is False - - -@pytest.mark.smoke -def test_openai_http_backend_custom_headers_override(): - # Set a default api_key, which would normally create an Authorization header - settings.openai.api_key = "default-api-key" - - # Set custom headers that override the default Authorization and add a new header - openshift_token = "Bearer sha256~xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" - override_headers = { - "Authorization": openshift_token, - "Custom-Header": "Custom-Value", - } - - # Initialize the backend - backend = OpenAIHTTPBackend(headers=override_headers) - - # Check that the override headers are used - assert backend.headers["Authorization"] == openshift_token - assert backend.headers["Custom-Header"] == "Custom-Value" - assert len(backend.headers) == 2 - - # Reset the settings - settings.openai.api_key = None - settings.openai.headers = None - - -@pytest.mark.smoke -def test_openai_http_backend_kwarg_headers_override_settings(): - # Set headers via settings (simulating environment variables) - settings.openai.headers = {"Authorization": "Bearer settings-token"} - - # Set different headers via kwargs (simulating --backend-args) - override_headers = { - "Authorization": "Bearer kwargs-token", - "Custom-Header": "Custom-Value", - } - - # Initialize the backend with kwargs - backend = OpenAIHTTPBackend(headers=override_headers) - - # Check that the kwargs headers took precedence - assert backend.headers["Authorization"] == "Bearer kwargs-token" - assert backend.headers["Custom-Header"] == "Custom-Value" - assert len(backend.headers) == 2 - - # Reset the settings - settings.openai.headers = None - - -@pytest.mark.smoke -def test_openai_http_backend_remove_header_with_none(): - # Set a default api_key, which would normally create an Authorization header - settings.openai.api_key = "default-api-key" - - # Set a custom header and explicitly set Authorization to None to remove it - override_headers = { - "Authorization": None, - "Custom-Header": "Custom-Value", - } - - # Initialize the backend - backend = OpenAIHTTPBackend(headers=override_headers) - - # Check that the Authorization header is removed and the custom header is present - assert "Authorization" not in backend.headers - assert backend.headers["Custom-Header"] == "Custom-Value" - assert len(backend.headers) == 1 - - # Reset the settings - settings.openai.api_key = None - settings.openai.headers = None diff --git a/tests/unit/backend/test_response.py b/tests/unit/backend/test_response.py deleted file mode 100644 index b3dc99c9..00000000 --- a/tests/unit/backend/test_response.py +++ /dev/null @@ -1,192 +0,0 @@ -from typing import get_args - -import pytest - -from guidellm.backend import ( - RequestArgs, - ResponseSummary, - StreamingResponseType, - StreamingTextResponse, -) - - -@pytest.mark.smoke -def test_streaming_response_types(): - valid_types = get_args(StreamingResponseType) - assert valid_types == ("start", "iter") - - -@pytest.mark.smoke -def test_streaming_text_response_default_initilization(): - response = StreamingTextResponse( - type_="start", - value="", - start_time=0.0, - first_iter_time=None, - iter_count=0, - delta="", - time=0.0, - ) - assert response.request_id is None - - -@pytest.mark.smoke -def test_streaming_text_response_initialization(): - response = StreamingTextResponse( - type_="start", - value="Hello, world!", - start_time=0.0, - first_iter_time=0.0, - iter_count=1, - delta="Hello, world!", - time=1.0, - request_id="123", - ) - assert response.type_ == "start" - assert response.value == "Hello, world!" - assert response.start_time == 0.0 - assert response.first_iter_time == 0.0 - assert response.iter_count == 1 - assert response.delta == "Hello, world!" - assert response.time == 1.0 - assert response.request_id == "123" - - -@pytest.mark.smoke -def test_streaming_text_response_marshalling(): - response = StreamingTextResponse( - type_="start", - value="Hello, world!", - start_time=0.0, - first_iter_time=0.0, - iter_count=0, - delta="Hello, world!", - time=1.0, - request_id="123", - ) - serialized = response.model_dump() - deserialized = StreamingTextResponse.model_validate(serialized) - - for key, value in vars(response).items(): - assert getattr(deserialized, key) == value - - -@pytest.mark.smoke -def test_request_args_default_initialization(): - args = RequestArgs( - target="http://example.com", - headers={}, - params={}, - payload={}, - ) - assert args.timeout is None - assert args.http2 is None - assert args.follow_redirects is None - - -@pytest.mark.smoke -def test_request_args_initialization(): - args = RequestArgs( - target="http://example.com", - headers={ - "Authorization": "Bearer token", - }, - params={}, - payload={ - "query": "Hello, world!", - }, - timeout=10.0, - http2=True, - follow_redirects=True, - ) - assert args.target == "http://example.com" - assert args.headers == {"Authorization": "Bearer token"} - assert args.payload == {"query": "Hello, world!"} - assert args.timeout == 10.0 - assert args.http2 is True - assert args.follow_redirects is True - - -@pytest.mark.smoke -def test_response_args_marshalling(): - args = RequestArgs( - target="http://example.com", - headers={"Authorization": "Bearer token"}, - params={}, - payload={"query": "Hello, world!"}, - timeout=10.0, - http2=True, - ) - serialized = args.model_dump() - deserialized = RequestArgs.model_validate(serialized) - - for key, value in vars(args).items(): - assert getattr(deserialized, key) == value - - -@pytest.mark.smoke -def test_response_summary_default_initialization(): - summary = ResponseSummary( - value="Hello, world!", - request_args=RequestArgs( - target="http://example.com", - headers={}, - params={}, - payload={}, - ), - start_time=0.0, - end_time=0.0, - first_iter_time=None, - last_iter_time=None, - ) - assert summary.value == "Hello, world!" - assert summary.request_args.target == "http://example.com" - assert summary.request_args.headers == {} - assert summary.request_args.payload == {} - assert summary.start_time == 0.0 - assert summary.end_time == 0.0 - assert summary.first_iter_time is None - assert summary.last_iter_time is None - assert summary.iterations == 0 - assert summary.request_prompt_tokens is None - assert summary.request_output_tokens is None - assert summary.response_prompt_tokens is None - assert summary.response_output_tokens is None - assert summary.request_id is None - - -@pytest.mark.smoke -def test_response_summary_initialization(): - summary = ResponseSummary( - value="Hello, world!", - request_args=RequestArgs( - target="http://example.com", - headers={}, - params={}, - payload={}, - ), - start_time=1.0, - end_time=2.0, - iterations=3, - first_iter_time=1.0, - last_iter_time=2.0, - request_prompt_tokens=5, - request_output_tokens=10, - response_prompt_tokens=5, - response_output_tokens=10, - request_id="123", - ) - assert summary.value == "Hello, world!" - assert summary.request_args.target == "http://example.com" - assert summary.request_args.headers == {} - assert summary.request_args.payload == {} - assert summary.start_time == 1.0 - assert summary.end_time == 2.0 - assert summary.iterations == 3 - assert summary.first_iter_time == 1.0 - assert summary.last_iter_time == 2.0 - assert summary.request_prompt_tokens == 5 - assert summary.request_output_tokens == 10 - assert summary.response_prompt_tokens == 5 - assert summary.response_output_tokens == 10 - assert summary.request_id == "123" diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a0457b6f..92bb89e1 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -7,7 +7,7 @@ import pytest import respx -from guidellm.backend import ResponseSummary, StreamingTextResponse +from guidellm.backends import ResponseSummary, StreamingTextResponse from .mock_backend import MockBackend diff --git a/tests/unit/mock_backend.py b/tests/unit/mock_backend.py index 27bfe382..6080a9d1 100644 --- a/tests/unit/mock_backend.py +++ b/tests/unit/mock_backend.py @@ -8,7 +8,7 @@ from lorem.text import TextLorem # type: ignore from PIL import Image -from guidellm.backend import ( +from guidellm.backends import ( Backend, RequestArgs, ResponseSummary, diff --git a/tests/unit/utils/test_encoding.py b/tests/unit/utils/test_encoding.py index da1f63ee..cc4600cf 100644 --- a/tests/unit/utils/test_encoding.py +++ b/tests/unit/utils/test_encoding.py @@ -6,7 +6,7 @@ import pytest from pydantic import BaseModel, Field -from guidellm.backend.objects import ( +from guidellm.backends.objects import ( GenerationRequest, GenerationResponse, ) diff --git a/tests/unit/utils/test_messaging.py b/tests/unit/utils/test_messaging.py index d6627e88..d6b3283d 100644 --- a/tests/unit/utils/test_messaging.py +++ b/tests/unit/utils/test_messaging.py @@ -10,7 +10,7 @@ import pytest from pydantic import BaseModel -from guidellm.backend import ( +from guidellm.backends import ( GenerationRequest, GenerationResponse, )