diff --git a/src/guidellm/mock_server/__init__.py b/src/guidellm/mock_server/__init__.py new file mode 100644 index 00000000..f76e98fb --- /dev/null +++ b/src/guidellm/mock_server/__init__.py @@ -0,0 +1,8 @@ +""" +GuideLLM Mock Server for OpenAI and vLLM API compatibility. +""" + +from .config import MockServerConfig +from .server import MockServer + +__all__ = ["MockServer", "MockServerConfig"] diff --git a/src/guidellm/mock_server/config.py b/src/guidellm/mock_server/config.py new file mode 100644 index 00000000..27d1d742 --- /dev/null +++ b/src/guidellm/mock_server/config.py @@ -0,0 +1,84 @@ +""" +Configuration settings for the mock server component. + +Provides centralized configuration management for mock server behavior including +network binding, model identification, response timing characteristics, and token +generation parameters. Supports environment variable configuration for deployment +flexibility with automatic validation through Pydantic settings. +""" + +from __future__ import annotations + +from pydantic import Field +from pydantic_settings import BaseSettings + +__all__ = ["MockServerConfig"] + + +class MockServerConfig(BaseSettings): + """ + Configuration settings for mock server behavior and deployment. + + Centralizes all configurable parameters for mock server operation including + network settings, model identification, response timing characteristics, and + token generation behavior. Environment variables with GUIDELLM_MOCK_SERVER_ + prefix override default values for deployment flexibility. + + Example: + :: + config = MockServerConfig(host="0.0.0.0", port=8080, model="custom-model") + # Use with environment variables: + # GUIDELLM_MOCK_SERVER_HOST=127.0.0.1 GUIDELLM_MOCK_SERVER_PORT=9000 + """ + + host: str = Field( + default="127.0.0.1", description="Host address to bind the server to" + ) + port: int = Field(default=8000, description="Port number to bind the server to") + workers: int = Field(default=1, description="Number of worker processes to spawn") + model: str = Field( + default="llama-3.1-8b-instruct", + description="Model name to present in API responses", + ) + processor: str | None = Field( + default=None, + description=( + "Processor type to use for token stats, tokenize, and detokenize. " + "If None, a mock one is created." + ), + ) + request_latency: float = Field( + default=3.0, + description="Base request latency in seconds for non-streaming responses", + ) + request_latency_std: float = Field( + default=0.0, + description="Standard deviation for request latency variation", + ) + ttft_ms: float = Field( + default=150.0, + description="Time to first token in milliseconds for streaming responses", + ) + ttft_ms_std: float = Field( + default=0.0, + description="Standard deviation for time to first token variation", + ) + itl_ms: float = Field( + default=10.0, + description="Inter-token latency in milliseconds for streaming responses", + ) + itl_ms_std: float = Field( + default=0.0, + description="Standard deviation for inter-token latency variation", + ) + output_tokens: int = Field( + default=128, description="Number of output tokens to generate in responses" + ) + output_tokens_std: float = Field( + default=0.0, + description="Standard deviation for output token count variation", + ) + + class Config: + env_prefix = "GUIDELLM_MOCK_SERVER_" + case_sensitive = False diff --git a/src/guidellm/mock_server/handlers/__init__.py b/src/guidellm/mock_server/handlers/__init__.py new file mode 100644 index 00000000..7dbc209f --- /dev/null +++ b/src/guidellm/mock_server/handlers/__init__.py @@ -0,0 +1,17 @@ +""" +HTTP request handlers for the GuideLLM mock server. + +This module exposes request handlers that implement OpenAI-compatible API endpoints +for the mock server. The handlers provide realistic LLM simulation capabilities +including chat completions, legacy completions, and tokenization services with +configurable timing characteristics, token counting, and proper error handling to +support comprehensive benchmarking and testing scenarios. +""" + +from __future__ import annotations + +from .chat_completions import ChatCompletionsHandler +from .completions import CompletionsHandler +from .tokenizer import TokenizerHandler + +__all__ = ["ChatCompletionsHandler", "CompletionsHandler", "TokenizerHandler"] diff --git a/src/guidellm/mock_server/handlers/chat_completions.py b/src/guidellm/mock_server/handlers/chat_completions.py new file mode 100644 index 00000000..de2781b0 --- /dev/null +++ b/src/guidellm/mock_server/handlers/chat_completions.py @@ -0,0 +1,280 @@ +""" +OpenAI Chat Completions API endpoint handler for the mock server. + +Provides a complete implementation of the /v1/chat/completions endpoint that simulates +realistic LLM behavior with configurable timing characteristics. Supports both streaming +and non-streaming responses with proper token counting, latency simulation including +TTFT (Time To First Token) and ITL (Inter-Token Latency), and OpenAI-compatible error +handling for comprehensive benchmarking scenarios. +""" + +from __future__ import annotations + +import asyncio +import json +import math +import time +import uuid + +from pydantic import ValidationError +from sanic import response +from sanic.request import Request +from sanic.response import HTTPResponse, ResponseStream +from transformers import PreTrainedTokenizer + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.models import ( + ChatCompletionChoice, + ChatCompletionsRequest, + ChatCompletionsResponse, + ChatMessage, + ErrorDetail, + ErrorResponse, + Usage, +) +from guidellm.mock_server.utils import ( + MockTokenizer, + create_fake_text, + create_fake_tokens_str, + sample_number, + times_generator, +) + +__all__ = ["ChatCompletionsHandler"] + + +class ChatCompletionsHandler: + """ + Handles OpenAI Chat Completions API requests with realistic LLM simulation. + + Implements the /v1/chat/completions endpoint behavior including request validation, + response generation, and timing simulation. Supports both streaming and + non-streaming modes with configurable latency characteristics for comprehensive + benchmarking. Uses either a mock tokenizer or a real tokenizer for accurate token + counting and realistic text generation. + + Example: + :: + config = MockServerConfig(ttft_ms=100, itl_ms=50) + handler = ChatCompletionsHandler(config) + response = await handler.handle(request) + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the Chat Completions handler with server configuration. + + :param config: Mock server configuration containing timing and behavior settings + """ + self.config = config + self.tokenizer = ( + MockTokenizer() + if config.processor is None + else PreTrainedTokenizer.from_pretrained(config.processor) + ) + + async def handle(self, request: Request) -> HTTPResponse: + """ + Process incoming chat completion requests with validation and routing. + + Validates the request payload, handles errors gracefully, and routes to + appropriate streaming or non-streaming response handlers based on the + request configuration. + + :param request: Sanic HTTP request containing chat completion parameters + :return: HTTP response with completion data or error information + :raises ValidationError: When request payload fails validation + :raises JSONDecodeError: When request contains invalid JSON + """ + try: + # Parse and validate request + req_data = ChatCompletionsRequest(**request.json) + except ValidationError as exc: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(exc)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (json.JSONDecodeError, TypeError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + # Handle streaming vs non-streaming + if req_data.stream: + return await self._handle_stream(req_data) + else: + return await self._handle_non_stream(req_data) + + async def _handle_non_stream(self, req: ChatCompletionsRequest) -> HTTPResponse: + """ + Generate complete non-streaming chat completion response. + + Simulates realistic LLM behavior with TTFT and ITL delays, generates + appropriate token counts, and returns a complete response with usage + statistics and generated content. + + :param req: Validated chat completion request parameters + :return: Complete HTTP response with generated completion data + """ + # TTFT delay + await asyncio.sleep( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + + # Token counts + prompt_text = self.tokenizer.apply_chat_template(req.messages) + prompt_tokens = len(self.tokenizer(prompt_text)) + max_tokens = req.max_completion_tokens or req.max_tokens or math.inf + completion_tokens_count = min( + sample_number(self.config.output_tokens, self.config.output_tokens_std), + max_tokens, + ) + + # ITL delay + itl_delay = 0.0 + delays_iter = iter(times_generator(self.config.itl_ms, self.config.itl_ms_std)) + for _ in range(int(completion_tokens_count) - 1): + itl_delay += next(delays_iter) + await asyncio.sleep(itl_delay / 1000.0) + + # Response + chat_response = ChatCompletionsResponse( + id=f"chatcmpl-{uuid.uuid4().hex[:29]}", + model=req.model, + choices=[ + ChatCompletionChoice( + index=0, + message=ChatMessage( + role="assistant", + content=create_fake_text( + int(completion_tokens_count), self.tokenizer + ), + ), + finish_reason="stop", + ) + ], + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=int(completion_tokens_count), + ), + system_fingerprint=f"fp_{uuid.uuid4().hex[:10]}", + ) + + return response.json(chat_response.model_dump()) + + async def _handle_stream(self, req: ChatCompletionsRequest) -> HTTPResponse: + """ + Generate streaming chat completion response with real-time token delivery. + + Creates a streaming response that delivers tokens incrementally with + realistic timing delays. Supports optional usage statistics in the final + stream chunk when requested via stream_options. + + :param req: Validated chat completion request with streaming enabled + :return: Streaming HTTP response delivering tokens with proper timing + """ + + async def generate_stream(stream_response): + completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" + + # TTFT delay + await asyncio.sleep( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + + # Token counts + prompt_text = self.tokenizer.apply_chat_template(req.messages) + prompt_tokens = len(self.tokenizer(prompt_text)) + max_tokens = req.max_completion_tokens or req.max_tokens or math.inf + completion_tokens_count = int( + min( + sample_number( + self.config.output_tokens, self.config.output_tokens_std + ), + max_tokens, + ) + ) + + # Send tokens + tokens = create_fake_tokens_str(completion_tokens_count, self.tokenizer) + delays_iter = iter( + times_generator(self.config.itl_ms, self.config.itl_ms_std) + ) + + for index, token in enumerate(tokens): + if index > 0: + itl_delay = next(delays_iter) + await asyncio.sleep(itl_delay / 1000.0) + + chunk_data = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": req.model, + "choices": [ + { + "index": 0, + "delta": {"content": token}, + "finish_reason": None, + } + ], + } + await stream_response.write(f"data: {json.dumps(chunk_data)}\n\n") + + # Send final chunk with finish reason + final_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": req.model, + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop", + } + ], + } + await stream_response.write(f"data: {json.dumps(final_chunk)}\n\n") + + # Send usage if requested + if req.stream_options and req.stream_options.include_usage: + usage_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": req.model, + "choices": [], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens_count, + "total_tokens": prompt_tokens + completion_tokens_count, + }, + } + await stream_response.write(f"data: {json.dumps(usage_chunk)}\n\n") + + # End stream + await stream_response.write("data: [DONE]\n\n") + + return ResponseStream( # type: ignore[return-value] + generate_stream, + content_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/src/guidellm/mock_server/handlers/completions.py b/src/guidellm/mock_server/handlers/completions.py new file mode 100644 index 00000000..5a4fe27d --- /dev/null +++ b/src/guidellm/mock_server/handlers/completions.py @@ -0,0 +1,280 @@ +""" +Legacy OpenAI Completions API handler for the mock server. + +This module provides the CompletionsHandler class that implements the /v1/completions +endpoint for the guidellm mock server. It supports both streaming and non-streaming +completions with configurable timing parameters (TTFT, ITL) and token generation to +simulate realistic LLM behavior for benchmarking and testing purposes. +""" + +from __future__ import annotations + +import asyncio +import json +import math +import time +import uuid + +from pydantic import ValidationError +from sanic import response +from sanic.request import Request +from sanic.response import HTTPResponse, ResponseStream +from transformers import PreTrainedTokenizer + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.models import ( + CompletionChoice, + CompletionsRequest, + CompletionsResponse, + ErrorDetail, + ErrorResponse, + Usage, +) +from guidellm.mock_server.utils import ( + MockTokenizer, + create_fake_text, + create_fake_tokens_str, + sample_number, + times_generator, +) + +__all__ = ["CompletionsHandler"] + + +class CompletionsHandler: + """ + Handler for the OpenAI /v1/completions endpoint in the mock server. + + This handler simulates the legacy OpenAI completions API by processing incoming + requests and generating responses with configurable timing and token generation + patterns. It supports both streaming and non-streaming modes, applying realistic + timing delays (TTFT and ITL) to mimic actual LLM behavior for benchmarking. + + Example: + :: + config = MockServerConfig(ttft_ms=100, itl_ms=50) + handler = CompletionsHandler(config) + response = await handler.handle(sanic_request) + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the completions handler with configuration settings. + + :param config: Mock server configuration containing timing parameters + and tokenizer settings + """ + self.config = config + self.tokenizer = ( + MockTokenizer() + if config.processor is None + else PreTrainedTokenizer.from_pretrained(config.processor) + ) + + async def handle(self, request: Request) -> HTTPResponse: + """ + Process a completions request and return the appropriate response. + + Validates the incoming request, determines whether to use streaming or + non-streaming mode, and delegates to the appropriate handler method. + + :param request: Sanic request object containing the completions request data + :return: HTTP response with completion data or error information + :raises ValidationError: When request validation fails + :raises json.JSONDecodeError: When request JSON is malformed + """ + try: + # Parse and validate request + req_data = CompletionsRequest(**request.json) + except ValidationError as e: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(e)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (json.JSONDecodeError, TypeError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + # Handle streaming vs non-streaming + if req_data.stream: + return await self._handle_stream(req_data) + else: + return await self._handle_non_stream(req_data) + + async def _handle_non_stream(self, req: CompletionsRequest) -> HTTPResponse: + """ + Generate a non-streaming completion response. + + Simulates TTFT and ITL delays, generates appropriate token counts, and returns + a complete response with the generated text and usage statistics. + + :param req: Validated completions request containing prompt and parameters + :return: JSON HTTP response with completion text and usage data + :raises NotImplementedError: When batch processing is requested + """ + if isinstance(req.prompt, list): + raise NotImplementedError("Batch processing is not supported.") + + # TTFT delay + await asyncio.sleep( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + + # Token counts + prompt_tokens = len(self.tokenizer(req.prompt)) + max_tokens = req.max_tokens or math.inf + completion_tokens_count = int( + min( + sample_number(self.config.output_tokens, self.config.output_tokens_std), + max_tokens, + ) + if req.stop + else max_tokens + ) + + # ITL delay + itl_delay = 0.0 + delays_iter = iter(times_generator(self.config.itl_ms, self.config.itl_ms_std)) + for _ in range(int(completion_tokens_count) - 1): + itl_delay += next(delays_iter) + await asyncio.sleep(itl_delay / 1000.0) + + # Response + completion_response = CompletionsResponse( + id=f"cmpl-{uuid.uuid4().hex[:29]}", + model=req.model, + choices=[ + CompletionChoice( + text=create_fake_text(completion_tokens_count, self.tokenizer), + index=0, + finish_reason="stop", + ) + ], + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens_count, + ), + system_fingerprint=f"fp_{uuid.uuid4().hex[:10]}", + ) + + return response.json(completion_response.model_dump()) + + async def _handle_stream(self, req: CompletionsRequest) -> HTTPResponse: + """ + Generate a streaming completion response. + + Creates a server-sent events stream that delivers tokens incrementally with + realistic timing delays between each token. Includes usage statistics if + requested and properly terminates the stream. + + :param req: Validated completions request containing prompt and streaming + options + :return: ResponseStream object that generates server-sent events + """ + + async def generate_stream(stream_response): + completion_id = f"cmpl-{uuid.uuid4().hex[:29]}" + + # TTFT delay + await asyncio.sleep( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + + # Token counts + prompt_tokens = len(self.tokenizer(req.prompt)) + max_tokens = req.max_tokens or math.inf + completion_tokens_count = int( + min( + sample_number( + self.config.output_tokens, self.config.output_tokens_std + ), + max_tokens, + ) + if req.stop + else max_tokens + ) + + # Send tokens + tokens = create_fake_tokens_str(completion_tokens_count, self.tokenizer) + delays_iter = iter( + times_generator(self.config.itl_ms, self.config.itl_ms_std) + ) + + for index, token in enumerate(tokens): + if index > 0: + itl_delay = next(delays_iter) + await asyncio.sleep(itl_delay / 1000.0) + + chunk_data = { + "id": completion_id, + "object": "text_completion", + "created": int(time.time()), + "model": req.model, + "choices": [ + { + "text": token, + "index": index, + "finish_reason": None, + } + ], + } + await stream_response.write(f"data: {json.dumps(chunk_data)}\n\n") + + # Send final chunk with finish reason + final_chunk = { + "id": completion_id, + "object": "text_completion", + "created": int(time.time()), + "model": req.model, + "choices": [ + { + "text": "", + "index": index, + "finish_reason": "stop", + } + ], + } + await stream_response.write(f"data: {json.dumps(final_chunk)}\n\n") + + # Send usage if requested + if req.stream_options and req.stream_options.include_usage: + usage_chunk = { + "id": completion_id, + "object": "text_completion", + "created": int(time.time()), + "model": req.model, + "choices": [], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens_count, + "total_tokens": prompt_tokens + completion_tokens_count, + }, + } + await stream_response.write(f"data: {json.dumps(usage_chunk)}\n\n") + + # End stream + await stream_response.write("data: [DONE]\n\n") + + return ResponseStream( # type: ignore[return-value] + generate_stream, + content_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/src/guidellm/mock_server/handlers/tokenizer.py b/src/guidellm/mock_server/handlers/tokenizer.py new file mode 100644 index 00000000..430ac0ef --- /dev/null +++ b/src/guidellm/mock_server/handlers/tokenizer.py @@ -0,0 +1,142 @@ +""" +HTTP request handler for vLLM tokenization API endpoints in the mock server. + +This module provides the TokenizerHandler class that implements vLLM-compatible +tokenization and detokenization endpoints for testing and development purposes. +It handles text-to-token conversion, token-to-text reconstruction, request +validation, and error responses with proper HTTP status codes and JSON formatting. +""" + +from __future__ import annotations + +from pydantic import ValidationError +from sanic import response +from sanic.request import Request +from sanic.response import HTTPResponse +from transformers.tokenization_utils import PreTrainedTokenizer + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.models import ( + DetokenizeRequest, + DetokenizeResponse, + ErrorDetail, + ErrorResponse, + TokenizeRequest, + TokenizeResponse, +) +from guidellm.mock_server.utils import MockTokenizer + +__all__ = ["TokenizerHandler"] + + +class TokenizerHandler: + """ + HTTP request handler for vLLM tokenization and detokenization endpoints. + + Provides mock implementations of vLLM's tokenization API endpoints including + /tokenize for converting text to tokens and /detokenize for reconstructing + text from token sequences. Handles request validation, error responses, and + JSON serialization with proper HTTP status codes. + + Example: + :: + handler = TokenizerHandler(config) + response = await handler.tokenize(request) + response = await handler.detokenize(request) + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the tokenizer handler with configuration. + + :param config: Server configuration object containing tokenizer settings + """ + self.config = config + self.tokenizer = ( + MockTokenizer() + if config.processor is None + else PreTrainedTokenizer.from_pretrained(config.processor) + ) + + async def tokenize(self, request: Request) -> HTTPResponse: + """ + Convert input text to token IDs via the /tokenize endpoint. + + Validates the request payload, extracts text content, and returns a JSON + response containing the token sequence and count. Handles validation errors + and malformed JSON with appropriate HTTP error responses. + + :param request: Sanic HTTP request containing JSON payload with text field + :return: JSON response with tokens list and count, or error response + """ + try: + req_data = TokenizeRequest(**request.json) + except ValidationError as exc: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(exc)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (ValueError, TypeError, KeyError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + tokens = self.tokenizer.tokenize(req_data.text) + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + + return response.json( + TokenizeResponse(tokens=token_ids, count=len(token_ids)).model_dump() + ) + + async def detokenize(self, request: Request) -> HTTPResponse: + """ + Convert token IDs back to text via the /detokenize endpoint. + + Validates the request payload, extracts token sequences, and returns a JSON + response containing the reconstructed text. Handles validation errors and + malformed JSON with appropriate HTTP error responses. + + :param request: Sanic HTTP request containing JSON payload with tokens field + :return: JSON response with reconstructed text, or error response + """ + try: + req_data = DetokenizeRequest(**request.json) + except ValidationError as exc: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(exc)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (ValueError, TypeError, KeyError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + text = self.tokenizer.decode(req_data.tokens, skip_special_tokens=False) + + return response.json(DetokenizeResponse(text=text).model_dump()) diff --git a/src/guidellm/mock_server/models.py b/src/guidellm/mock_server/models.py new file mode 100644 index 00000000..cd342f7a --- /dev/null +++ b/src/guidellm/mock_server/models.py @@ -0,0 +1,510 @@ +""" +Pydantic models for OpenAI API and vLLM API request/response validation. + +This module defines comprehensive data models for validating and serializing API +requests and responses compatible with both OpenAI's API specification and vLLM's +extended parameters. It includes models for chat completions, legacy text completions, +tokenization operations, and error handling, supporting both streaming and non-streaming +responses with full type safety and validation. +""" + +from __future__ import annotations + +import time +from typing import Any, Literal + +from pydantic import BaseModel, Field + +__all__ = [ + "ChatCompletionChoice", + "ChatCompletionChunk", + "ChatCompletionsRequest", + "ChatCompletionsResponse", + "ChatMessage", + "CompletionChoice", + "CompletionsRequest", + "CompletionsResponse", + "DetokenizeRequest", + "DetokenizeResponse", + "ErrorDetail", + "ErrorResponse", + "StreamOptions", + "TokenizeRequest", + "TokenizeResponse", + "Usage", +] + + +class Usage(BaseModel): + """Token usage statistics for API requests and responses. + + Tracks the number of tokens consumed in prompts, completions, and total + usage for billing and monitoring purposes. + """ + + prompt_tokens: int = Field(description="Number of tokens in the input prompt") + completion_tokens: int = Field( + description="Number of tokens in the generated completion" + ) + total_tokens: int = Field(description="Total tokens used (prompt + completion)") + + def __init__(self, prompt_tokens: int = 0, completion_tokens: int = 0, **kwargs): + """Initialize usage statistics. + + :param prompt_tokens: Number of tokens in the input prompt + :param completion_tokens: Number of tokens in the generated completion + :param kwargs: Additional keyword arguments passed to BaseModel + """ + super().__init__( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + **kwargs, + ) + + +class StreamOptions(BaseModel): + """Configuration options for streaming API responses. + + Controls the behavior and content of streamed responses including + whether to include usage statistics in the final chunk. + """ + + include_usage: bool | None = Field( + default=None, + description="Whether to include usage statistics in streaming responses", + ) + + +class ChatMessage(BaseModel): + """A single message in a chat conversation. + + Represents one exchange in a conversational interface with role-based + content and optional metadata for advanced features. + """ + + role: Literal["system", "user", "assistant", "tool"] = Field( + description="Role of the message sender in the conversation" + ) + content: str = Field(description="Text content of the message") + name: str | None = Field( + default=None, description="Optional name identifier for the message sender" + ) + + +class ChatCompletionsRequest(BaseModel): + """Request parameters for chat completion API endpoints. + + Comprehensive model supporting both OpenAI standard parameters and vLLM + extensions for advanced generation control, guided decoding, and performance + optimization. + """ + + model: str = Field(description="Model identifier to use for generation") + messages: list[ChatMessage] = Field( + description="List of messages in the conversation" + ) + max_tokens: int | None = Field( + default=None, description="Maximum number of tokens to generate" + ) + max_completion_tokens: int | None = Field( + default=None, description="Maximum tokens in completion (OpenAI naming)" + ) + temperature: float | None = Field( + default=1.0, description="Sampling temperature for randomness control" + ) + top_p: float | None = Field(default=1.0, description="Nucleus sampling parameter") + n: int | None = Field( + default=1, description="Number of completion choices to generate" + ) + stream: bool | None = Field( + default=False, description="Whether to stream response chunks" + ) + stream_options: StreamOptions | None = Field( + default=None, description="Configuration for streaming responses" + ) + stop: str | list[str] | None = Field( + default=None, description="Stop sequences to end generation" + ) + presence_penalty: float | None = Field( + default=0.0, description="Penalty for token presence to encourage diversity" + ) + frequency_penalty: float | None = Field( + default=0.0, description="Penalty for token frequency to reduce repetition" + ) + logit_bias: dict[str, float] | None = Field( + default=None, description="Bias values for specific tokens" + ) + seed: int | None = Field( + default=None, description="Random seed for reproducible outputs" + ) + user: str | None = Field( + default=None, description="User identifier for tracking and abuse monitoring" + ) + + # vLLM extensions + use_beam_search: bool | None = Field( + default=False, description="Enable beam search for better quality" + ) + top_k: int | None = Field(default=None, description="Top-k sampling parameter") + min_p: float | None = Field( + default=None, description="Minimum probability threshold for sampling" + ) + repetition_penalty: float | None = Field( + default=None, description="Penalty for repeated tokens" + ) + length_penalty: float | None = Field( + default=1.0, description="Length penalty for sequence scoring" + ) + stop_token_ids: list[int] | None = Field( + default=None, description="Token IDs that trigger generation stop" + ) + include_stop_str_in_output: bool | None = Field( + default=False, description="Include stop sequence in output" + ) + ignore_eos: bool | None = Field( + default=False, description="Ignore end-of-sequence tokens" + ) + min_tokens: int | None = Field( + default=0, description="Minimum number of tokens to generate" + ) + skip_special_tokens: bool | None = Field( + default=True, description="Skip special tokens in output" + ) + spaces_between_special_tokens: bool | None = Field( + default=True, description="Add spaces between special tokens" + ) + truncate_prompt_tokens: int | None = Field( + default=None, description="Maximum prompt tokens before truncation" + ) + allowed_token_ids: list[int] | None = Field( + default=None, description="Restrict generation to specific token IDs" + ) + prompt_logprobs: int | None = Field( + default=None, description="Number of logprobs to return for prompt tokens" + ) + add_special_tokens: bool | None = Field( + default=True, description="Add special tokens during processing" + ) + guided_json: str | dict[str, Any] | None = Field( + default=None, description="JSON schema for guided generation" + ) + guided_regex: str | None = Field( + default=None, description="Regex pattern for guided generation" + ) + guided_choice: list[str] | None = Field( + default=None, description="List of choices for guided generation" + ) + guided_grammar: str | None = Field( + default=None, description="Grammar specification for guided generation" + ) + guided_decoding_backend: str | None = Field( + default=None, description="Backend to use for guided decoding" + ) + guided_whitespace_pattern: str | None = Field( + default=None, description="Whitespace pattern for guided generation" + ) + priority: int | None = Field( + default=0, description="Request priority for scheduling" + ) + + +class ChatCompletionChoice(BaseModel): + """A single completion choice from a chat completion response. + + Contains the generated message and metadata about why generation + stopped and the choice's position in the response. + """ + + index: int = Field(description="Index of this choice in the response") + message: ChatMessage = Field(description="Generated message content") + finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] | None = ( + Field(description="Reason why generation finished") + ) + + +class ChatCompletionsResponse(BaseModel): + """Response from chat completion API endpoints. + + Contains generated choices, usage statistics, and metadata for + non-streaming chat completion requests. + """ + + id: str = Field(description="Unique identifier for this completion") + object: Literal["chat.completion"] = Field( + default="chat.completion", description="Object type identifier" + ) + created: int = Field( + default_factory=lambda: int(time.time()), + description="Unix timestamp of creation", + ) + model: str = Field(description="Model used for generation") + choices: list[ChatCompletionChoice] = Field( + description="Generated completion choices" + ) + usage: Usage | None = Field(default=None, description="Token usage statistics") + system_fingerprint: str | None = Field( + default=None, description="System configuration fingerprint" + ) + + +class ChatCompletionChunk(BaseModel): + """A single chunk in a streamed chat completion response. + + Represents one piece of a streaming response with delta content + and optional usage statistics in the final chunk. + """ + + id: str = Field(description="Unique identifier for this completion") + object: Literal["chat.completion.chunk"] = Field( + default="chat.completion.chunk", + description="Object type identifier for streaming chunks", + ) + created: int = Field( + default_factory=lambda: int(time.time()), + description="Unix timestamp of creation", + ) + model: str = Field(description="Model used for generation") + choices: list[dict[str, Any]] = Field(description="Delta choices for streaming") + usage: Usage | None = Field( + default=None, description="Token usage statistics (typically in final chunk)" + ) + + +class CompletionsRequest(BaseModel): + """Request parameters for legacy text completion API endpoints. + + Supports the older text completion format with prompt-based input + and the same extensive parameter set as chat completions for + backward compatibility. + """ + + model: str = Field(description="Model identifier to use for generation") + prompt: str | list[str] = Field(description="Input prompt(s) for completion") + max_tokens: int | None = Field( + default=16, description="Maximum number of tokens to generate" + ) + temperature: float | None = Field( + default=1.0, description="Sampling temperature for randomness control" + ) + top_p: float | None = Field(default=1.0, description="Nucleus sampling parameter") + n: int | None = Field( + default=1, description="Number of completion choices to generate" + ) + stream: bool | None = Field( + default=False, description="Whether to stream response chunks" + ) + stream_options: StreamOptions | None = Field( + default=None, description="Configuration for streaming responses" + ) + logprobs: int | None = Field( + default=None, description="Number of logprobs to return" + ) + echo: bool | None = Field( + default=False, description="Whether to echo the prompt in output" + ) + stop: str | list[str] | None = Field( + default_factory=lambda: ["<|endoftext|>"], + description="Stop sequences to end generation", + ) + presence_penalty: float | None = Field( + default=0.0, description="Penalty for token presence to encourage diversity" + ) + frequency_penalty: float | None = Field( + default=0.0, description="Penalty for token frequency to reduce repetition" + ) + best_of: int | None = Field( + default=1, description="Number of candidates to generate and return the best" + ) + logit_bias: dict[str, float] | None = Field( + default=None, description="Bias values for specific tokens" + ) + seed: int | None = Field( + default=None, description="Random seed for reproducible outputs" + ) + suffix: str | None = Field( + default=None, description="Suffix to append after completion" + ) + user: str | None = Field( + default=None, description="User identifier for tracking and abuse monitoring" + ) + + # vLLM extensions (same as chat completions) + use_beam_search: bool | None = Field( + default=False, description="Enable beam search for better quality" + ) + top_k: int | None = Field(default=None, description="Top-k sampling parameter") + min_p: float | None = Field( + default=None, description="Minimum probability threshold for sampling" + ) + repetition_penalty: float | None = Field( + default=None, description="Penalty for repeated tokens" + ) + length_penalty: float | None = Field( + default=1.0, description="Length penalty for sequence scoring" + ) + stop_token_ids: list[int] | None = Field( + default=None, description="Token IDs that trigger generation stop" + ) + include_stop_str_in_output: bool | None = Field( + default=False, description="Include stop sequence in output" + ) + ignore_eos: bool | None = Field( + default=False, description="Ignore end-of-sequence tokens" + ) + min_tokens: int | None = Field( + default=0, description="Minimum number of tokens to generate" + ) + skip_special_tokens: bool | None = Field( + default=True, description="Skip special tokens in output" + ) + spaces_between_special_tokens: bool | None = Field( + default=True, description="Add spaces between special tokens" + ) + truncate_prompt_tokens: int | None = Field( + default=None, description="Maximum prompt tokens before truncation" + ) + allowed_token_ids: list[int] | None = Field( + default=None, description="Restrict generation to specific token IDs" + ) + prompt_logprobs: int | None = Field( + default=None, description="Number of logprobs to return for prompt tokens" + ) + add_special_tokens: bool | None = Field( + default=True, description="Add special tokens during processing" + ) + guided_json: str | dict[str, Any] | None = Field( + default=None, description="JSON schema for guided generation" + ) + guided_regex: str | None = Field( + default=None, description="Regex pattern for guided generation" + ) + guided_choice: list[str] | None = Field( + default=None, description="List of choices for guided generation" + ) + guided_grammar: str | None = Field( + default=None, description="Grammar specification for guided generation" + ) + guided_decoding_backend: str | None = Field( + default=None, description="Backend to use for guided decoding" + ) + guided_whitespace_pattern: str | None = Field( + default=None, description="Whitespace pattern for guided generation" + ) + priority: int | None = Field( + default=0, description="Request priority for scheduling" + ) + + +class CompletionChoice(BaseModel): + """A single completion choice from a text completion response. + + Contains the generated text and metadata about completion + quality and stopping conditions. + """ + + text: str = Field(description="Generated text content") + index: int = Field(description="Index of this choice in the response") + logprobs: dict[str, Any] | None = Field( + default=None, description="Log probabilities for generated tokens" + ) + finish_reason: Literal["stop", "length", "content_filter"] | None = Field( + description="Reason why generation finished" + ) + + +class CompletionsResponse(BaseModel): + """Response from legacy text completion API endpoints. + + Contains generated text choices, usage statistics, and metadata + for non-streaming text completion requests. + """ + + id: str = Field(description="Unique identifier for this completion") + object: Literal["text_completion"] = Field( + default="text_completion", description="Object type identifier" + ) + created: int = Field( + default_factory=lambda: int(time.time()), + description="Unix timestamp of creation", + ) + model: str = Field(description="Model used for generation") + choices: list[CompletionChoice] = Field(description="Generated completion choices") + usage: Usage | None = Field(default=None, description="Token usage statistics") + system_fingerprint: str | None = Field( + default=None, description="System configuration fingerprint" + ) + + +class TokenizeRequest(BaseModel): + """Request for tokenizing text into token sequences. + + Converts input text into model-specific token representations + with optional special token handling. + """ + + text: str = Field(description="Text to tokenize") + add_special_tokens: bool | None = Field( + default=True, description="Whether to add model-specific special tokens" + ) + + +class TokenizeResponse(BaseModel): + """Response containing tokenized representation of input text. + + Provides both the token sequence and count for analysis + and token budget planning. + """ + + tokens: list[int] = Field(description="List of token IDs") + count: int = Field(description="Total number of tokens") + + +class DetokenizeRequest(BaseModel): + """Request for converting token sequences back to text. + + Reconstructs human-readable text from model token representations + with configurable special token handling. + """ + + tokens: list[int] = Field(description="List of token IDs to convert") + skip_special_tokens: bool | None = Field( + default=True, description="Whether to skip special tokens in output" + ) + spaces_between_special_tokens: bool | None = Field( + default=True, description="Whether to add spaces between special tokens" + ) + + +class DetokenizeResponse(BaseModel): + """Response containing text reconstructed from tokens. + + Provides the human-readable text representation of the + input token sequence. + """ + + text: str = Field(description="Reconstructed text from tokens") + + +class ErrorDetail(BaseModel): + """Detailed error information for API failures. + + Provides structured error data including message, type classification, + and optional error codes for debugging and error handling. + """ + + message: str = Field(description="Human-readable error description") + type: str = Field(description="Error type classification") + code: str | None = Field( + default=None, description="Optional error code for programmatic handling" + ) + + +class ErrorResponse(BaseModel): + """Standardized error response structure for API failures. + + Wraps error details in a consistent format compatible with + OpenAI API error response conventions. + """ + + error: ErrorDetail = Field(description="Detailed error information") diff --git a/src/guidellm/mock_server/server.py b/src/guidellm/mock_server/server.py new file mode 100644 index 00000000..ff9d5fcd --- /dev/null +++ b/src/guidellm/mock_server/server.py @@ -0,0 +1,168 @@ +""" +High-performance mock server for OpenAI and vLLM API compatibility testing. + +This module provides a Sanic-based mock server that simulates OpenAI and vLLM APIs +with configurable latency, token generation patterns, and response characteristics. +The server supports both streaming and non-streaming endpoints, enabling realistic +performance testing and validation of GuideLLM benchmarking workflows without +requiring actual model deployments. +""" + +from __future__ import annotations + +import time + +from sanic import Sanic, response +from sanic.exceptions import NotFound +from sanic.log import logger +from sanic.request import Request +from sanic.response import HTTPResponse + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.handlers import ( + ChatCompletionsHandler, + CompletionsHandler, + TokenizerHandler, +) + +__all__ = ["MockServer"] + + +class MockServer: + """ + High-performance mock server implementing OpenAI and vLLM API endpoints. + + Provides a Sanic-based web server that simulates API responses with configurable + timing characteristics for testing and benchmarking purposes. Supports chat + completions, text completions, tokenization endpoints, and model listing with + realistic latency patterns to enable comprehensive performance validation. + + Example: + :: + config = ServerConfig(model="test-model", port=8080) + server = MockServer(config) + server.run() + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the mock server with configuration. + + :param config: Server configuration containing network settings and response + timing parameters + """ + self.config = config + self.app = Sanic("guidellm-mock-server") + self.chat_handler = ChatCompletionsHandler(config) + self.completions_handler = CompletionsHandler(config) + self.tokenizer_handler = TokenizerHandler(config) + + self._setup_middleware() + self._setup_routes() + self._setup_error_handlers() + + def _setup_middleware(self): + """Setup middleware for CORS, logging, etc.""" + + @self.app.middleware("request") + async def add_cors_headers(_request: Request): + """Add CORS headers to all requests.""" + + @self.app.middleware("response") + async def add_response_headers(_request: Request, resp: HTTPResponse): + """Add standard response headers.""" + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS" + resp.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization" + resp.headers["Server"] = "guidellm-mock-server" + + def _setup_routes(self): # noqa: C901 + @self.app.get("/health") + async def health_check(_request: Request): + return response.json({"status": "healthy", "timestamp": time.time()}) + + @self.app.get("/v1/models") + async def list_models(_request: Request): + return response.json( + { + "object": "list", + "data": [ + { + "id": self.config.model, + "object": "model", + "created": int(time.time()), + "owned_by": "guidellm-mock", + } + ], + } + ) + + @self.app.route("/v1/chat/completions", methods=["POST", "OPTIONS"]) + async def chat_completions(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.chat_handler.handle(request) + + @self.app.route("/v1/completions", methods=["POST", "OPTIONS"]) + async def completions(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.completions_handler.handle(request) + + @self.app.route("/tokenize", methods=["POST", "OPTIONS"]) + async def tokenize(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.tokenizer_handler.tokenize(request) + + @self.app.route("/detokenize", methods=["POST", "OPTIONS"]) + async def detokenize(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.tokenizer_handler.detokenize(request) + + def _setup_error_handlers(self): + """Setup error handlers.""" + + @self.app.exception(Exception) + async def generic_error_handler(_request: Request, exception: Exception): + logger.error(f"Unhandled exception: {exception}") + return response.json( + { + "error": { + "message": "Internal server error", + "type": type(exception).__name__, + "error": str(exception), + } + }, + status=500, + ) + + @self.app.exception(NotFound) + async def not_found_handler(_request: Request, _exception): + return response.json( + { + "error": { + "message": "Not Found", + "type": "not_found_error", + "code": "not_found", + } + }, + status=404, + ) + + def run(self) -> None: + """ + Start the mock server with configured settings. + + Runs the Sanic application in single-process mode with access logging enabled + for debugging and monitoring request patterns during testing. + """ + self.app.run( + host=self.config.host, + port=self.config.port, + debug=False, + single_process=True, + access_log=True, + register_sys_signals=False, # Disable signal handlers for threading + ) diff --git a/src/guidellm/mock_server/utils.py b/src/guidellm/mock_server/utils.py new file mode 100644 index 00000000..8348d0a6 --- /dev/null +++ b/src/guidellm/mock_server/utils.py @@ -0,0 +1,307 @@ +""" +Mock server utilities for text generation and tokenization testing. + +This module provides mock tokenization and text generation utilities for testing +guidellm's mock server functionality. It includes a mock tokenizer that simulates +tokenization processes, functions to generate reproducible fake text with specific +token counts, and timing generators for realistic benchmarking scenarios. +""" + +from __future__ import annotations + +import random +import re +from collections.abc import Generator + +from faker import Faker +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer, TextInput + +__all__ = [ + "MockTokenizer", + "create_fake_text", + "create_fake_tokens_str", + "sample_number", + "times_generator", +] + + +class MockTokenizer(PreTrainedTokenizer): + """ + Mock tokenizer implementation for testing text processing workflows. + + Provides a simplified tokenizer that splits text using regex patterns and + generates deterministic token IDs based on string hashing. Used for testing + guidellm components without requiring actual model tokenizers. + + :cvar VocabSize: Fixed vocabulary size for the mock tokenizer + """ + + VocabSize = 100000007 + + def __len__(self) -> int: + """ + Get the vocabulary size of the tokenizer. + + :return: The total number of tokens in the vocabulary + """ + return self.VocabSize + + def __call__(self, text: str | list[str], **kwargs) -> list[int]: # noqa: ARG002 + """ + Tokenize text and return token IDs (callable interface). + + :param text: Input text to tokenize + :return: List of token IDs + """ + if isinstance(text, str): + tokens = self.tokenize(text) + return self.convert_tokens_to_ids(tokens) + elif isinstance(text, list): + # Handle batch processing + return [self.__call__(t) for t in text] + else: + msg = f"text input must be of type `str` or `list[str]`, got {type(text)}" + raise ValueError(msg) + + def tokenize(self, text: TextInput, **_kwargs) -> list[str]: + """ + Tokenize input text into a list of token strings. + + Splits text using regex to separate words, punctuation, and whitespace + into individual tokens for processing. + + :param text: Input text to tokenize + :return: List of token strings from the input text + """ + # Split text into tokens: words, spaces, and punctuation + return re.findall(r"\w+|[^\w\s]|\s+", text) + + def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + """ + Convert token strings to numeric token IDs. + + Uses deterministic hashing to generate consistent token IDs for + reproducible testing scenarios. + + :param tokens: Single token string or list of token strings + :return: Single token ID or list of token IDs + """ + if isinstance(tokens, str): + return hash(tokens) % self.VocabSize + return [hash(token) % self.VocabSize for token in tokens] + + def convert_ids_to_tokens( + self, ids: int | list[int], _skip_special_tokens: bool = False + ) -> str | list[str]: + """ + Convert numeric token IDs back to token strings. + + Generates fake text tokens using Faker library seeded with token IDs + for deterministic and reproducible token generation. + + :param ids: Single token ID or list of token IDs to convert + :return: Single token string or list of token strings + """ + if not ids and not isinstance(ids, list): + return "" + elif not ids: + return [""] + + if isinstance(ids, int): + fake = Faker() + fake.seed_instance(ids % self.VocabSize) + + return fake.word() + + fake = Faker() + fake.seed_instance(sum(ids) % self.VocabSize) + + target_count = len(ids) + current_count = 0 + tokens = [] + + while current_count < target_count: + text = fake.text( + max_nb_chars=(target_count - current_count) * 10 # oversample + ) + new_tokens = self.tokenize(text) + + if current_count > 0: + new_tokens = [".", " "] + new_tokens + + new_tokens = ( + new_tokens[: target_count - current_count] + if len(new_tokens) > (target_count - current_count) + else new_tokens + ) + tokens += new_tokens + current_count += len(new_tokens) + + return tokens + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + """ + Convert a list of token strings back to a single text string. + + :param tokens: List of token strings to concatenate + :return: Concatenated string from all tokens + """ + return "".join(tokens) + + def _add_tokens( + self, + new_tokens: list[str] | list[AddedToken], # noqa: ARG002 + special_tokens: bool = False, # noqa: ARG002 + ) -> int: + """ + Add new tokens to the tokenizer vocabulary (mock implementation). + + :param new_tokens: List of tokens to add to the vocabulary + :param special_tokens: Whether the tokens are special tokens + :return: Number of tokens actually added (always 0 for mock) + """ + return 0 + + def apply_chat_template( + self, + conversation: list, + tokenize: bool = False, # Changed default to False to match transformers + add_generation_prompt: bool = False, # noqa: ARG002 + **kwargs, # noqa: ARG002 + ) -> str | list[int]: + """ + Apply a chat template to format conversation messages. + + Mock implementation that concatenates all message content for testing. + + :param conversation: List of chat messages + :param tokenize: Whether to return tokens or string + :param add_generation_prompt: Whether to add generation prompt + :return: Formatted text string or token IDs + """ + # Simple concatenation of all message content + texts = [] + for message in conversation: + if isinstance(message, dict) and "content" in message: + texts.append(message["content"]) + elif hasattr(message, "content"): + texts.append(message.content) + + formatted_text = " ".join(texts) + + if tokenize: + return self.convert_tokens_to_ids(self.tokenize(formatted_text)) + return formatted_text + + def decode( + self, + token_ids: list[int], + skip_special_tokens: bool = True, + **kwargs, # noqa: ARG002 + ) -> str: + """ + Decode token IDs back to text string. + + :param token_ids: List of token IDs to decode + :param skip_special_tokens: Whether to skip special tokens + :return: Decoded text string + """ + tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens) + return self.convert_tokens_to_string(tokens) + + +def create_fake_text( + num_tokens: int, + processor: PreTrainedTokenizer, + seed: int = 42, + fake: Faker | None = None, +) -> str: + """ + Generate fake text using a tokenizer processor with specified token count. + + Creates text by generating fake tokens and joining them into a string, + ensuring the result has the exact number of tokens when processed by + the given tokenizer. + + :param num_tokens: Target number of tokens in the generated text + :param processor: Tokenizer to use for token generation and validation + :param seed: Random seed for reproducible text generation + :param fake: Optional Faker instance for text generation + :return: Generated text string with the specified token count + """ + return "".join(create_fake_tokens_str(num_tokens, processor, seed, fake)) + + +def create_fake_tokens_str( + num_tokens: int, + processor: PreTrainedTokenizer, + seed: int = 42, + fake: Faker | None = None, +) -> list[str]: + """ + Generate fake token strings using a tokenizer processor. + + Creates a list of token strings by generating fake text and tokenizing it + until the desired token count is reached. Uses the provided tokenizer + for accurate token boundary detection. + + :param num_tokens: Target number of tokens to generate + :param processor: Tokenizer to use for token generation and validation + :param seed: Random seed for reproducible token generation + :param fake: Optional Faker instance for text generation + :return: List of token strings with the specified count + """ + if not fake: + fake = Faker() + fake.seed_instance(seed) + + tokens = [] + + while len(tokens) < num_tokens: + text = fake.text( + max_nb_chars=(num_tokens - len(tokens)) * 30 # oversample + ) + new_tokens = processor.tokenize(text) + + if len(tokens) > 0: + new_tokens = [".", " "] + new_tokens + + new_tokens = ( + new_tokens[: num_tokens - len(tokens)] + if len(new_tokens) > (num_tokens - len(tokens)) + else new_tokens + ) + tokens += new_tokens + + return tokens + + +def times_generator(mean: float, standard_dev: float) -> Generator[float]: + """ + Generate infinite timing values from a normal distribution. + + Creates a generator that yields timing values sampled from a normal + distribution, useful for simulating realistic request timing patterns + in benchmarking scenarios. + + :param mean: Mean value for the normal distribution + :param standard_dev: Standard deviation for the normal distribution + :return: Generator yielding positive timing values from the distribution + """ + while True: + yield sample_number(mean, standard_dev) + + +def sample_number(mean: float, standard_dev: float) -> float: + """ + Generate a single timing value from a normal distribution. + + Samples one timing value from a normal distribution with the specified + parameters, ensuring the result is non-negative for realistic timing + simulation in benchmarking scenarios. + + :param mean: Mean value for the normal distribution + :param standard_dev: Standard deviation for the normal distribution + :return: Non-negative timing value from the distribution + """ + return max(0.0, random.gauss(mean, standard_dev)) diff --git a/tests/unit/mock_server/__init__.py b/tests/unit/mock_server/__init__.py new file mode 100644 index 00000000..e02d60bd --- /dev/null +++ b/tests/unit/mock_server/__init__.py @@ -0,0 +1 @@ +"""Unit tests for the GuideLLM mock server package.""" diff --git a/tests/unit/mock_server/test_server.py b/tests/unit/mock_server/test_server.py new file mode 100644 index 00000000..008103c3 --- /dev/null +++ b/tests/unit/mock_server/test_server.py @@ -0,0 +1,518 @@ +from __future__ import annotations + +import asyncio +import json +import multiprocessing + +import httpx +import pytest +import pytest_asyncio +from pydantic import ValidationError + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.server import MockServer + + +# Start server in a separate process +def _start_server_process(config: MockServerConfig): + server = MockServer(config) + server.run() + + +@pytest_asyncio.fixture(scope="class") +async def mock_server_instance(): + """Instance-level fixture that provides a running server for HTTP testing.""" + + config = MockServerConfig( + host="127.0.0.1", + port=8012, + model="test-model", + ttft_ms=10.0, + itl_ms=1.0, + request_latency=0.1, + ) + base_url = f"http://{config.host}:{config.port}" + server_process = multiprocessing.Process( + target=_start_server_process, args=(config,) + ) + server_process.start() + + # Wait for server to start up and be ready + async def wait_for_startup(): + poll_frequency = 1.0 + async with httpx.AsyncClient() as client: + while True: + try: + response = await client.get(f"{base_url}/health", timeout=1.0) + if response.status_code == 200: + break + except (httpx.RequestError, httpx.TimeoutException): + pass + await asyncio.sleep(poll_frequency) + poll_frequency = min(poll_frequency * 1.5, 2.0) + + timeout = 30.0 + try: + await asyncio.wait_for(wait_for_startup(), timeout) + except TimeoutError: + # Server failed to start within timeout + server_process.terminate() + server_process.kill() + server_process.join(timeout=5) + pytest.fail(f"Server failed to start within {timeout} seconds") + + yield base_url, config + + # Cleanup: terminate the server process + server_process.terminate() + server_process.kill() + server_process.join(timeout=5) + + +class TestMockServerConfig: + """Test suite for MockServerConfig class.""" + + @pytest.mark.smoke + def test_default_initialization(self): + """Test MockServerConfig initialization with default values.""" + config = MockServerConfig() + assert config.host == "127.0.0.1" + assert config.port == 8000 + assert config.workers == 1 + assert config.model == "llama-3.1-8b-instruct" + assert config.processor is None + assert config.request_latency == 3.0 + assert config.request_latency_std == 0.0 + assert config.ttft_ms == 150.0 + assert config.ttft_ms_std == 0.0 + assert config.itl_ms == 10.0 + assert config.itl_ms_std == 0.0 + assert config.output_tokens == 128 + assert config.output_tokens_std == 0.0 + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("kwargs", "expected_values"), + [ + ( + {"host": "127.0.0.1", "port": 9000, "model": "custom-model"}, + {"host": "127.0.0.1", "port": 9000, "model": "custom-model"}, + ), + ( + {"request_latency": 1.5, "ttft_ms": 100.0, "output_tokens": 256}, + {"request_latency": 1.5, "ttft_ms": 100.0, "output_tokens": 256}, + ), + ], + ) + def test_custom_initialization(self, kwargs, expected_values): + """Test MockServerConfig initialization with custom values.""" + config = MockServerConfig(**kwargs) + for key, expected_value in expected_values.items(): + assert getattr(config, key) == expected_value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("port", "not_int"), + ("request_latency", "not_float"), + ("output_tokens", "not_int"), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test MockServerConfig with invalid field values.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + MockServerConfig(**kwargs) + + +class TestMockServer: + """Test suite for MockServer class.""" + + @pytest.mark.smoke + def test_class_signatures(self): + """Test MockServer class signatures and attributes.""" + assert hasattr(MockServer, "__init__") + assert hasattr(MockServer, "run") + assert hasattr(MockServer, "_setup_middleware") + assert hasattr(MockServer, "_setup_routes") + assert hasattr(MockServer, "_setup_error_handlers") + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test MockServer initialization without required config.""" + with pytest.raises(TypeError): + MockServer() + + +class TestMockServerEndpoints: + """Test suite for MockServer HTTP endpoints with real server instances.""" + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_health_endpoint(self, mock_server_instance): + """Test the health check endpoint.""" + server_url, _ = mock_server_instance + + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/health", timeout=5.0) + assert response.status_code == 200 + + data = response.json() + assert "status" in data + assert data["status"] == "healthy" + assert "timestamp" in data + assert isinstance(data["timestamp"], (int, float)) + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_models_endpoint(self, mock_server_instance): + """Test the models listing endpoint.""" + server_url, _ = mock_server_instance + + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/v1/models", timeout=5.0) + assert response.status_code == 200 + + data = response.json() + assert "object" in data + assert data["object"] == "list" + assert "data" in data + assert isinstance(data["data"], list) + assert len(data["data"]) > 0 + + model = data["data"][0] + assert "id" in model + assert "object" in model + assert "created" in model + assert "owned_by" in model + assert model["object"] == "model" + assert model["owned_by"] == "guidellm-mock" + assert model["id"] == "test-model" + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("payload", "expected_fields"), + [ + ( + { + "model": "test-model", + "messages": [{"role": "user", "content": "Hello!"}], + "max_tokens": 10, + }, + ["choices", "usage", "model", "object"], + ), + ( + { + "model": "test-model", + "messages": [{"role": "user", "content": "Test"}], + "max_tokens": 5, + "temperature": 0.7, + }, + ["choices", "usage", "model", "object"], + ), + ], + ) + async def test_chat_completions_endpoint( + self, mock_server_instance, payload, expected_fields + ): + """Test the chat completions endpoint.""" + server_url, _ = mock_server_instance + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/v1/chat/completions", json=payload, timeout=10.0 + ) + assert response.status_code == 200 + + data = response.json() + for field in expected_fields: + assert field in data + + assert len(data["choices"]) > 0 + choice = data["choices"][0] + assert "message" in choice + assert "content" in choice["message"] + assert "role" in choice["message"] + assert choice["message"]["role"] == "assistant" + assert isinstance(choice["message"]["content"], str) + assert len(choice["message"]["content"]) > 0 + + # Verify usage information + assert "prompt_tokens" in data["usage"] + assert "completion_tokens" in data["usage"] + assert "total_tokens" in data["usage"] + assert data["usage"]["total_tokens"] == ( + data["usage"]["prompt_tokens"] + data["usage"]["completion_tokens"] + ) + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_streaming_chat_completions(self, mock_server_instance): + """Test streaming chat completions endpoint.""" + server_url, _ = mock_server_instance + + payload = { + "model": "test-model", + "messages": [{"role": "user", "content": "Hi!"}], + "max_tokens": 5, + "stream": True, + } + + async with ( + httpx.AsyncClient() as client, + client.stream( + "POST", + f"{server_url}/v1/chat/completions", + json=payload, + timeout=10.0, + ) as response, + ): + assert response.status_code == 200 + assert "text/event-stream" in response.headers.get("content-type", "") + + chunks = [] + async for line in response.aiter_lines(): + if line and line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + break + try: + chunk_data = json.loads(data_str) + chunks.append(chunk_data) + except json.JSONDecodeError: + continue + + assert len(chunks) > 0 + # Verify chunk structure + for chunk in chunks: + assert "choices" in chunk + assert len(chunk["choices"]) > 0 + assert "delta" in chunk["choices"][0] + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("payload", "expected_fields"), + [ + ( + { + "model": "test-model", + "prompt": "Hello", + "max_tokens": 10, + }, + ["choices", "usage", "model", "object"], + ), + ( + { + "model": "test-model", + "prompt": "Test prompt", + "max_tokens": 5, + "temperature": 0.8, + }, + ["choices", "usage", "model", "object"], + ), + ], + ) + @pytest.mark.asyncio + async def test_completions_endpoint( + self, mock_server_instance, payload, expected_fields + ): + """Test the legacy completions endpoint.""" + server_url, _ = mock_server_instance + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/v1/completions", json=payload, timeout=10.0 + ) + assert response.status_code == 200 + + data = response.json() + for field in expected_fields: + assert field in data + + assert len(data["choices"]) > 0 + choice = data["choices"][0] + assert "text" in choice + assert isinstance(choice["text"], str) + assert len(choice["text"]) > 0 + + # Verify usage information + assert "prompt_tokens" in data["usage"] + assert "completion_tokens" in data["usage"] + assert "total_tokens" in data["usage"] + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_streaming_completions(self, mock_server_instance): + """Test streaming completions endpoint.""" + server_url, _ = mock_server_instance + payload = { + "model": "test-model", + "prompt": "Hello", + "max_tokens": 5, + "stream": True, + } + + async with ( + httpx.AsyncClient() as client, + client.stream( + "POST", + f"{server_url}/v1/completions", + json=payload, + timeout=10.0, + ) as response, + ): + assert response.status_code == 200 + assert "text/event-stream" in response.headers.get("content-type", "") + + chunks = [] + async for line in response.aiter_lines(): + if line and line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + break + try: + chunk_data = json.loads(data_str) + chunks.append(chunk_data) + except json.JSONDecodeError: + continue + + assert len(chunks) > 0 + # Verify chunk structure + for chunk in chunks: + assert "choices" in chunk + assert len(chunk["choices"]) > 0 + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("payload", "expected_fields"), + [ + ( + {"text": "Hello world!"}, + ["tokens", "count"], + ), + ( + {"text": "This is a test sentence."}, + ["tokens", "count"], + ), + ], + ) + @pytest.mark.asyncio + async def test_tokenize_endpoint( + self, mock_server_instance, payload, expected_fields + ): + """Test the tokenize endpoint.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/tokenize", json=payload, timeout=5.0 + ) + assert response.status_code == 200 + + data = response.json() + for field in expected_fields: + assert field in data + + assert isinstance(data["tokens"], list) + assert isinstance(data["count"], int) + assert data["count"] == len(data["tokens"]) + assert len(data["tokens"]) > 0 + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("payload", "expected_fields"), + [ + ( + {"tokens": [123, 456, 789]}, + ["text"], + ), + ( + {"tokens": [100, 200]}, + ["text"], + ), + ], + ) + @pytest.mark.asyncio + async def test_detokenize_endpoint( + self, mock_server_instance, payload, expected_fields + ): + """Test the detokenize endpoint.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/detokenize", json=payload, timeout=5.0 + ) + assert response.status_code == 200 + + data = response.json() + for field in expected_fields: + assert field in data + + assert isinstance(data["text"], str) + assert len(data["text"]) > 0 + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_options_endpoint(self, mock_server_instance): + """Test the OPTIONS endpoint for CORS support.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.options( + f"{server_url}/v1/chat/completions", timeout=5.0 + ) + assert response.status_code == 204 + assert response.text == "" + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_cors_headers(self, mock_server_instance): + """Test CORS headers are properly set.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/health", timeout=5.0) + assert response.status_code == 200 + + # Check for CORS headers + assert response.headers.get("Access-Control-Allow-Origin") == "*" + methods_header = response.headers.get("Access-Control-Allow-Methods", "") + assert "GET, POST, OPTIONS" in methods_header + headers_header = response.headers.get("Access-Control-Allow-Headers", "") + assert "Content-Type, Authorization" in headers_header + assert response.headers.get("Server") == "guidellm-mock-server" + + @pytest.mark.sanity + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("endpoint", "method", "payload"), + [ + ("/v1/chat/completions", "POST", {"invalid": "payload"}), + ("/v1/completions", "POST", {"invalid": "payload"}), + ("/tokenize", "POST", {"invalid": "payload"}), + ("/detokenize", "POST", {"invalid": "payload"}), + ], + ) + async def test_invalid_request_handling( + self, mock_server_instance, endpoint, method, payload + ): + """Test handling of invalid requests.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + if method == "POST": + response = await client.post( + f"{server_url}{endpoint}", json=payload, timeout=5.0 + ) + else: + response = await client.get(f"{server_url}{endpoint}", timeout=5.0) + + # Should return an error response, not crash + assert response.status_code in [400, 422, 500] + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_nonexistent_endpoint(self, mock_server_instance): + """Test handling of requests to nonexistent endpoints.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/nonexistent", timeout=5.0) + assert response.status_code == 404