diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f60d0673..869abb3f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,13 +1,13 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v6.0.0 hooks: - id: trailing-whitespace exclude: ^tests/?.*/assets/.+ - id: end-of-file-fixer exclude: ^tests/?.*/assets/.+ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.7 + rev: v0.12.10 hooks: - id: ruff name: run linter @@ -15,7 +15,7 @@ repos: - id: ruff-format name: run formatter - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 + rev: v1.17.1 hooks: - id: mypy args: [--check-untyped-defs] diff --git a/pyproject.toml b/pyproject.toml index 6c46da4e..27d76006 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "culsans~=0.9.0", "datasets", "eval_type_backport", + "faker", "ftfy>=6.0.0", "httpx[http2]<1.0.0", "loguru", @@ -59,7 +60,9 @@ dependencies = [ "pyhumps>=3.8.0", "pyyaml>=6.0.0", "rich", + "sanic", "transformers", + "uvloop>=0.18", ] [project.optional-dependencies] @@ -78,11 +81,13 @@ dev = [ # testing "lorem~=0.1.1", "pytest~=8.2.2", - "pytest-asyncio~=0.23.8", + "pytest-asyncio~=1.1.0", "pytest-cov~=5.0.0", "pytest-mock~=3.14.0", "pytest-rerunfailures~=14.0", + "pytest-timeout~=2.3.1", "respx~=0.22.0", + "hypothesis~=6.138.3", # code quality "mypy~=1.15.0", diff --git a/tests/unit/objects/__init__.py b/research/__init__.py similarity index 100% rename from tests/unit/objects/__init__.py rename to research/__init__.py diff --git a/research/multiprocesssing_communication_perf/README.md b/research/multiprocesssing_communication_perf/README.md new file mode 100644 index 00000000..e69de29b diff --git a/research/multiprocesssing_communication_perf/__init__.py b/research/multiprocesssing_communication_perf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/research/multiprocesssing_communication_perf/requirements.txt b/research/multiprocesssing_communication_perf/requirements.txt new file mode 100644 index 00000000..e69de29b diff --git a/research/multiprocesssing_communication_perf/test_encoding_perf.py b/research/multiprocesssing_communication_perf/test_encoding_perf.py new file mode 100644 index 00000000..b955efc3 --- /dev/null +++ b/research/multiprocesssing_communication_perf/test_encoding_perf.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import csv +import io +import pickle +import random +import sys +import time +from typing import Any + +import click +import numpy as np +from pydantic import BaseModel + +from guidellm.utils import EncodingTypesAlias, MessageEncoding, SerializationTypesAlias + +from .utils import create_all_test_objects + + +def calculate_size(obj: Any) -> int: + if isinstance(obj, BaseModel): + return sys.getsizeof(obj.__dict__) + + if isinstance(obj, (tuple, list)) and any( + isinstance(item, BaseModel) for item in obj + ): + return sum( + sys.getsizeof(item.__dict__) + if isinstance(item, BaseModel) + else sys.getsizeof(item) + for item in obj + ) + elif isinstance(obj, dict) and any( + isinstance(value, BaseModel) for value in obj.values() + ): + return sum( + sys.getsizeof(value.__dict__) + if isinstance(value, BaseModel) + else sys.getsizeof(value) + for value in obj.values() + if isinstance(value, BaseModel) + ) + + return sys.getsizeof(obj) + + +def time_encode_decode( + objects: list[Any], + serialization: SerializationTypesAlias, + encoding: EncodingTypesAlias, + pydantic_models: list[type[BaseModel]] | None, + num_iterations: int, +) -> tuple[float, float, float, float]: + message_encoding = MessageEncoding(serialization=serialization, encoding=encoding) + if pydantic_models: + for model in pydantic_models: + message_encoding.register_pydantic(model) + msg_sizes = [] + decoded = [] + encode_time = 0.0 + decode_time = 0.0 + + for _ in range(num_iterations): + for obj in objects: + start = time.perf_counter_ns() + message = message_encoding.encode(obj) + pickled_msg = pickle.dumps(message) + end = time.perf_counter_ns() + encode_time += end - start + + msg_sizes.append(calculate_size(pickled_msg)) + + start = time.perf_counter_ns() + message = pickle.loads(pickled_msg) + decoded.append(message_encoding.decode(message=message)) + end = time.perf_counter_ns() + decode_time += end - start + + correct = 0 + for obj, dec in zip(objects, decoded): + if ( + obj == dec + or type(obj) is type(dec) + and ( + ( + hasattr(obj, "model_dump") + and hasattr(dec, "model_dump") + and obj.model_dump() == dec.model_dump() + ) + or str(obj) == str(dec) + ) + ): + correct += 1 + + percent_differences = 100.0 * correct / len(objects) + avg_msg_size = np.mean(msg_sizes) + + return ( + encode_time / len(objects), + decode_time / len(objects), + avg_msg_size, + percent_differences, + ) + + +def run_benchmarks(objects_size: int, num_objects: int, num_iterations: int): + results = {} + + for obj_type, objects, pydantic_models in create_all_test_objects( + objects_size=objects_size, + num_objects=num_objects, + ): + for serialization in ("dict", "sequence", None): + for encoding in ("msgpack", "msgspec", None): + try: + encode_time, decode_time, avg_msg_size, percent_differences = ( + time_encode_decode( + objects=objects, + serialization=serialization, + encoding=encoding, + pydantic_models=pydantic_models, + num_iterations=num_iterations, + ) + ) + error = None + except Exception as err: + print( + f"Error occurred while benchmarking {obj_type} for " + f"serialization={serialization} and encoding={encoding}: {err}" + ) + error = err + encode_time = None + decode_time = None + avg_msg_size = None + percent_differences = None + + results[f"{obj_type}_{serialization}_{encoding}"] = { + "obj_type": obj_type, + "serialization": serialization, + "encoding": encoding, + "encode_time": encode_time, + "decode_time": decode_time, + "total_time": ( + encode_time + decode_time + if encode_time is not None and decode_time is not None + else None + ), + "avg_msg_size": avg_msg_size, + "percent_differences": percent_differences, + "err": error, + } + + # Print results as a CSV table + + # Create CSV output + output = io.StringIO() + writer = csv.writer(output) + + # Write header + writer.writerow( + [ + "Object Type", + "Serialization", + "Encoding", + "Encode Time (ns)", + "Decode Time (ns)", + "Total Time (ns)", + "Avg Message Size (bytes)", + "Accuracy (%)", + "Error", + ] + ) + + # Write data rows + for result in results.values(): + writer.writerow( + [ + result["obj_type"], + result["serialization"], + result["encoding"], + result["encode_time"], + result["decode_time"], + result["total_time"], + result["avg_msg_size"], + result["percent_differences"], + result["err"], + ] + ) + + # Print the CSV table + print(output.getvalue()) + + +@click.command() +@click.option("--size", default=1024, type=int, help="Size of each object in bytes") +@click.option( + "--objects", default=1000, type=int, help="Number of objects to benchmark" +) +@click.option("--iterations", default=5, type=int, help="Number of iterations to run") +def main(size, objects, iterations): + random.seed(42) + run_benchmarks(objects_size=size, num_objects=objects, num_iterations=iterations) + + +if __name__ == "__main__": + run_benchmarks(objects_size=1024, num_objects=10, num_iterations=5) diff --git a/research/multiprocesssing_communication_perf/test_multiprocess_messaging_perf.py b/research/multiprocesssing_communication_perf/test_multiprocess_messaging_perf.py new file mode 100644 index 00000000..e6a247ee --- /dev/null +++ b/research/multiprocesssing_communication_perf/test_multiprocess_messaging_perf.py @@ -0,0 +1,318 @@ +""" +Multiprocessing Communication Performance Benchmarking Tool + +This module benchmarks various multiprocessing communication mechanisms +for the guidellm project. + +FIXES APPLIED: +1. Fixed manager context creation - manager_fork and manager_spawn now correctly + create Manager() instances instead of passing raw contexts +2. Added comprehensive timeout handling to prevent hanging tests +3. Improved process cleanup with graceful termination, then kill if needed +4. Added better error handling in benchmark loops with specific exception types +5. Fixed response counting and metrics calculation to handle incomplete responses +6. Added timeout handling for individual test scenarios (60s each) +7. Enhanced process cleanup to avoid zombie processes +8. Added support for multiple serialization (None, pickle, json) and encoding (None, gzip) options +9. Improved error reporting to distinguish between timeouts and other failures + +KNOWN ISSUES: +- Pipe implementation tends to timeout, likely due to design issues in the messaging layer +- This is expected behavior and helps identify performance bottlenecks +""" + +from __future__ import annotations + +import asyncio +import csv +import io +import multiprocessing +import random +import time +from typing import Any, Literal + +import click +from pydantic import BaseModel +from utils import ( + calculate_size, + create_all_test_objects, +) + +from guidellm.utils import ( + EncodingTypesAlias, + InterProcessMessaging, + InterProcessMessagingManagerQueue, + InterProcessMessagingPipe, + InterProcessMessagingQueue, + SerializationTypesAlias, +) + + +async def benchmark_process_loop( + messaging: InterProcessMessaging, +) -> tuple[float, float]: + await messaging.start() + start_time = time.perf_counter() + + try: + while True: + try: + received = await messaging.get(timeout=1.0) + if received is None: + break + await messaging.put(received, timeout=0.1) + except asyncio.TimeoutError: + # If we timeout waiting for a message, continue the loop + # This might happen during shutdown + continue + except Exception as e: + print(f"Error in benchmark loop: {e}") + break + except Exception as e: + print(f"Error in benchmark process: {e}") + finally: + try: + await messaging.stop() + except Exception as e: + print(f"Error stopping messaging: {e}") + + end_time = time.perf_counter() + + return start_time, end_time + + +def benchmark_process(messaging: InterProcessMessaging) -> tuple[float, float]: + try: + return asyncio.run(benchmark_process_loop(messaging)) + except Exception as e: + print(f"Error in benchmark_process: {e}") + return 0.0, 0.0 + + +async def time_multiprocessing_messaging( + objects: list[Any], + mp_messaging: Literal[ + "queue", "manager_queue", "manager_fork", "manager_spawn", "pipe" + ], + serialization: SerializationTypesAlias, + encoding: EncodingTypesAlias, + pydantic_models: list[type[BaseModel]] | None, + num_iterations: int, + num_processes: int, +) -> tuple[float, float]: + if mp_messaging == "queue": + messaging = InterProcessMessagingQueue( + serialization=serialization, + encoding=encoding, + pydantic_models=pydantic_models, + ) + elif mp_messaging in ("manager_queue", "manager_fork", "manager_spawn"): + messaging = InterProcessMessagingManagerQueue( + manager=( + multiprocessing.Manager() + if mp_messaging == "manager_queue" + else multiprocessing.get_context("fork").Manager() + if mp_messaging == "manager_fork" + else multiprocessing.get_context("spawn").Manager() + ), + serialization=serialization, + encoding=encoding, + pydantic_models=pydantic_models, + ) + elif mp_messaging == "pipe": + messaging = InterProcessMessagingPipe( + num_workers=num_processes, + serialization=serialization, + encoding=encoding, + pydantic_models=pydantic_models, + ) + else: + raise ValueError(f"Unknown messaging type: {mp_messaging}") + + processes = [] + responses = [] + for ind in range(num_processes): + process = multiprocessing.Process( + target=benchmark_process, args=(messaging.create_worker_copy(ind),) + ) + process.start() + processes.append(process) + + await messaging.start() + await asyncio.sleep(1) # process startup time + start_time = time.perf_counter() + + try: + # push messages + for _ in range(num_iterations): + for obj in objects: + await messaging.put(obj, timeout=5.0) + + # shut down processes + for _ in range(num_processes): + await messaging.put(None, timeout=5.0) + + # get results + for _ in range(num_iterations): + for _ in range(len(objects)): + response = await messaging.get(timeout=30.0) + responses.append(response) + + end_time = time.perf_counter() + + except asyncio.TimeoutError as e: + print(f"Timeout during messaging: {e}") + end_time = time.perf_counter() + except Exception as e: + print(f"Error during messaging: {e}") + end_time = time.perf_counter() + finally: + # Clean up processes more gracefully + for process in processes: + if process.is_alive(): + process.join(timeout=2) + if process.is_alive(): + print(f"Terminating process {process.pid}") + process.terminate() + process.join(timeout=2) + if process.is_alive(): + print(f"Force killing process {process.pid}") + process.kill() + process.join() + + # Clean up messaging + try: + await messaging.stop() + except Exception as e: + print(f"Error stopping messaging: {e}") + + # Calculate metrics + correct = 0 + size = 0.0 + expected_responses = num_iterations * len(objects) + + # Handle case where we didn't get all responses + if len(responses) < expected_responses: + print(f"Warning: Expected {expected_responses} responses, got {len(responses)}") + + # Compare responses with original objects (cycling through objects if needed) + for i, response in enumerate(responses): + obj_index = i % len(objects) + obj = objects[obj_index] + + if ( + obj == response + or type(obj) is type(response) + and ( + ( + hasattr(obj, "model_dump") + and hasattr(response, "model_dump") + and obj.model_dump() == response.model_dump() + ) + or str(obj) == str(response) + ) + ): + correct += 1 + size += calculate_size(obj) + + # If we don't have timing data, return zeros + if start_time >= end_time: + return 0.0, 0.0 + + # Calculate average time and size + actual_count = max(len(responses), 1) # Avoid division by zero + avg_time = (end_time - start_time) / actual_count + avg_size = size / len(objects) if len(objects) > 0 else 0.0 + + return avg_time, avg_size + + +def run_benchmarks(objects_size: int, num_objects: int, num_iterations: int): + results = [] + + for obj_type, objects, pydantic_models in create_all_test_objects( + objects_size=objects_size, + num_objects=num_objects, + ): + # Only test simple data types for now + if obj_type not in ["str", "list", "dict", "bytes"]: + continue + for mp_messaging in ( + "queue", + "manager_queue", + "manager_fork", + "manager_spawn", + "pipe", + ): + for serialization in (None, "pickle", "json"): # Expanded options + for encoding in (None,): # Only None available + try: + # Add timeout to prevent hanging + avg_time, avg_size = asyncio.run( + asyncio.wait_for( + time_multiprocessing_messaging( + objects=objects, + mp_messaging=mp_messaging, + serialization=serialization, + encoding=encoding, + pydantic_models=pydantic_models, + num_iterations=num_iterations, + num_processes=2, + ), + timeout=60.0, # 60 second timeout per test + ) + ) + results.append( + { + "object_type": obj_type, + "mp_messaging": mp_messaging, + "serialization": serialization + if serialization is not None + else "none", + "encoding": encoding + if encoding is not None + else "none", + "avg_time_sec": avg_time, + "avg_size_bytes": avg_size, + } + ) + print( + f"Completed: {obj_type}, {mp_messaging}, {serialization}, {encoding}" + ) + except asyncio.TimeoutError: + print( + f"Timeout: {obj_type}, {mp_messaging}, {serialization}, {encoding}" + ) + except Exception as e: + print( + f"Failed: {obj_type}, {mp_messaging}, {serialization}, {encoding} with error {e}" + ) + + output = io.StringIO() + writer = csv.DictWriter( + output, + fieldnames=[ + "object_type", + "mp_messaging", + "serialization", + "encoding", + "avg_time_sec", + "avg_size_bytes", + ], + ) + writer.writeheader() + writer.writerows(results) + print(output.getvalue()) + + +@click.command() +@click.option("--size", default=1024, type=int, help="Size of each object in bytes") +@click.option("--objects", default=100, type=int, help="Number of objects to benchmark") +@click.option("--iterations", default=5, type=int, help="Number of iterations to run") +def main(size, objects, iterations): + random.seed(42) + run_benchmarks(objects_size=size, num_objects=objects, num_iterations=iterations) + + +if __name__ == "__main__": + run_benchmarks(objects_size=1024, num_objects=10, num_iterations=5) diff --git a/research/multiprocesssing_communication_perf/utils.py b/research/multiprocesssing_communication_perf/utils.py new file mode 100644 index 00000000..a029f62b --- /dev/null +++ b/research/multiprocesssing_communication_perf/utils.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import random +import string +import sys +import time +import uuid +from typing import Any, Literal + +from pydantic import BaseModel, Field + +from guidellm.backend import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, +) +from guidellm.scheduler import RequestSchedulerTimings, ScheduledRequestInfo + +__all__ = [ + "TestModel", + "calculate_size", + "create_all_test_objects", + "create_test_objects", + "generate_str", + "generate_strs_dict", + "generate_strs_list", +] + + +class TestModel(BaseModel): + test_str: str = Field(default="") + test_int: int = Field(default=0) + test_float: float = Field(default=0.0) + test_bool: bool = Field(default=True) + + +def generate_str(target_bytes: int) -> str: + chars = string.ascii_letters + string.digits + " " + return "".join(random.choice(chars) for _ in range(target_bytes)) + + +def generate_strs_list(target_bytes: int, num_strs: int) -> list[str]: + bytes_per_str = target_bytes // num_strs + + return [ + generate_str( + bytes_per_str + 1 if ind < target_bytes % num_strs else bytes_per_str + ) + for ind in range(num_strs) + ] + + +def generate_strs_dict(target_bytes: int, num_strs: int) -> dict[str, str]: + bytes_per_element = target_bytes // num_strs + bytes_per_key = bytes_per_element // 4 + bytes_per_value = bytes_per_element - bytes_per_key + + return { + generate_str(bytes_per_key): generate_str( + bytes_per_value + 1 if ind < num_strs - 1 else bytes_per_value + ) + for ind in range(num_strs) + } + + +def create_test_objects( + type_: Literal[ + "bytes", + "str", + "list", + "dict", + "pydantic", + "tuple(pydantic)", + "dict[pydantic]", + "tuple[GenerativeUpdate]", + "tuple[GenerationResponse]", + ], + objects_size: int, + num_objects: int, +) -> tuple[list[Any], list[type[BaseModel]] | None]: + if type_ == "bytes": + return [random.randbytes(objects_size) for _ in range(num_objects)], None + + if type_ == "str": + return [generate_str(objects_size) for _ in range(num_objects)], None + + if type_ == "list": + return [generate_strs_list(objects_size, 10) for _ in range(num_objects)], None + + if type_ == "dict": + return [generate_strs_dict(objects_size, 10) for _ in range(num_objects)], None + + if type_ == "pydantic": + return ( + [ + TestModel( + test_str=generate_str(objects_size), + test_int=random.randint(1, 100), + test_float=random.random(), + test_bool=random.choice([True, False]), + ) + for _ in range(num_objects) + ], + [TestModel], + ) + + if type_ == "tuple(pydantic)": + return [ + ( + TestModel( + test_str=generate_str(objects_size // 8), + test_int=random.randint(1, 100), + test_float=random.random(), + test_bool=random.choice([True, False]), + ), + TestModel( + test_str=generate_str(objects_size // 2), + test_int=random.randint(1, 100), + test_float=random.random(), + test_bool=random.choice([True, False]), + ), + TestModel( + test_str=generate_str(objects_size // 4 + objects_size // 8), + test_int=random.randint(1, 100), + test_float=random.random(), + test_bool=random.choice([True, False]), + ), + ) + ], [TestModel] + + if type_ == "dict[pydantic]": + return [ + { + generate_str(8): TestModel( + test_str=generate_str(objects_size // 4), + test_int=random.randint(1, 100), + test_float=random.random(), + test_bool=random.choice([True, False]), + ), + generate_str(8): TestModel( + test_str=generate_str(objects_size // 2 + objects_size // 4), + test_int=random.randint(1, 100), + test_float=random.random(), + test_bool=random.choice([True, False]), + ), + } + for _ in range(num_objects) + ], [TestModel] + + if type_ == "tuple[GenerativeUpdate]": + return [ + ( + None, + GenerationRequest( + content=generate_str(objects_size), + ), + ScheduledRequestInfo( + scheduler_timings=RequestSchedulerTimings( + targeted_start=time.time(), + queued=time.time(), + dequeued=time.time(), + scheduled_at=time.time(), + resolve_start=time.time(), + resolve_end=time.time(), + finalized=time.time(), + ), + request_timings=GenerationRequestTimings( + request_start=time.time(), + request_end=time.time(), + first_iteration=time.time(), + last_iteration=time.time(), + ), + ), + ) + for _ in range(num_objects) + ], [GenerationRequest, ScheduledRequestInfo] + + if type_ == "tuple[GenerationResponse]": + return [ + ( + GenerationResponse( + request_id=str(uuid.uuid4()), + request_args={}, + value=generate_str(objects_size // 2), + ), + GenerationRequest( + content=generate_str(objects_size // 2), + ), + ScheduledRequestInfo( + scheduler_timings=RequestSchedulerTimings( + targeted_start=time.time(), + queued=time.time(), + dequeued=time.time(), + scheduled_at=time.time(), + resolve_start=time.time(), + resolve_end=time.time(), + finalized=time.time(), + ), + request_timings=GenerationRequestTimings( + request_start=time.time(), + request_end=time.time(), + first_iteration=time.time(), + last_iteration=time.time(), + ), + ), + ) + for _ in range(num_objects) + ], [ + GenerationResponse, + GenerationRequest, + ScheduledRequestInfo, + ] + + raise ValueError(f"Unknown type_: {type_}") + + +def create_all_test_objects( + objects_size: int, num_objects: int +) -> list[tuple[str, list[Any], dict[str, type[BaseModel]] | None]]: + tests = [] + + for object_type in ( + "bytes", + "str", + "list", + "dict", + "pydantic", + "tuple(pydantic)", + "dict[pydantic]", + "tuple[GenerativeUpdate]", + "tuple[GenerationResponse]", + ): + tests.append( + (object_type, *create_test_objects(object_type, objects_size, num_objects)) + ) + + return tests + + +def calculate_size(obj: Any) -> int: + if isinstance(obj, BaseModel): + return sys.getsizeof(obj.__dict__) + + if isinstance(obj, (tuple, list)) and any( + isinstance(item, BaseModel) for item in obj + ): + return sum( + sys.getsizeof(item.__dict__) + if isinstance(item, BaseModel) + else sys.getsizeof(item) + for item in obj + ) + elif isinstance(obj, dict) and any( + isinstance(value, BaseModel) for value in obj.values() + ): + return sum( + sys.getsizeof(value.__dict__) + if isinstance(value, BaseModel) + else sys.getsizeof(value) + for value in obj.values() + if isinstance(value, BaseModel) + ) + + return sys.getsizeof(obj) diff --git a/src/guidellm/__init__.py b/src/guidellm/__init__.py index 9333860e..f2206e94 100644 --- a/src/guidellm/__init__.py +++ b/src/guidellm/__init__.py @@ -20,7 +20,8 @@ hf_logging.set_verbosity_error() logging.getLogger("transformers").setLevel(logging.ERROR) -from .config import ( +from .logger import configure_logger, logger +from .settings import ( DatasetSettings, Environment, LoggingSettings, @@ -30,7 +31,6 @@ reload_settings, settings, ) -from .logger import configure_logger, logger __all__ = [ "DatasetSettings", diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index 120f5264..f4630899 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -1,10 +1,50 @@ +""" +GuideLLM command-line interface providing benchmarking, dataset preprocessing, and +mock server functionality. + +This module serves as the primary entry point for the GuideLLM CLI application, +offering a comprehensive suite of tools for language model evaluation and testing. +It provides three main command groups: benchmark operations for performance testing +against generative models, dataset preprocessing utilities for data preparation and +transformation, and a mock server for testing and development scenarios. The CLI +supports various backends, output formats, and configuration options to accommodate +different benchmarking needs and deployment environments. + +Example: +:: + # Run a benchmark against a model + guidellm benchmark run --target http://localhost:8000 --data dataset.json \\ + --profile sweep + + # Preprocess a dataset + guidellm preprocess dataset input.json output.json --processor gpt2 + + # Start a mock server for testing + guidellm mock-server --host 0.0.0.0 --port 8080 +""" + +from __future__ import annotations + import asyncio import codecs from pathlib import Path -from typing import get_args +from typing import Annotated, Union import click +try: + import uvloop + + HAS_UVLOOP: Annotated[ + bool, "Flag indicating if uvloop is available for event loop optimization" + ] = True +except ImportError: + uvloop = None + + HAS_UVLOOP: Annotated[ + bool, "Flag indicating if uvloop is available for event loop optimization" + ] = False + from guidellm.backend import BackendType from guidellm.benchmark import ( GenerativeConsoleBenchmarkerProgress, @@ -16,20 +56,62 @@ from guidellm.benchmark.scenario import ( GenerativeTextScenario, ) -from guidellm.config import print_config +from guidellm.mock_server import MockServer, MockServerConfig from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset from guidellm.scheduler import StrategyType -from guidellm.utils import DefaultGroupHandler +from guidellm.settings import print_config +from guidellm.utils import Console, DefaultGroupHandler, get_literal_vals from guidellm.utils import cli as cli_tools -STRATEGY_PROFILE_CHOICES = list( - set(list(get_args(ProfileType)) + list(get_args(StrategyType))) -) +__all__ = [ + "STRATEGY_PROFILE_CHOICES", + "benchmark", + "cli", + "config", + "dataset", + "decode_escaped_str", + "from_file", + "mock_server", + "preprocess", + "run", +] + +STRATEGY_PROFILE_CHOICES: Annotated[ + list[str], "Available strategy and profile choices for benchmark execution types" +] = list(get_literal_vals(Union[ProfileType, StrategyType])) + + +def decode_escaped_str(_ctx, _param, value): + """ + Decode escape sequences in Click option values. + + Click automatically escapes characters in option values, converting sequences + like "\\n" to "\\\\n". This function properly decodes these escape sequences + to their intended characters for use in CLI options. + + :param _ctx: Click context (unused) + :param _param: Click parameter (unused) + :param value: String value to decode escape sequences from + :return: Decoded string with proper escape sequences + :raises click.BadParameter: When escape sequence decoding fails + """ + if value is None: + return None + try: + return codecs.decode(value, "unicode_escape") + except Exception as e: + raise click.BadParameter(f"Could not decode escape sequences: {e}") from e @click.group() def cli(): - pass + """ + Main entry point for the GuideLLM command-line interface. + + This is the root command group that organizes all GuideLLM CLI functionality + into logical subgroups for benchmarking, preprocessing, configuration, and + mock server operations. + """ @cli.group( @@ -38,7 +120,13 @@ def cli(): default="run", ) def benchmark(): - pass + """ + Benchmark command group for running and managing performance tests. + + This command group provides functionality to execute new benchmarks against + generative models and load previously saved benchmark reports for analysis. + Supports various benchmarking strategies, output formats, and backend types. + """ @benchmark.command( @@ -93,10 +181,10 @@ def benchmark(): "--backend", "--backend-type", # legacy alias "backend", - type=click.Choice(list(get_args(BackendType))), + type=click.Choice(list(get_literal_vals(BackendType))), help=( "The type of backend to use to run requests against. Defaults to 'openai_http'." - f" Supported types: {', '.join(get_args(BackendType))}" + f" Supported types: {', '.join(get_literal_vals(BackendType))}" ), default="openai_http", ) @@ -266,9 +354,24 @@ def benchmark(): "If None, will run until max_seconds or the data is exhausted." ), ) -@click.option("--max-errors", type=int, default=None, help="") -@click.option("--max-error-rate", type=float, default=None, help="") -@click.option("--max-global-error-rate", type=float, default=None, help="") +@click.option( + "--max-errors", + type=int, + default=None, + help="Maximum number of errors allowed before stopping the benchmark", +) +@click.option( + "--max-error-rate", + type=float, + default=None, + help="Maximum error rate allowed before stopping the benchmark", +) +@click.option( + "--max-global-error-rate", + type=float, + default=None, + help="Maximum global error rate allowed across all benchmarks", +) def run( target, data, @@ -303,6 +406,16 @@ def run( max_error_rate, max_global_error_rate, ): + """ + Execute a generative text benchmark against a target model backend. + + Runs comprehensive performance testing using various strategies and profiles, + collecting metrics on latency, throughput, error rates, and resource usage. + Supports multiple backends, data sources, output formats, and constraint types + for flexible benchmark configuration. + """ + if HAS_UVLOOP: + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.run( benchmark_generative_text( target=target, @@ -377,21 +490,14 @@ def run( ), ) def from_file(path, output_path): - reimport_benchmarks_report(path, output_path) - - -def decode_escaped_str(_ctx, _param, value): """ - Click auto adds characters. For example, when using --pad-char "\n", - it parses it as "\\n". This method decodes the string to handle escape - sequences correctly. + Load and optionally re-export a previously saved benchmark report. + + Imports benchmark results from a saved file and provides optional conversion + to different output formats. Supports JSON, YAML, and CSV export formats + based on the output file extension. """ - if value is None: - return None - try: - return codecs.decode(value, "unicode_escape") - except Exception as e: - raise click.BadParameter(f"Could not decode escape sequences: {e}") from e + reimport_benchmarks_report(path, output_path) @cli.command( @@ -402,12 +508,25 @@ def decode_escaped_str(_ctx, _param, value): ), ) def config(): + """ + Display available GuideLLM configuration environment variables. + + Prints a comprehensive list of all environment variables that can be used + to configure GuideLLM behavior, including their current values, defaults, + and descriptions. + """ print_config() @cli.group(help="General preprocessing tools and utilities.") def preprocess(): - pass + """ + Preprocessing command group for dataset preparation and transformation. + + This command group provides utilities for converting, processing, and + optimizing datasets for use in GuideLLM benchmarks. Includes functionality + for token count adjustments, format conversions, and data validation. + """ @preprocess.command( @@ -523,6 +642,13 @@ def dataset( hub_dataset_id, random_seed, ): + """ + Convert and process datasets for specific prompt and output token requirements. + + Transforms datasets to meet target token length specifications using various + strategies for handling short prompts and output length adjustments. Supports + multiple input formats and can optionally push results to Hugging Face Hub. + """ process_dataset( data=data, output_path=output_path, @@ -540,5 +666,121 @@ def dataset( ) +@cli.command(help="Start the GuideLLM mock OpenAI/vLLM server for testing.") +@click.option("--host", default="127.0.0.1", help="Host to bind the server to") +@click.option("--port", default=8000, type=int, help="Port to bind the server to") +@click.option("--workers", default=1, type=int, help="Number of worker processes") +@click.option( + "--model", default="llama-3.1-8b-instruct", help="The name of the model to mock" +) +@click.option("--processor", default=None, help="The processor to use for requests") +@click.option( + "--request-latency", + default=3, + type=float, + help="Request latency in seconds for non-streaming requests", +) +@click.option( + "--request-latency-std", + default=0, + type=float, + help=( + "Request latency standard deviation (normal distribution) " + "in seconds for non-streaming requests" + ), +) +@click.option( + "--ttft-ms", + default=150, + type=float, + help="Time to first token in milliseconds for streaming requests", +) +@click.option( + "--ttft-ms-std", + default=0, + type=float, + help=( + "Time to first token standard deviation (normal distribution) in milliseconds" + ), +) +@click.option( + "--itl-ms", + default=10, + type=float, + help="Inter token latency in milliseconds for streaming requests", +) +@click.option( + "--itl-ms-std", + default=0, + type=float, + help=( + "Inter token latency standard deviation (normal distribution) " + "in milliseconds for streaming requests" + ), +) +@click.option( + "--output-tokens", + default=128, + type=int, + help="Output tokens for streaming requests", +) +@click.option( + "--output-tokens-std", + default=0, + type=float, + help=( + "Output tokens standard deviation (normal distribution) for streaming requests" + ), +) +def mock_server( + host: str, + port: int, + workers: int, + model: str, + processor: str | None, + request_latency: float, + request_latency_std: float, + ttft_ms: float, + ttft_ms_std: float, + itl_ms: float, + itl_ms_std: float, + output_tokens: int, + output_tokens_std: float, +): + """ + Start a GuideLLM mock OpenAI/vLLM-compatible server for testing and development. + + Launches a mock server that simulates model inference with configurable latency + characteristics, token generation patterns, and response timing. Useful for + testing GuideLLM benchmarks without requiring actual model deployment or for + development scenarios requiring predictable server behavior. + """ + + config = MockServerConfig( + host=host, + port=port, + workers=workers, + model=model, + processor=processor, + request_latency=request_latency, + request_latency_std=request_latency_std, + ttft_ms=ttft_ms, + ttft_ms_std=ttft_ms_std, + itl_ms=itl_ms, + itl_ms_std=itl_ms_std, + output_tokens=output_tokens, + output_tokens_std=output_tokens_std, + ) + + server = MockServer(config) + console = Console() + console.print_update( + title="GuideLLM mock server starting...", + details=f"Listening on http://{host}:{port} for model {model}", + status="success", + ) + server.run() + + if __name__ == "__main__": cli() diff --git a/src/guidellm/backend/backend.py b/src/guidellm/backend/backend.py index a69df07a..c9a73535 100644 --- a/src/guidellm/backend/backend.py +++ b/src/guidellm/backend/backend.py @@ -18,7 +18,6 @@ from guidellm.backend.objects import ( GenerationRequest, - GenerationRequestTimings, GenerationResponse, ) from guidellm.scheduler import BackendInterface @@ -35,7 +34,7 @@ class Backend( RegistryMixin["type[Backend]"], - BackendInterface[GenerationRequest, GenerationRequestTimings, GenerationResponse], + BackendInterface[GenerationRequest, GenerationResponse], ): """ Base class for generative AI backends with registry and lifecycle. diff --git a/src/guidellm/backend/objects.py b/src/guidellm/backend/objects.py index 125e5354..4e538684 100644 --- a/src/guidellm/backend/objects.py +++ b/src/guidellm/backend/objects.py @@ -11,7 +11,11 @@ from pydantic import Field -from guidellm.scheduler import MeasuredRequestTimings +from guidellm.scheduler import ( + MeasuredRequestTimings, + ScheduledRequestInfo, + SchedulerMessagingPydanticRegistry, +) from guidellm.utils import StandardBaseModel __all__ = [ @@ -21,6 +25,7 @@ ] +@SchedulerMessagingPydanticRegistry.register() class GenerationRequest(StandardBaseModel): """Request model for backend generation operations.""" @@ -59,6 +64,7 @@ class GenerationRequest(StandardBaseModel): ) +@SchedulerMessagingPydanticRegistry.register() class GenerationResponse(StandardBaseModel): """Response model for backend generation operations.""" @@ -135,9 +141,11 @@ def preferred_output_tokens( return self.response_output_tokens or self.request_output_tokens +@MeasuredRequestTimings.register("generation_request_timings") class GenerationRequestTimings(MeasuredRequestTimings): """Timing model for tracking generation request lifecycle events.""" + timings_type: Literal["generation_request_timings"] = "generation_request_timings" first_iteration: Optional[float] = Field( default=None, description="Unix timestamp when the first generation iteration began.", @@ -146,3 +154,9 @@ class GenerationRequestTimings(MeasuredRequestTimings): default=None, description="Unix timestamp when the last generation iteration completed.", ) + + +# Rebuild ScheduledRequestInfo to recognize MeasuredRequestTimings schema change +ScheduledRequestInfo.model_rebuild(force=True) + +SchedulerMessagingPydanticRegistry.register_decorator(GenerationRequestTimings) diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index d259f498..d616be6a 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -279,11 +279,9 @@ async def default_model(self) -> Optional[str]: async def resolve( self, request: GenerationRequest, - request_info: ScheduledRequestInfo[GenerationRequestTimings], + request_info: ScheduledRequestInfo, history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, - ) -> AsyncIterator[ - tuple[GenerationResponse, ScheduledRequestInfo[GenerationRequestTimings]] - ]: + ) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]: """ Process a generation request and yield progressive responses. diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py index 1df6013b..29cf0316 100644 --- a/src/guidellm/benchmark/aggregator.py +++ b/src/guidellm/benchmark/aggregator.py @@ -34,12 +34,10 @@ runtime_checkable, ) -import numpy as np from pydantic import Field, PrivateAttr from guidellm.backend import ( GenerationRequest, - GenerationRequestTimings, GenerationResponse, ) from guidellm.benchmark.objects import ( @@ -47,14 +45,13 @@ GenerativeMetrics, GenerativeRequestStats, ) -from guidellm.config import settings from guidellm.scheduler import ( - MeasuredRequestTimingsT, RequestT, ResponseT, ScheduledRequestInfo, SchedulerState, ) +from guidellm.settings import settings from guidellm.utils import ( InfoMixin, PydanticClassRegistryMixin, @@ -154,7 +151,7 @@ def get_metric( @runtime_checkable -class Aggregator(Protocol[ResponseT, RequestT, MeasuredRequestTimingsT]): +class Aggregator(Protocol[ResponseT, RequestT]): """ Protocol for processing benchmark data updates during execution. @@ -168,7 +165,7 @@ def __call__( state: AggregatorState, response: ResponseT | None, request: RequestT, - request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + request_info: ScheduledRequestInfo, scheduler_state: SchedulerState, ) -> dict[str, Any] | None: """ @@ -184,7 +181,7 @@ def __call__( @runtime_checkable -class CompilableAggregator(Protocol[ResponseT, RequestT, MeasuredRequestTimingsT]): +class CompilableAggregator(Protocol[ResponseT, RequestT]): """ Protocol for aggregators that compile final results from aggregated state. @@ -197,7 +194,7 @@ def __call__( state: AggregatorState, response: ResponseT | None, request: RequestT, - request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + request_info: ScheduledRequestInfo, scheduler_state: SchedulerState, ) -> dict[str, Any] | None: """ @@ -226,7 +223,7 @@ def compile( class SerializableAggregator( PydanticClassRegistryMixin[type["SerializableAggregator"]], ABC, - Generic[ResponseT, RequestT, MeasuredRequestTimingsT], + Generic[ResponseT, RequestT], ): schema_discriminator: ClassVar[str] = "type_" @@ -287,7 +284,7 @@ def __call__( state: AggregatorState, response: ResponseT | None, request: RequestT, - request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + request_info: ScheduledRequestInfo, scheduler_state: SchedulerState, ) -> dict[str, Any] | None: """ @@ -315,9 +312,7 @@ def compile( @SerializableAggregator.register("inject_extras") -class InjectExtrasAggregator( - SerializableAggregator[ResponseT, RequestT, MeasuredRequestTimingsT], InfoMixin -): +class InjectExtrasAggregator(SerializableAggregator[ResponseT, RequestT], InfoMixin): """ Aggregator for injecting extra metadata into the output. """ @@ -334,7 +329,7 @@ def __call__( state: AggregatorState, response: ResponseT | None, request: RequestT, - request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + request_info: ScheduledRequestInfo, scheduler_state: SchedulerState, ) -> dict[str, Any] | None: """ @@ -356,9 +351,7 @@ def compile( @SerializableAggregator.register("scheduler_stats") -class SchedulerStatsAggregator( - SerializableAggregator[ResponseT, RequestT, MeasuredRequestTimingsT], InfoMixin -): +class SchedulerStatsAggregator(SerializableAggregator[ResponseT, RequestT], InfoMixin): """ Aggregates scheduler timing and performance metrics. @@ -377,7 +370,7 @@ def __call__( state: AggregatorState, response: ResponseT | None, request: RequestT, - request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + request_info: ScheduledRequestInfo, scheduler_state: SchedulerState, ) -> dict[str, Any] | None: """ @@ -477,7 +470,7 @@ def compile( key="worker_resolve_time", type_="avg", default=0.0 ), worker_resolve_end_delay_avg=state.get_metric( - key="worker_resolve_end_delay", type_="avg" + key="worker_resolve_end_delay", type_="avg", default=0.0 ), finalized_delay_avg=state.get_metric( key="finalized_delay", type_="avg", default=0.0 @@ -500,9 +493,7 @@ def compile( @SerializableAggregator.register("generative_stats_progress") class GenerativeStatsProgressAggregator( - SerializableAggregator[ - GenerationResponse, GenerationRequest, GenerationRequestTimings - ] + SerializableAggregator[GenerationResponse, GenerationRequest] ): """ Tracks generative model metrics during benchmark execution. @@ -524,7 +515,7 @@ def __call__( state: AggregatorState, response: GenerationResponse | None, request: GenerationRequest, - request_info: ScheduledRequestInfo[GenerationRequestTimings], + request_info: ScheduledRequestInfo, scheduler_state: SchedulerState, ) -> dict[str, Any] | None: """ @@ -668,9 +659,7 @@ def compile( @SerializableAggregator.register("generative_requests") class GenerativeRequestsAggregator( - SerializableAggregator[ - GenerationResponse, GenerationRequest, GenerationRequestTimings - ], + SerializableAggregator[GenerationResponse, GenerationRequest], ): """ Compiles complete generative benchmark results with warmup/cooldown filtering. @@ -713,7 +702,7 @@ def __call__( state: AggregatorState, response: GenerationResponse | None, request: GenerationRequest, - request_info: ScheduledRequestInfo[GenerationRequestTimings], + request_info: ScheduledRequestInfo, scheduler_state: SchedulerState, ) -> dict[str, Any] | None: """ @@ -876,7 +865,7 @@ def compile( def _is_in_warmup( self, - request_info: ScheduledRequestInfo[GenerationRequestTimings], + request_info: ScheduledRequestInfo, scheduler_state: SchedulerState, ) -> bool: """Check if the current request is within the warmup period.""" @@ -903,7 +892,7 @@ def _is_in_warmup( def _is_in_cooldown( self, - request_info: ScheduledRequestInfo[GenerationRequestTimings], + request_info: ScheduledRequestInfo, scheduler_state: SchedulerState, ) -> bool: """Check if the current request is within the cooldown period.""" @@ -937,7 +926,7 @@ def _create_generative_request_stats( cls, response: GenerationResponse, request: GenerationRequest, - request_info: ScheduledRequestInfo[GenerationRequestTimings], + request_info: ScheduledRequestInfo, ) -> GenerativeRequestStats: prompt_tokens = response.preferred_prompt_tokens( settings.preferred_prompt_tokens_source diff --git a/src/guidellm/benchmark/benchmarker.py b/src/guidellm/benchmark/benchmarker.py index ce035623..ae591c23 100644 --- a/src/guidellm/benchmark/benchmarker.py +++ b/src/guidellm/benchmark/benchmarker.py @@ -36,7 +36,6 @@ BackendInterface, Constraint, Environment, - MeasuredRequestTimingsT, NonDistributedEnvironment, RequestT, ResponseT, @@ -51,7 +50,7 @@ class Benchmarker( - Generic[BenchmarkT, RequestT, MeasuredRequestTimingsT, ResponseT], + Generic[BenchmarkT, RequestT, ResponseT], ABC, ThreadSafeSingletonMixin, ): @@ -69,13 +68,12 @@ class Benchmarker( async def run( self, requests: Iterable[RequestT | Iterable[RequestT | tuple[RequestT, float]]], - backend: BackendInterface[RequestT, MeasuredRequestTimingsT, ResponseT], + backend: BackendInterface[RequestT, ResponseT], profile: Profile, benchmark_class: type[BenchmarkT], benchmark_aggregators: dict[ str, - Aggregator[ResponseT, RequestT, MeasuredRequestTimingsT] - | CompilableAggregator[ResponseT, RequestT, MeasuredRequestTimingsT], + Aggregator[ResponseT, RequestT] | CompilableAggregator[ResponseT, RequestT], ], environment: Environment | None = None, ) -> AsyncIterator[ @@ -121,7 +119,7 @@ async def run( request, request_info, scheduler_state, - ) in Scheduler[RequestT, MeasuredRequestTimingsT, ResponseT]().run( + ) in Scheduler[RequestT, ResponseT]().run( requests=requests, backend=backend, strategy=strategy, @@ -170,12 +168,11 @@ def _compile_benchmark_kwargs( run_index: int, profile: Profile, requests: Iterable[RequestT | Iterable[RequestT | tuple[RequestT, float]]], - backend: BackendInterface[RequestT, MeasuredRequestTimingsT, ResponseT], + backend: BackendInterface[RequestT, ResponseT], environment: Environment, aggregators: dict[ str, - Aggregator[ResponseT, RequestT, MeasuredRequestTimingsT] - | CompilableAggregator[ResponseT, RequestT, MeasuredRequestTimingsT], + Aggregator[ResponseT, RequestT] | CompilableAggregator[ResponseT, RequestT], ], aggregators_state: dict[str, dict[str, Any]], strategy: SchedulingStrategy, diff --git a/src/guidellm/benchmark/entrypoints.py b/src/guidellm/benchmark/entrypoints.py index 250725f0..82f92ceb 100644 --- a/src/guidellm/benchmark/entrypoints.py +++ b/src/guidellm/benchmark/entrypoints.py @@ -13,7 +13,6 @@ Backend, BackendType, GenerationRequest, - GenerationRequestTimings, GenerationResponse, ) from guidellm.benchmark.aggregator import ( @@ -42,7 +41,7 @@ NonDistributedEnvironment, StrategyType, ) -from guidellm.utils import UNSET, Console, InfoMixin +from guidellm.utils import Console, InfoMixin __all__ = [ "benchmark_generative_text", @@ -103,8 +102,8 @@ async def benchmark_generative_text( # noqa: C901 print_updates: bool = False, # Aggregators configuration add_aggregators: ( - dict[str, str | dict[str, Any] | Aggregator | CompilableAggregator] - ) = UNSET, + dict[str, str | dict[str, Any] | Aggregator | CompilableAggregator] | None + ) = None, warmup: float | None = None, cooldown: float | None = None, request_samples: int | None = 20, @@ -209,7 +208,7 @@ async def benchmark_generative_text( # noqa: C901 ) elif constraints: raise ValueError( - "Constraints must be empty or unset when providing a Profile instance. " + "Constraints must be empty when providing a Profile instance. " f"Provided constraints: {constraints} ; provided profile: {profile}" ) console_step.finish( @@ -266,7 +265,6 @@ async def benchmark_generative_text( # noqa: C901 Benchmarker[ GenerativeBenchmark, GenerationRequest, - GenerationRequestTimings, GenerationResponse, ]().run( requests=request_loader, diff --git a/src/guidellm/benchmark/objects.py b/src/guidellm/benchmark/objects.py index 36d6a01a..8afabba9 100644 --- a/src/guidellm/benchmark/objects.py +++ b/src/guidellm/benchmark/objects.py @@ -31,7 +31,6 @@ import yaml from pydantic import Field, computed_field -from guidellm.backend import GenerationRequestTimings from guidellm.benchmark.profile import ( Profile, ) @@ -134,7 +133,7 @@ class BenchmarkMetrics(StandardBaseDict): class BenchmarkRequestStats(StandardBaseDict): """Individual request processing statistics and scheduling metadata.""" - scheduler_info: ScheduledRequestInfo[GenerationRequestTimings] = Field( + scheduler_info: ScheduledRequestInfo = Field( description="Scheduler metadata and timing information for the request" ) diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py index 2288de41..5816cd38 100644 --- a/src/guidellm/benchmark/output.py +++ b/src/guidellm/benchmark/output.py @@ -5,7 +5,6 @@ import math from abc import ABC, abstractmethod from collections import OrderedDict -from datetime import datetime from pathlib import Path from typing import Any, ClassVar @@ -25,9 +24,9 @@ SweepProfile, ThroughputProfile, ) -from guidellm.config import settings from guidellm.presentation import UIDataBuilder from guidellm.presentation.injector import create_report +from guidellm.settings import settings from guidellm.utils import ( Colors, DistributionSummary, @@ -36,6 +35,7 @@ safe_format_timestamp, split_text_list_by_length, ) +from guidellm.utils.general import safe_format_timestamp __all__ = [ "GenerativeBenchmarkerCSV", @@ -621,8 +621,8 @@ def _get_benchmark_desc_headers_and_values( benchmark.run_id, benchmark.id_, str(benchmark.scheduler.strategy), - datetime.fromtimestamp(benchmark.start_time).strftime("%Y-%m-%d %H:%M:%S"), - datetime.fromtimestamp(benchmark.end_time).strftime("%Y-%m-%d %H:%M:%S"), + safe_format_timestamp(benchmark.start_time, "%Y-%m-%d %H:%M:%S", "N/A"), + safe_format_timestamp(benchmark.end_time, "%Y-%m-%d %H:%M:%S", "N/A"), benchmark.duration, ] return headers, values diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py index 1f677c1c..b1cbdf5f 100644 --- a/src/guidellm/benchmark/profile.py +++ b/src/guidellm/benchmark/profile.py @@ -653,15 +653,22 @@ def next_strategy( :param prev_strategy: The previously completed strategy. :param prev_benchmark: Benchmark results from the previous strategy. :return: Next strategy in sweep sequence, or None if complete. + :raises RuntimeError: If synchronous or throughput benchmarks fail + (≤0 requests/second). :raises ValueError: If strategy_type is neither 'constant' nor 'poisson'. """ if prev_strategy is None: return SynchronousStrategy() if prev_strategy.type_ == "synchronous": - self.synchronous_rate = ( - prev_benchmark.metrics.requests_per_second.successful.mean - ) + sync_rate = prev_benchmark.metrics.requests_per_second.successful.mean + if sync_rate <= 0: + raise RuntimeError( + f"Synchronous benchmark failed with {sync_rate:.2f} " + "requests/second. Cannot proceed with sweep - check server " + "connectivity and constraints." + ) + self.synchronous_rate = sync_rate return ThroughputStrategy( max_concurrency=self.max_concurrency, @@ -669,9 +676,15 @@ def next_strategy( ) if prev_strategy.type_ == "throughput": - self.throughput_rate = ( - prev_benchmark.metrics.requests_per_second.successful.mean - ) + throughput_rate = prev_benchmark.metrics.requests_per_second.successful.mean + if throughput_rate <= 0: + raise RuntimeError( + f"Throughput benchmark failed with {throughput_rate:.2f} " + "requests/second. Cannot proceed with sweep - check server " + "connectivity and constraints." + ) + self.throughput_rate = throughput_rate + self.measured_rates = list( np.linspace( self.synchronous_rate, diff --git a/src/guidellm/benchmark/progress.py b/src/guidellm/benchmark/progress.py index 17bfb605..edbb9f37 100644 --- a/src/guidellm/benchmark/progress.py +++ b/src/guidellm/benchmark/progress.py @@ -20,7 +20,6 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterable, AsyncIterator, Iterable from dataclasses import dataclass -from datetime import datetime from typing import Any, Generic, Literal from rich.console import Group @@ -46,6 +45,7 @@ StrategyType, ) from guidellm.utils import Colors, format_value_display +from guidellm.utils.general import safe_format_timestamp __all__ = [ "BenchmarkerProgress", @@ -624,7 +624,7 @@ def formatted_start_time(self) -> str: if self.start_time < 0.0: return "--:--:--" - return datetime.fromtimestamp(self.start_time).strftime("%H:%M:%S") + return safe_format_timestamp(self.start_time, "%H:%M:%S", "--:--:--") @property def formatted_progress_status(self) -> str: @@ -802,7 +802,11 @@ def start(self, strategy: SchedulingStrategy): def update( self, aggregator_update: AggregatorState, scheduler_state: SchedulerState ): - self.progress = scheduler_state.remaining_fraction + self.progress = ( + (1.0 - scheduler_state.remaining_fraction) + if scheduler_state.remaining_fraction is not None + else 0.0 + ) status: Literal["in_warmup", "in_progress", "in_cooldown"] | None = ( "in_progress" # Need to handle requests_in_* isn't in aggregator_update ) diff --git a/src/guidellm/logger.py b/src/guidellm/logger.py index ac235c99..48b41a49 100644 --- a/src/guidellm/logger.py +++ b/src/guidellm/logger.py @@ -41,7 +41,7 @@ from loguru import logger -from guidellm.config import LoggingSettings, settings +from guidellm.settings import LoggingSettings, settings __all__ = ["configure_logger", "logger"] diff --git a/src/guidellm/mock_server/__init__.py b/src/guidellm/mock_server/__init__.py new file mode 100644 index 00000000..f76e98fb --- /dev/null +++ b/src/guidellm/mock_server/__init__.py @@ -0,0 +1,8 @@ +""" +GuideLLM Mock Server for OpenAI and vLLM API compatibility. +""" + +from .config import MockServerConfig +from .server import MockServer + +__all__ = ["MockServer", "MockServerConfig"] diff --git a/src/guidellm/mock_server/config.py b/src/guidellm/mock_server/config.py new file mode 100644 index 00000000..27d1d742 --- /dev/null +++ b/src/guidellm/mock_server/config.py @@ -0,0 +1,84 @@ +""" +Configuration settings for the mock server component. + +Provides centralized configuration management for mock server behavior including +network binding, model identification, response timing characteristics, and token +generation parameters. Supports environment variable configuration for deployment +flexibility with automatic validation through Pydantic settings. +""" + +from __future__ import annotations + +from pydantic import Field +from pydantic_settings import BaseSettings + +__all__ = ["MockServerConfig"] + + +class MockServerConfig(BaseSettings): + """ + Configuration settings for mock server behavior and deployment. + + Centralizes all configurable parameters for mock server operation including + network settings, model identification, response timing characteristics, and + token generation behavior. Environment variables with GUIDELLM_MOCK_SERVER_ + prefix override default values for deployment flexibility. + + Example: + :: + config = MockServerConfig(host="0.0.0.0", port=8080, model="custom-model") + # Use with environment variables: + # GUIDELLM_MOCK_SERVER_HOST=127.0.0.1 GUIDELLM_MOCK_SERVER_PORT=9000 + """ + + host: str = Field( + default="127.0.0.1", description="Host address to bind the server to" + ) + port: int = Field(default=8000, description="Port number to bind the server to") + workers: int = Field(default=1, description="Number of worker processes to spawn") + model: str = Field( + default="llama-3.1-8b-instruct", + description="Model name to present in API responses", + ) + processor: str | None = Field( + default=None, + description=( + "Processor type to use for token stats, tokenize, and detokenize. " + "If None, a mock one is created." + ), + ) + request_latency: float = Field( + default=3.0, + description="Base request latency in seconds for non-streaming responses", + ) + request_latency_std: float = Field( + default=0.0, + description="Standard deviation for request latency variation", + ) + ttft_ms: float = Field( + default=150.0, + description="Time to first token in milliseconds for streaming responses", + ) + ttft_ms_std: float = Field( + default=0.0, + description="Standard deviation for time to first token variation", + ) + itl_ms: float = Field( + default=10.0, + description="Inter-token latency in milliseconds for streaming responses", + ) + itl_ms_std: float = Field( + default=0.0, + description="Standard deviation for inter-token latency variation", + ) + output_tokens: int = Field( + default=128, description="Number of output tokens to generate in responses" + ) + output_tokens_std: float = Field( + default=0.0, + description="Standard deviation for output token count variation", + ) + + class Config: + env_prefix = "GUIDELLM_MOCK_SERVER_" + case_sensitive = False diff --git a/src/guidellm/mock_server/handlers/__init__.py b/src/guidellm/mock_server/handlers/__init__.py new file mode 100644 index 00000000..7dbc209f --- /dev/null +++ b/src/guidellm/mock_server/handlers/__init__.py @@ -0,0 +1,17 @@ +""" +HTTP request handlers for the GuideLLM mock server. + +This module exposes request handlers that implement OpenAI-compatible API endpoints +for the mock server. The handlers provide realistic LLM simulation capabilities +including chat completions, legacy completions, and tokenization services with +configurable timing characteristics, token counting, and proper error handling to +support comprehensive benchmarking and testing scenarios. +""" + +from __future__ import annotations + +from .chat_completions import ChatCompletionsHandler +from .completions import CompletionsHandler +from .tokenizer import TokenizerHandler + +__all__ = ["ChatCompletionsHandler", "CompletionsHandler", "TokenizerHandler"] diff --git a/src/guidellm/mock_server/handlers/chat_completions.py b/src/guidellm/mock_server/handlers/chat_completions.py new file mode 100644 index 00000000..976901f9 --- /dev/null +++ b/src/guidellm/mock_server/handlers/chat_completions.py @@ -0,0 +1,280 @@ +""" +OpenAI Chat Completions API endpoint handler for the mock server. + +Provides a complete implementation of the /v1/chat/completions endpoint that simulates +realistic LLM behavior with configurable timing characteristics. Supports both streaming +and non-streaming responses with proper token counting, latency simulation including +TTFT (Time To First Token) and ITL (Inter-Token Latency), and OpenAI-compatible error +handling for comprehensive benchmarking scenarios. +""" + +from __future__ import annotations + +import asyncio +import json +import math +import time +import uuid + +from pydantic import ValidationError +from sanic import response +from sanic.request import Request +from sanic.response import HTTPResponse, ResponseStream +from transformers import PreTrainedTokenizer + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.models import ( + ChatCompletionChoice, + ChatCompletionsRequest, + ChatCompletionsResponse, + ChatMessage, + ErrorDetail, + ErrorResponse, + Usage, +) +from guidellm.mock_server.utils import ( + MockTokenizer, + create_fake_text, + create_fake_tokens_str, + sample_number, + times_generator, +) + +__all__ = ["ChatCompletionsHandler"] + + +class ChatCompletionsHandler: + """ + Handles OpenAI Chat Completions API requests with realistic LLM simulation. + + Implements the /v1/chat/completions endpoint behavior including request validation, + response generation, and timing simulation. Supports both streaming and + non-streaming modes with configurable latency characteristics for comprehensive + benchmarking. Uses either a mock tokenizer or a real tokenizer for accurate token + counting and realistic text generation. + + Example: + :: + config = MockServerConfig(ttft_ms=100, itl_ms=50) + handler = ChatCompletionsHandler(config) + response = await handler.handle(request) + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the Chat Completions handler with server configuration. + + :param config: Mock server configuration containing timing and behavior settings + """ + self.config = config + self.tokenizer = ( + MockTokenizer() + if config.processor is None + else PreTrainedTokenizer.from_pretrained(config.processor) + ) + + async def handle(self, request: Request) -> HTTPResponse: + """ + Process incoming chat completion requests with validation and routing. + + Validates the request payload, handles errors gracefully, and routes to + appropriate streaming or non-streaming response handlers based on the + request configuration. + + :param request: Sanic HTTP request containing chat completion parameters + :return: HTTP response with completion data or error information + :raises ValidationError: When request payload fails validation + :raises JSONDecodeError: When request contains invalid JSON + """ + try: + # Parse and validate request + req_data = ChatCompletionsRequest(**request.json) + except ValidationError as exc: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(exc)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (json.JSONDecodeError, TypeError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + # Handle streaming vs non-streaming + if req_data.stream: + return await self._handle_stream(req_data) + else: + return await self._handle_non_stream(req_data) + + async def _handle_non_stream(self, req: ChatCompletionsRequest) -> HTTPResponse: + """ + Generate complete non-streaming chat completion response. + + Simulates realistic LLM behavior with TTFT and ITL delays, generates + appropriate token counts, and returns a complete response with usage + statistics and generated content. + + :param req: Validated chat completion request parameters + :return: Complete HTTP response with generated completion data + """ + # TTFT delay + await asyncio.sleep( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + + # Token counts + prompt_text = self.tokenizer.apply_chat_template(req.messages) + prompt_tokens = len(self.tokenizer(prompt_text)) + max_tokens = req.max_completion_tokens or req.max_tokens or math.inf + completion_tokens_count = min( + sample_number(self.config.output_tokens, self.config.output_tokens_std), + max_tokens, + ) + + # ITL delay + itl_delay = 0.0 + delays_iter = iter(times_generator(self.config.itl_ms, self.config.itl_ms_std)) + for _ in range(int(completion_tokens_count) - 1): + itl_delay += next(delays_iter) + await asyncio.sleep(itl_delay / 1000.0) + + # Response + chat_response = ChatCompletionsResponse( + id=f"chatcmpl-{uuid.uuid4().hex[:29]}", + model=req.model, + choices=[ + ChatCompletionChoice( + index=0, + message=ChatMessage( + role="assistant", + content=create_fake_text( + int(completion_tokens_count), self.tokenizer + ), + ), + finish_reason="stop", + ) + ], + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=int(completion_tokens_count), + ), + system_fingerprint=f"fp_{uuid.uuid4().hex[:10]}", + ) + + return response.json(chat_response.model_dump()) + + async def _handle_stream(self, req: ChatCompletionsRequest) -> HTTPResponse: + """ + Generate streaming chat completion response with real-time token delivery. + + Creates a streaming response that delivers tokens incrementally with + realistic timing delays. Supports optional usage statistics in the final + stream chunk when requested via stream_options. + + :param req: Validated chat completion request with streaming enabled + :return: Streaming HTTP response delivering tokens with proper timing + """ + + async def generate_stream(stream_response): + completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" + + # TTFT delay + await asyncio.sleep( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + + # Token counts + prompt_text = self.tokenizer.apply_chat_template(req.messages) + prompt_tokens = len(self.tokenizer(prompt_text)) + max_tokens = req.max_completion_tokens or req.max_tokens or math.inf + completion_tokens_count = int( + min( + sample_number( + self.config.output_tokens, self.config.output_tokens_std + ), + max_tokens, + ) + ) + + # Send tokens + tokens = create_fake_tokens_str(completion_tokens_count, self.tokenizer) + delays_iter = iter( + times_generator(self.config.itl_ms, self.config.itl_ms_std) + ) + + for index, token in enumerate(tokens): + if index > 0: + itl_delay = next(delays_iter) + await asyncio.sleep(itl_delay / 1000.0) + + chunk_data = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": req.model, + "choices": [ + { + "index": 0, + "delta": {"content": token}, + "finish_reason": None, + } + ], + } + await stream_response.write(f"data: {json.dumps(chunk_data)}\n\n") + + # Send final chunk with finish reason + final_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": req.model, + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop", + } + ], + } + await stream_response.write(f"data: {json.dumps(final_chunk)}\n\n") + + # Send usage if requested + if req.stream_options and req.stream_options.get("include_usage"): + usage_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": req.model, + "choices": [], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens_count, + "total_tokens": prompt_tokens + completion_tokens_count, + }, + } + await stream_response.write(f"data: {json.dumps(usage_chunk)}\n\n") + + # End stream + await stream_response.write("data: [DONE]\n\n") + + return ResponseStream( # type: ignore[return-value] + generate_stream, + content_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/src/guidellm/mock_server/handlers/completions.py b/src/guidellm/mock_server/handlers/completions.py new file mode 100644 index 00000000..418d2b3c --- /dev/null +++ b/src/guidellm/mock_server/handlers/completions.py @@ -0,0 +1,280 @@ +""" +Legacy OpenAI Completions API handler for the mock server. + +This module provides the CompletionsHandler class that implements the /v1/completions +endpoint for the guidellm mock server. It supports both streaming and non-streaming +completions with configurable timing parameters (TTFT, ITL) and token generation to +simulate realistic LLM behavior for benchmarking and testing purposes. +""" + +from __future__ import annotations + +import asyncio +import json +import math +import time +import uuid + +from pydantic import ValidationError +from sanic import response +from sanic.request import Request +from sanic.response import HTTPResponse, ResponseStream +from transformers import PreTrainedTokenizer + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.models import ( + CompletionChoice, + CompletionsRequest, + CompletionsResponse, + ErrorDetail, + ErrorResponse, + Usage, +) +from guidellm.mock_server.utils import ( + MockTokenizer, + create_fake_text, + create_fake_tokens_str, + sample_number, + times_generator, +) + +__all__ = ["CompletionsHandler"] + + +class CompletionsHandler: + """ + Handler for the OpenAI /v1/completions endpoint in the mock server. + + This handler simulates the legacy OpenAI completions API by processing incoming + requests and generating responses with configurable timing and token generation + patterns. It supports both streaming and non-streaming modes, applying realistic + timing delays (TTFT and ITL) to mimic actual LLM behavior for benchmarking. + + Example: + :: + config = MockServerConfig(ttft_ms=100, itl_ms=50) + handler = CompletionsHandler(config) + response = await handler.handle(sanic_request) + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the completions handler with configuration settings. + + :param config: Mock server configuration containing timing parameters + and tokenizer settings + """ + self.config = config + self.tokenizer = ( + MockTokenizer() + if config.processor is None + else PreTrainedTokenizer.from_pretrained(config.processor) + ) + + async def handle(self, request: Request) -> HTTPResponse: + """ + Process a completions request and return the appropriate response. + + Validates the incoming request, determines whether to use streaming or + non-streaming mode, and delegates to the appropriate handler method. + + :param request: Sanic request object containing the completions request data + :return: HTTP response with completion data or error information + :raises ValidationError: When request validation fails + :raises json.JSONDecodeError: When request JSON is malformed + """ + try: + # Parse and validate request + req_data = CompletionsRequest(**request.json) + except ValidationError as e: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(e)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (json.JSONDecodeError, TypeError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + # Handle streaming vs non-streaming + if req_data.stream: + return await self._handle_stream(req_data) + else: + return await self._handle_non_stream(req_data) + + async def _handle_non_stream(self, req: CompletionsRequest) -> HTTPResponse: + """ + Generate a non-streaming completion response. + + Simulates TTFT and ITL delays, generates appropriate token counts, and returns + a complete response with the generated text and usage statistics. + + :param req: Validated completions request containing prompt and parameters + :return: JSON HTTP response with completion text and usage data + :raises NotImplementedError: When batch processing is requested + """ + if isinstance(req.prompt, list): + raise NotImplementedError("Batch processing is not supported.") + + # TTFT delay + await asyncio.sleep( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + + # Token counts + prompt_tokens = len(self.tokenizer(req.prompt)) + max_tokens = req.max_tokens or math.inf + completion_tokens_count = int( + min( + sample_number(self.config.output_tokens, self.config.output_tokens_std), + max_tokens, + ) + if req.stop + else max_tokens + ) + + # ITL delay + itl_delay = 0.0 + delays_iter = iter(times_generator(self.config.itl_ms, self.config.itl_ms_std)) + for _ in range(int(completion_tokens_count) - 1): + itl_delay += next(delays_iter) + await asyncio.sleep(itl_delay / 1000.0) + + # Response + completion_response = CompletionsResponse( + id=f"cmpl-{uuid.uuid4().hex[:29]}", + model=req.model, + choices=[ + CompletionChoice( + text=create_fake_text(completion_tokens_count, self.tokenizer), + index=0, + finish_reason="stop", + ) + ], + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens_count, + ), + system_fingerprint=f"fp_{uuid.uuid4().hex[:10]}", + ) + + return response.json(completion_response.model_dump()) + + async def _handle_stream(self, req: CompletionsRequest) -> HTTPResponse: + """ + Generate a streaming completion response. + + Creates a server-sent events stream that delivers tokens incrementally with + realistic timing delays between each token. Includes usage statistics if + requested and properly terminates the stream. + + :param req: Validated completions request containing prompt and streaming + options + :return: ResponseStream object that generates server-sent events + """ + + async def generate_stream(stream_response): + completion_id = f"cmpl-{uuid.uuid4().hex[:29]}" + + # TTFT delay + await asyncio.sleep( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + + # Token counts + prompt_tokens = len(self.tokenizer(req.prompt)) + max_tokens = req.max_tokens or math.inf + completion_tokens_count = int( + min( + sample_number( + self.config.output_tokens, self.config.output_tokens_std + ), + max_tokens, + ) + if req.stop + else max_tokens + ) + + # Send tokens + tokens = create_fake_tokens_str(completion_tokens_count, self.tokenizer) + delays_iter = iter( + times_generator(self.config.itl_ms, self.config.itl_ms_std) + ) + + for index, token in enumerate(tokens): + if index > 0: + itl_delay = next(delays_iter) + await asyncio.sleep(itl_delay / 1000.0) + + chunk_data = { + "id": completion_id, + "object": "text_completion", + "created": int(time.time()), + "model": req.model, + "choices": [ + { + "text": token, + "index": index, + "finish_reason": None, + } + ], + } + await stream_response.write(f"data: {json.dumps(chunk_data)}\n\n") + + # Send final chunk with finish reason + final_chunk = { + "id": completion_id, + "object": "text_completion", + "created": int(time.time()), + "model": req.model, + "choices": [ + { + "text": "", + "index": index, + "finish_reason": "stop", + } + ], + } + await stream_response.write(f"data: {json.dumps(final_chunk)}\n\n") + + # Send usage if requested + if req.stream_options and req.stream_options.get("include_usage"): + usage_chunk = { + "id": completion_id, + "object": "text_completion", + "created": int(time.time()), + "model": req.model, + "choices": [], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens_count, + "total_tokens": prompt_tokens + completion_tokens_count, + }, + } + await stream_response.write(f"data: {json.dumps(usage_chunk)}\n\n") + + # End stream + await stream_response.write("data: [DONE]\n\n") + + return ResponseStream( # type: ignore[return-value] + generate_stream, + content_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/src/guidellm/mock_server/handlers/tokenizer.py b/src/guidellm/mock_server/handlers/tokenizer.py new file mode 100644 index 00000000..430ac0ef --- /dev/null +++ b/src/guidellm/mock_server/handlers/tokenizer.py @@ -0,0 +1,142 @@ +""" +HTTP request handler for vLLM tokenization API endpoints in the mock server. + +This module provides the TokenizerHandler class that implements vLLM-compatible +tokenization and detokenization endpoints for testing and development purposes. +It handles text-to-token conversion, token-to-text reconstruction, request +validation, and error responses with proper HTTP status codes and JSON formatting. +""" + +from __future__ import annotations + +from pydantic import ValidationError +from sanic import response +from sanic.request import Request +from sanic.response import HTTPResponse +from transformers.tokenization_utils import PreTrainedTokenizer + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.models import ( + DetokenizeRequest, + DetokenizeResponse, + ErrorDetail, + ErrorResponse, + TokenizeRequest, + TokenizeResponse, +) +from guidellm.mock_server.utils import MockTokenizer + +__all__ = ["TokenizerHandler"] + + +class TokenizerHandler: + """ + HTTP request handler for vLLM tokenization and detokenization endpoints. + + Provides mock implementations of vLLM's tokenization API endpoints including + /tokenize for converting text to tokens and /detokenize for reconstructing + text from token sequences. Handles request validation, error responses, and + JSON serialization with proper HTTP status codes. + + Example: + :: + handler = TokenizerHandler(config) + response = await handler.tokenize(request) + response = await handler.detokenize(request) + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the tokenizer handler with configuration. + + :param config: Server configuration object containing tokenizer settings + """ + self.config = config + self.tokenizer = ( + MockTokenizer() + if config.processor is None + else PreTrainedTokenizer.from_pretrained(config.processor) + ) + + async def tokenize(self, request: Request) -> HTTPResponse: + """ + Convert input text to token IDs via the /tokenize endpoint. + + Validates the request payload, extracts text content, and returns a JSON + response containing the token sequence and count. Handles validation errors + and malformed JSON with appropriate HTTP error responses. + + :param request: Sanic HTTP request containing JSON payload with text field + :return: JSON response with tokens list and count, or error response + """ + try: + req_data = TokenizeRequest(**request.json) + except ValidationError as exc: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(exc)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (ValueError, TypeError, KeyError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + tokens = self.tokenizer.tokenize(req_data.text) + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + + return response.json( + TokenizeResponse(tokens=token_ids, count=len(token_ids)).model_dump() + ) + + async def detokenize(self, request: Request) -> HTTPResponse: + """ + Convert token IDs back to text via the /detokenize endpoint. + + Validates the request payload, extracts token sequences, and returns a JSON + response containing the reconstructed text. Handles validation errors and + malformed JSON with appropriate HTTP error responses. + + :param request: Sanic HTTP request containing JSON payload with tokens field + :return: JSON response with reconstructed text, or error response + """ + try: + req_data = DetokenizeRequest(**request.json) + except ValidationError as exc: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(exc)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (ValueError, TypeError, KeyError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + text = self.tokenizer.decode(req_data.tokens, skip_special_tokens=False) + + return response.json(DetokenizeResponse(text=text).model_dump()) diff --git a/src/guidellm/mock_server/models.py b/src/guidellm/mock_server/models.py new file mode 100644 index 00000000..cd342f7a --- /dev/null +++ b/src/guidellm/mock_server/models.py @@ -0,0 +1,510 @@ +""" +Pydantic models for OpenAI API and vLLM API request/response validation. + +This module defines comprehensive data models for validating and serializing API +requests and responses compatible with both OpenAI's API specification and vLLM's +extended parameters. It includes models for chat completions, legacy text completions, +tokenization operations, and error handling, supporting both streaming and non-streaming +responses with full type safety and validation. +""" + +from __future__ import annotations + +import time +from typing import Any, Literal + +from pydantic import BaseModel, Field + +__all__ = [ + "ChatCompletionChoice", + "ChatCompletionChunk", + "ChatCompletionsRequest", + "ChatCompletionsResponse", + "ChatMessage", + "CompletionChoice", + "CompletionsRequest", + "CompletionsResponse", + "DetokenizeRequest", + "DetokenizeResponse", + "ErrorDetail", + "ErrorResponse", + "StreamOptions", + "TokenizeRequest", + "TokenizeResponse", + "Usage", +] + + +class Usage(BaseModel): + """Token usage statistics for API requests and responses. + + Tracks the number of tokens consumed in prompts, completions, and total + usage for billing and monitoring purposes. + """ + + prompt_tokens: int = Field(description="Number of tokens in the input prompt") + completion_tokens: int = Field( + description="Number of tokens in the generated completion" + ) + total_tokens: int = Field(description="Total tokens used (prompt + completion)") + + def __init__(self, prompt_tokens: int = 0, completion_tokens: int = 0, **kwargs): + """Initialize usage statistics. + + :param prompt_tokens: Number of tokens in the input prompt + :param completion_tokens: Number of tokens in the generated completion + :param kwargs: Additional keyword arguments passed to BaseModel + """ + super().__init__( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + **kwargs, + ) + + +class StreamOptions(BaseModel): + """Configuration options for streaming API responses. + + Controls the behavior and content of streamed responses including + whether to include usage statistics in the final chunk. + """ + + include_usage: bool | None = Field( + default=None, + description="Whether to include usage statistics in streaming responses", + ) + + +class ChatMessage(BaseModel): + """A single message in a chat conversation. + + Represents one exchange in a conversational interface with role-based + content and optional metadata for advanced features. + """ + + role: Literal["system", "user", "assistant", "tool"] = Field( + description="Role of the message sender in the conversation" + ) + content: str = Field(description="Text content of the message") + name: str | None = Field( + default=None, description="Optional name identifier for the message sender" + ) + + +class ChatCompletionsRequest(BaseModel): + """Request parameters for chat completion API endpoints. + + Comprehensive model supporting both OpenAI standard parameters and vLLM + extensions for advanced generation control, guided decoding, and performance + optimization. + """ + + model: str = Field(description="Model identifier to use for generation") + messages: list[ChatMessage] = Field( + description="List of messages in the conversation" + ) + max_tokens: int | None = Field( + default=None, description="Maximum number of tokens to generate" + ) + max_completion_tokens: int | None = Field( + default=None, description="Maximum tokens in completion (OpenAI naming)" + ) + temperature: float | None = Field( + default=1.0, description="Sampling temperature for randomness control" + ) + top_p: float | None = Field(default=1.0, description="Nucleus sampling parameter") + n: int | None = Field( + default=1, description="Number of completion choices to generate" + ) + stream: bool | None = Field( + default=False, description="Whether to stream response chunks" + ) + stream_options: StreamOptions | None = Field( + default=None, description="Configuration for streaming responses" + ) + stop: str | list[str] | None = Field( + default=None, description="Stop sequences to end generation" + ) + presence_penalty: float | None = Field( + default=0.0, description="Penalty for token presence to encourage diversity" + ) + frequency_penalty: float | None = Field( + default=0.0, description="Penalty for token frequency to reduce repetition" + ) + logit_bias: dict[str, float] | None = Field( + default=None, description="Bias values for specific tokens" + ) + seed: int | None = Field( + default=None, description="Random seed for reproducible outputs" + ) + user: str | None = Field( + default=None, description="User identifier for tracking and abuse monitoring" + ) + + # vLLM extensions + use_beam_search: bool | None = Field( + default=False, description="Enable beam search for better quality" + ) + top_k: int | None = Field(default=None, description="Top-k sampling parameter") + min_p: float | None = Field( + default=None, description="Minimum probability threshold for sampling" + ) + repetition_penalty: float | None = Field( + default=None, description="Penalty for repeated tokens" + ) + length_penalty: float | None = Field( + default=1.0, description="Length penalty for sequence scoring" + ) + stop_token_ids: list[int] | None = Field( + default=None, description="Token IDs that trigger generation stop" + ) + include_stop_str_in_output: bool | None = Field( + default=False, description="Include stop sequence in output" + ) + ignore_eos: bool | None = Field( + default=False, description="Ignore end-of-sequence tokens" + ) + min_tokens: int | None = Field( + default=0, description="Minimum number of tokens to generate" + ) + skip_special_tokens: bool | None = Field( + default=True, description="Skip special tokens in output" + ) + spaces_between_special_tokens: bool | None = Field( + default=True, description="Add spaces between special tokens" + ) + truncate_prompt_tokens: int | None = Field( + default=None, description="Maximum prompt tokens before truncation" + ) + allowed_token_ids: list[int] | None = Field( + default=None, description="Restrict generation to specific token IDs" + ) + prompt_logprobs: int | None = Field( + default=None, description="Number of logprobs to return for prompt tokens" + ) + add_special_tokens: bool | None = Field( + default=True, description="Add special tokens during processing" + ) + guided_json: str | dict[str, Any] | None = Field( + default=None, description="JSON schema for guided generation" + ) + guided_regex: str | None = Field( + default=None, description="Regex pattern for guided generation" + ) + guided_choice: list[str] | None = Field( + default=None, description="List of choices for guided generation" + ) + guided_grammar: str | None = Field( + default=None, description="Grammar specification for guided generation" + ) + guided_decoding_backend: str | None = Field( + default=None, description="Backend to use for guided decoding" + ) + guided_whitespace_pattern: str | None = Field( + default=None, description="Whitespace pattern for guided generation" + ) + priority: int | None = Field( + default=0, description="Request priority for scheduling" + ) + + +class ChatCompletionChoice(BaseModel): + """A single completion choice from a chat completion response. + + Contains the generated message and metadata about why generation + stopped and the choice's position in the response. + """ + + index: int = Field(description="Index of this choice in the response") + message: ChatMessage = Field(description="Generated message content") + finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] | None = ( + Field(description="Reason why generation finished") + ) + + +class ChatCompletionsResponse(BaseModel): + """Response from chat completion API endpoints. + + Contains generated choices, usage statistics, and metadata for + non-streaming chat completion requests. + """ + + id: str = Field(description="Unique identifier for this completion") + object: Literal["chat.completion"] = Field( + default="chat.completion", description="Object type identifier" + ) + created: int = Field( + default_factory=lambda: int(time.time()), + description="Unix timestamp of creation", + ) + model: str = Field(description="Model used for generation") + choices: list[ChatCompletionChoice] = Field( + description="Generated completion choices" + ) + usage: Usage | None = Field(default=None, description="Token usage statistics") + system_fingerprint: str | None = Field( + default=None, description="System configuration fingerprint" + ) + + +class ChatCompletionChunk(BaseModel): + """A single chunk in a streamed chat completion response. + + Represents one piece of a streaming response with delta content + and optional usage statistics in the final chunk. + """ + + id: str = Field(description="Unique identifier for this completion") + object: Literal["chat.completion.chunk"] = Field( + default="chat.completion.chunk", + description="Object type identifier for streaming chunks", + ) + created: int = Field( + default_factory=lambda: int(time.time()), + description="Unix timestamp of creation", + ) + model: str = Field(description="Model used for generation") + choices: list[dict[str, Any]] = Field(description="Delta choices for streaming") + usage: Usage | None = Field( + default=None, description="Token usage statistics (typically in final chunk)" + ) + + +class CompletionsRequest(BaseModel): + """Request parameters for legacy text completion API endpoints. + + Supports the older text completion format with prompt-based input + and the same extensive parameter set as chat completions for + backward compatibility. + """ + + model: str = Field(description="Model identifier to use for generation") + prompt: str | list[str] = Field(description="Input prompt(s) for completion") + max_tokens: int | None = Field( + default=16, description="Maximum number of tokens to generate" + ) + temperature: float | None = Field( + default=1.0, description="Sampling temperature for randomness control" + ) + top_p: float | None = Field(default=1.0, description="Nucleus sampling parameter") + n: int | None = Field( + default=1, description="Number of completion choices to generate" + ) + stream: bool | None = Field( + default=False, description="Whether to stream response chunks" + ) + stream_options: StreamOptions | None = Field( + default=None, description="Configuration for streaming responses" + ) + logprobs: int | None = Field( + default=None, description="Number of logprobs to return" + ) + echo: bool | None = Field( + default=False, description="Whether to echo the prompt in output" + ) + stop: str | list[str] | None = Field( + default_factory=lambda: ["<|endoftext|>"], + description="Stop sequences to end generation", + ) + presence_penalty: float | None = Field( + default=0.0, description="Penalty for token presence to encourage diversity" + ) + frequency_penalty: float | None = Field( + default=0.0, description="Penalty for token frequency to reduce repetition" + ) + best_of: int | None = Field( + default=1, description="Number of candidates to generate and return the best" + ) + logit_bias: dict[str, float] | None = Field( + default=None, description="Bias values for specific tokens" + ) + seed: int | None = Field( + default=None, description="Random seed for reproducible outputs" + ) + suffix: str | None = Field( + default=None, description="Suffix to append after completion" + ) + user: str | None = Field( + default=None, description="User identifier for tracking and abuse monitoring" + ) + + # vLLM extensions (same as chat completions) + use_beam_search: bool | None = Field( + default=False, description="Enable beam search for better quality" + ) + top_k: int | None = Field(default=None, description="Top-k sampling parameter") + min_p: float | None = Field( + default=None, description="Minimum probability threshold for sampling" + ) + repetition_penalty: float | None = Field( + default=None, description="Penalty for repeated tokens" + ) + length_penalty: float | None = Field( + default=1.0, description="Length penalty for sequence scoring" + ) + stop_token_ids: list[int] | None = Field( + default=None, description="Token IDs that trigger generation stop" + ) + include_stop_str_in_output: bool | None = Field( + default=False, description="Include stop sequence in output" + ) + ignore_eos: bool | None = Field( + default=False, description="Ignore end-of-sequence tokens" + ) + min_tokens: int | None = Field( + default=0, description="Minimum number of tokens to generate" + ) + skip_special_tokens: bool | None = Field( + default=True, description="Skip special tokens in output" + ) + spaces_between_special_tokens: bool | None = Field( + default=True, description="Add spaces between special tokens" + ) + truncate_prompt_tokens: int | None = Field( + default=None, description="Maximum prompt tokens before truncation" + ) + allowed_token_ids: list[int] | None = Field( + default=None, description="Restrict generation to specific token IDs" + ) + prompt_logprobs: int | None = Field( + default=None, description="Number of logprobs to return for prompt tokens" + ) + add_special_tokens: bool | None = Field( + default=True, description="Add special tokens during processing" + ) + guided_json: str | dict[str, Any] | None = Field( + default=None, description="JSON schema for guided generation" + ) + guided_regex: str | None = Field( + default=None, description="Regex pattern for guided generation" + ) + guided_choice: list[str] | None = Field( + default=None, description="List of choices for guided generation" + ) + guided_grammar: str | None = Field( + default=None, description="Grammar specification for guided generation" + ) + guided_decoding_backend: str | None = Field( + default=None, description="Backend to use for guided decoding" + ) + guided_whitespace_pattern: str | None = Field( + default=None, description="Whitespace pattern for guided generation" + ) + priority: int | None = Field( + default=0, description="Request priority for scheduling" + ) + + +class CompletionChoice(BaseModel): + """A single completion choice from a text completion response. + + Contains the generated text and metadata about completion + quality and stopping conditions. + """ + + text: str = Field(description="Generated text content") + index: int = Field(description="Index of this choice in the response") + logprobs: dict[str, Any] | None = Field( + default=None, description="Log probabilities for generated tokens" + ) + finish_reason: Literal["stop", "length", "content_filter"] | None = Field( + description="Reason why generation finished" + ) + + +class CompletionsResponse(BaseModel): + """Response from legacy text completion API endpoints. + + Contains generated text choices, usage statistics, and metadata + for non-streaming text completion requests. + """ + + id: str = Field(description="Unique identifier for this completion") + object: Literal["text_completion"] = Field( + default="text_completion", description="Object type identifier" + ) + created: int = Field( + default_factory=lambda: int(time.time()), + description="Unix timestamp of creation", + ) + model: str = Field(description="Model used for generation") + choices: list[CompletionChoice] = Field(description="Generated completion choices") + usage: Usage | None = Field(default=None, description="Token usage statistics") + system_fingerprint: str | None = Field( + default=None, description="System configuration fingerprint" + ) + + +class TokenizeRequest(BaseModel): + """Request for tokenizing text into token sequences. + + Converts input text into model-specific token representations + with optional special token handling. + """ + + text: str = Field(description="Text to tokenize") + add_special_tokens: bool | None = Field( + default=True, description="Whether to add model-specific special tokens" + ) + + +class TokenizeResponse(BaseModel): + """Response containing tokenized representation of input text. + + Provides both the token sequence and count for analysis + and token budget planning. + """ + + tokens: list[int] = Field(description="List of token IDs") + count: int = Field(description="Total number of tokens") + + +class DetokenizeRequest(BaseModel): + """Request for converting token sequences back to text. + + Reconstructs human-readable text from model token representations + with configurable special token handling. + """ + + tokens: list[int] = Field(description="List of token IDs to convert") + skip_special_tokens: bool | None = Field( + default=True, description="Whether to skip special tokens in output" + ) + spaces_between_special_tokens: bool | None = Field( + default=True, description="Whether to add spaces between special tokens" + ) + + +class DetokenizeResponse(BaseModel): + """Response containing text reconstructed from tokens. + + Provides the human-readable text representation of the + input token sequence. + """ + + text: str = Field(description="Reconstructed text from tokens") + + +class ErrorDetail(BaseModel): + """Detailed error information for API failures. + + Provides structured error data including message, type classification, + and optional error codes for debugging and error handling. + """ + + message: str = Field(description="Human-readable error description") + type: str = Field(description="Error type classification") + code: str | None = Field( + default=None, description="Optional error code for programmatic handling" + ) + + +class ErrorResponse(BaseModel): + """Standardized error response structure for API failures. + + Wraps error details in a consistent format compatible with + OpenAI API error response conventions. + """ + + error: ErrorDetail = Field(description="Detailed error information") diff --git a/src/guidellm/mock_server/server.py b/src/guidellm/mock_server/server.py new file mode 100644 index 00000000..e35acf75 --- /dev/null +++ b/src/guidellm/mock_server/server.py @@ -0,0 +1,168 @@ +""" +High-performance mock server for OpenAI and vLLM API compatibility testing. + +This module provides a Sanic-based mock server that simulates OpenAI and vLLM APIs +with configurable latency, token generation patterns, and response characteristics. +The server supports both streaming and non-streaming endpoints, enabling realistic +performance testing and validation of GuideLLM benchmarking workflows without +requiring actual model deployments. +""" + +from __future__ import annotations + +import time + +from sanic import Sanic, response +from sanic.exceptions import NotFound +from sanic.log import logger +from sanic.request import Request +from sanic.response import HTTPResponse + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.handlers import ( + ChatCompletionsHandler, + CompletionsHandler, + TokenizerHandler, +) + +__all__ = ["MockServer"] + + +class MockServer: + """ + High-performance mock server implementing OpenAI and vLLM API endpoints. + + Provides a Sanic-based web server that simulates API responses with configurable + timing characteristics for testing and benchmarking purposes. Supports chat + completions, text completions, tokenization endpoints, and model listing with + realistic latency patterns to enable comprehensive performance validation. + + Example: + :: + config = ServerConfig(model="test-model", port=8080) + server = MockServer(config) + server.run() + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the mock server with configuration. + + :param config: Server configuration containing network settings and response + timing parameters + """ + self.config = config + self.app = Sanic("guidellm-mock-server") + self.chat_handler = ChatCompletionsHandler(config) + self.completions_handler = CompletionsHandler(config) + self.tokenizer_handler = TokenizerHandler(config) + + self._setup_middleware() + self._setup_routes() + self._setup_error_handlers() + + def _setup_middleware(self): + """Setup middleware for CORS, logging, etc.""" + + @self.app.middleware("request") + async def add_cors_headers(_request: Request): + """Add CORS headers to all requests.""" + + @self.app.middleware("response") + async def add_response_headers(_request: Request, resp: HTTPResponse): + """Add standard response headers.""" + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS" + resp.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization" + resp.headers["Server"] = "guidellm-mock-server" + + def _setup_routes(self): + @self.app.get("/health") + async def health_check(_request: Request): + return response.json({"status": "healthy", "timestamp": time.time()}) + + @self.app.get("/v1/models") + async def list_models(_request: Request): + return response.json( + { + "object": "list", + "data": [ + { + "id": self.config.model, + "object": "model", + "created": int(time.time()), + "owned_by": "guidellm-mock", + } + ], + } + ) + + @self.app.route("/v1/chat/completions", methods=["POST", "OPTIONS"]) + async def chat_completions(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.chat_handler.handle(request) + + @self.app.route("/v1/completions", methods=["POST", "OPTIONS"]) + async def completions(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.completions_handler.handle(request) + + @self.app.route("/tokenize", methods=["POST", "OPTIONS"]) + async def tokenize(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.tokenizer_handler.tokenize(request) + + @self.app.route("/detokenize", methods=["POST", "OPTIONS"]) + async def detokenize(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.tokenizer_handler.detokenize(request) + + def _setup_error_handlers(self): + """Setup error handlers.""" + + @self.app.exception(Exception) + async def generic_error_handler(_request: Request, exception: Exception): + logger.error(f"Unhandled exception: {exception}") + return response.json( + { + "error": { + "message": "Internal server error", + "type": type(exception).__name__, + "error": str(exception), + } + }, + status=500, + ) + + @self.app.exception(NotFound) + async def not_found_handler(_request: Request, _exception): + return response.json( + { + "error": { + "message": "Not Found", + "type": "not_found_error", + "code": "not_found", + } + }, + status=404, + ) + + def run(self) -> None: + """ + Start the mock server with configured settings. + + Runs the Sanic application in single-process mode with access logging enabled + for debugging and monitoring request patterns during testing. + """ + self.app.run( + host=self.config.host, + port=self.config.port, + debug=False, + single_process=True, + access_log=True, + register_sys_signals=False, # Disable signal handlers for threading + ) diff --git a/src/guidellm/mock_server/utils.py b/src/guidellm/mock_server/utils.py new file mode 100644 index 00000000..8348d0a6 --- /dev/null +++ b/src/guidellm/mock_server/utils.py @@ -0,0 +1,307 @@ +""" +Mock server utilities for text generation and tokenization testing. + +This module provides mock tokenization and text generation utilities for testing +guidellm's mock server functionality. It includes a mock tokenizer that simulates +tokenization processes, functions to generate reproducible fake text with specific +token counts, and timing generators for realistic benchmarking scenarios. +""" + +from __future__ import annotations + +import random +import re +from collections.abc import Generator + +from faker import Faker +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer, TextInput + +__all__ = [ + "MockTokenizer", + "create_fake_text", + "create_fake_tokens_str", + "sample_number", + "times_generator", +] + + +class MockTokenizer(PreTrainedTokenizer): + """ + Mock tokenizer implementation for testing text processing workflows. + + Provides a simplified tokenizer that splits text using regex patterns and + generates deterministic token IDs based on string hashing. Used for testing + guidellm components without requiring actual model tokenizers. + + :cvar VocabSize: Fixed vocabulary size for the mock tokenizer + """ + + VocabSize = 100000007 + + def __len__(self) -> int: + """ + Get the vocabulary size of the tokenizer. + + :return: The total number of tokens in the vocabulary + """ + return self.VocabSize + + def __call__(self, text: str | list[str], **kwargs) -> list[int]: # noqa: ARG002 + """ + Tokenize text and return token IDs (callable interface). + + :param text: Input text to tokenize + :return: List of token IDs + """ + if isinstance(text, str): + tokens = self.tokenize(text) + return self.convert_tokens_to_ids(tokens) + elif isinstance(text, list): + # Handle batch processing + return [self.__call__(t) for t in text] + else: + msg = f"text input must be of type `str` or `list[str]`, got {type(text)}" + raise ValueError(msg) + + def tokenize(self, text: TextInput, **_kwargs) -> list[str]: + """ + Tokenize input text into a list of token strings. + + Splits text using regex to separate words, punctuation, and whitespace + into individual tokens for processing. + + :param text: Input text to tokenize + :return: List of token strings from the input text + """ + # Split text into tokens: words, spaces, and punctuation + return re.findall(r"\w+|[^\w\s]|\s+", text) + + def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + """ + Convert token strings to numeric token IDs. + + Uses deterministic hashing to generate consistent token IDs for + reproducible testing scenarios. + + :param tokens: Single token string or list of token strings + :return: Single token ID or list of token IDs + """ + if isinstance(tokens, str): + return hash(tokens) % self.VocabSize + return [hash(token) % self.VocabSize for token in tokens] + + def convert_ids_to_tokens( + self, ids: int | list[int], _skip_special_tokens: bool = False + ) -> str | list[str]: + """ + Convert numeric token IDs back to token strings. + + Generates fake text tokens using Faker library seeded with token IDs + for deterministic and reproducible token generation. + + :param ids: Single token ID or list of token IDs to convert + :return: Single token string or list of token strings + """ + if not ids and not isinstance(ids, list): + return "" + elif not ids: + return [""] + + if isinstance(ids, int): + fake = Faker() + fake.seed_instance(ids % self.VocabSize) + + return fake.word() + + fake = Faker() + fake.seed_instance(sum(ids) % self.VocabSize) + + target_count = len(ids) + current_count = 0 + tokens = [] + + while current_count < target_count: + text = fake.text( + max_nb_chars=(target_count - current_count) * 10 # oversample + ) + new_tokens = self.tokenize(text) + + if current_count > 0: + new_tokens = [".", " "] + new_tokens + + new_tokens = ( + new_tokens[: target_count - current_count] + if len(new_tokens) > (target_count - current_count) + else new_tokens + ) + tokens += new_tokens + current_count += len(new_tokens) + + return tokens + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + """ + Convert a list of token strings back to a single text string. + + :param tokens: List of token strings to concatenate + :return: Concatenated string from all tokens + """ + return "".join(tokens) + + def _add_tokens( + self, + new_tokens: list[str] | list[AddedToken], # noqa: ARG002 + special_tokens: bool = False, # noqa: ARG002 + ) -> int: + """ + Add new tokens to the tokenizer vocabulary (mock implementation). + + :param new_tokens: List of tokens to add to the vocabulary + :param special_tokens: Whether the tokens are special tokens + :return: Number of tokens actually added (always 0 for mock) + """ + return 0 + + def apply_chat_template( + self, + conversation: list, + tokenize: bool = False, # Changed default to False to match transformers + add_generation_prompt: bool = False, # noqa: ARG002 + **kwargs, # noqa: ARG002 + ) -> str | list[int]: + """ + Apply a chat template to format conversation messages. + + Mock implementation that concatenates all message content for testing. + + :param conversation: List of chat messages + :param tokenize: Whether to return tokens or string + :param add_generation_prompt: Whether to add generation prompt + :return: Formatted text string or token IDs + """ + # Simple concatenation of all message content + texts = [] + for message in conversation: + if isinstance(message, dict) and "content" in message: + texts.append(message["content"]) + elif hasattr(message, "content"): + texts.append(message.content) + + formatted_text = " ".join(texts) + + if tokenize: + return self.convert_tokens_to_ids(self.tokenize(formatted_text)) + return formatted_text + + def decode( + self, + token_ids: list[int], + skip_special_tokens: bool = True, + **kwargs, # noqa: ARG002 + ) -> str: + """ + Decode token IDs back to text string. + + :param token_ids: List of token IDs to decode + :param skip_special_tokens: Whether to skip special tokens + :return: Decoded text string + """ + tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens) + return self.convert_tokens_to_string(tokens) + + +def create_fake_text( + num_tokens: int, + processor: PreTrainedTokenizer, + seed: int = 42, + fake: Faker | None = None, +) -> str: + """ + Generate fake text using a tokenizer processor with specified token count. + + Creates text by generating fake tokens and joining them into a string, + ensuring the result has the exact number of tokens when processed by + the given tokenizer. + + :param num_tokens: Target number of tokens in the generated text + :param processor: Tokenizer to use for token generation and validation + :param seed: Random seed for reproducible text generation + :param fake: Optional Faker instance for text generation + :return: Generated text string with the specified token count + """ + return "".join(create_fake_tokens_str(num_tokens, processor, seed, fake)) + + +def create_fake_tokens_str( + num_tokens: int, + processor: PreTrainedTokenizer, + seed: int = 42, + fake: Faker | None = None, +) -> list[str]: + """ + Generate fake token strings using a tokenizer processor. + + Creates a list of token strings by generating fake text and tokenizing it + until the desired token count is reached. Uses the provided tokenizer + for accurate token boundary detection. + + :param num_tokens: Target number of tokens to generate + :param processor: Tokenizer to use for token generation and validation + :param seed: Random seed for reproducible token generation + :param fake: Optional Faker instance for text generation + :return: List of token strings with the specified count + """ + if not fake: + fake = Faker() + fake.seed_instance(seed) + + tokens = [] + + while len(tokens) < num_tokens: + text = fake.text( + max_nb_chars=(num_tokens - len(tokens)) * 30 # oversample + ) + new_tokens = processor.tokenize(text) + + if len(tokens) > 0: + new_tokens = [".", " "] + new_tokens + + new_tokens = ( + new_tokens[: num_tokens - len(tokens)] + if len(new_tokens) > (num_tokens - len(tokens)) + else new_tokens + ) + tokens += new_tokens + + return tokens + + +def times_generator(mean: float, standard_dev: float) -> Generator[float]: + """ + Generate infinite timing values from a normal distribution. + + Creates a generator that yields timing values sampled from a normal + distribution, useful for simulating realistic request timing patterns + in benchmarking scenarios. + + :param mean: Mean value for the normal distribution + :param standard_dev: Standard deviation for the normal distribution + :return: Generator yielding positive timing values from the distribution + """ + while True: + yield sample_number(mean, standard_dev) + + +def sample_number(mean: float, standard_dev: float) -> float: + """ + Generate a single timing value from a normal distribution. + + Samples one timing value from a normal distribution with the specified + parameters, ensuring the result is non-negative for realistic timing + simulation in benchmarking scenarios. + + :param mean: Mean value for the normal distribution + :param standard_dev: Standard deviation for the normal distribution + :return: Non-negative timing value from the distribution + """ + return max(0.0, random.gauss(mean, standard_dev)) diff --git a/src/guidellm/presentation/injector.py b/src/guidellm/presentation/injector.py index 02d53b1d..bb1fd684 100644 --- a/src/guidellm/presentation/injector.py +++ b/src/guidellm/presentation/injector.py @@ -4,7 +4,7 @@ from loguru import logger -from guidellm.config import settings +from guidellm.settings import settings from guidellm.utils.text import load_text diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index a7f4a67b..e3f13d5d 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -12,8 +12,8 @@ from transformers import PreTrainedTokenizerBase # type: ignore[import] from guidellm.backend import GenerationRequest -from guidellm.config import settings from guidellm.dataset import ColumnInputTypes, load_dataset +from guidellm.settings import settings from guidellm.utils import StandardBaseModel __all__ = [ diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index a0f9dcfd..24d73df2 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -16,12 +16,12 @@ BackendInterface, BackendT, MeasuredRequestTimings, - MeasuredRequestTimingsT, MultiTurnRequestT, RequestSchedulerTimings, RequestT, ResponseT, ScheduledRequestInfo, + SchedulerMessagingPydanticRegistry, SchedulerState, SchedulerUpdateAction, SchedulerUpdateActionProgress, @@ -63,7 +63,6 @@ "MaxGlobalErrorRateConstraint", "MaxNumberConstraint", "MeasuredRequestTimings", - "MeasuredRequestTimingsT", "MultiTurnRequestT", "NoDelayRequestTimings", "NonDistributedEnvironment", @@ -75,6 +74,7 @@ "ScheduledRequestInfo", "ScheduledRequestTimings", "Scheduler", + "SchedulerMessagingPydanticRegistry", "SchedulerState", "SchedulerUpdateAction", "SchedulerUpdateActionProgress", diff --git a/src/guidellm/scheduler/constraints.py b/src/guidellm/scheduler/constraints.py index 12d15b06..c724a74a 100644 --- a/src/guidellm/scheduler/constraints.py +++ b/src/guidellm/scheduler/constraints.py @@ -4,21 +4,8 @@ Provides flexible constraints for managing scheduler behavior with configurable thresholds based on time, error rates, and request counts. Constraints evaluate scheduler state and individual requests to determine whether processing should -continue or stop based on predefined limits. - -Example: -:: - from guidellm.scheduler.constraints import ConstraintsInitializerFactory - - # Create constraints from configuration - constraints = ConstraintsInitializerFactory.resolve_constraints({ - "max_number": 1000, - "max_duration": 300.0, - "max_error_rate": {"max_error_rate": 0.1, "window_size": 50} - }) - - # Evaluate constraint during scheduling - action = constraints["max_number"](scheduler_state, request_info) +continue or stop based on predefined limits. The constraint system enables +sophisticated benchmark stopping criteria through composable constraint types. """ from __future__ import annotations @@ -29,13 +16,13 @@ from pydantic import Field, field_validator -from guidellm.config import settings from guidellm.scheduler.objects import ( ScheduledRequestInfo, SchedulerState, SchedulerUpdateAction, SchedulerUpdateActionProgress, ) +from guidellm.settings import settings from guidellm.utils import InfoMixin, RegistryMixin, StandardBaseModel __all__ = [ @@ -48,6 +35,7 @@ "MaxGlobalErrorRateConstraint", "MaxNumberConstraint", "PydanticConstraintInitializer", + "RequestsExhaustedConstraint", "SerializableConstraintInitializer", "UnserializableConstraintInitializer", ] @@ -63,9 +51,9 @@ def __call__( """ Evaluate constraint against scheduler state and request information. - :param state: Current scheduler state with metrics and timing + :param state: Current scheduler state with metrics and timing information :param request: Individual request information and metadata - :return: Action indicating whether to continue or stop operations + :return: Action indicating whether to continue or stop scheduler operations """ @@ -127,28 +115,21 @@ class ConstraintsInitializerFactory(RegistryMixin[ConstraintInitializer]): Provides centralized access to registered constraint types with support for creating constraints from configuration dictionaries, simple values, or - pre-configured instances. Handles constraint resolution and type validation. + pre-configured instances. Handles constraint resolution and type validation + for the scheduler constraint system. Example: :: - from guidellm.scheduler import ( - ConstraintsInitializerFactory, - SchedulerUpdateAction, - SchedulerState, - ScheduledRequestInfo - ) - + from guidellm.scheduler import ConstraintsInitializerFactory - # Register - ConstraintsInitializerFactory.register("new_constraint") + # Register new constraint type + @ConstraintsInitializerFactory.register("new_constraint") class NewConstraint: def create_constraint(self, **kwargs) -> Constraint: return lambda state, request: SchedulerUpdateAction() - - # Create constraint - constraint = factory.create_constraint("new_constraint") - print(constraint(SchedulerState(), ScheduledRequestInfo())) + # Create and use constraint + constraint = ConstraintsInitializerFactory.create_constraint("new_constraint") """ @classmethod @@ -159,7 +140,7 @@ def create(cls, key: str, *args, **kwargs) -> ConstraintInitializer: :param key: Registered constraint initializer key :param args: Positional arguments for initializer creation :param kwargs: Keyword arguments for initializer creation - :return: Configured constraint initializer function + :return: Configured constraint initializer instance :raises ValueError: If the key is not registered in the factory """ if cls.registry is None or key not in cls.registry: @@ -168,10 +149,11 @@ def create(cls, key: str, *args, **kwargs) -> ConstraintInitializer: initializer_class = cls.registry[key] return ( - initializer_class(*args, **kwargs) - if not isinstance(initializer_class, SerializableConstraintInitializer) - else initializer_class.model_validate( - initializer_class.validated_kwargs(*args, **kwargs) + initializer_class(*args, **kwargs) # type: ignore[operator] + if not isinstance(initializer_class, type) + or not issubclass(initializer_class, SerializableConstraintInitializer) + else initializer_class( + **initializer_class.validated_kwargs(*args, **kwargs) # type: ignore[misc] ) ) @@ -183,13 +165,13 @@ def serialize(cls, initializer: ConstraintInitializer) -> dict[str, Any]: :param initializer: Constraint initializer to serialize :return: Dictionary representation or unserializable placeholder """ - return ( - initializer.model_dump() - if isinstance(initializer, SerializableConstraintInitializer) - else UnserializableConstraintInitializer( + if isinstance(initializer, SerializableConstraintInitializer): + return initializer.model_dump() + else: + unserializable = UnserializableConstraintInitializer( orig_info=InfoMixin.extract_from_obj(initializer) ) - ) + return unserializable.model_dump() @classmethod def deserialize( @@ -211,10 +193,14 @@ def deserialize( and initializer_dict["type_"] in cls.registry ): initializer_class = cls.registry[initializer_dict["type_"]] - return initializer_class.model_validate(initializer_dict) + if hasattr(initializer_class, "model_validate"): + return initializer_class.model_validate(initializer_dict) # type: ignore[return-value] + else: + return initializer_class(**initializer_dict) # type: ignore[return-value,operator] raise ValueError( - f"Cannot deserialize unknown constraint initializer: {initializer_class}" + f"Cannot deserialize unknown constraint initializer: " + f"{initializer_dict.get('type_', 'unknown')}" ) @classmethod @@ -223,6 +209,7 @@ def create_constraint(cls, key: str, *args, **kwargs) -> Constraint: Create a constraint instance for the specified key. :param key: Registered constraint initializer key + :param args: Positional arguments for constraint creation :param kwargs: Keyword arguments for constraint creation :return: Configured constraint function ready for evaluation :raises ValueError: If the key is not registered in the factory @@ -289,10 +276,10 @@ class PydanticConstraintInitializer(StandardBaseModel, ABC, InfoMixin): Provides standardized serialization, validation, and metadata handling for constraint initializers using Pydantic models. Subclasses implement specific - constraint creation logic while inheriting common functionality. + constraint creation logic while inheriting validation and persistence support. """ - type_: str = Field(description="Type identifier for the constraint") + type_: str = Field(description="Type identifier for the constraint initializer") @property def info(self) -> dict[str, Any]: @@ -309,7 +296,8 @@ def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: """ Validate and process arguments for constraint creation. - Must be implemented by subclasses to handle their specific parameter patterns. + Must be implemented by subclasses to handle their specific parameter patterns + and validation requirements. :param args: Positional arguments passed to the constraint :param kwargs: Keyword arguments passed to the constraint @@ -323,7 +311,8 @@ def create_constraint(self, **kwargs) -> Constraint: """ Create a constraint instance. - Must be implemented by subclasses to return their specific constraint type. + Must be implemented by subclasses to return their specific constraint type + with appropriate configuration and validation. :param kwargs: Additional keyword arguments (usually unused) :return: Configured constraint instance @@ -344,13 +333,13 @@ class UnserializableConstraintInitializer(PydanticConstraintInitializer): type_: Literal["unserializable"] = "unserializable" # type: ignore[assignment] orig_info: dict[str, Any] = Field( default_factory=dict, - description="Information about why this constraint is unserializable", + description="Original constraint information before serialization failure", ) @classmethod def validated_kwargs( cls, - orig_info: dict[str, Any] = None, + orig_info: dict[str, Any] | None = None, **kwargs, # noqa: ARG003 ) -> dict[str, Any]: """ @@ -396,7 +385,7 @@ def __call__( ) -@ConstraintsInitializerFactory.register( +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] ["max_number", "max_num", "max_requests", "max_req"] ) class MaxNumberConstraint(PydanticConstraintInitializer): @@ -430,7 +419,8 @@ def validated_kwargs( """ aliases = ["max_number", "max_num", "max_requests", "max_req"] for alias in aliases: - max_num = max_num or kwargs.get(alias) + if max_num is None: + max_num = kwargs.get(alias) return {"max_num": max_num, "current_index": kwargs.get("current_index", -1)} @@ -443,7 +433,7 @@ def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 """ self.current_index += 1 - return self.model_copy() + return self.model_copy() # type: ignore[return-value] def __call__( self, @@ -451,7 +441,7 @@ def __call__( request_info: ScheduledRequestInfo, # noqa: ARG002 ) -> SchedulerUpdateAction: """ - Evaluate constraint against current scheduler state. + Evaluate constraint against current scheduler state and request count. :param state: Current scheduler state with request counts :param request_info: Individual request information (unused) @@ -466,10 +456,8 @@ def __call__( create_exceeded = state.created_requests >= max_num processed_exceeded = state.processed_requests >= max_num - remaining_fraction = min( - max(0.0, 1.0 - state.processed_requests / float(max_num)), 1.0 - ) - remaining_requests = max(0, max_num - state.processed_requests) + remaining_requests = min(max(0, max_num - state.processed_requests), max_num) + remaining_fraction = remaining_requests / float(max_num) return SchedulerUpdateAction( request_queuing="stop" if create_exceeded else "continue", @@ -509,7 +497,7 @@ def _validate_max_num( return value[0] if isinstance(value, list) and len(value) == 1 else value -@ConstraintsInitializerFactory.register( +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] ["max_duration", "max_dur", "max_sec", "max_seconds", "max_min", "max_minutes"] ) class MaxDurationConstraint(PydanticConstraintInitializer): @@ -529,7 +517,7 @@ class MaxDurationConstraint(PydanticConstraintInitializer): @classmethod def validated_kwargs( - cls, max_duration: int | float | list[int | float] = None, **kwargs + cls, max_duration: int | float | list[int | float] | None = None, **kwargs ) -> dict[str, Any]: """ Validate and process arguments for MaxDurationConstraint creation. @@ -541,12 +529,13 @@ def validated_kwargs( """ seconds_aliases = ["max_dur", "max_sec", "max_seconds"] for alias in seconds_aliases: - max_duration = max_duration or kwargs.get(alias) + if max_duration is None: + max_duration = kwargs.get(alias) minutes_aliases = ["max_min", "max_minutes"] for alias in minutes_aliases: minutes = kwargs.get(alias) - if minutes is not None: - max_duration = max_duration or minutes * 60 + if minutes is not None and max_duration is None: + max_duration = minutes * 60 return { "max_duration": max_duration, @@ -562,7 +551,7 @@ def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 """ self.current_index += 1 - return self.model_copy() + return self.model_copy() # type: ignore[return-value] def __call__( self, @@ -586,6 +575,8 @@ def __call__( current_time = time.time() elapsed = current_time - state.start_time duration_exceeded = elapsed >= max_duration + remaining_duration = min(max(0.0, max_duration - elapsed), max_duration) + remaining_fraction = remaining_duration / float(max_duration) return SchedulerUpdateAction( request_queuing="stop" if duration_exceeded else "continue", @@ -598,8 +589,8 @@ def __call__( "current_time": current_time, }, progress=SchedulerUpdateActionProgress( - remaining_fraction=max(0.0, 1.0 - elapsed / float(max_duration)), - remaining_duration=max(0.0, max_duration - elapsed), + remaining_fraction=remaining_fraction, + remaining_duration=remaining_duration, ), ) @@ -625,7 +616,7 @@ def _validate_max_duration( return value[0] if isinstance(value, list) and len(value) == 1 else value -@ConstraintsInitializerFactory.register( +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] ["max_errors", "max_err", "max_error", "max_errs"] ) class MaxErrorsConstraint(PydanticConstraintInitializer): @@ -634,7 +625,7 @@ class MaxErrorsConstraint(PydanticConstraintInitializer): Stops both request queuing and all request processing when the total number of errored requests reaches the maximum threshold. Uses global error tracking - across all requests. + across all requests for immediate constraint evaluation. """ type_: Literal["max_errors"] = "max_errors" # type: ignore[assignment] @@ -645,7 +636,7 @@ class MaxErrorsConstraint(PydanticConstraintInitializer): @classmethod def validated_kwargs( - cls, max_errors: int | float | list[int | float] = None, **kwargs + cls, max_errors: int | float | list[int | float] | None = None, **kwargs ) -> dict[str, Any]: """ Validate and process arguments for MaxErrorsConstraint creation. @@ -657,7 +648,8 @@ def validated_kwargs( """ aliases = ["max_errors", "max_err", "max_error", "max_errs"] for alias in aliases: - max_errors = max_errors or kwargs.get(alias) + if max_errors is None: + max_errors = kwargs.get(alias) return { "max_errors": max_errors, @@ -673,7 +665,7 @@ def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 """ self.current_index += 1 - return self.model_copy() + return self.model_copy() # type: ignore[return-value] def __call__( self, @@ -726,7 +718,7 @@ def _validate_max_errors( return value[0] if isinstance(value, list) and len(value) == 1 else value -@ConstraintsInitializerFactory.register( +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] ["max_error_rate", "max_err_rate", "max_errors_rate"] ) class MaxErrorRateConstraint(PydanticConstraintInitializer): @@ -735,7 +727,8 @@ class MaxErrorRateConstraint(PydanticConstraintInitializer): Tracks error status of recent requests in a sliding window and stops all processing when the error rate exceeds the threshold. Only applies the - constraint after processing enough requests to fill the minimum window size. + constraint after processing enough requests to fill the minimum window size + for statistical significance. """ type_: Literal["max_error_rate"] = "max_error_rate" # type: ignore[assignment] @@ -770,7 +763,8 @@ def validated_kwargs( """ aliases = ["max_error_rate", "max_err_rate", "max_errors_rate"] for alias in aliases: - max_error_rate = max_error_rate or kwargs.get(alias) + if max_error_rate is None: + max_error_rate = kwargs.get(alias) return { "max_error_rate": max_error_rate, @@ -790,7 +784,7 @@ def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 """ self.current_index += 1 - return self.model_copy() + return self.model_copy() # type: ignore[return-value] def __call__( self, state: SchedulerState, request_info: ScheduledRequestInfo @@ -865,7 +859,7 @@ def _validate_max_error_rate( return value[0] if isinstance(value, list) and len(value) == 1 else value -@ConstraintsInitializerFactory.register( +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] ["max_global_error_rate", "max_global_err_rate", "max_global_errors_rate"] ) class MaxGlobalErrorRateConstraint(PydanticConstraintInitializer): @@ -874,7 +868,8 @@ class MaxGlobalErrorRateConstraint(PydanticConstraintInitializer): Calculates error rate across all processed requests and stops all processing when the rate exceeds the threshold. Only applies the constraint after - processing the minimum number of requests to ensure statistical significance. + processing the minimum number of requests to ensure statistical significance + for global error rate calculations. """ type_: Literal["max_global_error_rate"] = "max_global_error_rate" # type: ignore[assignment] @@ -908,7 +903,8 @@ def validated_kwargs( "max_global_err_rate", "max_global_errors_rate", ]: - max_error_rate = max_error_rate or kwargs.get(alias) + if max_error_rate is None: + max_error_rate = kwargs.get(alias) return { "max_error_rate": max_error_rate, @@ -927,7 +923,7 @@ def create_constraint(self, **kwargs) -> Constraint: # noqa: ARG002 """ self.current_index += 1 - return self.model_copy() + return self.model_copy() # type: ignore[return-value] def __call__( self, @@ -948,7 +944,9 @@ def __call__( else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] ) - exceeded_min_processed = state.processed_requests >= self.min_processed + exceeded_min_processed = ( + self.min_processed is None or state.processed_requests >= self.min_processed + ) error_rate = ( state.errored_requests / float(state.processed_requests) if state.processed_requests > 0 @@ -991,3 +989,47 @@ def _validate_max_error_rate( ) return value[0] if isinstance(value, list) and len(value) == 1 else value + + +class RequestsExhaustedConstraint(StandardBaseModel, InfoMixin): + type_: Literal["requests_exhausted"] = "requests_exhausted" # type: ignore[assignment] + num_requests: int + + @property + def info(self) -> dict[str, Any]: + """ + Extract serializable information from this constraint initializer. + + :return: Dictionary containing constraint configuration and metadata + """ + return self.model_dump() + + def __call__( + self, + state: SchedulerState, + request_info: ScheduledRequestInfo, # noqa: ARG002 + ) -> SchedulerUpdateAction: + create_exceeded = state.created_requests >= self.num_requests + processed_exceeded = state.processed_requests >= self.num_requests + remaining_fraction = min( + max(0.0, 1.0 - state.processed_requests / float(self.num_requests)), 1.0 + ) + remaining_requests = max(0, self.num_requests - state.processed_requests) + + return SchedulerUpdateAction( + request_queuing="stop" if create_exceeded else "continue", + request_processing="stop_local" if processed_exceeded else "continue", + metadata={ + "num_requests": self.num_requests, + "create_exceeded": create_exceeded, + "processed_exceeded": processed_exceeded, + "created_requests": state.created_requests, + "processed_requests": state.processed_requests, + "remaining_fraction": remaining_fraction, + "remaining_requests": remaining_requests, + }, + progress=SchedulerUpdateActionProgress( + remaining_fraction=remaining_fraction, + remaining_requests=remaining_requests, + ), + ) diff --git a/src/guidellm/scheduler/environment.py b/src/guidellm/scheduler/environment.py index 27f2881f..3bc29681 100644 --- a/src/guidellm/scheduler/environment.py +++ b/src/guidellm/scheduler/environment.py @@ -24,10 +24,8 @@ Generic, ) -from guidellm.config import settings from guidellm.scheduler.constraints import Constraint from guidellm.scheduler.objects import ( - MeasuredRequestTimingsT, MultiTurnRequestT, RequestT, ResponseT, @@ -35,6 +33,7 @@ SchedulerState, ) from guidellm.scheduler.strategy import SchedulingStrategy +from guidellm.settings import settings from guidellm.utils import InfoMixin __all__ = ["Environment", "NonDistributedEnvironment"] @@ -94,7 +93,7 @@ async def update_run_iteration( self, response: ResponseT | None, request: RequestT, - request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + request_info: ScheduledRequestInfo, state: SchedulerState, ): """ @@ -132,7 +131,7 @@ async def sync_run_end( tuple[ ResponseT, RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo[MeasuredRequestTimingsT], + ScheduledRequestInfo, SchedulerState, ] ]: @@ -225,7 +224,7 @@ async def update_run_iteration( self, response: ResponseT | None, request: RequestT, - request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + request_info: ScheduledRequestInfo, state: SchedulerState, ): """ @@ -252,7 +251,7 @@ async def sync_run_end( tuple[ ResponseT, RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo[MeasuredRequestTimingsT], + ScheduledRequestInfo, SchedulerState, ] ]: diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/objects.py index 8b6437f0..b7f2efc3 100644 --- a/src/guidellm/scheduler/objects.py +++ b/src/guidellm/scheduler/objects.py @@ -11,12 +11,13 @@ import time import uuid -from abc import ABC, abstractmethod from collections.abc import AsyncIterator from typing import ( Any, + ClassVar, Generic, Literal, + Protocol, TypeVar, Union, ) @@ -24,18 +25,23 @@ from pydantic import Field, computed_field from typing_extensions import TypeAliasType, TypedDict -from guidellm.utils import StandardBaseModel +from guidellm.utils import ( + PydanticClassRegistryMixin, + RegistryMixin, + StandardBaseModel, +) +from guidellm.utils.registry import RegistryObjT __all__ = [ "BackendInterface", "BackendT", "MeasuredRequestTimings", - "MeasuredRequestTimingsT", "MultiTurnRequestT", "RequestSchedulerTimings", "RequestT", "ResponseT", "ScheduledRequestInfo", + "SchedulerMessagingPydanticRegistry", "SchedulerState", "SchedulerUpdateAction", "SchedulerUpdateActionProgress", @@ -58,8 +64,19 @@ """Multi-turn request structure supporting conversation history with optional delays.""" +class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]): + """ + Registry for enabling a generic interface to define the pydantic class types used + for inter-process messaging within the scheduler. + """ + + +@SchedulerMessagingPydanticRegistry.register() class RequestSchedulerTimings(StandardBaseModel): - """Scheduler-level timing measurements for request lifecycle tracking.""" + """ + Scheduler-level timing measurements for request lifecycle tracking. + All timestamps are expected to be in Unix time (seconds since epoch). + """ targeted_start: float | None = Field( default=None, @@ -88,9 +105,26 @@ class RequestSchedulerTimings(StandardBaseModel): ) -class MeasuredRequestTimings(StandardBaseModel): - """Base timing measurements for backend request processing.""" +@SchedulerMessagingPydanticRegistry.register() +class MeasuredRequestTimings(PydanticClassRegistryMixin["MeasuredRequestTimings"]): + """ + Base timing measurements for backend request processing. + All timestamps are expected to be in Unix time (seconds since epoch). + """ + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[MeasuredRequestTimings]: + if cls.__name__ == "MeasuredRequestTimings": + return cls + + return MeasuredRequestTimings + schema_discriminator: ClassVar[str] = "timings_type" + + timings_type: Literal["measured_request_timings"] = Field( + default="measured_request_timings", + description="Type identifier for the timing measurement", + ) request_start: float | None = Field( default=None, description="When the backend began processing the request" ) @@ -99,13 +133,8 @@ class MeasuredRequestTimings(StandardBaseModel): ) -MeasuredRequestTimingsT = TypeVar( - "MeasuredRequestTimingsT", bound=MeasuredRequestTimings -) -"""Generic timing measurements type for backend-specific request processing.""" - - -class ScheduledRequestInfo(StandardBaseModel, Generic[MeasuredRequestTimingsT]): +@SchedulerMessagingPydanticRegistry.register() +class ScheduledRequestInfo(StandardBaseModel): """ Complete request information including status, timings, and metadata. @@ -155,12 +184,12 @@ class ScheduledRequestInfo(StandardBaseModel, Generic[MeasuredRequestTimingsT]): default_factory=RequestSchedulerTimings, description="Scheduler-level timing measurements for request lifecycle", ) - request_timings: MeasuredRequestTimingsT | None = Field( + request_timings: MeasuredRequestTimings | None = Field( default=None, description="Backend-specific timing measurements for request processing", ) - @computed_field + @computed_field # type: ignore[misc] @property def started_at(self) -> float | None: """ @@ -174,7 +203,7 @@ def started_at(self) -> float | None: return request_start or self.scheduler_timings.resolve_start - @computed_field + @computed_field # type: ignore[misc] @property def completed_at(self) -> float | None: """ @@ -186,7 +215,12 @@ def completed_at(self) -> float | None: return request_end or self.scheduler_timings.resolve_end - def model_copy(self) -> ScheduledRequestInfo: + def model_copy(self, **kwargs) -> ScheduledRequestInfo: # type: ignore[override] # noqa: ARG002 + """ + Create a deep copy of the request info with copied timing objects. + + :return: New ScheduledRequestInfo instance with independent timing objects + """ return super().model_copy( update={ "scheduler_timings": self.scheduler_timings.model_copy(), @@ -198,7 +232,7 @@ def model_copy(self) -> ScheduledRequestInfo: ) -class BackendInterface(ABC, Generic[RequestT, MeasuredRequestTimingsT, ResponseT]): +class BackendInterface(Protocol, Generic[RequestT, ResponseT]): """ Abstract interface for request processing backends. @@ -222,28 +256,23 @@ async def resolve(self, request, request_info, history=None): """ @property - @abstractmethod def processes_limit(self) -> int | None: """ - :return: The maximum worker processes supported, or None if unlimited + :return: Maximum worker processes supported, or None if unlimited """ @property - @abstractmethod def requests_limit(self) -> int | None: """ - :return: The maximum concurrent requests supported, or None if unlimited + :return: Maximum concurrent requests supported, or None if unlimited """ @property - @abstractmethod def info(self) -> dict[str, Any]: """ - :return: The backend metadata including model initialization and configuration. + :return: Backend metadata including model initialization and configuration """ - ... - @abstractmethod async def process_startup(self) -> None: """ Perform backend initialization and startup procedures. @@ -251,7 +280,6 @@ async def process_startup(self) -> None: :raises: Implementation-specific exceptions for startup failures. """ - @abstractmethod async def validate(self) -> None: """ Validate backend configuration and operational status. @@ -259,7 +287,6 @@ async def validate(self) -> None: :raises: Implementation-specific exceptions for validation failures. """ - @abstractmethod async def process_shutdown(self) -> None: """ Perform backend cleanup and shutdown procedures. @@ -267,13 +294,12 @@ async def process_shutdown(self) -> None: :raises: Implementation-specific exceptions for shutdown failures. """ - @abstractmethod async def resolve( self, request: RequestT, - request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + request_info: ScheduledRequestInfo, history: list[tuple[RequestT, ResponseT]] | None = None, - ) -> AsyncIterator[tuple[ResponseT, ScheduledRequestInfo[MeasuredRequestTimingsT]]]: + ) -> AsyncIterator[tuple[ResponseT, ScheduledRequestInfo]]: """ Process a request and yield incremental response updates. @@ -298,14 +324,9 @@ class SchedulerUpdateActionProgress(TypedDict, total=False): track execution progress and make termination decisions. """ - remaining_fraction: float | None = None - """Estimated fraction of work remaining (0.0 to 1.0), if known.""" - - remaining_requests: float | None = None - """Estimated number of requests remaining to be processed, if known.""" - - remaining_duration: float | None = None - """Estimated time remaining in seconds for completion, if known.""" + remaining_fraction: float | None + remaining_requests: float | None + remaining_duration: float | None class SchedulerUpdateAction(StandardBaseModel): @@ -340,7 +361,7 @@ class SchedulerUpdateAction(StandardBaseModel): description="Additional context and data for the scheduler action", ) progress: SchedulerUpdateActionProgress = Field( - default_factory=SchedulerUpdateActionProgress, + default_factory=lambda: SchedulerUpdateActionProgress(), description="Progress information for the scheduler action", ) @@ -394,7 +415,7 @@ class SchedulerState(StandardBaseModel): ) end_processing_constraints: dict[str, SchedulerUpdateAction] = Field( default_factory=dict, - description="Constraints that triggered processing termination", + description="Constraints that triggered process ing termination", ) scheduler_constraints: dict[str, SchedulerUpdateAction] = Field( default_factory=dict, @@ -409,7 +430,7 @@ class SchedulerState(StandardBaseModel): "Estimated fraction for the remaining progress of the run, if known" ), ) - remaining_requests: int | None = Field( + remaining_requests: float | None = Field( default=None, description="Estimated number of requests remaining to be processed, if known", ) @@ -427,7 +448,8 @@ class SchedulerState(StandardBaseModel): default=0, description="Total number of requests queued for processing" ) pending_requests: int = Field( - default=0, description="Number of requests currently pending processing" + default=0, + description="Total number of requests pending processing within a worker", ) processing_requests: int = Field( default=0, description="Number of requests currently being processed" diff --git a/src/guidellm/scheduler/scheduler.py b/src/guidellm/scheduler/scheduler.py index e4e9f4f6..8089c64c 100644 --- a/src/guidellm/scheduler/scheduler.py +++ b/src/guidellm/scheduler/scheduler.py @@ -20,7 +20,6 @@ from guidellm.scheduler.environment import Environment, NonDistributedEnvironment from guidellm.scheduler.objects import ( BackendInterface, - MeasuredRequestTimingsT, MultiTurnRequestT, RequestT, ResponseT, @@ -35,7 +34,7 @@ class Scheduler( - Generic[RequestT, MeasuredRequestTimingsT, ResponseT], + Generic[RequestT, ResponseT], ThreadSafeSingletonMixin, ): """ @@ -68,7 +67,7 @@ class Scheduler( async def run( self, requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], - backend: BackendInterface[RequestT, MeasuredRequestTimingsT, ResponseT], + backend: BackendInterface[RequestT, ResponseT], strategy: SchedulingStrategy, env: Environment | None, **constraints: dict[str, Any | dict[str, Any] | Constraint], @@ -76,7 +75,7 @@ async def run( tuple[ ResponseT | None, RequestT, - ScheduledRequestInfo[MeasuredRequestTimingsT], + ScheduledRequestInfo, SchedulerState, ] ]: @@ -107,9 +106,7 @@ async def run( if env is None: env = NonDistributedEnvironment() - worker_group: ( - WorkerProcessGroup[RequestT, MeasuredRequestTimingsT, ResponseT] | None - ) = None + worker_group: WorkerProcessGroup[RequestT, ResponseT] | None = None # Any issues during the run will raise an error (local or remote), # be caught and passed to the environment, @@ -126,11 +123,10 @@ async def run( ) = await env.sync_run_params(requests, strategy, constraints) # Setup the worker group, sync start with the environment - worker_group = WorkerProcessGroup[ - RequestT, MeasuredRequestTimingsT, ResponseT - ]( + worker_group = WorkerProcessGroup[RequestT, ResponseT]( + requests=None, + cycle_requests=local_requests, backend=backend, - requests=local_requests, strategy=local_strategy, constraints=local_constraints, ) diff --git a/src/guidellm/scheduler/strategy.py b/src/guidellm/scheduler/strategy.py index 15e15e7c..8c791671 100644 --- a/src/guidellm/scheduler/strategy.py +++ b/src/guidellm/scheduler/strategy.py @@ -1,30 +1,10 @@ """ -Request scheduling strategies for the GuideLLM toolkit. - -This module provides a comprehensive set of scheduling strategies that control how -requests are processed and timed within the GuideLLM benchmarking system. These -strategies enable fine-grained control over request concurrency, timing patterns, -and throughput characteristics to simulate various real-world usage scenarios. - -The scheduling system is built around abstract timing implementations that define -when requests should be executed, and concrete strategy classes that combine -timing behaviors with process and concurrency limits. - -Classes: - ScheduledRequestTimings: Abstract base class for request timing implementations - LastCompletionRequestTimings: Timing implementation for synchronous/concurrent - strategies - NoDelayRequestTimings: Timing implementation for throughput-maximizing strategies - ConstantRateRequestTimings: Timing implementation for constant-rate request - scheduling - PoissonRateRequestTimings: Timing implementation for Poisson-distributed request - scheduling - SchedulingStrategy: Abstract base class for all scheduling strategies - SynchronousStrategy: Sequential request processing with maximum throughput - ConcurrentStrategy: Parallel request processing with limited concurrency - ThroughputStrategy: Unrestricted request processing for maximum system throughput - AsyncConstantStrategy: Asynchronous request scheduling at a constant rate - AsyncPoissonStrategy: Asynchronous request scheduling with Poisson distribution +Request scheduling strategies for controlling how benchmark requests are processed. + +This module provides timing implementations and concrete strategies that control request +concurrency, timing patterns, and throughput characteristics to simulate real-world +usage scenarios. The scheduling system separates timing logic from strategy constraints, +enabling flexible combination of timing behaviors with process and concurrency limits. """ from __future__ import annotations @@ -33,7 +13,7 @@ import random import time from abc import ABC, abstractmethod -from typing import ClassVar, Literal, TypeVar +from typing import Annotated, ClassVar, Literal, TypeVar from pydantic import Field, PrivateAttr @@ -57,23 +37,29 @@ ] -StrategyType = Literal["synchronous", "concurrent", "throughput", "constant", "poisson"] +StrategyType = Annotated[ + Literal["synchronous", "concurrent", "throughput", "constant", "poisson"], + "Valid strategy type identifiers for scheduling request patterns", +] def _exponential_decay_tau(max_progress: float, convergence: float = 0.99) -> float: """ + Calculate tau value for exponential decay to reach target progress level. + :param max_progress: The max progress value to reach - :param convergence: The target convergence level for reaching max_progress. - Default 0.99 represents at 99% exponential decay reach max_progress. - :return: The calculated tau value for the given max_progress and convergence. + :param convergence: The target convergence level for reaching max_progress + :return: The calculated tau value for the given max_progress and convergence """ return max_progress / (-math.log(1 - convergence)) def _exponential_decay_fraction(progress: float, tau: float = 1.0) -> float: """ + Calculate completion fraction based on exponential decay curve. + :param progress: The current progress value (>=0) - :param tau: The scale factor for the exponential decay (default: 1.0) + :param tau: The scale factor for the exponential decay :return: The fraction of completion based on exponential decay (0 -> 1) """ return 1 - math.exp(-progress / tau) @@ -81,15 +67,11 @@ def _exponential_decay_fraction(progress: float, tau: float = 1.0) -> float: class ScheduledRequestTimings(StandardBaseModel, ABC): """ - Abstract base class for request timing implementations in scheduling strategies. - - This class defines the interface for controlling when requests are scheduled - and how timing offsets are calculated. Different implementations provide - various timing behaviors such as synchronous, constant-rate, or stochastic - request scheduling patterns. + Abstract base class for controlling when requests are scheduled. - Implementations must provide logic for calculating the next request offset - and handling request completion events that may affect future timing decisions. + Defines the interface for timing implementations that determine request scheduling + behavior. Different implementations provide various patterns like synchronous, + constant-rate, or stochastic scheduling to simulate real-world scenarios. """ @abstractmethod @@ -97,21 +79,16 @@ def next_offset(self) -> float: """ Calculate the time offset for the next request to be scheduled. - :return: The offset in seconds from the scheduler start time when the - next request should be scheduled. + :return: The offset in seconds from scheduler start time for next request """ @abstractmethod def request_completed(self, request_info: ScheduledRequestInfo): """ - Handle the completion of a request and update internal timing state. - - This method is called when a request completes (successfully or with error) - and allows the timing implementation to update its internal state based on - the completion information. + Handle request completion and update internal timing state. :param request_info: Information about the completed request including - timing details and completion status. + timing details and completion status """ @@ -119,37 +96,31 @@ class LastCompletionRequestTimings(ScheduledRequestTimings): """ Timing implementation for synchronous and concurrent scheduling strategies. - This implementation schedules the next request immediately after the last - request has completed, enabling sequential or limited concurrent processing. - It maintains an internal offset based on completion times to ensure proper - scheduling behavior. + Schedules the next request immediately after the last request completes, enabling + sequential or limited concurrent processing with completion-based timing control. """ offset: float = Field( default=0.0, - description="The current time offset in seconds from scheduler start time.", + description="Current time offset in seconds from scheduler start time", ) startup_requests: int = Field( default=0, - description=( - "Number of initial requests to schedule during startup phase with equal " - "spacing of startup_requests_delay before going to last request times." - ), + description="Number of initial requests to schedule with equal spacing", ge=0, ) startup_requests_delay: float = Field( default=0.0, - description=( - "Delay in seconds used to add to the offset for each request " - "within the startup phase (_requests_count <= startup_requests)." - ), + description="Delay in seconds between startup requests", ge=0, ) _requests_count: int = PrivateAttr(0) def next_offset(self) -> float: """ - :return: The current offset value in seconds from scheduler start time. + Get the current offset value and apply startup delay if applicable. + + :return: The current offset value in seconds from scheduler start time """ self._requests_count += 1 @@ -160,10 +131,9 @@ def next_offset(self) -> float: def request_completed(self, request_info: ScheduledRequestInfo): """ - Update timing state and offset based on the completed request. + Update timing state based on the completed request. - :param request_info: Information about the completed request including - timing details and completion status. + :param request_info: Information about the completed request """ if ( self._requests_count > self.startup_requests @@ -177,42 +147,37 @@ class NoDelayRequestTimings(ScheduledRequestTimings): """ Timing implementation for throughput-maximizing scheduling strategies. - This implementation schedules requests with no delay, allowing the system - to process requests as quickly as possible. It always returns a zero offset, - enabling maximum throughput by scheduling requests immediately without - waiting for previous requests to complete. + Schedules requests with minimal delay to achieve maximum throughput, with optional + startup ramping to gradually increase request processing during initialization. """ offset: float = Field( default=0.0, - description="The time offset to apply in seconds from scheduler start time.", + description="Base time offset in seconds from scheduler start time", ge=0, ) startup_duration: float = Field( default=0.0, - description=( - "The duration of the startup phase in seconds to gradually ramp up " - "request processing." - ), + description="Duration in seconds for gradual startup ramp", ge=0, ) startup_target_requests: int = Field( default=1, - description=( - "The target number of requests to converge to in the startup phase." - ), + description="Target number of requests to converge to during startup", gt=0, ) startup_convergence: float = Field( default=0.99, - description=("The target convergence rate during the startup phase."), + description="Target convergence rate during startup phase", ) _start_time: float | None = PrivateAttr(None) _requests_count: int = PrivateAttr(0) def next_offset(self) -> float: """ - :return: Static offset plus any startup adjustment. + Calculate offset with optional startup adjustment. + + :return: Static offset plus any startup adjustment """ if self._start_time is None: self._start_time = time.time() @@ -236,7 +201,7 @@ def request_completed(self, request_info: ScheduledRequestInfo): """ Handle request completion (no action needed for throughput strategy). - :param request_info: Information about the completed request (unused). + :param request_info: Information about the completed request (unused) """ @@ -244,18 +209,17 @@ class ConstantRateRequestTimings(ScheduledRequestTimings): """ Timing implementation for constant-rate scheduling strategies. - This implementation schedules requests at a constant rate defined in requests - per second. The offset for each subsequent request is calculated as a multiple - of the interval between requests, ensuring evenly spaced request scheduling. + Schedules requests at a fixed rate with evenly spaced intervals to provide + predictable timing behavior for steady-state load simulation. """ rate: float = Field( - description="The target rate in requests per second. Must be positive.", + description="Target rate in requests per second", gt=0, ) offset: float = Field( default=0.0, - description="The time offset to apply in seconds from scheduler start time.", + description="Base time offset in seconds from scheduler start time", ge=0, ) _requests_count: int = PrivateAttr(0) @@ -264,10 +228,7 @@ def next_offset(self) -> float: """ Calculate the offset for the next request at a constant rate. - Each request is scheduled at a fixed interval based on the target rate, - with offsets increasing linearly: 0, 1/rate, 2/rate, 3/rate, etc. - - :return: The offset in seconds for the next request. + :return: The offset in seconds for the next request """ num_requests = self._requests_count self._requests_count += 1 @@ -279,7 +240,7 @@ def request_completed(self, request_info: ScheduledRequestInfo): """ Handle request completion (no action needed for constant rate strategy). - :param request_info: Information about the completed request (unused). + :param request_info: Information about the completed request (unused) """ @@ -287,25 +248,21 @@ class PoissonRateRequestTimings(ScheduledRequestTimings): """ Timing implementation for Poisson-distributed scheduling strategies. - This implementation schedules requests following a Poisson process with - exponentially distributed inter-arrival times. The average rate is specified - in requests per second, but individual intervals vary randomly according to - the exponential distribution, simulating realistic traffic patterns. + Schedules requests following a Poisson process with exponentially distributed + inter-arrival times to simulate realistic traffic patterns with random variance. """ rate: float = Field( - description="The target average rate in requests per second. Must be positive.", + description="Target average rate in requests per second", gt=0, ) random_seed: int = Field( default=42, - description=( - "Seed for the random number generator to ensure reproducible behavior." - ), + description="Seed for random number generator for reproducible behavior", ) offset: float = Field( default=0.0, - description="The time offset to apply in seconds from scheduler start time.", + description="Base time offset in seconds from scheduler start time", ) _requests_count: int = PrivateAttr(0) _random: random.Random | None = PrivateAttr(None) @@ -314,11 +271,7 @@ def next_offset(self) -> float: """ Calculate the offset for the next request using Poisson distribution. - Uses exponential distribution to generate inter-arrival times that - follow a Poisson process. Each call advances the cumulative offset - by a randomly generated delay. - - :return: The cumulative offset in seconds for the next request. + :return: The cumulative offset in seconds for the next request """ self._requests_count += 1 @@ -334,16 +287,16 @@ def request_completed(self, request_info: ScheduledRequestInfo): """ Handle request completion (no action needed for Poisson rate strategy). - :param request_info: Information about the completed request (unused). + :param request_info: Information about the completed request (unused) """ -class SchedulingStrategy( - PydanticClassRegistryMixin["type[SchedulingStrategy]"], InfoMixin -): +class SchedulingStrategy(PydanticClassRegistryMixin["SchedulingStrategy"], InfoMixin): """ - An abstract base class for scheduling strategies enabling control over how - requests are processed by the scheduler. + Abstract base class for scheduling strategies controlling request processing. + + Defines the interface for strategies that combine timing implementations with + process and concurrency constraints to enable various benchmark scenarios. """ schema_discriminator: ClassVar[str] = "type_" @@ -356,22 +309,24 @@ def __pydantic_schema_base_type__(cls) -> type[SchedulingStrategy]: return SchedulingStrategy type_: Literal["strategy"] = Field( - description="The type of scheduling strategy to schedule requests with.", + description="The type of scheduling strategy to schedule requests with", ) @property def processes_limit(self) -> int | None: """ - :return: The maximum number of worker processes supported by the - scheduling strategy. None if not limited. + Get the maximum number of worker processes supported by this strategy. + + :return: Maximum number of worker processes, None if unlimited """ return None @property def requests_limit(self) -> int | None: """ - :return: The maximum number of concurrent requests that can be processed - at once by the scheduling strategy. None if not limited. + Get the maximum number of concurrent requests supported by this strategy. + + :return: Maximum number of concurrent requests, None if unlimited """ return None @@ -379,14 +334,13 @@ def create_request_timings( self, local_rank: int, local_world_size: int, local_max_concurrency: int ) -> ScheduledRequestTimings: """ - Create a ScheduledRequestTimings instance to define the timing behavior - for the worker process to schedule requests. + Create a timing instance to define scheduling behavior for a worker process. - :param local_rank: The rank of the worker process within the local world size. - :param local_world_size: The total num of worker processes in the local world. - :param local_max_concurrency: The maximum number of concurrent requests - for the worker process. - :return: A ScheduledRequestTimings instance for the worker process. + :param local_rank: The rank of the worker process within local world size + :param local_world_size: Total number of worker processes in local world + :param local_max_concurrency: Maximum concurrent requests for the worker + :return: A ScheduledRequestTimings instance for the worker process + :raises NotImplementedError: Must be implemented by subclasses """ raise NotImplementedError( "create_worker_timings method must be implemented by subclasses." @@ -399,53 +353,55 @@ def create_request_timings( @SchedulingStrategy.register("synchronous") class SynchronousStrategy(SchedulingStrategy): """ - Sequential request processing strategy with maximum throughput constraints. + Sequential request processing strategy with single-process constraint. - This strategy processes requests one at a time in strict sequential order, - waiting for each request to complete before starting the next. It provides - the most predictable timing behavior and is useful for measuring maximum - achievable throughput under sequential processing constraints. - - The strategy enforces a limit of one worker process and one concurrent request, - making it ideal for scenarios where request ordering and isolation are critical. + Processes requests one at a time in strict sequential order, providing predictable + timing behavior ideal for measuring maximum sequential throughput and ensuring + request isolation. """ type_: Literal["synchronous"] = "synchronous" # type: ignore[assignment] def __str__(self) -> str: - """Return string representation of the strategy.""" + """ + Return string representation of the strategy. + + :return: String identifier for synchronous strategy + """ return "synchronous" @property def processes_limit(self) -> int | None: """ - Get the maximum number of worker processes for synchronous scheduling. + Get maximum number of worker processes for synchronous scheduling. - :return: Always returns 1 to enforce single-process constraint. + :return: Always returns 1 to enforce single-process constraint """ return 1 @property def requests_limit(self) -> int | None: """ - Get the maximum number of concurrent requests for synchronous scheduling. + Get maximum number of concurrent requests for synchronous scheduling. - :return: Always returns 1 to enforce single-request constraint. + :return: Always returns 1 to enforce single-request constraint """ return 1 def create_request_timings( - self, local_rank: int, local_world_size: int, local_max_concurrency: int + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, # noqa: ARG002 ) -> ScheduledRequestTimings: """ - Create timing implementation for synchronous request scheduling. + Create timing implementation for synchronous request scheduling. - :param local_rank: The rank of the worker process. Must be 0. - :param local_world_size: Total number of worker processes. Must be 1. - :param local_max_concurrency: The maximum number of concurrent requests - for the worker process. Unused in this strategy. - :return: LastCompletionRequestTimings instance for sequential processing. - :raises ValueError: If multiple workers or non-zero rank is specified. + :param local_rank: The rank of the worker process (must be 0) + :param local_world_size: Total number of worker processes (must be 1) + :param local_max_concurrency: Maximum concurrent requests (unused) + :return: LastCompletionRequestTimings instance for sequential processing + :raises ValueError: If multiple workers or non-zero rank specified """ if local_world_size > 1 or local_rank != 0: raise ValueError( @@ -460,69 +416,62 @@ class ConcurrentStrategy(SchedulingStrategy): """ Parallel request processing strategy with controlled concurrency limits. - This strategy enables concurrent request processing up to a specified number - of streams, allowing multiple requests to be processed simultaneously while - maintaining predictable resource usage. It provides a balance between - throughput and resource control. - - The number of concurrent streams determines both the maximum number of worker - processes and the maximum number of requests that can be processed in parallel. - Each worker process handles one stream and waits for request completion before - processing the next request in that stream. + Enables concurrent request processing up to a specified number of streams, + providing balanced throughput while maintaining predictable resource usage + and completion-based timing coordination. """ type_: Literal["concurrent"] = "concurrent" # type: ignore[assignment] streams: int = Field( - description=( - "The number of concurrent streams to use for scheduling requests. " - "This must be a positive integer." - ), + description="Number of concurrent streams for scheduling requests", gt=0, ) startup_duration: float = Field( default=0.0, - description=( - "Duration in seconds over which startup requests are distributed " - "before switching to completion-based timing." - ), + description="Duration in seconds for distributing startup requests", ge=0, ) def __str__(self) -> str: - """Return string representation of the strategy.""" + """ + Return string representation of the strategy. + + :return: String identifier with stream count + """ return f"concurrent@{self.streams}" @property def processes_limit(self) -> int: """ - Get the maximum number of worker processes for concurrent scheduling. + Get maximum number of worker processes for concurrent scheduling. - :return: The number of streams, which equals the maximum worker processes. + :return: Number of streams as maximum worker processes """ return self.streams @property def requests_limit(self) -> int: """ - Get the maximum number of concurrent requests for concurrent scheduling. + Get maximum number of concurrent requests for concurrent scheduling. - :return: The number of streams, which equals the maximum concurrent requests. + :return: Number of streams as maximum concurrent requests """ return self.streams def create_request_timings( - self, local_rank: int, local_world_size: int, local_max_concurrency: int + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, # noqa: ARG002 ) -> LastCompletionRequestTimings: """ - Create timing implementation for concurrent request scheduling. + Create timing implementation for concurrent request scheduling. - :param local_rank: The rank of the worker process. Must be less than streams. - :param local_world_size: Total number of worker processes. Must not exceed - streams. - :param local_max_concurrency: The maximum number of concurrent requests - for the worker process. Unused in this strategy. - :return: LastCompletionRequestTimings instance for stream-based processing. - :raises ValueError: If worker configuration exceeds stream limits. + :param local_rank: The rank of the worker process (must be < streams) + :param local_world_size: Total worker processes (must not exceed streams) + :param local_max_concurrency: Maximum concurrent requests (unused) + :return: LastCompletionRequestTimings instance for stream-based processing + :raises ValueError: If worker configuration exceeds stream limits """ if local_world_size > self.streams: raise ValueError( @@ -567,54 +516,45 @@ class ThroughputStrategy(SchedulingStrategy): """ Maximum throughput strategy with optional concurrency limits. - This strategy schedules requests to maximize system throughput by allowing - unlimited concurrent request processing. Requests are scheduled immediately - without waiting for previous requests to complete, enabling the system to - achieve its maximum processing capacity. - - An optional maximum concurrency limit can be set to prevent resource - exhaustion while still allowing high-throughput processing patterns. + Schedules requests to maximize system throughput by allowing unlimited concurrent + processing with optional constraints and startup ramping for controlled ramp-up. """ type_: Literal["throughput"] = "throughput" # type: ignore[assignment] max_concurrency: int | None = Field( default=None, - description=( - "The maximum number of concurrent requests to schedule. " - "This must be a positive integer greater than 0." - ), + description="Maximum number of concurrent requests to schedule", gt=0, ) startup_duration: float = Field( default=0.0, - description=( - "Duration in seconds over which startup requests are distributed " - "before switching to full throughput scheduling." - ), + description="Duration in seconds for startup request distribution", ge=0, ) def __str__(self) -> str: - """Return string representation of the strategy.""" + """ + Return string representation of the strategy. + + :return: String identifier for throughput strategy + """ return "throughput" @property def processes_limit(self) -> int | None: """ - Get the maximum number of worker processes for throughput scheduling. + Get maximum number of worker processes for throughput scheduling. :return: The max_concurrency value if set, otherwise None for unlimited - worker processes. """ return self.max_concurrency @property def requests_limit(self) -> int | None: """ - Get the maximum number of concurrent requests for throughput scheduling. + Get maximum number of concurrent requests for throughput scheduling. :return: The max_concurrency value if set, otherwise None for unlimited - concurrent requests. """ return self.max_concurrency @@ -624,12 +564,10 @@ def create_request_timings( """ Create timing implementation for throughput request scheduling. - :param local_rank: The rank of the worker process (unused for throughput). - :param local_world_size: Total number of worker processes (unused for - throughput). - :param local_max_concurrency: The maximum number of concurrent requests - for the worker process. - :return: NoDelayRequestTimings instance for immediate request scheduling. + :param local_rank: The rank of the worker process + :param local_world_size: Total number of worker processes + :param local_max_concurrency: Maximum concurrent requests for the worker + :return: NoDelayRequestTimings instance for immediate request scheduling """ if self.startup_duration > 0: # Vary offset by up to 5% of the startup duration for a bit of variance @@ -652,58 +590,52 @@ class AsyncConstantStrategy(ThroughputStrategy): """ Asynchronous constant-rate scheduling strategy for predictable load patterns. - This strategy schedules requests at a fixed rate specified in requests per - second, distributed evenly across all worker processes. It provides predictable - timing behavior while allowing asynchronous processing, making it ideal for - simulating steady-state load conditions and measuring system performance - under consistent request rates. - - The total rate is divided equally among all worker processes, ensuring the - aggregate rate matches the specified value regardless of the number of workers. + Schedules requests at a fixed rate distributed evenly across worker processes, + providing predictable timing behavior for steady-state load simulation and + consistent system performance measurement. """ type_: Literal["constant"] = "constant" # type: ignore[assignment] rate: float = Field( - description=( - "The rate at which to schedule requests asynchronously in " - "requests per second. This must be a positive float." - ), + description="Rate for scheduling requests asynchronously in requests/second", gt=0, ) startup_duration: float = Field( default=0.0, - description=( - "Duration in seconds over which startup requests are distributed " - "to converge quickly to the desired rate before switching to " - "constant-rate scheduling." - ), + description="Duration in seconds for startup request distribution", ge=0, ) def __str__(self) -> str: - """Return string representation of the strategy.""" + """ + Return string representation of the strategy. + + :return: String identifier with rate value + """ return f"constant@{self.rate:.2f}" def create_request_timings( - self, local_rank: int, local_world_size: int, local_max_concurrency: int + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, # noqa: ARG002 ) -> ScheduledRequestTimings: """ - Create timing implementation for constant-rate request scheduling. - - Divides the total rate evenly across all worker processes to maintain - the specified aggregate rate. + Create timing implementation for constant-rate request scheduling. - :param local_rank: The rank of the worker process (unused). - :param local_world_size: Total number of worker processes for rate division. - :param local_max_concurrency: The maximum number of concurrent requests - for the worker process. - :return: ConstantRateRequestTimings instance with per-worker rate. + :param local_rank: The rank of the worker process + :param local_world_size: Total number of worker processes for rate division + :param local_max_concurrency: Maximum concurrent requests for the worker + :return: ConstantRateRequestTimings instance with per-worker rate """ # Divide the rate evenly across all worker processes worker_rate = self.rate / local_world_size + # Start each worker with an offset to interleave rates + worker_offset = (1 / self.rate) * local_rank return ConstantRateRequestTimings( rate=worker_rate, + offset=worker_offset, ) @@ -712,63 +644,57 @@ class AsyncPoissonStrategy(ThroughputStrategy): """ Asynchronous Poisson-distributed scheduling strategy for realistic load simulation. - This strategy schedules requests following a Poisson process with exponentially - distributed inter-arrival times. The average rate is specified in requests per - second, but individual intervals vary randomly, providing a more realistic - simulation of user behavior and network traffic patterns. - - The total rate is divided equally among all worker processes, with each worker - using a different random seed to ensure independent request streams that - collectively achieve the target rate. + Schedules requests following a Poisson process with exponentially distributed + inter-arrival times, providing realistic simulation of user behavior and network + traffic patterns with random variance around the target rate. """ type_: Literal["poisson"] = "poisson" # type: ignore[assignment] rate: float = Field( - description=( - "The rate at which to schedule requests asynchronously in " - "requests per second. This must be a positive float." - ), + description="Rate for scheduling requests asynchronously in requests/second", gt=0, ) startup_duration: float = Field( default=0.0, - description=( - "Duration in seconds over which startup requests are distributed " - "to converge quickly to the desired rate before switching to " - "constant-rate scheduling." - ), + description="Duration in seconds for startup request distribution", ge=0, ) random_seed: int = Field( default=42, - description=("The random seed to use for the Poisson distribution."), + description="Random seed to use for Poisson distribution", ) def __str__(self) -> str: - """Return string representation of the strategy.""" + """ + Return string representation of the strategy. + + :return: String identifier with rate value + """ return f"poisson@{self.rate:.2f}" def create_request_timings( - self, local_rank: int, local_world_size: int, local_max_concurrency: int + self, + local_rank: int, + local_world_size: int, + local_max_concurrency: int, # noqa: ARG002 ) -> ScheduledRequestTimings: """ - Create timing implementation for Poisson-distributed request scheduling. + Create timing implementation for Poisson-distributed request scheduling. - Divides the total rate evenly across all worker processes and assigns - unique random seeds to ensure independent but coordinated request streams. - - :param local_rank: The rank of the worker process for seed generation. - :param local_world_size: Total number of worker processes for rate division. - :param local_max_concurrency: The maximum number of concurrent requests - for the worker process. - :return: PoissonRateRequestTimings instance with per-worker rate and - unique seed. + :param local_rank: The rank of the worker process for seed generation + :param local_world_size: Total number of worker processes for rate division + :param local_max_concurrency: Maximum concurrent requests for the worker + :return: PoissonRateRequestTimings instance with per-worker rate and unique seed """ # Divide the rate evenly across all worker processes worker_rate = self.rate / local_world_size # Use a different seed for each worker to ensure different sequences worker_seed = self.random_seed + local_rank + # Start each worker with an offset to interleave rates + worker_offset = (1 / self.rate) * local_rank + return PoissonRateRequestTimings( rate=worker_rate, random_seed=worker_seed, + offset=worker_offset, ) diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 5f9e4f3c..bf7537d0 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -1,144 +1,143 @@ """ -Worker process management for multi-process request scheduling and execution. +Individual worker process management for multi-process request execution. -Provides infrastructure for managing individual worker processes that handle -request scheduling, processing, and coordination in multi-process environments. - -Classes: - WorkerProcess: Individual worker process for request processing and coordination. +Manages worker processes that handle request scheduling, backend processing, and +coordination in distributed benchmark environments. Workers consume requests from +queues, apply timing strategies, process requests through backends, and publish +status updates while maintaining synchronization across the process group. """ from __future__ import annotations import asyncio import time -from collections.abc import Generator -from multiprocessing import Queue from multiprocessing.synchronize import Barrier as ProcessingBarrier from multiprocessing.synchronize import Event as ProcessingEvent -from queue import Empty as QueueEmpty -from threading import Event as ThreadingEvent -from typing import Generic, Literal +from typing import Annotated, Generic, Literal + +try: + import uvloop + + HAS_UVLOOP: Annotated[ + bool, "Flag indicating if uvloop is available for event loop optimization" + ] = True +except ImportError: + uvloop = None + + HAS_UVLOOP: Annotated[ + bool, "Flag indicating if uvloop is available for event loop optimization" + ] = False -import culsans from guidellm.scheduler.objects import ( BackendInterface, - MeasuredRequestTimingsT, MultiTurnRequestT, RequestT, ResponseT, ScheduledRequestInfo, + SchedulerMessagingPydanticRegistry, ) from guidellm.scheduler.strategy import ScheduledRequestTimings -from guidellm.utils import MsgpackEncoding, synchronous_to_exitable_async +from guidellm.utils import ( + InterProcessMessaging, + wait_for_sync_barrier, + wait_for_sync_event, + wait_for_sync_objects, +) __all__ = ["WorkerProcess"] -class WorkerProcess(Generic[RequestT, MeasuredRequestTimingsT, ResponseT]): +class WorkerProcess(Generic[RequestT, ResponseT]): """ - Individual worker process for request processing and coordination. - - Manages the complete lifecycle of requests from queue consumption through backend - processing and updates publication, maintaining synchronization with other - processes in the group. + Individual worker process for distributed request execution and coordination. + + Manages the complete request lifecycle from queue consumption through backend + processing and status publication. Coordinates with other workers through + barriers and events while maintaining configurable concurrency limits and + timing strategies for request scheduling. + + Example: + :: + worker = WorkerProcess( + messaging=messaging_interface, + async_limit=10, + startup_barrier=barrier, + shutdown_event=shutdown, + error_event=error, + backend=backend_instance, + request_timings=timing_strategy + ) + worker.run() """ def __init__( self, - local_rank: int, - local_world_size: int, - async_limit: int, - startup_barrier: ProcessingBarrier, - shutdown_event: ProcessingEvent, - error_event: ProcessingEvent, - requests_queue: Queue[ - tuple[ - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo[MeasuredRequestTimingsT], - ] - ], - updates_queue: Queue[ + messaging: InterProcessMessaging[ tuple[ ResponseT | None, RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo[MeasuredRequestTimingsT], - ] + ScheduledRequestInfo, + ], ], - backend: BackendInterface[RequestT, MeasuredRequestTimingsT, ResponseT], + backend: BackendInterface[RequestT, ResponseT], request_timings: ScheduledRequestTimings, - poll_intervals: float = 0.1, - max_requests_queue_buffer: int = 2, + async_limit: int, + startup_barrier: ProcessingBarrier, + requests_generated_event: ProcessingEvent, + constraint_reached_event: ProcessingEvent, + shutdown_event: ProcessingEvent, + error_event: ProcessingEvent, ): """ Initialize worker process instance. - :param local_rank: Process rank within the worker group. - :param local_world_size: Total number of worker processes in the group. - :param async_limit: Maximum concurrent requests this worker can handle. - :param startup_barrier: Multiprocessing barrier for coordinated startup. - :param shutdown_event: Event for signaling graceful shutdown. - :param error_event: Event for signaling error conditions across processes. - :param requests_queue: Queue for receiving requests to process. - :param updates_queue: Queue for publishing processing updates. - :param backend: Backend instance for processing requests. - :param request_timings: Timing strategy for request scheduling. - :param poll_intervals: Time interval for polling operations. + :param messaging: Inter-process communication interface for request coordination + :param backend: Backend instance for processing requests + :param request_timings: Timing strategy for request scheduling + :param async_limit: Maximum concurrent requests this worker can handle + :param startup_barrier: Multiprocessing barrier for coordinated startup + :param requests_generated_event: Event signaling when request generation is + complete + :param constraint_reached_event: Event signaling when processing constraints + are met + :param shutdown_event: Event for signaling graceful shutdown + :param error_event: Event for signaling error conditions across processes """ - # Worker info - self.local_rank = local_rank - self.local_world_size = local_world_size + self.messaging = messaging + self.backend = backend + self.request_timings = request_timings self.async_limit = async_limit - - # Process synchronization self.startup_barrier = startup_barrier + self.requests_generated_event = requests_generated_event + self.constraint_reached_event = constraint_reached_event self.shutdown_event = shutdown_event self.error_event = error_event - self.requests_queue = requests_queue - self.updates_queue = updates_queue - # Local synchronization (initialized during start up) - self.pending_requests_queue: culsans.Queue[ - tuple[ - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo[MeasuredRequestTimingsT], - ] - ] = None - self.pending_updates_queue: culsans.Queue[ - tuple[ - ResponseT | None, - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo[MeasuredRequestTimingsT], - ] - ] = None - self.requests_canceled: ThreadingEvent = None - self.pull_requests_stopped: ThreadingEvent = None - self.pull_task: asyncio.Task = None - self.push_task: asyncio.Task = None - - # Request processing - self.backend = backend - self.request_timings = request_timings - self.poll_intervals = poll_intervals - self.max_requests_queue_buffer = max_requests_queue_buffer - self.startup_completed: bool = False + # Internal states + self.startup_completed = False + self.backend_started = False + self.messaging_started = False def run(self): """ Main entry point for worker process execution. - Initializes asyncio event loop and starts worker async operations. + Initializes asyncio event loop with optional uvloop optimization and starts + worker async operations. Handles event loop cleanup for forked processes. - :raises RuntimeError: If worker encounters unrecoverable error during execution. + :raises RuntimeError: If worker encounters unrecoverable error during execution """ try: + if HAS_UVLOOP: + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.run(self.run_async()) except Exception as err: - print(f"******EXCEPTION in worker {self.local_rank} run: {err}") + print(f"******EXCEPTION in worker {self.messaging.worker_index} run: {err}") self.error_event.set() raise RuntimeError( - f"Worker process {self.local_rank} encountered an error: {err}" + f"Worker process {self.messaging.worker_index} encountered an " + f"error: {err}" ) from err async def run_async(self): @@ -146,393 +145,246 @@ async def run_async(self): Execute main asynchronous worker process logic. Orchestrates concurrent execution of request processing and shutdown monitoring - tasks, handling cleanup and error propagation when tasks complete. + tasks. Handles task cleanup, error propagation, and cancellation coordination + when any task completes or fails. - :raises RuntimeError: If worker tasks encounter unrecoverable errors. + :raises RuntimeError: If worker tasks encounter unrecoverable errors + :raises asyncio.CancelledError: If worker process was cancelled """ - # Start both shutdown monitoring and request processing concurrently - tasks = [ - asyncio.create_task(self.run_async_stop_processing()), - asyncio.create_task(self.run_async_requests_processing()), - ] + stop_task = asyncio.create_task(self._stop_monitor()) + request_proc_task = asyncio.create_task(self._process_requests()) + caller_cancelled = False try: - # Wait for the first task to complete (shut down or error) - completed, pending = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED + await asyncio.wait( + [stop_task, request_proc_task], + return_when=asyncio.FIRST_COMPLETED, ) + except asyncio.CancelledError: + caller_cancelled = True - # Cancel remaining tasks - if pending: - for task in pending: - task.cancel() - await asyncio.gather(*pending, return_exceptions=True) + stop_task.cancel() + request_proc_task.cancel() - # Check for exceptions in completed tasks - for task in completed: - if not task.cancelled() and (exception := task.exception()): - raise exception + try: + # Ensure all child tasks cancel correctly + await asyncio.wait( + [stop_task, request_proc_task], return_when=asyncio.ALL_COMPLETED + ) except asyncio.CancelledError: - # Ensure all tasks are canceled before re-raising - for task in tasks: - if not task.done(): - task.cancel() - if any(not task.done() for task in tasks): - await asyncio.gather(*tasks, return_exceptions=True) - raise - - async def run_async_stop_processing(self): - """ - Monitor for shutdown and error signals. + caller_cancelled = True + + if ( + task_err := ( + request_proc_task.exception() + if not request_proc_task.cancelled() + else stop_task.exception() + if not stop_task.cancelled() + else None + ) + ) is not None: + raise RuntimeError( + f"Worker process {self.messaging.worker_index} encountered an " + f"error: {task_err}" + ) from task_err - Runs in parallel with request processing to monitor for shutdown or error - events and trigger appropriate cleanup procedures. + if caller_cancelled: + raise asyncio.CancelledError("Worker process was cancelled") - :raises RuntimeError: If error event is signaled or unexpected exit occurs. - :raises asyncio.CancelledError: If shutdown event is signaled. - """ - exit_reason, _ = await synchronous_to_exitable_async( - synchronous=None, - exit_events={ + async def _stop_monitor( + self, + ) -> Literal["error_event", "shutdown_event"]: + exit_key = await wait_for_sync_objects( + { "error_event": self.error_event, "shutdown_event": self.shutdown_event, }, - poll_interval=self.poll_intervals, + poll_interval=self.messaging.poll_interval, ) - if exit_reason == "error_event": - raise RuntimeError( - f"Worker process {self.local_rank} received error signal." - ) - elif exit_reason == "shutdown_event": - raise asyncio.CancelledError( - f"Worker process {self.local_rank} received shutdown signal." - ) - else: + if exit_key == "error_event": raise RuntimeError( - f"Worker process {self.local_rank} received unexpected exit reason: " - f"{exit_reason}" + f"Worker process {self.messaging.worker_index} received error signal." ) - async def run_async_requests_processing(self): - """ - Process incoming requests from the queue. - - Handles backend initialization, process synchronization, concurrent request - processing with semaphore limiting, and graceful shutdown with task cleanup. - - :raises RuntimeError: If backend initialization or startup synchronization - fails. - :raises asyncio.CancelledError: If shutdown is requested during processing. - :raises NotImplementedError: If multi-turn requests are encountered. - """ + async def _process_requests(self): try: - await self._initialize_requests_processing() - await self._start_ready_requests_processing() - await self._loop_requests_processing() - except asyncio.CancelledError: - await self._shutdown_requests_processing() + # 1. Start up synchronization (backend, messaging, and other processes) + # 2. Messaging startup, receive requests until requests_generated event + await self._processing_startup() + + # 3. Run process requests loop until constraint_reached event + processing_task = asyncio.create_task(self._process_requests_loop()) + await wait_for_sync_event( + self.constraint_reached_event, + poll_interval=self.messaging.poll_interval, + ) + processing_task.cancel() - raise + # 4. Cancel pending requests until proc canceled (manual, shutdown, error) + await self._cancel_requests_loop() + finally: + # 5. On cancel, shut down event, error event, or internal error: + # attempt to shut down this worker cleanly (stop backend and messaging) + await self._processing_shutdown() - async def _initialize_requests_processing(self): - # Ensure backend is ready on this worker + async def _processing_startup(self): + # Get backend ready await self.backend.process_startup() + self.backend_started = True await self.backend.validate() - # Setup local queues - self.pending_requests_queue = culsans.Queue( - maxsize=self.max_requests_queue_buffer - ) - self.pending_updates_queue = culsans.Queue() - self.requests_canceled = ThreadingEvent() - self.pull_requests_stopped = ThreadingEvent() - - # Start background tasks for queue management - self.pull_task = asyncio.create_task( - synchronous_to_exitable_async( - self._pull_requests_generator(), - poll_interval=0, # no delays on thread for checking queue - ) - ) - self.push_task = asyncio.create_task( - synchronous_to_exitable_async( - self._push_updates_generator(), - poll_interval=0, # no delays on thread for checking queue - ) + # Get messaging system ready + await self.messaging.start( + receive_stop_criteria=[self.requests_generated_event], + pydantic_models=list(SchedulerMessagingPydanticRegistry.registry.values()), ) + self.messaging_started = True - async def _start_ready_requests_processing(self): # Wait for all processes to be ready - barrier_exit_reason, _ = await synchronous_to_exitable_async( - synchronous=None, - exit_barrier=self.startup_barrier, - poll_interval=self.poll_intervals, + await wait_for_sync_barrier( + self.startup_barrier, + poll_interval=self.messaging.poll_interval, ) - if barrier_exit_reason not in ["barrier", "canceled"]: - raise RuntimeError( - f"Worker process {self.local_rank} failed to synchronize at " - f"startup: {barrier_exit_reason}" - ) - self.startup_completed = True - async def _loop_requests_processing(self): - async_semaphore = asyncio.Semaphore(self.async_limit) - pending_tasks = set() + async def _processing_shutdown(self): + if self.backend_started: + await self.backend.process_shutdown() + self.backend_started = False - def _task_done(task): - pending_tasks.discard(task) - async_semaphore.release() + if self.messaging_started: + await self.messaging.stop() + self.messaging_started = False - if not task.cancelled() and (exception := task.exception()): - raise exception + self.startup_completed = False + async def _process_requests_loop(self): try: + # Run request processing + async_semaphore = asyncio.Semaphore(self.async_limit) + pending_tasks: set[asyncio.Task] = set() + + def _task_done(task): + pending_tasks.discard(task) + async_semaphore.release() + + if not task.cancelled() and (exception := task.exception()): + raise exception + # Main loop; loop until canceled while True: await async_semaphore.acquire() request_task = asyncio.create_task(self._process_next_request()) pending_tasks.add(request_task) request_task.add_done_callback(_task_done) - await asyncio.sleep(0) - except asyncio.CancelledError: - # Shut down requests queuing - self.requests_canceled.set() - - # Cancel pending requests - if pending_tasks: - for task in list(pending_tasks): - task.cancel() - await asyncio.gather(*pending_tasks, return_exceptions=True) - raise + except asyncio.CancelledError as err: + for task in pending_tasks: + task.cancel() + await asyncio.gather(*pending_tasks, return_exceptions=True) + + raise err - async def _shutdown_requests_processing(self): - if self.requests_canceled is not None: - # Queues have been constructed, cancel pending and ensure updates - self.requests_canceled.set() - await self._cancel_pending_requests() - await self.pending_updates_queue.async_join() - await self.pending_requests_queue.aclose() - await self.pending_updates_queue.aclose() - - # Cancel background tasks - tasks = [] - if self.push_task is not None and not self.push_task.done(): - self.push_task.cancel() - tasks.append(self.push_task) - if self.pull_task is not None and not self.pull_task.done(): - self.pull_task.cancel() - tasks.append(self.pull_task) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - # Shut down backend - await self.backend.process_shutdown() - - # Reset state - self.pending_requests_queue = None - self.pending_updates_queue = None - self.pull_task = None - self.push_task = None - self.requests_canceled = None + async def _cancel_requests_loop(self): + while True: + try: + request: RequestT + request_info: ScheduledRequestInfo + request, request_info = await self.messaging.get( + timeout=self.messaging.poll_interval + ) + except asyncio.TimeoutError: + continue + + request_info.scheduler_node_id = self.messaging.worker_index + request_info.error = "Request was cancelled" + request_info.scheduler_timings.resolve_end = time.time() + self._send_update("cancelled", None, request, request_info) async def _process_next_request(self): request: RequestT | MultiTurnRequestT[RequestT] | None = None - request_info: ScheduledRequestInfo[MeasuredRequestTimingsT] | None = None + request_info: ScheduledRequestInfo | None = None response: ResponseT | None = None try: - # get next request to send - request, request_info = await self.pending_requests_queue.async_get() - current_time = time.time() - request_info.scheduler_timings.dequeued = current_time - await self._handle_request_update( - new_status="pending", - response=response, - request=request, - request_info=request_info, - ) + # Pull request from the queue + request, request_info = await self.messaging.get() if isinstance(request, (list, tuple)): raise NotImplementedError("Multi-turn requests are not yet supported") - # Calculate when to start processing request - timings_offset = self.request_timings.next_offset() - target_start = request_info.scheduler_start_time + timings_offset + # Calculate targeted start and set pending state for request + request_info.scheduler_node_id = self.messaging.worker_index + request_info.scheduler_timings.dequeued = time.time() + target_start = ( + request_info.scheduler_start_time + self.request_timings.next_offset() + ) request_info.scheduler_timings.targeted_start = target_start + self._send_update("pending", response, request, request_info) + # Schedule the request + current_time = time.time() + request_info.scheduler_timings.scheduled_at = current_time if target_start > current_time: await asyncio.sleep(target_start - current_time) + # Adapt delay so that scheduled at reflects the sleep time request_info.scheduler_timings.scheduled_at = target_start - else: - request_info.scheduler_timings.scheduled_at = current_time - # Process the request + # Process the request with the backend request_info.scheduler_timings.resolve_start = time.time() - await self._handle_request_update( - new_status="in_progress", - response=response, - request=request, - request_info=request_info, - ) - async for resp, updated_request_info in self.backend.resolve( - request, request_info, None - ): + self._send_update("in_progress", response, request, request_info) + async for resp, info in self.backend.resolve(request, request_info, None): response = resp - request_info = updated_request_info + request_info = info - # Complete + # Complete the request request_info.scheduler_timings.resolve_end = time.time() - await self._handle_request_update( - new_status="completed", - response=response, - request=request, - request_info=request_info, - ) + self._send_update("completed", response, request, request_info) + + response = request = request_info = None except asyncio.CancelledError: # Handle cancellation if request is not None and request_info is not None: request_info.error = "Request was cancelled" request_info.scheduler_timings.resolve_end = time.time() - await self._handle_request_update( - new_status="cancelled", - response=response, - request=request, - request_info=request_info, - ) + self._send_update("cancelled", response, request, request_info) raise except Exception as exc: # noqa: BLE001 if request is not None and request_info is not None: request_info.error = str(exc) request_info.scheduler_timings.resolve_end = time.time() - await self._handle_request_update( - new_status="errored", - response=response, - request=request, - request_info=request_info, - ) + self._send_update("errored", response, request, request_info) - async def _handle_request_update( + def _send_update( self, new_status: Literal[ "pending", "in_progress", "completed", "errored", "cancelled" ], response: ResponseT | None, request: RequestT | MultiTurnRequestT[RequestT], - request_info: ScheduledRequestInfo[MeasuredRequestTimingsT], + request_info: ScheduledRequestInfo, ): - status_orders = { - "queued": -2, # does not send event - "pending": -1, # does not send event - "in_progress": 1, - "completed": 2, - "errored": 2, - "cancelled": 2, - } prev_status = request_info.status - try: - if ( - status_orders[new_status] >= status_orders["in_progress"] - and status_orders[prev_status] < status_orders["in_progress"] - ): - # Haven't sent start update yet - request_info.status = "in_progress" - await self.pending_updates_queue.async_put( - (None, request, request_info.model_copy()) - ) - prev_status = "in_progress" - - if ( - status_orders[new_status] > status_orders["in_progress"] - and status_orders[new_status] > status_orders[prev_status] - ): - # Haven't sent resolved update yet - request_info.status = new_status - await self.pending_updates_queue.async_put( - (response, request, request_info.model_copy()) - ) - prev_status = new_status - # Notify instance states - self.request_timings.request_completed(request_info) - self.pending_requests_queue.task_done() + if new_status == prev_status: + # already sent this update, don't send again + return + + try: + request_info.status = new_status + request_info = ( + request_info.model_copy() + if new_status not in {"completed", "errored", "cancelled"} + else request_info # last update, don't need to copy + ) + self.messaging.put_sync( + (response, request, request_info), + timeout=-1, + ) + prev_status = new_status except Exception as exc: # Reset status to last one that succeeded or started function with # Calling logic can retry after handling error, if possible request_info.status = prev_status raise exc - - async def _cancel_pending_requests(self): - while True: - try: - request, request_info = await asyncio.wait_for( - self.pending_requests_queue.async_get(), timeout=self.poll_intervals - ) - request_info.error = "Request was cancelled" - request_info.scheduler_timings.resolve_end = time.time() - await self._handle_request_update( - new_status="cancelled", - response=None, - request=request, - request_info=request_info, - ) - except (culsans.QueueEmpty, asyncio.TimeoutError): - if self.pull_requests_stopped.is_set(): - # No more requests will be put on the Queue - break - - def _pull_requests_generator(self) -> Generator: - last_check = time.time() - - while True: - if self.requests_canceled.is_set(): - break - - try: - message = self.requests_queue.get(timeout=self.poll_intervals) - request_tuple = MsgpackEncoding.decode(message) - self.pending_requests_queue.sync_put(request_tuple) - except QueueEmpty: - pass # No update available, continue polling - except culsans.QueueShutDown: - break - except Exception: # noqa: BLE001, S110 - pass - - if time.time() - last_check > self.poll_intervals: - # Yield to allow cancel/error/stop checks in wrapper - last_check = time.time() - yield None - - self.pull_requests_stopped.set() - - def _push_updates_generator(self) -> Generator: - last_check = time.time() - - while True: - try: - update_tuple = self.pending_updates_queue.sync_get( - timeout=self.poll_intervals - ) - response: ResponseT | None = update_tuple[0] - request: RequestT | MultiTurnRequestT[RequestT] = update_tuple[1] - request_info: ScheduledRequestInfo[MeasuredRequestTimingsT] = ( - update_tuple[2] - ) - - message = MsgpackEncoding.encode((response, request, request_info)) - self.updates_queue.put(message) - self.pending_updates_queue.task_done() - except culsans.QueueEmpty: - pass # No update available, continue polling - except culsans.QueueShutDown: - break - except Exception: # noqa: BLE001, S110 - pass - - if time.time() - last_check > self.poll_intervals: - # Yield to allow cancel/error/stop checks in wrapper - last_check = time.time() - yield None diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py index 52a711fd..15172f49 100644 --- a/src/guidellm/scheduler/worker_group.py +++ b/src/guidellm/scheduler/worker_group.py @@ -2,126 +2,172 @@ Multi-process worker group orchestration for distributed request scheduling. Provides infrastructure for coordinating worker processes with shared state -management, inter-process communication, and lifecycle coordination. - -Classes: - WorkerProcessGroup: Orchestrates multiple worker processes for distributed - request processing with centralized coordination. +management, inter-process communication, and lifecycle coordination. Handles +dynamic scaling, load balancing, constraint evaluation, and graceful shutdown +across distributed workers processing concurrent requests. """ from __future__ import annotations import asyncio -import contextlib import math -import queue import threading import time import uuid -from asyncio import Task -from collections.abc import AsyncIterator, Iterable, Iterator -from multiprocessing import Queue, get_context +from collections.abc import AsyncIterator, Generator, Iterable, Iterator +from multiprocessing import get_context +from multiprocessing.context import BaseContext +from multiprocessing.managers import BaseManager from multiprocessing.process import BaseProcess from multiprocessing.synchronize import Barrier, Event -from threading import Event as ThreadingEvent -from typing import Generic - -import culsans +from typing import Generic, NamedTuple -from guidellm.config import settings -from guidellm.scheduler.constraints import Constraint +from guidellm.scheduler.constraints import Constraint, RequestsExhaustedConstraint from guidellm.scheduler.objects import ( BackendInterface, - MeasuredRequestTimingsT, MultiTurnRequestT, RequestT, ResponseT, ScheduledRequestInfo, + SchedulerMessagingPydanticRegistry, SchedulerState, + SchedulerUpdateAction, ) from guidellm.scheduler.strategy import SchedulingStrategy from guidellm.scheduler.worker import WorkerProcess -from guidellm.utils import MsgpackEncoding, synchronous_to_exitable_async +from guidellm.settings import settings +from guidellm.utils import ( + InterProcessMessaging, + InterProcessMessagingManagerQueue, + InterProcessMessagingPipe, + InterProcessMessagingQueue, + wait_for_sync_objects, +) -__all__ = ["WorkerProcessGroup"] +__all__ = ["WorkerGroupState", "WorkerProcessGroup"] -class WorkerProcessGroup(Generic[RequestT, MeasuredRequestTimingsT, ResponseT]): +class WorkerProcessGroup(Generic[RequestT, ResponseT]): """ Orchestrates multiple worker processes for distributed request processing. Manages process lifecycle, request distribution, response collection, and state synchronization across workers. Handles dynamic scaling, load balancing, and - constraint evaluation with graceful shutdown coordination. + constraint evaluation with graceful shutdown coordination for high-throughput + request processing workloads. + + Example: + :: + from guidellm.scheduler.worker_group import WorkerProcessGroup + + group = WorkerProcessGroup( + requests=request_iterable, + cycle_requests=None, + backend=backend_instance, + strategy=scheduling_strategy, + constraints={"max_time": time_constraint} + ) + + await group.create_processes() + await group.start(time.time()) + + async for response, request, info, state in group.request_updates(): + if response is not None: + # Process completed request + handle_response(response) + + await group.shutdown() """ def __init__( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], - backend: BackendInterface[RequestT, MeasuredRequestTimingsT, ResponseT], + requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, + cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, + backend: BackendInterface[RequestT, ResponseT], strategy: SchedulingStrategy, constraints: dict[str, Constraint], - infinite_requests: bool | None = None, ): + """ + Initialize a worker process group for distributed request processing. + + :param requests: Finite iterable of requests to process sequentially + :param cycle_requests: Iterable of requests to cycle through indefinitely + :param backend: Backend interface for processing requests + :param strategy: Scheduling strategy for request timing and distribution + :param constraints: Named constraints for controlling execution behavior + :raises ValueError: If neither requests nor cycle_requests are provided, + or if cycle_requests is an Iterator rather than Iterable + """ + if not requests and not cycle_requests: + raise ValueError( + "At least one of 'requests' or 'cycle_requests' must be provided. " + f"Got requests: {requests}, cycle_requests: {cycle_requests}" + ) + + if isinstance(cycle_requests, Iterator): + raise ValueError( + f"cycle_requests must be an Iterable or None, not an Iterator. " + f"Got {type(cycle_requests)}" + ) + self.requests = requests + self.cycle_requests = cycle_requests self.backend = backend self.strategy = strategy self.constraints = constraints - self.infinite_requests = infinite_requests # Multiprocessing contexts and primitives, created in create_processes - self.mp_context = None + self.mp_context: BaseContext = None + self.mp_manager: BaseManager = None self.processes: list[BaseProcess] = None self.startup_barrier: Barrier = None + self.requests_generated_event: Event = None + self.constraint_reached_event: Event = None self.shutdown_event: Event = None self.error_event: Event = None - self.requests_queue: Queue[ + + # Scheduler and messaging state, created in start + self.state: WorkerGroupState[ResponseT, RequestT] = None + self.messaging: InterProcessMessaging[ tuple[ RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo[MeasuredRequestTimingsT], - ] - ] = None - self.updates_queue: Queue[ - tuple[ - ResponseT | None, - RequestT, - ScheduledRequestInfo[MeasuredRequestTimingsT], - ] - ] = None - - # Local process async/threading bridges + signals - self.pending_updates_queue: culsans.Queue[ + ScheduledRequestInfo, + ], tuple[ ResponseT | None, RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo[MeasuredRequestTimingsT], - ] + ScheduledRequestInfo, + SchedulerState, + ], ] = None - self.pending_requests_complete: ThreadingEvent = None - self.pending_updates_complete: ThreadingEvent = None - self.populate_requests_task: Task = None - self.populate_updates_task: Task = None - - # Scheduler state - self.state_update_lock: threading.Lock = None - self.scheduler_state: SchedulerState = None async def create_processes(self): """ - Initialize and start the worker process group. + Create and initialize worker processes for distributed request processing. Sets up multiprocessing infrastructure and worker processes based on strategy constraints, backend capabilities, and system configuration. + Determines optimal process count and concurrency limits, then spawns + worker processes with distributed request handling capabilities. - :param backend: Backend instance for processing requests. - :param requests: Iterable of requests to process. - :param strategy: Scheduling strategy configuration. - :param constraints: Dictionary of named constraints for controlling execution. - :raises RuntimeError: If process initialization or startup fails. + :raises RuntimeError: If process initialization or startup fails """ # Processes limits and params + max_conc: int = min( + self.strategy.requests_limit or math.inf, + self.backend.requests_limit or math.inf, + ) + if max_conc == math.inf: + # if concurrency not specified, use settings + max_conc = settings.max_concurrency + if max_conc <= 0: + raise RuntimeError("max_concurrency resolved to 0; increase limits/config") + + # Calculate number of processes, ensure we don't exceed the max concurrency, + # or limits from the backend, strategy, or user settings num_processes = int( min( + max_conc, self.strategy.processes_limit or math.inf, self.backend.processes_limit or math.inf, settings.max_worker_processes, @@ -130,69 +176,96 @@ async def create_processes(self): if num_processes <= 0: raise RuntimeError("num_processes resolved to 0; increase limits/config") - max_conc = int( - min( - self.strategy.requests_limit or math.inf, - self.backend.requests_limit or math.inf, - settings.max_concurrency, - ) + per_proc_max_conc = max_conc // num_processes + max_pending_size = max( + 1, math.floor(max_conc * settings.mp_max_pending_buffer_percent) ) - if max_conc <= 0: - raise RuntimeError("max_concurrency resolved to 0; increase limits/config") - - per_proc_max_conc = math.ceil(max_conc / num_processes) - per_proc_max_queue = min(2, per_proc_max_conc) - max_queued_requests = ( # Add queue buffer for each process - max_conc + (num_processes * per_proc_max_queue) + per_proc_max_buffer_size = max( + 1, math.floor(per_proc_max_conc * settings.mp_max_worker_buffer_percent) ) # Initialize multiprocessing components - self.mp_context = get_context("fork") + self.mp_context: BaseContext = get_context(settings.mp_context_type) + self.mp_manager = self.mp_context.Manager() self.startup_barrier = self.mp_context.Barrier(num_processes + 1) + self.requests_generated_event = self.mp_context.Event() + self.constraint_reached_event = self.mp_context.Event() self.shutdown_event = self.mp_context.Event() self.error_event = self.mp_context.Event() - self.requests_queue = self.mp_context.Queue(maxsize=max_queued_requests) - self.updates_queue = self.mp_context.Queue() + + if settings.mp_messaging_object == "queue": + self.messaging = InterProcessMessagingQueue( + mp_context=self.mp_context, + serialization=settings.mp_serialization, + encoding=settings.mp_encoding, + max_pending_size=max_pending_size, + max_buffer_send_size=settings.mp_requests_send_buffer_size, + poll_interval=settings.mp_poll_interval, + ) + elif settings.mp_messaging_object == "manager_queue": + self.messaging = InterProcessMessagingManagerQueue( + manager=self.mp_manager, + mp_context=self.mp_context, + serialization=settings.mp_serialization, + encoding=settings.mp_encoding, + max_pending_size=max_pending_size, + max_buffer_send_size=settings.mp_requests_send_buffer_size, + poll_interval=settings.mp_poll_interval, + ) + elif settings.mp_messaging_object == "pipe": + self.messaging = InterProcessMessagingPipe( + num_workers=num_processes, + mp_context=self.mp_context, + serialization=settings.mp_serialization, + encoding=settings.mp_encoding, + max_pending_size=max_pending_size, + max_buffer_send_size=settings.mp_requests_send_buffer_size, + poll_interval=settings.mp_poll_interval, + ) # Initialize worker processes self.processes = [] for rank in range(num_processes): + # Distribute any remainder across the first N ranks async_limit = per_proc_max_conc + ( 1 if rank < (max_conc % num_processes) else 0 ) - worker = WorkerProcess[RequestT, MeasuredRequestTimingsT, ResponseT]( - local_rank=rank, - local_world_size=num_processes, - async_limit=async_limit, - startup_barrier=self.startup_barrier, - shutdown_event=self.shutdown_event, - error_event=self.error_event, - requests_queue=self.requests_queue, - updates_queue=self.updates_queue, + + worker = WorkerProcess[RequestT, ResponseT]( + messaging=self.messaging.create_worker_copy( + worker_index=rank, + max_buffer_send_size=None, + max_buffer_receive_size=per_proc_max_buffer_size, + ), backend=self.backend, request_timings=self.strategy.create_request_timings( local_rank=rank, local_world_size=num_processes, local_max_concurrency=async_limit, ), - poll_intervals=settings.scheduler_poll_interval, + async_limit=async_limit, + startup_barrier=self.startup_barrier, + requests_generated_event=self.requests_generated_event, + constraint_reached_event=self.constraint_reached_event, + shutdown_event=self.shutdown_event, + error_event=self.error_event, ) proc = self.mp_context.Process(target=worker.run, daemon=False) proc.start() self.processes.append(proc) - reason, _ = await synchronous_to_exitable_async( - synchronous=None, - exit_events={ - "error_event": self.error_event, + wait_key = await wait_for_sync_objects( + { + "startup_barrier": self.startup_barrier, "shutdown_event": self.shutdown_event, + "error_event": self.error_event, }, - exit_barrier=self.startup_barrier, - poll_interval=settings.scheduler_poll_interval, + poll_interval=settings.mp_poll_interval, ) - if reason != "barrier": + + if wait_key == "error_event": raise RuntimeError( - f"Worker process group startup failed with exit reason: {reason}" + "Worker process group startup failed: error_event is set" ) async def start(self, start_time: float): @@ -200,40 +273,42 @@ async def start(self, start_time: float): Begin request processing at the specified start time. Initializes scheduler state and background tasks, then waits until the - specified start time before beginning operations. + specified start time before beginning operations. Sets up inter-process + communication and coordinates synchronized startup across all workers. - :param start_time: Unix timestamp when processing should begin. - :raises RuntimeError: If workers encounter errors during startup. + :param start_time: Unix timestamp when processing should begin + :raises RuntimeError: If workers encounter errors during startup or + if create_processes() was not called first """ - if self.processes is None: + if not self.processes: raise RuntimeError("create_processes() must be called before start()") - self.state_update_lock = threading.Lock() - self.scheduler_state = SchedulerState( - node_id=0, # Process group node identifier - num_processes=len(self.processes), + stop_send_requests_event = threading.Event() + send_requests_stopped_event = threading.Event() + self.state = WorkerGroupState[RequestT, ResponseT]( start_time=start_time, + processes=self.processes, + constraints=self.constraints, + stop_send_requests_event=stop_send_requests_event, + send_requests_stopped_event=send_requests_stopped_event, + requests_generated_event=self.requests_generated_event, + constraint_reached_event=self.constraint_reached_event, + shutdown_event=self.shutdown_event, + error_event=self.error_event, ) - self.pending_updates_queue = culsans.Queue() - self.pending_requests_complete = ThreadingEvent() - self.pending_updates_complete = ThreadingEvent() - - self.populate_requests_task = asyncio.create_task( - synchronous_to_exitable_async( - self._populate_requests_generator(start_time), - exit_events={"error_event": self.error_event}, - poll_interval=0.0, - ) - ) - self.populate_updates_task = asyncio.create_task( - synchronous_to_exitable_async( - self._populate_updates_generator(), - exit_events={"error_event": self.error_event}, - poll_interval=0.0, - ) + await self.messaging.start( + send_items=self.state.requests_generator( + self.requests, self.cycle_requests + ), + receive_callback=self.state.received_callback, + send_stopped_event=send_requests_stopped_event, + send_stop_criteria=[stop_send_requests_event], + receive_stop_criteria=[self.shutdown_event], + pydantic_models=list(SchedulerMessagingPydanticRegistry.registry.values()), ) - await asyncio.sleep(max(0, start_time - time.time())) + if (wait_time := start_time - time.time()) > 0: + await asyncio.sleep(wait_time) if self.error_event.is_set(): raise RuntimeError( "error_event is set in WorkerProcessGroup, " @@ -246,7 +321,7 @@ async def request_updates( tuple[ ResponseT | None, RequestT, - ScheduledRequestInfo[MeasuredRequestTimingsT], + ScheduledRequestInfo, SchedulerState, ] ]: @@ -254,365 +329,355 @@ async def request_updates( Yield request processing updates as they become available. Returns an async iterator of request updates including response, request, - scheduling metadata, and scheduler state. Updates occur on request queued, - processing start, and completion. + request scheduling info, and scheduler state. Updates occur on request queued, + processing start, and completion. Response is None until processing completes. :return: Async iterator yielding (response, request, request_info, state) - tuples; response is None until processing is complete. - :raises RuntimeError: If workers encounter unrecoverable errors. + tuples where response is None until processing is complete + :raises RuntimeError: If workers encounter unrecoverable errors """ - last_check_time = -1 * math.inf + while True: + if self.error_event.is_set(): + raise RuntimeError( + "error_event is set in WorkerProcessGroup, " + "indicating an error occurred in one of the worker processes." + ) - while ( - not self.pending_updates_complete.is_set() - or not self.pending_updates_queue.empty() - ): try: ( response, request, request_info, scheduler_state, - ) = await asyncio.wait_for( - self.pending_updates_queue.async_get(), - timeout=settings.scheduler_poll_interval, - ) + ) = await self.messaging.get(timeout=settings.mp_poll_interval) yield response, request, request_info, scheduler_state except asyncio.TimeoutError: - pass - - if (time.time() - last_check_time) >= settings.scheduler_poll_interval: - if self.error_event.is_set(): - raise RuntimeError( - "error_event is set in WorkerProcessGroup, " - "indicating an error occurred in one of the worker processes." - ) - last_check_time = time.time() + if self.shutdown_event.is_set(): + # Everything yielded, exit + break async def shutdown(self) -> list[Exception]: # noqa: C901 """ Gracefully shut down the worker process group and clean up resources. Performs safe shutdown of worker processes, background tasks, and - multiprocessing resources. + multiprocessing resources. Coordinates orderly termination across + all workers and collects any exceptions encountered during shutdown. - :return: List of exceptions encountered during shutdown; empty if no errors. + :return: List of exceptions encountered during shutdown; empty if no errors """ exceptions: list[Exception] = [] - if self.shutdown_event is not None: self.shutdown_event.set() - cancel_tasks = [ - task - for task in (self.populate_requests_task, self.populate_updates_task) - if task and not task.done() - ] - for task in cancel_tasks: - task.cancel() - with contextlib.suppress(asyncio.CancelledError): - if cancel_tasks: + # Clear out start values + if self.messaging is not None: + try: + await asyncio.wait_for(self.messaging.stop(), timeout=5.0) + except Exception as err: + exceptions.append(err) + self.messaging = None + self.state = None + + # Clear out create processes values + if self.processes is not None: + for proc in self.processes: try: - await asyncio.gather(*cancel_tasks, return_exceptions=True) + await asyncio.to_thread(proc.join, timeout=5.0) + if proc.exitcode is not None and proc.exitcode > 0: + exceptions.append( + RuntimeError( + f"Worker {proc.pid} exited with code {proc.exitcode}" + ) + ) except Exception as err: # noqa: BLE001 exceptions.append(err) - self.populate_requests_task = None - self.populate_updates_task = None - - if self.processes: - for proc in self.processes: - await asyncio.to_thread(proc.join, 5) - if proc.exitcode not in (0, None): - exceptions.append( - RuntimeError( - f"Worker {proc.pid} exited with code {proc.exitcode}" - ) - ) self.processes = None - self.mp_context = None - self.startup_barrier = None + self.requests_generated_event = None + self.constraint_reached_event = None self.shutdown_event = None self.error_event = None - self.requests_queue = None - self.updates_queue = None - self.pending_updates_queue = None + if self.mp_manager is not None: + try: + self.mp_manager.shutdown() + except Exception as err: + exceptions.append(err) + self.mp_manager = None + self.mp_context = None return exceptions - def _update_state( - self, info: ScheduledRequestInfo[MeasuredRequestTimingsT] - ) -> tuple[SchedulerState, bool, bool]: - if not self.scheduler_state or not self.state_update_lock: - raise RuntimeError("workerProcessGroup not started") - - with self.state_update_lock: - state = self.scheduler_state - if info.status == "queued": - state.created_requests += 1 - state.queued_requests += 1 - elif info.status == "in_progress": - state.queued_requests -= 1 - state.processing_requests += 1 - elif info.status in ("completed", "errored", "cancelled"): - state.processing_requests -= 1 - state.processed_requests += 1 - state.successful_requests += 1 if info.status == "completed" else 0 - state.errored_requests += 1 if info.status == "errored" else 0 - state.cancelled_requests += 1 if info.status == "cancelled" else 0 - else: - raise ValueError( - f"Unknown request status: {info.status}. " - "Supported statuses are: queued, pending, in_progress, " - "completed, errored, cancelled." - ) - state.end_time = time.time() # Always update for last time update received - actions = { - name: const(state, info) for name, const in self.constraints.items() - } - state.scheduler_constraints = actions - - if state.end_queuing_time is None and ( - stop_queueing_actions := { - key: action - for key, action in actions.items() - if action.request_queuing == "stop" - } - ): - # Queuing not stopped and actions returned to stop it - state.end_queuing_constraints.update(stop_queueing_actions) - state.end_queuing_time = time.time() - - if state.end_processing_time is None and ( - stop_processing_actions := { - key: action - for key, action in actions.items() - if action.request_processing in ("stop_local", "stop_all") - } - ): - # Processing not stopped and actions returned to stop it - state.end_processing_constraints.update(stop_processing_actions) - state.end_processing_time = time.time() +class _StateUpdate(NamedTuple): + state: SchedulerState + stop_queueing: bool + stop_processing: bool - state_copy: SchedulerState = state.model_copy() - return ( - state_copy, - state_copy.end_queuing_time is None, - state_copy.end_processing_time is None, - ) +class WorkerGroupState(Generic[RequestT, ResponseT]): + """ + Manages scheduler state and synchronization for worker process groups. - def _populate_requests_generator(self, scheduler_start_time: float): - last_check_time: float = time.time() - continue_requests: bool = True - message: bytes | None = None - request_iter: Iterator[RequestT] | None = ( - self._populate_requests_create_iterator(first=True) + Handles request generation, state updates, constraint evaluation, and + coordination between worker processes. Provides thread-safe state management + with request lifecycle tracking and constraint-based termination logic. + """ + + def __init__( + self, + start_time: float, + processes: list[BaseProcess], + constraints: dict[str, Constraint], + stop_send_requests_event: threading.Event, + send_requests_stopped_event: threading.Event, + requests_generated_event: Event, + constraint_reached_event: Event, + shutdown_event: Event, + error_event: Event, + ): + """ + Initialize worker group state management. + + :param start_time: Unix timestamp when processing should begin + :param processes: List of worker process instances + :param constraints: Named constraints for controlling execution behavior + :param send_requests_stopped_event: Threading event for request coordination + :param requests_generated_event: Multiprocessing event for generation completion + :param constraint_reached_event: Multiprocessing event for constraint stopping + :param shutdown_event: Multiprocessing event for coordinated shutdown + :param error_event: Multiprocessing event for error condition signaling + """ + self.start_time = start_time + self.processes = processes + self.constraints = constraints + self.stop_send_requests_event = stop_send_requests_event + self.send_requests_stopped_event = send_requests_stopped_event + self.requests_generated_event = requests_generated_event + self.constraint_reached_event = constraint_reached_event + self.shutdown_event = shutdown_event + self.error_event = error_event + + self._update_lock: threading.Lock = threading.Lock() + self._state: SchedulerState = SchedulerState( + node_id=0, + num_processes=len(processes), + start_time=start_time, ) + self._queued_requests = set() + self._pending_requests = set() + self._processing_requests = set() - try: - while continue_requests or message is not None: - if request_iter is None: - request_iter = self._populate_requests_create_iterator(first=False) - - if request_iter is None and continue_requests: - # Out of requests so stop - continue_requests = False - # Update scheduler state that requests were exhausted - with self.state_update_lock: - self.scheduler_state.end_queuing_constraints["request_iter"] = { - "status": "exhausted", - "time": time.time(), - } - self.scheduler_state.end_queuing_time = time.time() - - if continue_requests and message is None: - message, continue_requests = self._populate_requests_next_message( - request_iter, scheduler_start_time - ) - if message is None: - # No message returned because request_iter is exhausted - request_iter = None - - if message is not None: - with contextlib.suppress(queue.Full): - self.requests_queue.put( - message[0], timeout=settings.scheduler_poll_interval - ) - self.pending_updates_queue.sync_put(message[1]) - message = None - - if (time.time() - last_check_time) >= settings.scheduler_poll_interval: - last_check_time = time.time() - continue_requests = ( - continue_requests and not self.shutdown_event.is_set() - ) - yield None # Yield to check for error in wrapper to stop - except Exception as err: # noqa: BLE001 - print(f"******EXCEPTION in _populate_requests_generator: {err}") - self.error_event.set() - raise err - finally: - self.pending_requests_complete.set() - - def _populate_requests_create_iterator( - self, first: bool = False - ) -> Iterator[RequestT] | None: - if first: - # First invocation, get a new iterator if not already one - return ( - iter(self.requests) - if not isinstance(self.requests, Iterator) - else self.requests - ) + def requests_generator( + self, + requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, + cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, + ) -> Generator[tuple[RequestT | MultiTurnRequestT[RequestT],], None, None]: + """ + Generate request-info pairs for worker processing with constraint evaluation. - if self.infinite_requests is True and isinstance(self.requests, Iterator): - # Out of requests and infinite set to True, but request_iter is Iterator - # Cannot create new, raise RuntimeError - raise RuntimeError( - f"Requests iterator {self.requests} exhausted and " - "infinite_requests is set to True" - ) + Processes finite requests sequentially then cycles through repeating requests + indefinitely. Creates scheduling metadata for each request and evaluates + constraints to determine when to stop request generation. - if self.infinite_requests is not False and isinstance(self.requests, Iterable): - # Out of requests and infinite set to True or set to default - # Create new iterator out of the Iterable - return iter(self.requests) - - # Either infinite is False for Iterable or Iterator - # or infinite is None (default) for Iterator - # So, return None to stop - return None - - def _populate_requests_next_message( - self, request_iter: Iterator[RequestT], scheduler_start_time: float - ) -> tuple[tuple[bytes, bytes] | None, bool]: - try: - request = next(request_iter) - request_id = ( - request.request_id or request.id or request.id_ or str(uuid.uuid4()) - ) - request_info = ScheduledRequestInfo[MeasuredRequestTimingsT]( + :param requests: Finite iterable of requests to process sequentially + :param cycle_requests: Iterable of requests to cycle through indefinitely + :return: Generator yielding (request, request_info) tuples + """ + + def _iter(): + if requests: + yield from requests + + if cycle_requests: + while True: + yield from cycle_requests + + count = 0 + request_info: ScheduledRequestInfo = None + for request in _iter(): + count += 1 + + if hasattr(request, "request_id"): + request_id = request.request_id + elif hasattr(request, "id"): + request_id = request.id + else: + request_id = str(uuid.uuid4()) + request_info: ScheduledRequestInfo = ScheduledRequestInfo( request_id=request_id, status="queued", - scheduler_node_id=-1, scheduler_process_id=0, - scheduler_start_time=scheduler_start_time, + scheduler_start_time=self.start_time, ) - state, continue_requests, _ = self._update_state(request_info) - - request_msg = MsgpackEncoding.encode((request, request_info)) - update_msg = (None, request, request_info, state) - - return (request_msg, update_msg), continue_requests - except StopIteration: - return None, True - - def _populate_updates_generator(self): - """Generator for populating updates from workers.""" - last_check_time = time.time() - last_state: SchedulerState = None - continue_processing = True - shutdown_set = False - canceled_remaining = False - - try: - while ( - continue_processing - or last_state is None - or (last_state.processed_requests < last_state.created_requests) - ): - next_state, continue_updates = self._populate_updates_process_next() - if next_state is not None: - last_state = next_state - continue_processing = continue_processing and continue_updates - - if not continue_processing and not shutdown_set: - self.shutdown_event.set() - shutdown_set = True - time.sleep( - settings.scheduler_poll_interval - ) # Ensure shut down propagates - - if not continue_processing and not canceled_remaining: - # We've shut down, no more requests will be added, cancel remaining - next_state = self._populate_updates_cancel_remaining() - if next_state is not None: - last_state = next_state - canceled_remaining = True - - if (time.time() - last_check_time) >= settings.scheduler_poll_interval: - last_check_time = time.time() - if not shutdown_set and self.shutdown_event.is_set(): - shutdown_set = True - continue_processing = False - with self.state_update_lock: - self.scheduler_state.end_queuing_constraints[ - "shutdown_event" - ] = { - "status": "set", - "time": time.time(), - } - self.scheduler_state.end_processing_time = time.time() - - yield None # Yield to check for error in wrapper to stop - except Exception as err: # noqa: BLE001 - print(f"******EXCEPTION in _populate_updates_generator: {err}") - self.error_event.set() - raise err - finally: - self.pending_updates_complete.set() - - def _populate_updates_process_next( + state_update = self._locked_update(request_info) + yield (request, request_info) + + if state_update.stop_queueing: + self.stop_send_requests_event.set() + return + + # Reached the end, inject a RequestsExhaustedConstraint to record + self._locked_update( + info=None, + requests_exhausted=RequestsExhaustedConstraint(num_requests=count), + ) + self.stop_send_requests_event.set() + + def received_callback( self, - ) -> tuple[SchedulerState | None, bool]: - try: - message = self.updates_queue.get(timeout=settings.scheduler_poll_interval) - response, request, request_info = MsgpackEncoding.decode(message) - - scheduler_state, _, continue_updates = self._update_state(request_info) - self.pending_updates_queue.sync_put( - (response, request, request_info, scheduler_state) - ) + update: tuple[ + ResponseT | None, + RequestT | MultiTurnRequestT, + ScheduledRequestInfo, + ], + ) -> tuple[ + ResponseT | None, + RequestT | MultiTurnRequestT, + ScheduledRequestInfo, + SchedulerState, + ]: + """ + Process received request updates and inject current scheduler state. + + Updates internal state tracking based on request status changes and + evaluates constraints to determine if processing should be terminated. + Triggers shutdown when stop conditions are met. + + :param update: Tuple containing response, request, and request info + :return: Updated tuple with injected scheduler state + """ + response, request, request_info = update + state_update = self._locked_update(info=request_info) + + # Check if we need to tell workers to stop pulling new requests + # based on no more requests sent and all requests removed from queue + if ( + state_update.state.queued_requests == 0 + and self.send_requests_stopped_event.is_set() + and not self.requests_generated_event.is_set() + ): + self.requests_generated_event.set() + + # Check if we need to tell workers to stop processing requests (constraints) + if state_update.stop_processing and not self.constraint_reached_event.is_set(): + self.constraint_reached_event.set() + + # Check if all requests have been processed and can shutdown + if ( + state_update.state.processed_requests == state_update.state.created_requests + and self.send_requests_stopped_event.is_set() + and self.requests_generated_event.is_set() + and self.constraint_reached_event.is_set() + and not self.shutdown_event.is_set() + ): + self.shutdown_event.set() - return scheduler_state, continue_updates - except queue.Empty: - return None, True + return ( + response, + request, + request_info, + state_update.state, # inject state for updates to be yielded back + ) - def _populate_updates_cancel_remaining( + def _locked_update( self, - ) -> SchedulerState | None: - last_state = None + info: ScheduledRequestInfo | None = None, + **add_constraints: dict[str, Constraint], + ) -> _StateUpdate: + with self._update_lock: + if add_constraints: + self.constraints.update(add_constraints) - while True: - try: - message = self.requests_queue.get( - timeout=settings.scheduler_poll_interval - ) - request, request_info = MsgpackEncoding.decode(message) + if info is not None: + self._state.end_time = time.time() # Always update in case last update + self._update_state_request_counts(info) + self._update_with_constraints(info) - # Send start first - request_info.status = "in_progress" - scheduler_state, _, _ = self._update_state(request_info) - self.pending_updates_queue.sync_put( - (None, request, request_info.model_copy(), scheduler_state) - ) + state_copy: SchedulerState = self._state.model_copy() - # Send canceled - request_info.status = "cancelled" - request_info.error = "Request was cancelled" - request_info.scheduler_timings.resolve_end = time.time() - scheduler_state, _, _ = self._update_state(request_info) - self.pending_updates_queue.sync_put( - (None, request, request_info, scheduler_state) - ) + return _StateUpdate( + state_copy, + state_copy.end_queuing_time is not None, + state_copy.end_processing_time is not None, + ) - last_state = scheduler_state - except queue.Empty: - if self.pending_requests_complete.is_set(): - # no more requests being pushed to queue, safe to exit - break + def _update_state_request_counts(self, info: ScheduledRequestInfo): + if info.status == "queued": + self._queued_requests.add(info.request_id) + self._state.queued_requests = len(self._queued_requests) + self._state.created_requests += 1 + elif info.status == "pending": + self._queued_requests.remove(info.request_id) + self._state.queued_requests = len(self._queued_requests) + self._pending_requests.add(info.request_id) + self._state.pending_requests = len(self._pending_requests) + elif info.status == "in_progress": + self._pending_requests.remove(info.request_id) + self._state.pending_requests = len(self._pending_requests) + self._processing_requests.add(info.request_id) + self._state.processing_requests = len(self._processing_requests) + elif info.status == "completed": + self._processing_requests.remove(info.request_id) + self._state.processing_requests = len(self._processing_requests) + self._state.processed_requests += 1 + self._state.successful_requests += 1 + elif info.status in ("errored", "cancelled"): + if info.request_id in self._queued_requests: + self._queued_requests.remove(info.request_id) + self._state.queued_requests = len(self._queued_requests) + elif info.request_id in self._pending_requests: + self._pending_requests.remove(info.request_id) + self._state.pending_requests = len(self._pending_requests) + elif info.request_id in self._processing_requests: + self._processing_requests.remove(info.request_id) + self._state.processing_requests = len(self._processing_requests) + else: + print(f"WARNING: Request was not present in state request sets: {info}") + + self._state.processed_requests += 1 + self._state.errored_requests += 1 if info.status == "errored" else 0 + self._state.cancelled_requests += 1 if info.status == "cancelled" else 0 + else: + raise ValueError(f"Unknown request_info status {info.status} for {info}") + + def _update_with_constraints(self, info: ScheduledRequestInfo): + actions: dict[str, SchedulerUpdateAction] = { + name: const(self._state, info) for name, const in self.constraints.items() + } + self._state.scheduler_constraints = actions + stop_queuing_actions = {} + stop_processing_actions = {} + + for key, action in actions.items(): + # Action updates + if ( + self._state.end_queuing_time is None + and action.request_queuing == "stop" + ): + stop_queuing_actions[key] = action + if ( + self._state.end_processing_time is None + and action.request_processing in ("stop_local", "stop_all") + ): + stop_processing_actions[key] = action - return last_state + for progress_key in ( + "remaining_fraction", + "remaining_requests", + "remaining_duration", + ): + if (new_val := action.progress.get(progress_key)) is not None and ( + getattr(self._state, progress_key) is None + or new_val < getattr(self._state, progress_key) + ): + setattr(self._state, progress_key, new_val) + + if stop_queuing_actions: + self._state.end_queuing_constraints = stop_queuing_actions + self._state.end_queuing_time = time.time() + + if stop_processing_actions: + self._state.end_processing_constraints = stop_processing_actions + self._state.end_processing_time = time.time() diff --git a/src/guidellm/config.py b/src/guidellm/settings.py similarity index 87% rename from src/guidellm/config.py rename to src/guidellm/settings.py index 9dd9b0dc..20d9ff96 100644 --- a/src/guidellm/config.py +++ b/src/guidellm/settings.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import json from collections.abc import Sequence from enum import Enum -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -45,8 +47,8 @@ class LoggingSettings(BaseModel): disabled: bool = False clear_loggers: bool = True console_log_level: str = "WARNING" - log_file: Optional[str] = None - log_file_level: Optional[str] = None + log_file: str | None = None + log_file_level: str | None = None class DatasetSettings(BaseModel): @@ -79,11 +81,11 @@ class OpenAISettings(BaseModel): for OpenAI server based pathways """ - api_key: Optional[str] = None - bearer_token: Optional[str] = None - headers: Optional[dict[str, str]] = None - organization: Optional[str] = None - project: Optional[str] = None + api_key: str | None = None + bearer_token: str | None = None + headers: dict[str, str] | None = None + organization: str | None = None + project: str | None = None base_url: str = "http://localhost:8000" max_output_tokens: int = 16384 verify: bool = True @@ -130,11 +132,21 @@ class Settings(BaseSettings): request_http2: bool = True # Scheduler settings + mp_context_type: Literal["spawn", "fork", "forkserver"] | None = "fork" + mp_serialization: Literal["dict", "sequence"] | None = "dict" + mp_encoding: ( + Literal["msgpack", "msgspec"] + | None + | list[Literal["msgpack", "msgspec"] | None] + ) = ["msgspec", "msgpack", None] + mp_messaging_object: Literal["queue", "manager_queue", "pipe"] = "queue" + mp_requests_send_buffer_size: int = 1 + mp_poll_interval: float = 0.1 + mp_max_pending_buffer_percent: float = 0.5 + mp_max_worker_buffer_percent: float = 0.2 max_concurrency: int = 512 max_worker_processes: int = 10 - max_add_requests_per_loop: int = 20 - scheduler_start_delay_non_distributed: float = 0.1 - scheduler_poll_interval: float = 0.05 + scheduler_start_delay_non_distributed: float = 1.0 constraint_error_window_size: float = 30 constraint_error_min_processed: float = 30 diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index eee17bbf..83a276b2 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -1,17 +1,20 @@ from .auto_importer import AutoImporterMixin from .console import Colors, Console, ConsoleUpdateStep, StatusIcons, StatusStyles from .default_group import DefaultGroupHandler -from .encoding import MsgpackEncoding -from .general import ( - UNSET, - UnsetType, +from .encoding import ( + Encoder, + EncodingTypesAlias, + MessageEncoding, + SerializationTypesAlias, + Serializer, +) +from .functions import ( all_defined, safe_add, safe_divide, safe_format_timestamp, safe_getattr, safe_multiply, - safe_subtract, ) from .hf_datasets import ( SUPPORTED_TYPES, @@ -20,6 +23,13 @@ from .hf_transformers import ( check_load_processor, ) +from .messaging import ( + InterProcessMessaging, + InterProcessMessagingManagerQueue, + InterProcessMessagingPipe, + InterProcessMessagingQueue, + SendMessageT, +) from .mixins import InfoMixin from .pydantic_utils import ( PydanticClassRegistryMixin, @@ -29,7 +39,7 @@ StatusBreakdown, ) from .random import IntegerRangeSampler -from .registry import RegistryMixin +from .registry import RegistryMixin, RegistryObjT from .singleton import SingletonMixin, ThreadSafeSingletonMixin from .statistics import ( DistributionSummary, @@ -38,6 +48,11 @@ StatusDistributionSummary, TimeRunningStats, ) +from .synchronous import ( + wait_for_sync_barrier, + wait_for_sync_event, + wait_for_sync_objects, +) from .text import ( EndlessTextCreator, clean_text, @@ -48,11 +63,10 @@ split_text, split_text_list_by_length, ) -from .threading import synchronous_to_exitable_async +from .typing import get_literal_vals __all__ = [ "SUPPORTED_TYPES", - "UNSET", "AutoImporterMixin", "Colors", "Colors", @@ -60,15 +74,26 @@ "ConsoleUpdateStep", "DefaultGroupHandler", "DistributionSummary", + "Encoder", + "EncodingTypesAlias", "EndlessTextCreator", "InfoMixin", "IntegerRangeSampler", - "MsgpackEncoding", + "InterProcessMessaging", + "InterProcessMessagingManagerQueue", + "InterProcessMessagingPipe", + "InterProcessMessagingQueue", + "MessageEncoding", + "MessageEncoding", "Percentiles", "PydanticClassRegistryMixin", "RegistryMixin", + "RegistryObjT", "ReloadableBaseModel", "RunningStats", + "SendMessageT", + "SerializationTypesAlias", + "Serializer", "SingletonMixin", "StandardBaseDict", "StandardBaseModel", @@ -78,12 +103,12 @@ "StatusStyles", "ThreadSafeSingletonMixin", "TimeRunningStats", - "UnsetType", "all_defined", "check_load_processor", "clean_text", "filter_text", "format_value_display", + "get_literal_vals", "is_puncutation", "load_text", "safe_add", @@ -91,9 +116,10 @@ "safe_format_timestamp", "safe_getattr", "safe_multiply", - "safe_subtract", "save_dataset_to_file", "split_text", "split_text_list_by_length", - "synchronous_to_exitable_async", + "wait_for_sync_barrier", + "wait_for_sync_event", + "wait_for_sync_objects", ] diff --git a/src/guidellm/utils/auto_importer.py b/src/guidellm/utils/auto_importer.py index 3b3240d3..5b939014 100644 --- a/src/guidellm/utils/auto_importer.py +++ b/src/guidellm/utils/auto_importer.py @@ -9,56 +9,54 @@ The AutoImporterMixin can be combined with registration mechanisms to create extensible systems where new implementations are automatically discovered and registered when they are placed in the correct package structure. - -Classes: - - AutoImporterMixin: A mixin class that provides functionality to automatically - import all modules within a specified package or list of packa """ +from __future__ import annotations + import importlib import pkgutil import sys -from typing import ClassVar, Optional, Union +from typing import ClassVar __all__ = ["AutoImporterMixin"] class AutoImporterMixin: """ - A mixin class that provides functionality to automatically import all modules - within a specified package or list of packages. - - This mixin is designed to be used with class registration mechanisms to enable - automatic discovery and registration of classes without explicit imports. When - a class inherits from AutoImporterMixin, it can define the package(s) to scan - for modules by setting the `auto_package` class variable. - - Usage Example: - ```python - from speculators.utils import AutoImporterMixin - class MyRegistry(AutoImporterMixin): - auto_package = "my_package.implementations" - - MyRegistry.auto_import_package_modules() - ``` - - :cvar auto_package: The package name or tuple of names to import modules from. - :cvar auto_ignore_modules: Optional tuple of module names to ignore during import. - :cvar auto_imported_modules: List tracking which modules have been imported. + Mixin class for automatic module importing within packages. + + This mixin enables dynamic discovery of classes and implementations without + explicit imports by automatically importing all modules within specified + packages. It is designed for use with class registration mechanisms to enable + automatic discovery and registration of classes when they are placed in the + correct package structure. + + Example: + :: + from guidellm.utils import AutoImporterMixin + + class MyRegistry(AutoImporterMixin): + auto_package = "my_package.implementations" + + MyRegistry.auto_import_package_modules() + + :cvar auto_package: Package name or tuple of package names to import modules from + :cvar auto_ignore_modules: Module names to ignore during import + :cvar auto_imported_modules: List tracking which modules have been imported """ - auto_package: ClassVar[Optional[Union[str, tuple[str, ...]]]] = None - auto_ignore_modules: ClassVar[Optional[tuple[str, ...]]] = None - auto_imported_modules: ClassVar[Optional[list]] = None + auto_package: ClassVar[str | tuple[str, ...] | None] = None + auto_ignore_modules: ClassVar[tuple[str, ...] | None] = None + auto_imported_modules: ClassVar[list[str] | None] = None @classmethod - def auto_import_package_modules(cls): + def auto_import_package_modules(cls) -> None: """ - Automatically imports all modules within the specified package(s). + Automatically import all modules within the specified package(s). - This method scans the package(s) defined in the `auto_package` class variable - and imports all modules found, tracking them in `auto_imported_modules`. It - skips packages (directories) and any modules listed in `auto_ignore_modules`. + Scans the package(s) defined in the `auto_package` class variable and imports + all modules found, tracking them in `auto_imported_modules`. Skips packages + (directories) and any modules listed in `auto_ignore_modules`. :raises ValueError: If the `auto_package` class variable is not set """ diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py index e54e8c1a..ccd26982 100644 --- a/src/guidellm/utils/encoding.py +++ b/src/guidellm/utils/encoding.py @@ -1,153 +1,787 @@ """ -MessagePack encoding utilities with Pydantic model support. +Message encoding utilities for multiprocess communication with Pydantic model support. -Provides binary serialization and deserialization of Python objects using MessagePack, -with special handling for Pydantic models to preserve type information and generic -parameters for accurate reconstruction. - -Classes: - MsgpackEncoding: MessagePack encoder/decoder with Pydantic support. +Provides binary serialization and deserialization of Python objects using various +serialization formats and encoding packages to enable performance configurations +for distributed scheduler operations. Supports configurable two-stage processing +pipeline: object serialization (to dict/sequence) followed by binary encoding +(msgpack/msgspec) with specialized Pydantic model handling for type preservation. """ -import importlib -from typing import Any, get_args, get_origin +from __future__ import annotations -import msgpack -from pydantic import BaseModel +import json +from collections.abc import Mapping +from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar + +try: + import msgpack + from msgpack import Packer, Unpacker + + HAS_MSGPACK = True +except ImportError: + msgpack = Packer = Unpacker = None + HAS_MSGPACK = False -__all__ = ["MsgpackEncoding"] +try: + from msgspec.msgpack import Decoder as MsgspecDecoder + from msgspec.msgpack import Encoder as MsgspecEncoder + HAS_MSGSPEC = True +except ImportError: + MsgspecDecoder = MsgspecEncoder = None + HAS_MSGSPEC = False -class MsgpackEncoding: +try: + import orjson + + HAS_ORJSON = True +except ImportError: + orjson = None + HAS_ORJSON = False + +from pydantic import BaseModel +from typing_extensions import TypeAlias + +__all__ = [ + "Encoder", + "EncodingTypesAlias", + "MessageEncoding", + "MsgT", + "ObjT", + "SerializationTypesAlias", + "Serializer", +] + +ObjT = TypeVar("ObjT") +MsgT = TypeVar("MsgT") + +SerializationTypesAlias: TypeAlias = Annotated[ + Optional[Literal["dict", "sequence"]], + "Type alias for available serialization strategies", +] +EncodingTypesAlias: TypeAlias = Annotated[ + Optional[Literal["msgpack", "msgspec"]], + "Type alias for available binary encoding formats", +] + + +class MessageEncoding(Generic[ObjT, MsgT]): """ - MessagePack encoder/decoder with Pydantic model support. + High-performance message encoding and decoding for multiprocessing communication. + + Supports configurable object serialization and binary encoding with specialized + handling for Pydantic models. Provides a two-stage pipeline of serialization + (object to dict/str) followed by encoding (dict/str to binary) for optimal + performance and compatibility across different transport mechanisms used in + distributed scheduler operations. - Provides binary serialization of Python objects with special handling - for Pydantic models to preserve type information and generic parameters. + Example: + :: + from guidellm.utils.encoding import MessageEncoding + from pydantic import BaseModel + + class DataModel(BaseModel): + name: str + value: int + + # Configure with dict serialization and msgpack encoding + encoding = MessageEncoding(serialization="dict", encoding="msgpack") + encoding.register_pydantic(DataModel) + + # Encode and decode objects + data = DataModel(name="test", value=42) + encoded_msg = encoding.encode(data) + decoded_data = encoding.decode(encoded_msg) + + :cvar DEFAULT_ENCODING_PREFERENCE: Preferred encoding formats in priority order """ - PYDANTIC_TAG = "__pydantic__" - PYDANTIC_DATA = "data" - PYDANTIC_ARGS = "args" + DEFAULT_ENCODING_PREFERENCE: ClassVar[list[str]] = ["msgspec", "msgpack"] @classmethod - def encode(cls, obj: Any) -> bytes: + def encode_message( + cls, + obj: ObjT, + serializer: Serializer | None, + encoder: Encoder | None, + ) -> MsgT: """ - Encode a Python object to MessagePack binary format. + Encode object using specified serializer and encoder. - :param obj: The object to encode (supports Pydantic models, dicts, lists, etc.). - :return: Binary MessagePack representation. + :param obj: Object to encode + :param serializer: Serializer for object conversion, None for no serialization + :param encoder: Encoder for binary conversion, None for no encoding + :return: Encoded message ready for transport """ - return msgpack.packb(cls.to_primitive(obj), use_bin_type=True) + serialized = serializer.serialize(obj) if serializer else obj + + return encoder.encode(serialized) if encoder else serialized @classmethod - def decode(cls, data: bytes) -> Any: + def decode_message( + cls, + message: MsgT, + serializer: Serializer | None, + encoder: Encoder | None, + ) -> ObjT: + """ + Decode message using specified serializer and encoder. + Must match the encoding configuration originally used. + + :param message: Encoded message to decode + :param serializer: Serializer for object reconstruction, None for no + serialization + :param encoder: Encoder for binary decoding, None for no encoding + :return: Reconstructed object """ - Decode MessagePack binary data back to Python objects. + serialized = encoder.decode(message) if encoder else message - :param data: Binary MessagePack data to decode. - :return: Reconstructed Python object with original types preserved. + return serializer.deserialize(serialized) if serializer else serialized + + def __init__( + self, + serialization: SerializationTypesAlias = None, + encoding: EncodingTypesAlias | list[EncodingTypesAlias] = None, + pydantic_models: list[type[BaseModel]] | None = None, + ) -> None: """ - return cls.from_primitive(msgpack.unpackb(data, raw=False)) + Initialize MessageEncoding with serialization and encoding strategies. - @classmethod - def to_primitive(cls, obj: Any) -> Any: + :param serialization: Serialization strategy (None, "dict", or "sequence") + :param encoding: Encoding strategy (None, "msgpack", "msgspec", or + preference list) """ - Convert objects to primitive types for MessagePack serialization. + self.serializer = Serializer(serialization, pydantic_models) + self.encoder = Encoder(encoding) - Recursively converts complex objects to primitives. Pydantic models are - converted to tagged dictionaries with type metadata for reconstruction. + def register_pydantic(self, model: type[BaseModel]) -> None: + """ + Register Pydantic model for specialized serialization handling. - :param obj: The object to convert. - :return: Primitive representation suitable for MessagePack. + :param model: Pydantic model class to register for type preservation """ - if isinstance(obj, BaseModel): - # Get the module, class, and any generics for reconstruction later - model_cls = obj.__class__ - origin = get_origin(model_cls) or model_cls - args = tuple(get_args(model_cls)) - if not args and hasattr(model_cls, "__pydantic_generic_metadata__"): - meta = model_cls.__pydantic_generic_metadata__ - origin = meta.get("origin", origin) or origin - args = tuple(meta.get("args") or []) - - # Construct data by manually running model_dump and encoding BaseModel - data: dict[str, Any] = {} - for name in origin.model_fields: - value = getattr(obj, name, None) - data[name] = cls.to_primitive(value) - extras = getattr(obj, "__pydantic_extras__", {}) - for name, value in extras.items(): - data[name] = cls.to_primitive(value) - - encoded = { - cls.PYDANTIC_TAG: f"{origin.__module__}.{origin.__name__}", - cls.PYDANTIC_DATA: data, - } + self.serializer.register_pydantic(model) - if args: - encoded[cls.PYDANTIC_ARGS] = [ - f"{arg.__module__}.{arg.__qualname__}" - for arg in args - if isinstance(arg, type) - ] + def encode(self, obj: ObjT) -> MsgT: + """ + Encode object using instance configuration. - return encoded + :param obj: Object to encode using configured serialization and encoding + :return: Encoded message ready for transport + """ + return self.encode_message( + obj=obj, + serializer=self.serializer, + encoder=self.encoder, + ) - if isinstance(obj, dict): - return { - cls.to_primitive(key): cls.to_primitive(val) for key, val in obj.items() - } + def decode(self, message: MsgT) -> ObjT: + """ + Decode message using instance configuration. + + :param message: Encoded message to decode using configured strategies + :return: Reconstructed object + """ + return self.decode_message( + message=message, + serializer=self.serializer, + encoder=self.encoder, + ) + + +class Encoder: + """ + Binary encoding and decoding using MessagePack or msgspec formats. + + Handles binary serialization of Python objects using configurable encoding + strategies with automatic fallback when dependencies are unavailable. Supports + both standalone instances and pooled encoder/decoder pairs for performance + optimization in high-throughput scenarios. + """ + + def __init__( + self, encoding: EncodingTypesAlias | list[EncodingTypesAlias] = None + ) -> None: + """ + Initialize encoder with specified encoding strategy. + + :param encoding: Encoding format preference (None, "msgpack", "msgspec", or + preference list) + """ + self.encoding, self.encoder, self.decoder = self._resolve_encoding(encoding) + + def encode(self, obj: Any) -> bytes | Any: + """ + Encode object to binary format using configured encoding strategy. + + :param obj: Object to encode (must be serializable by chosen format) + :return: Encoded bytes or original object if no encoding configured + :raises ImportError: If required encoding library is not available + """ + if self.encoding == "msgpack": + if not HAS_MSGPACK: + raise ImportError("msgpack is not available") + + return self.encoder.pack(obj) if self.encoder else msgpack.packb(obj) - if isinstance(obj, list): - return [cls.to_primitive(val) for val in obj] + if self.encoding == "msgspec": + if not HAS_MSGSPEC: + raise ImportError("msgspec is not available") - if isinstance(obj, tuple): - return tuple(cls.to_primitive(val) for val in obj) + return ( + self.encoder.encode(obj) + if self.encoder + else MsgspecEncoder().encode(obj) + ) return obj - @classmethod - def from_primitive(cls, obj: Any) -> Any: + def decode(self, data: bytes | Any) -> Any: + """ + Decode binary data using configured encoding strategy. + + :param data: Binary data to decode or object if no encoding configured + :return: Decoded Python object + :raises ImportError: If required encoding library is not available + """ + if self.encoding == "msgpack": + if not HAS_MSGPACK: + raise ImportError("msgpack is not available") + + if self.decoder is not None: + self.decoder.feed(data) + return self.decoder.unpack() + + return msgpack.unpackb(data, raw=False) + + if self.encoding == "msgspec": + if not HAS_MSGSPEC: + raise ImportError("msgspec is not available") + + if self.decoder is not None: + return self.decoder.decode(data) + + return MsgspecDecoder().decode(data) + + return data + + def _resolve_encoding( + self, encoding: EncodingTypesAlias | list[EncodingTypesAlias] | None + ) -> tuple[EncodingTypesAlias, Any, Any]: + def _get_available_encoder_decoder( + encoding: EncodingTypesAlias, + ) -> tuple[Any, Any]: + if encoding == "msgpack" and HAS_MSGPACK: + return Packer(), Unpacker(raw=False) + if encoding == "msgspec" and HAS_MSGSPEC: + return MsgspecEncoder(), MsgspecDecoder() + return None, None + + if not isinstance(encoding, list): + if encoding is None: + return None, None, None + + encoder, decoder = _get_available_encoder_decoder(encoding) + if encoder is None or decoder is None: + raise ImportError(f"Encoding '{encoding}' is not available.") + + return encoding, encoder, decoder + + for test_encoding in encoding: + encoder, decoder = _get_available_encoder_decoder(test_encoding) + if encoder is not None and decoder is not None: + return test_encoding, encoder, decoder + + return None, None, None + + +class Serializer: + """ + Object serialization with specialized Pydantic model support. + + Converts Python objects to serializable formats (dict/sequence) with type + preservation for Pydantic models. Maintains object integrity through + encoding/decoding cycles by storing class metadata and enabling proper + reconstruction of complex objects. Supports both dictionary-based and + sequence-based serialization strategies for different use cases. + """ + + def __init__( + self, + serialization: SerializationTypesAlias = None, + pydantic_models: list[type[BaseModel]] | None = None, + ): """ - Reconstruct objects from their primitive MessagePack representation. + Initialize serializer with strategy and Pydantic registry. - Recursively converts primitives back to original objects. Tagged dictionaries - are restored to Pydantic models with proper types and generic parameters. + :param serialization: Default serialization strategy for this instance + """ + self.serialization = serialization + self.pydantic_registry: dict[tuple[str, str], type[BaseModel]] = {} + if pydantic_models: + for model in pydantic_models: + self.register_pydantic(model) - :param obj: The primitive representation to convert. - :return: Reconstructed object with original types. - :raises ImportError: If a Pydantic model's module cannot be imported. - :raises AttributeError: If a class reference cannot be found. + def register_pydantic(self, model: type[BaseModel]) -> None: """ - if isinstance(obj, dict) and cls.PYDANTIC_TAG in obj: - origin_path = obj[cls.PYDANTIC_TAG] - module_name, class_name = origin_path.rsplit(".", 1) - origin_cls = getattr(importlib.import_module(module_name), class_name) + Register Pydantic model for specialized serialization handling. - type_args = [] - if cls.PYDANTIC_ARGS in obj: - for arg_path in obj[cls.PYDANTIC_ARGS]: - mod, clazz = arg_path.rsplit(".", 1) - type_args.append(getattr(importlib.import_module(mod), clazz)) + :param model: Pydantic model class to register for type preservation + """ + key = (model.__module__, model.__name__) + self.pydantic_registry[key] = model - model_cls = origin_cls[tuple(type_args)] if type_args else origin_cls - payload = { - key: cls.from_primitive(value) - for key, value in obj[cls.PYDANTIC_DATA].items() + def load_pydantic(self, type_name: str, module_name: str) -> type[BaseModel]: + """ + Load Pydantic class by name with registry fallback to dynamic import. + + :param type_name: Class name to load + :param module_name: Module containing the class + :return: Loaded Pydantic model class + """ + key = (module_name, type_name) + + if key in self.pydantic_registry: + return self.pydantic_registry[key] + + # Dynamic import fallback; need to update to better handle generics + module = __import__(module_name, fromlist=[type_name]) + pydantic_class = getattr(module, type_name) + self.pydantic_registry[key] = pydantic_class + + return pydantic_class + + def serialize(self, obj: Any) -> Any: + """ + Serialize object using specified or configured strategy. + + :param obj: Object to serialize + :return: Serialized representation (dict, str, or original object) + """ + if self.serialization == "dict": + return self.to_dict(obj) + elif self.serialization == "sequence": + return self.to_sequence(obj) + + return obj + + def deserialize(self, msg: Any) -> Any: + """ + Deserialize object using specified or configured strategy. + + :param msg: Serialized message to deserialize + :return: Reconstructed object + """ + if self.serialization == "dict": + return self.from_dict(msg) + elif self.serialization == "sequence": + return self.from_sequence(msg) + + return msg + + def to_dict(self, obj: Any) -> Any: + """ + Convert object to dictionary with Pydantic model type preservation. + + :param obj: Object to convert (BaseModel, collections, or primitive) + :return: Dictionary representation with type metadata for Pydantic models + """ + if isinstance(obj, BaseModel): + return self.to_dict_pydantic(obj) + + if isinstance(obj, (list, tuple)) and any( + isinstance(item, BaseModel) for item in obj + ): + return [ + self.to_dict_pydantic(item) if isinstance(item, BaseModel) else item + for item in obj + ] + + if isinstance(obj, dict) and any( + isinstance(value, BaseModel) for value in obj.values() + ): + return { + key: self.to_dict_pydantic(value) + if isinstance(value, BaseModel) + else value + for key, value in obj.items() } - return model_cls.model_validate(payload) + return obj + + def from_dict(self, data: Any) -> Any: + """ + Reconstruct object from dictionary with Pydantic model type restoration. + + :param data: Dictionary representation possibly containing type metadata + :return: Reconstructed object with proper types restored + """ + if isinstance(data, (list, tuple)): + return [ + self.from_dict_pydantic(item) + if isinstance(item, dict) and "*PYD*" in item + else item + for item in data + ] + elif isinstance(data, dict) and data: + if "*PYD*" in data: + return self.from_dict_pydantic(data) - if isinstance(obj, dict): return { - cls.from_primitive(k): cls.from_primitive(v) for k, v in obj.items() + key: self.from_dict_pydantic(value) + if isinstance(value, dict) and "*PYD*" in value + else value + for key, value in data.items() } - if isinstance(obj, list): - return [cls.from_primitive(v) for v in obj] + return data - if isinstance(obj, tuple): - return tuple(cls.from_primitive(v) for v in obj) + def to_dict_pydantic(self, item: Any) -> Any: + """ + Convert item to dictionary with Pydantic type metadata. - return obj + :param item: Item to convert (may or may not be a Pydantic model) + :return: Dictionary with type preservation metadata + """ + return { + "*PYD*": True, + "typ": item.__class__.__name__, + "mod": item.__class__.__module__, + "dat": item.model_dump(mode="python"), + } + + def from_dict_pydantic(self, item: dict[str, Any]) -> Any: + """ + Reconstruct object from dictionary with Pydantic type metadata. + + :param item: Dictionary containing type metadata and data + :return: Reconstructed Pydantic model or original data + """ + type_name = item["typ"] + module_name = item["mod"] + model_class = self.load_pydantic(type_name, module_name) + + return model_class.model_validate(item["dat"]) + + def to_sequence(self, obj: Any) -> str | Any: + """ + Convert object to sequence format with type-aware serialization. + + Handles Pydantic models, collections, and mappings with proper type + preservation through structured sequence encoding. + + :param obj: Object to serialize to sequence format + :return: Serialized sequence string or bytes + """ + if isinstance(obj, BaseModel): + payload_type = "pydantic" + payload = self.to_sequence_pydantic(obj) + elif isinstance(obj, (list, tuple)) and any( + isinstance(item, BaseModel) for item in obj + ): + payload_type = "collection_sequence" + payload = None + + for item in obj: + is_pydantic = isinstance(item, BaseModel) + payload = self.pack_next_sequence( + type_="pydantic" if is_pydantic else "python", + payload=( + self.to_sequence_pydantic(item) + if is_pydantic + else self.to_sequence_python(item) + ), + current=payload, + ) + elif isinstance(obj, Mapping) and any( + isinstance(value, BaseModel) for value in obj.values() + ): + payload_type = "collection_mapping" + keys = ",".join(str(key) for key in obj) + payload = keys.encode() + b"|" if HAS_ORJSON else keys + "|" + for item in obj.values(): + is_pydantic = isinstance(item, BaseModel) + payload = self.pack_next_sequence( + type_="pydantic" if is_pydantic else "python", + payload=( + self.to_sequence_pydantic(item) + if is_pydantic + else self.to_sequence_python(item) + ), + current=payload, + ) + else: + payload_type = "python" + payload = self.to_sequence_python(obj) + + return self.pack_next_sequence(payload_type, payload, None) + + def from_sequence(self, data: str | Any) -> Any: # noqa: C901, PLR0912 + """ + Reconstruct object from sequence format with type restoration. + + Handles deserialization of objects encoded with to_sequence, properly + restoring Pydantic models and collection structures. + + :param data: Serialized sequence data to reconstruct + :return: Reconstructed object with proper types + :raises ValueError: If sequence format is invalid or contains multiple + packed sequences + """ + type_, payload, remaining = self.unpack_next_sequence(data) + if remaining is not None: + raise ValueError("Data contains multiple packed sequences; expected one.") + + if type_ == "pydantic": + return self.from_sequence_pydantic(payload) + + if type_ == "python": + return self.from_sequence_python(payload) + + if type_ in {"collection_sequence", "collection_tuple"}: + items = [] + while payload: + type_, item_payload, payload = self.unpack_next_sequence(payload) + if type_ == "pydantic": + items.append(self.from_sequence_pydantic(item_payload)) + elif type_ == "python": + items.append(self.from_sequence_python(item_payload)) + else: + raise ValueError("Invalid type in collection sequence") + return items + + if type_ != "collection_mapping": + raise ValueError(f"Invalid type for mapping sequence: {type_}") + + if isinstance(payload, bytes): + keys_end = payload.index(b"|") + keys = payload[:keys_end].decode().split(",") + payload = payload[keys_end + 1 :] + else: + keys_end = payload.index("|") + keys = payload[:keys_end].split(",") + payload = payload[keys_end + 1 :] + + items = {} + index = 0 + while payload: + type_, item_payload, payload = self.unpack_next_sequence(payload) + if type_ == "pydantic": + items[keys[index]] = self.from_sequence_pydantic(item_payload) + elif type_ == "python": + items[keys[index]] = self.from_sequence_python(item_payload) + else: + raise ValueError("Invalid type in mapping sequence") + index += 1 + return items + + def to_sequence_pydantic(self, obj: BaseModel) -> str | bytes: + """ + Serialize Pydantic model to sequence format with class metadata. + + :param obj: Pydantic model instance to serialize + :return: Sequence string or bytes containing class info and JSON data + """ + class_name: str = obj.__class__.__name__ + class_module: str = obj.__class__.__module__ + json_data = obj.__pydantic_serializer__.to_json(obj) + + return ( + (class_name.encode() + b"|" + class_module.encode() + b"|" + json_data) + if HAS_ORJSON + else ( + class_name + "|" + class_module + "|" + json_data.decode() + if isinstance(json_data, bytes) + else json_data + ) + ) + + def from_sequence_pydantic(self, data: str | bytes) -> BaseModel: + """ + Reconstruct Pydantic model from sequence format. + + :param data: Sequence data containing class metadata and JSON + :return: Reconstructed Pydantic model instance + """ + if isinstance(data, bytes): + class_name_end = data.index(b"|") + class_name = data[:class_name_end].decode() + module_name_end = data.index(b"|", class_name_end + 1) + module_name = data[class_name_end + 1 : module_name_end].decode() + json_data = data[module_name_end + 1 :] + else: + class_name_end = data.index("|") + class_name = data[:class_name_end] + module_name_end = data.index("|", class_name_end + 1) + module_name = data[class_name_end + 1 : module_name_end] + json_data = data[module_name_end + 1 :] + + model_class = self.load_pydantic(class_name, module_name) + + return model_class.model_validate_json(json_data) + + def to_sequence_python(self, obj: Any) -> str | bytes: + """ + Serialize Python object to JSON format. + + :param obj: Python object to serialize + :return: JSON string or bytes representation + """ + return orjson.dumps(obj) if HAS_ORJSON else json.dumps(obj) + + def from_sequence_python(self, data: str | bytes) -> Any: + """ + Deserialize Python object from JSON format. + + :param data: JSON string or bytes to deserialize + :return: Reconstructed Python object + :raises ImportError: If orjson is required but not available + """ + if isinstance(data, bytes): + if not HAS_ORJSON: + raise ImportError("orjson is not available, cannot deserialize bytes") + return orjson.loads(data) + + return json.loads(data) + + def pack_next_sequence( # noqa: C901, PLR0912 + self, + type_: Literal[ + "pydantic", + "python", + "collection_tuple", + "collection_sequence", + "collection_mapping", + ], + payload: str | bytes, + current: str | bytes | None, + ) -> str | bytes: + """ + Pack payload into sequence format with type and length metadata. + + :param type_: Type identifier for the payload + :param payload: Data to pack into sequence + :param current: Current sequence data to append to (unused but maintained + for signature compatibility) + :return: Packed sequence with type, length, and payload + :raises ValueError: If payload type doesn't match current type or unknown + type specified + """ + if current is not None and type(payload) is not type(current): + raise ValueError("Payload and current must be of the same type") + + payload_len = len(payload) + + if isinstance(payload, bytes): + payload_len = payload_len.to_bytes( + length=(payload_len.bit_length() + 7) // 8 if payload_len > 0 else 1, + byteorder="big", + ) + if type_ == "pydantic": + payload_type = b"P" + elif type_ == "python": + payload_type = b"p" + elif type_ == "collection_tuple": + payload_type = b"T" + elif type_ == "collection_sequence": + payload_type = b"S" + elif type_ == "collection_mapping": + payload_type = b"M" + else: + raise ValueError(f"Unknown type for packing: {type_}") + delimiter = b"|" + else: + payload_len = str(payload_len) + if type_ == "pydantic": + payload_type = "P" + elif type_ == "python": + payload_type = "p" + elif type_ == "collection_tuple": + payload_type = "T" + elif type_ == "collection_sequence": + payload_type = "S" + elif type_ == "collection_mapping": + payload_type = "M" + else: + raise ValueError(f"Unknown type for packing: {type_}") + delimiter = "|" + + next_sequence = payload_type + delimiter + payload_len + delimiter + payload + + return current + next_sequence if current else next_sequence + + def unpack_next_sequence( # noqa: C901, PLR0912 + self, data: str | bytes + ) -> tuple[ + Literal[ + "pydantic", + "python", + "collection_tuple", + "collection_sequence", + "collection_mapping", + ], + str | bytes, + str | bytes | None, + ]: + """ + Unpack sequence format to extract type, payload, and remaining data. + + :param data: Packed sequence data to unpack + :return: Tuple of (type, payload, remaining_data) + :raises ValueError: If sequence format is invalid or unknown type character + """ + if isinstance(data, bytes): + if len(data) < len(b"T|N") or data[1:2] != b"|": + raise ValueError("Invalid packed data format") + + type_char = data[0:1] + if type_char == b"P": + type_ = "pydantic" + elif type_char == b"p": + type_ = "python" + elif type_char == b"T": + type_ = "collection_tuple" + elif type_char == b"S": + type_ = "collection_sequence" + elif type_char == b"M": + type_ = "collection_mapping" + else: + raise ValueError("Unknown type character in packed data") + + len_end = data.index(b"|", 2) + payload_len = int.from_bytes(data[2:len_end], "big") + payload = data[len_end + 1 : len_end + 1 + payload_len] + remaining = ( + data[len_end + 1 + payload_len :] + if len_end + 1 + payload_len < len(data) + else None + ) + + return type_, payload, remaining + + if len(data) < len("T|N") or data[1] != "|": + raise ValueError("Invalid packed data format") + + type_char = data[0] + if type_char == "P": + type_ = "pydantic" + elif type_char == "p": + type_ = "python" + elif type_char == "S": + type_ = "collection_sequence" + elif type_char == "M": + type_ = "collection_mapping" + else: + raise ValueError("Unknown type character in packed data") + + len_end = data.index("|", 2) + payload_len = int(data[2:len_end]) + payload = data[len_end + 1 : len_end + 1 + payload_len] + remaining = ( + data[len_end + 1 + payload_len :] + if len_end + 1 + payload_len < len(data) + else None + ) + + return type_, payload, remaining diff --git a/src/guidellm/utils/functions.py b/src/guidellm/utils/functions.py new file mode 100644 index 00000000..6343cbf2 --- /dev/null +++ b/src/guidellm/utils/functions.py @@ -0,0 +1,133 @@ +""" +Utility functions for safe operations and value handling. + +Provides defensive programming utilities for common operations that may encounter +None values, invalid inputs, or edge cases. Includes safe arithmetic operations, +attribute access, and timestamp formatting. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +__all__ = [ + "all_defined", + "safe_add", + "safe_divide", + "safe_format_timestamp", + "safe_getattr", + "safe_multiply", +] + + +def safe_getattr(obj: Any | None, attr: str, default: Any = None) -> Any: + """ + Safely get an attribute from an object with None handling. + + :param obj: Object to get the attribute from, or None + :param attr: Name of the attribute to retrieve + :param default: Value to return if object is None or attribute doesn't exist + :return: Attribute value or default if not found or object is None + """ + if obj is None: + return default + + return getattr(obj, attr, default) + + +def all_defined(*values: Any | None) -> bool: + """ + Check if all provided values are defined (not None). + + :param values: Variable number of values to check for None + :return: True if all values are not None, False otherwise + """ + return all(value is not None for value in values) + + +def safe_divide( + numerator: int | float | None, + denominator: int | float | None, + num_default: float = 0.0, + den_default: float = 1.0, +) -> float: + """ + Safely divide two numbers with None handling and zero protection. + + :param numerator: Number to divide, or None to use num_default + :param denominator: Number to divide by, or None to use den_default + :param num_default: Default value for numerator if None + :param den_default: Default value for denominator if None + :return: Division result with protection against division by zero + """ + numerator = numerator if numerator is not None else num_default + denominator = denominator if denominator is not None else den_default + + return numerator / (denominator or 1e-10) + + +def safe_multiply(*values: int | float | None, default: float = 1.0) -> float: + """ + Safely multiply multiple numbers with None handling. + + :param values: Variable number of values to multiply, None values treated as 1.0 + :param default: Starting value for multiplication + :return: Product of all non-None values multiplied by default + """ + result = default + for val in values: + result *= val if val is not None else 1.0 + return result + + +def safe_add( + *values: int | float | None, signs: list[int] | None = None, default: float = 0.0 +) -> float: + """ + Safely add multiple numbers with None handling and optional signs. + + :param values: Variable number of values to add, None values use default + :param signs: Optional list of 1 (add) or -1 (subtract) for each value. + If None, all values are added. Must match length of values. + :param default: Value to substitute for None values + :return: Result of adding all values safely (default used when value is None) + """ + if not values: + return default + + values = list(values) + + if signs is None: + signs = [1] * len(values) + + if len(signs) != len(values): + raise ValueError("Length of signs must match length of values") + + result = values[0] if values[0] is not None else default + + for ind in range(1, len(values)): + val = values[ind] if values[ind] is not None else default + result += signs[ind] * val + + return result + + +def safe_format_timestamp( + timestamp: float | None, format_: str = "%H:%M:%S", default: str = "N/A" +) -> str: + """ + Safely format a timestamp with error handling and validation. + + :param timestamp: Unix timestamp to format, or None + :param format_: Strftime format string for timestamp formatting + :param default: Value to return if timestamp is invalid or None + :return: Formatted timestamp string or default value + """ + if timestamp is None or timestamp < 0 or timestamp > 2**31: + return default + + try: + return datetime.fromtimestamp(timestamp).strftime(format_) + except (ValueError, OverflowError, OSError): + return default diff --git a/src/guidellm/utils/general.py b/src/guidellm/utils/general.py index 64e6c753..e61acd3b 100644 --- a/src/guidellm/utils/general.py +++ b/src/guidellm/utils/general.py @@ -5,11 +5,11 @@ __all__ = [ "UNSET", - "Safe_format_timestamp", "UnsetType", "all_defined", "safe_add", "safe_divide", + "safe_format_timestamp", "safe_getattr", "safe_multiply", "safe_subtract", @@ -89,7 +89,7 @@ def safe_subtract(*values: int | float | None, default: float = 0.0) -> float: def safe_format_timestamp( timestamp: float | None, format_: str = "%H:%M:%S", default: str = "N/A" ) -> str: - if timestamp is None or timestamp < 0 or timestamp > 2**31: + if timestamp is not None and timestamp >= 0 and timestamp <= 2**31: try: return datetime.fromtimestamp(timestamp).strftime(format_) except (ValueError, OverflowError, OSError): diff --git a/src/guidellm/utils/messaging.py b/src/guidellm/utils/messaging.py new file mode 100644 index 00000000..c56ec29a --- /dev/null +++ b/src/guidellm/utils/messaging.py @@ -0,0 +1,1029 @@ +""" +Inter-process messaging abstractions for distributed scheduler coordination. + +Provides high-level interfaces for asynchronous message passing between worker +processes using various transport mechanisms including queues and pipes. Supports +configurable encoding, serialization, error handling, and flow control with +buffering and stop event coordination for distributed scheduler operations. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import multiprocessing +import queue +import threading +import time +from abc import ABC, abstractmethod +from collections.abc import Iterable +from multiprocessing.connection import Connection +from multiprocessing.context import BaseContext +from multiprocessing.managers import SyncManager +from multiprocessing.synchronize import Event as ProcessingEvent +from threading import Event as ThreadingEvent +from typing import Any, Callable, Generic, Protocol, TypeVar + +import culsans +from pydantic import BaseModel + +from guidellm.utils.encoding import ( + EncodingTypesAlias, + MessageEncoding, + SerializationTypesAlias, +) + +__all__ = [ + "InterProcessMessaging", + "InterProcessMessagingManagerQueue", + "InterProcessMessagingPipe", + "InterProcessMessagingQueue", + "MessagingStopCallback", + "ReceiveMessageT", + "SendMessageT", +] + +SendMessageT = TypeVar("SendMessageT", bound=Any) +"""Generic type variable for messages sent through the messaging system""" +ReceiveMessageT = TypeVar("ReceiveMessageT", bound=Any) +"""Generic type variable for messages received through the messaging system""" + + +class MessagingStopCallback(Protocol): + """Protocol for evaluating stop conditions in messaging operations.""" + + def __call__( + self, messaging: InterProcessMessaging, pending: bool, queue_empty: int + ) -> bool: + """ + Evaluate whether messaging operations should stop. + + :param messaging: The messaging instance to evaluate + :param pending: Whether there are pending operations + :param queue_empty: The number of times in a row the queue has been empty + :return: True if operations should stop, False otherwise + """ + ... + + +class InterProcessMessaging(Generic[SendMessageT, ReceiveMessageT], ABC): + """ + Abstract base for inter-process messaging in distributed scheduler coordination. + + Provides unified interface for asynchronous message passing between scheduler + components using configurable transport mechanisms, encoding schemes, and + flow control policies. Manages buffering, serialization, error handling, + and coordinated shutdown across worker processes for distributed operations. + + Example: + :: + from guidellm.utils.messaging import InterProcessMessagingQueue + + messaging = InterProcessMessagingQueue( + serialization="pickle", + max_pending_size=100 + ) + + await messaging.start() + await messaging.put(request_data) + response = await messaging.get(timeout=5.0) + await messaging.stop() + """ + + STOP_REQUIRED_QUEUE_EMPTY: int = 3 + + def __init__( + self, + mp_context: BaseContext | None = None, + serialization: SerializationTypesAlias = "dict", + encoding: EncodingTypesAlias | list[EncodingTypesAlias] = None, + max_pending_size: int | None = None, + max_buffer_send_size: int | None = None, + max_done_size: int | None = None, + max_buffer_receive_size: int | None = None, + poll_interval: float = 0.1, + worker_index: int | None = None, + ): + """ + Initialize inter-process messaging coordinator. + + :param serialization: Message serialization method for transport encoding + :param encoding: Optional encoding scheme for serialized message data + :param max_pending_size: Maximum items in send queue before blocking + :param max_buffer_send_size: Maximum items in buffer send queue + :param max_done_size: Maximum items in done queue before blocking + :param max_buffer_receive_size: Maximum items in buffer receive queue + :param poll_interval: Time interval for checking queue status and events + :param worker_index: Index identifying this worker in the process group + """ + self.worker_index: int | None = worker_index + self.mp_context = mp_context or multiprocessing.get_context() + self.serialization = serialization + self.encoding = encoding + self.max_pending_size = max_pending_size + self.max_buffer_send_size = max_buffer_send_size + self.max_done_size = max_done_size + self.max_buffer_receive_size = max_buffer_receive_size + self.poll_interval = poll_interval + + self.send_stopped_event: ThreadingEvent | ProcessingEvent = None + self.receive_stopped_event: ThreadingEvent | ProcessingEvent = None + self.shutdown_event: ThreadingEvent = None + self.buffer_send_queue: culsans.Queue[SendMessageT] = None + self.buffer_receive_queue: culsans.Queue[ReceiveMessageT] = None + self.send_task: asyncio.Task = None + self.receive_task: asyncio.Task = None + self.running = False + + @abstractmethod + def create_worker_copy( + self, worker_index: int, **kwargs + ) -> InterProcessMessaging[ReceiveMessageT, SendMessageT]: + """ + Create worker-specific copy for distributed process coordination. + + :param worker_index: Index of the worker process for message routing + :return: Configured messaging instance for the specified worker + """ + ... + + @abstractmethod + def create_send_messages_threads( + self, + send_items: Iterable[Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ) -> list[tuple[Callable, tuple[Any, ...]]]: + """ + Create send message processing threads for transport implementation. + + :param send_items: Optional collection of items to send during processing + :param message_encoding: Message encoding configuration for serialization + :param check_stop: Callable for evaluating stop conditions during processing + :return: List of thread callables with their arguments for execution + """ + ... + + @abstractmethod + def create_receive_messages_threads( + self, + receive_callback: Callable[[Any], Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ) -> list[tuple[Callable, tuple[Any, ...]]]: + """ + Create receive message processing threads for transport implementation. + + :param receive_callback: Optional callback for processing received messages + :param message_encoding: Message encoding configuration for deserialization + :param check_stop: Callable for evaluating stop conditions during processing + :return: List of thread callables with their arguments for execution + """ + ... + + async def start( + self, + send_items: Iterable[Any] | None = None, + receive_callback: Callable[[Any], Any] | None = None, + send_stop_criteria: ( + list[ThreadingEvent | ProcessingEvent | MessagingStopCallback] | None + ) = None, + send_stopped_event: ThreadingEvent | ProcessingEvent | None = None, + receive_stop_criteria: ( + list[ThreadingEvent | ProcessingEvent | MessagingStopCallback] | None + ) = None, + receive_stopped_event: ThreadingEvent | ProcessingEvent | None = None, + pydantic_models: list[type[BaseModel]] | None = None, + ): + """ + Start asynchronous message processing tasks with buffering. + + :param send_items: Optional collection of items to send during processing + :param receive_callback: Optional callback for processing received messages + :param send_stop_criteria: Events and callables that trigger send task shutdown + :param send_stopped_event: Event set when send task has fully stopped + :param receive_stop_criteria: Events and callables that trigger receive shutdown + :param receive_stopped_event: Event set when receive task has fully stopped + :param pydantic_models: Optional list of Pydantic models for serialization + """ + self.running = True + self.send_stopped_event = send_stopped_event or ThreadingEvent() + self.receive_stopped_event = receive_stopped_event or ThreadingEvent() + self.shutdown_event = ThreadingEvent() + self.buffer_send_queue = culsans.Queue[SendMessageT]( + maxsize=self.max_buffer_send_size or 0 + ) + self.buffer_receive_queue = culsans.Queue[ReceiveMessageT]( + maxsize=self.max_buffer_receive_size or 0 + ) + self.tasks_lock = threading.Lock() + + message_encoding = MessageEncoding( + serialization=self.serialization, + encoding=self.encoding, + pydantic_models=pydantic_models, + ) + send_stop_criteria = send_stop_criteria or [] + receive_stop_events = receive_stop_criteria or [] + + self.send_task = asyncio.create_task( + self.send_messages_coroutine( + send_items=send_items, + message_encoding=message_encoding, + send_stop_criteria=send_stop_criteria, + ) + ) + self.receive_task = asyncio.create_task( + self.receive_messages_coroutine( + receive_callback=receive_callback, + message_encoding=message_encoding, + receive_stop_criteria=receive_stop_events, + ) + ) + + async def stop(self): + """ + Stop message processing tasks and clean up resources. + """ + self.shutdown_event.set() + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather( + self.send_task, self.receive_task, return_exceptions=True + ) + self.send_task = None + self.receive_task = None + if self.worker_index is None: + self.buffer_send_queue.clear() + await self.buffer_send_queue.aclose() + self.buffer_receive_queue.clear() + await self.buffer_receive_queue.aclose() + self.buffer_send_queue = None + self.buffer_receive_queue = None + self.send_stopped_event = None + self.receive_stopped_event = None + self.shutdown_event = None + self.running = False + + async def send_messages_coroutine( + self, + send_items: Iterable[Any] | None, + message_encoding: MessageEncoding, + send_stop_criteria: ( + list[ThreadingEvent | ProcessingEvent | MessagingStopCallback] | None + ), + ): + """ + Execute send message processing with encoding and stop condition handling. + + :param send_items: Optional collection of items to send during processing + :param message_encoding: Message encoding configuration for serialization + :param send_stop_criteria: Events and callables that trigger send task shutdown + """ + canceled_event = ThreadingEvent() + + try: + await asyncio.gather( + *[ + asyncio.to_thread(thread, *args) + for (thread, args) in self.create_send_messages_threads( + send_items=send_items, + message_encoding=message_encoding, + check_stop=self._create_check_stop_callable( + send_stop_criteria, canceled_event + ), + ) + ] + ) + except asyncio.CancelledError: + canceled_event.set() + raise + finally: + self.send_stopped_event.set() + + async def receive_messages_coroutine( + self, + receive_callback: Callable[[Any], Any] | None, + message_encoding: MessageEncoding, + receive_stop_criteria: ( + list[ThreadingEvent | ProcessingEvent | MessagingStopCallback] | None + ), + ): + """ + Execute receive message processing with decoding and callback handling. + + :param receive_callback: Optional callback for processing received messages + :param message_encoding: Message encoding configuration for deserialization + :param receive_stop_criteria: Events and callables that trigger receive shutdown + """ + canceled_event = ThreadingEvent() + + try: + await asyncio.gather( + *[ + asyncio.to_thread(thread, *args) + for thread, args in self.create_receive_messages_threads( + receive_callback=receive_callback, + message_encoding=message_encoding, + check_stop=self._create_check_stop_callable( + receive_stop_criteria, canceled_event + ), + ) + ] + ) + except asyncio.CancelledError: + canceled_event.set() + raise + finally: + self.receive_stopped_event.set() + + async def get(self, timeout: float | None = None) -> ReceiveMessageT: + """ + Retrieve message from receive buffer with optional timeout. + + :param timeout: Maximum time to wait for a message + :return: Decoded message from the receive buffer + """ + return await asyncio.wait_for( + self.buffer_receive_queue.async_get(), timeout=timeout + ) + + def get_sync(self, timeout: float | None = None) -> ReceiveMessageT: + """ + Retrieve message from receive buffer synchronously with optional timeout. + + :param timeout: Maximum time to wait for a message, if <=0 uses get_nowait + :return: Decoded message from the receive buffer + """ + if timeout is not None and timeout <= 0: + return self.buffer_receive_queue.get_nowait() + else: + return self.buffer_receive_queue.sync_get(timeout=timeout) + + async def put(self, item: SendMessageT, timeout: float | None = None): + """ + Add message to send buffer with optional timeout. + + :param item: Message item to add to the send buffer + :param timeout: Maximum time to wait for buffer space + """ + await asyncio.wait_for(self.buffer_send_queue.async_put(item), timeout=timeout) + + def put_sync(self, item: SendMessageT, timeout: float | None = None): + """ + Add message to send buffer synchronously with optional timeout. + + :param item: Message item to add to the send buffer + :param timeout: Maximum time to wait for buffer space, if <=0 uses put_nowait + """ + if timeout is not None and timeout <= 0: + self.buffer_send_queue.put_nowait(item) + else: + self.buffer_send_queue.sync_put(item, timeout=timeout) + + def _create_check_stop_callable( + self, + stop_criteria: ( + list[ThreadingEvent | ProcessingEvent | MessagingStopCallback] | None + ), + canceled_event: ThreadingEvent, + ): + stop_events = tuple( + item + for item in stop_criteria or [] + if isinstance(item, (ThreadingEvent, ProcessingEvent)) + ) + stop_callbacks = tuple(item for item in stop_criteria or [] if callable(item)) + + def check_stop(pending: bool, queue_empty: int) -> bool: + if canceled_event.is_set(): + return True + + if stop_callbacks and any( + cb(self, pending, queue_empty) for cb in stop_callbacks + ): + return True + + return ( + not pending + and queue_empty >= self.STOP_REQUIRED_QUEUE_EMPTY + and ( + self.shutdown_event.is_set() + or any(event.is_set() for event in stop_events) + ) + ) + + return check_stop + + +class InterProcessMessagingQueue(InterProcessMessaging[SendMessageT, ReceiveMessageT]): + """ + Queue-based inter-process messaging for distributed scheduler coordination. + + Provides message passing using multiprocessing.Queue objects for communication + between scheduler workers and main process. Handles message encoding, buffering, + flow control, and coordinated shutdown with configurable queue behavior and + error handling policies for distributed operations. + + Example: + :: + from guidellm.utils.messaging import InterProcessMessagingQueue + + messaging = InterProcessMessagingQueue( + serialization="pickle", + max_pending_size=100 + ) + + # Create worker copy for distributed processing + worker_messaging = messaging.create_worker_copy(worker_index=0) + """ + + def __init__( + self, + mp_context: BaseContext | None = None, + serialization: SerializationTypesAlias = "dict", + encoding: EncodingTypesAlias = None, + max_pending_size: int | None = None, + max_buffer_send_size: int | None = None, + max_done_size: int | None = None, + max_buffer_receive_size: int | None = None, + poll_interval: float = 0.1, + worker_index: int | None = None, + pending_queue: multiprocessing.Queue | None = None, + done_queue: multiprocessing.Queue | None = None, + ): + """ + Initialize queue-based messaging for inter-process communication. + + :param serialization: Message serialization method for transport encoding + :param encoding: Optional encoding scheme for serialized message data + :param max_pending_size: Maximum items in send queue before blocking + :param max_buffer_send_size: Maximum items in buffer send queue + :param max_done_size: Maximum items in receive queue before blocking + :param max_buffer_receive_size: Maximum items in buffer receive queue + :param poll_interval: Time interval for checking queue status and events + :param worker_index: Index identifying this worker in the process group + :param pending_queue: Multiprocessing queue for sending messages + :param done_queue: Multiprocessing queue for receiving completed messages + :param context: Multiprocessing context for creating queues + """ + super().__init__( + mp_context=mp_context, + serialization=serialization, + encoding=encoding, + max_pending_size=max_pending_size, + max_buffer_send_size=max_buffer_send_size, + max_done_size=max_done_size, + max_buffer_receive_size=max_buffer_receive_size, + poll_interval=poll_interval, + worker_index=worker_index, + ) + self.pending_queue = pending_queue or self.mp_context.Queue( + maxsize=max_pending_size or 0 + ) + self.done_queue = done_queue or self.mp_context.Queue( + maxsize=max_done_size or 0 + ) + + def create_worker_copy( + self, worker_index: int, **kwargs + ) -> InterProcessMessagingQueue[ReceiveMessageT, SendMessageT]: + """ + Create worker-specific copy for distributed queue-based coordination. + + :param worker_index: Index of the worker process for message routing + :return: Configured queue messaging instance for the specified worker + """ + copy_args = { + "mp_context": self.mp_context, + "serialization": self.serialization, + "encoding": self.encoding, + "max_pending_size": self.max_pending_size, + "max_buffer_send_size": self.max_buffer_send_size, + "max_done_size": self.max_done_size, + "max_buffer_receive_size": self.max_buffer_receive_size, + "poll_interval": self.poll_interval, + "worker_index": worker_index, + "pending_queue": self.pending_queue, + "done_queue": self.done_queue, + } + copy_args.update(kwargs) + + return InterProcessMessagingQueue[ReceiveMessageT, SendMessageT](**copy_args) + + async def stop(self): + """ + Stop the messaging system and wait for all tasks to complete. + """ + await super().stop() + if self.worker_index is None: + # only main process should close the queues + with contextlib.suppress(queue.Empty): + while True: + self.pending_queue.get_nowait() + self.pending_queue.close() + + with contextlib.suppress(queue.Empty): + while True: + self.done_queue.get_nowait() + self.done_queue.close() + + self.pending_queue = None + self.done_queue = None + + def create_send_messages_threads( + self, + send_items: Iterable[Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ) -> list[tuple[Callable, tuple[Any, ...]]]: + """ + Create send message processing threads for queue-based transport. + + :param send_items: Optional collection of items to send during processing + :param message_encoding: Message encoding configuration for serialization + :param check_stop: Callable for evaluating stop conditions during processing + :return: List of thread callables with their arguments for execution + """ + return [ + ( + self._send_messages_task_thread, + (send_items, message_encoding, check_stop), + ) + ] + + def create_receive_messages_threads( + self, + receive_callback: Callable[[Any], Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ) -> list[tuple[Callable, tuple[Any, ...]]]: + """ + Create receive message processing threads for queue-based transport. + + :param receive_callback: Optional callback for processing received messages + :param message_encoding: Message encoding configuration for deserialization + :param check_stop: Callable for evaluating stop conditions during processing + :return: List of thread callables with their arguments for execution + """ + return [ + ( + self._receive_messages_task_thread, + (receive_callback, message_encoding, check_stop), + ) + ] + + def _send_messages_task_thread( # noqa: C901, PLR0912 + self, + send_items: Iterable[Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ): + send_items_iter = iter(send_items) if send_items is not None else None + pending_item = None + queue_empty = 0 + + while not check_stop(pending_item is not None, queue_empty): + if pending_item is None: + try: + if send_items_iter is not None: + item = next(send_items_iter) + else: + item = self.buffer_send_queue.sync_get( + timeout=self.poll_interval + ) + pending_item = message_encoding.encode(item) + queue_empty = 0 + except (culsans.QueueEmpty, queue.Empty, StopIteration): + queue_empty += 1 + + if pending_item is not None: + try: + if self.worker_index is None: + # Main publisher + self.pending_queue.put(pending_item, timeout=self.poll_interval) + else: + # Worker + self.done_queue.put(pending_item, timeout=self.poll_interval) + if send_items_iter is None: + self.buffer_send_queue.task_done() + pending_item = None + except (culsans.QueueFull, queue.Full): + pass + + def _receive_messages_task_thread( # noqa: C901 + self, + receive_callback: Callable[[Any], Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ): + pending_item = None + received_item = None + queue_empty = 0 + + while not check_stop(pending_item is not None, queue_empty): + if pending_item is None: + try: + if self.worker_index is None: + # Main publisher + item = self.done_queue.get(timeout=self.poll_interval) + else: + # Worker + item = self.pending_queue.get(timeout=self.poll_interval) + pending_item = message_encoding.decode(item) + queue_empty = 0 + except (culsans.QueueEmpty, queue.Empty): + queue_empty += 1 + + if pending_item is not None or received_item is not None: + try: + if received_item is None: + received_item = ( + pending_item + if not receive_callback + else receive_callback(pending_item) + ) + + self.buffer_receive_queue.sync_put(received_item) + pending_item = None + received_item = None + except (culsans.QueueFull, queue.Full): + pass + + +class InterProcessMessagingManagerQueue( + InterProcessMessagingQueue[SendMessageT, ReceiveMessageT] +): + """ + Manager-based queue messaging for inter-process scheduler coordination. + + Extends queue-based messaging with multiprocessing.Manager support for + shared state coordination across worker processes. Provides managed queues + for reliable message passing in distributed scheduler environments with + enhanced process synchronization and resource management capabilities. + + Example: + :: + import multiprocessing + from guidellm.utils.messaging import InterProcessMessagingManagerQueue + + manager = multiprocessing.Manager() + messaging = InterProcessMessagingManagerQueue( + manager=manager, + serialization="pickle" + ) + """ + + def __init__( + self, + manager: SyncManager, + mp_context: BaseContext | None = None, + serialization: SerializationTypesAlias = "dict", + encoding: EncodingTypesAlias = None, + max_pending_size: int | None = None, + max_buffer_send_size: int | None = None, + max_done_size: int | None = None, + max_buffer_receive_size: int | None = None, + poll_interval: float = 0.1, + worker_index: int | None = None, + pending_queue: multiprocessing.Queue | None = None, + done_queue: multiprocessing.Queue | None = None, + ): + """ + Initialize manager-based queue messaging for inter-process communication. + + :param manager: Multiprocessing manager for shared queue creation + :param serialization: Message serialization method for transport encoding + :param encoding: Optional encoding scheme for serialized message data + :param max_pending_size: Maximum items in send queue before blocking + :param max_buffer_send_size: Maximum items in buffer send queue + :param max_done_size: Maximum items in receive queue before blocking + :param max_buffer_receive_size: Maximum items in buffer receive queue + :param poll_interval: Time interval for checking queue status and events + :param worker_index: Index identifying this worker in the process group + :param pending_queue: Managed multiprocessing queue for sending messages + :param done_queue: Managed multiprocessing queue for receiving completed + messages + """ + super().__init__( + mp_context=mp_context, + serialization=serialization, + encoding=encoding, + max_pending_size=max_pending_size, + max_buffer_send_size=max_buffer_send_size, + max_done_size=max_done_size, + max_buffer_receive_size=max_buffer_receive_size, + poll_interval=poll_interval, + worker_index=worker_index, + pending_queue=pending_queue or manager.Queue(maxsize=max_pending_size or 0), # type: ignore [assignment] + done_queue=done_queue or manager.Queue(maxsize=max_done_size or 0), # type: ignore [assignment] + ) + + def create_worker_copy( + self, worker_index: int, **kwargs + ) -> InterProcessMessagingManagerQueue[ReceiveMessageT, SendMessageT]: + """ + Create worker-specific copy for managed queue-based coordination. + + :param worker_index: Index of the worker process for message routing + :return: Configured manager queue messaging instance for the specified worker + """ + copy_args = { + "manager": None, + "mp_context": self.mp_context, + "serialization": self.serialization, + "encoding": self.encoding, + "max_pending_size": self.max_pending_size, + "max_buffer_send_size": self.max_buffer_send_size, + "max_done_size": self.max_done_size, + "max_buffer_receive_size": self.max_buffer_receive_size, + "poll_interval": self.poll_interval, + "worker_index": worker_index, + "pending_queue": self.pending_queue, + "done_queue": self.done_queue, + } + copy_args.update(kwargs) + + return InterProcessMessagingManagerQueue(**copy_args) + + async def stop(self): + """ + Stop the messaging system and wait for all tasks to complete. + """ + await InterProcessMessaging.stop(self) + self.pending_queue = None + self.done_queue = None + + +class InterProcessMessagingPipe(InterProcessMessaging[SendMessageT, ReceiveMessageT]): + """ + Pipe-based inter-process messaging for distributed scheduler coordination. + + Provides message passing using multiprocessing.Pipe objects for direct + communication between scheduler workers and main process. Offers lower + latency than queue-based messaging with duplex communication channels + for high-performance distributed operations. + + Example: + :: + from guidellm.utils.messaging import InterProcessMessagingPipe + + messaging = InterProcessMessagingPipe( + num_workers=4, + serialization="pickle", + poll_interval=0.05 + ) + + # Create worker copy for specific worker process + worker_messaging = messaging.create_worker_copy(worker_index=0) + """ + + def __init__( + self, + num_workers: int, + mp_context: BaseContext | None = None, + serialization: SerializationTypesAlias = "dict", + encoding: EncodingTypesAlias = None, + max_pending_size: int | None = None, + max_buffer_send_size: int | None = None, + max_done_size: int | None = None, + max_buffer_receive_size: int | None = None, + poll_interval: float = 0.1, + worker_index: int | None = None, + pipe: tuple[Connection, Connection] | None = None, + ): + """ + Initialize pipe-based messaging for inter-process communication. + + :param num_workers: Number of worker processes requiring pipe connections + :param serialization: Message serialization method for transport encoding + :param encoding: Optional encoding scheme for serialized message data + :param max_pending_size: Maximum items in send queue before blocking + :param max_buffer_send_size: Maximum items in buffer send queue + :param max_done_size: Maximum items in receive queue before blocking + :param max_buffer_receive_size: Maximum items in buffer receive queue + :param poll_interval: Time interval for checking queue status and events + :param worker_index: Index identifying this worker in the process group + :param pipe: Existing pipe connection for worker-specific instances + """ + super().__init__( + mp_context=mp_context, + serialization=serialization, + encoding=encoding, + max_pending_size=max_pending_size, + max_buffer_send_size=max_buffer_send_size, + max_done_size=max_done_size, + max_buffer_receive_size=max_buffer_receive_size, + poll_interval=poll_interval, + worker_index=worker_index, + ) + self.num_workers = num_workers + + if pipe is None: + self.pipes: list[tuple[Connection, Connection]] = [ + self.mp_context.Pipe(duplex=True) for _ in range(num_workers) + ] + else: + self.pipes: list[tuple[Connection, Connection]] = [pipe] + + def create_worker_copy( + self, worker_index: int, **kwargs + ) -> InterProcessMessagingPipe[ReceiveMessageT, SendMessageT]: + """ + Create worker-specific copy for pipe-based coordination. + + :param worker_index: Index of the worker process for pipe routing + :return: Configured pipe messaging instance for the specified worker + """ + copy_args = { + "num_workers": self.num_workers, + "mp_context": self.mp_context, + "serialization": self.serialization, + "encoding": self.encoding, + "max_pending_size": self.max_pending_size, + "max_buffer_send_size": self.max_buffer_send_size, + "max_done_size": self.max_done_size, + "max_buffer_receive_size": self.max_buffer_receive_size, + "poll_interval": self.poll_interval, + "worker_index": worker_index, + "pipe": self.pipes[worker_index], + } + copy_args.update(kwargs) + + return InterProcessMessagingPipe(**copy_args) + + async def stop(self): + """ + Stop the messaging system and wait for all tasks to complete. + """ + await super().stop() + if self.worker_index is None: + # Only main process should close the pipes + for main_con, worker_con in self.pipes: + main_con.close() + worker_con.close() + + def create_send_messages_threads( + self, + send_items: Iterable[Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ) -> list[tuple[Callable, tuple[Any, ...]]]: + """ + Create send message processing threads for pipe-based transport. + + :param send_items: Optional collection of items to send during processing + :param message_encoding: Message encoding configuration for serialization + :param check_stop: Callable for evaluating stop conditions during processing + :return: List of thread callables with their arguments for execution + """ + if self.worker_index is None: + # Create a separate task for each worker's pipe + return [ + ( + self._send_messages_task_thread, + (self.pipes[index], send_items, message_encoding, check_stop), + ) + for index in range(self.num_workers) + ] + else: + return [ + ( + self._send_messages_task_thread, + (self.pipes[0], send_items, message_encoding, check_stop), + ) + ] + + def create_receive_messages_threads( + self, + receive_callback: Callable[[Any], Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ) -> list[tuple[Callable, tuple[Any, ...]]]: + """ + Create receive message processing threads for pipe-based transport. + + :param receive_callback: Optional callback for processing received messages + :param message_encoding: Message encoding configuration for deserialization + :param check_stop: Callable for evaluating stop conditions during processing + :return: List of thread callables with their arguments for execution + """ + if self.worker_index is None: + # Create a separate task for each worker's pipe + return [ + ( + self._receive_messages_task_thread, + (self.pipes[index], receive_callback, message_encoding, check_stop), + ) + for index in range(self.num_workers) + ] + else: + return [ + ( + self._receive_messages_task_thread, + (self.pipes[0], receive_callback, message_encoding, check_stop), + ) + ] + + def _send_messages_task_thread( # noqa: C901, PLR0912 + self, + pipe: tuple[Connection, Connection], + send_items: Iterable[Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ): + local_stop = ThreadingEvent() + send_connection: Connection = pipe[0] if self.worker_index is None else pipe[1] + send_items_iter = iter(send_items) if send_items is not None else None + pending_item = None + queue_empty = 0 + pipe_item = None + pipe_lock = threading.Lock() + + def _background_pipe_recv(): + nonlocal pipe_item + + while not local_stop.is_set(): + try: + with pipe_lock: + pending = pipe_item + pipe_item = None + + if pending is not None: + send_connection.send(pending) + except (EOFError, ConnectionResetError): + break + + if send_items_iter is None: + threading.Thread(target=_background_pipe_recv, daemon=True).start() + + try: + while not check_stop(pending_item is not None, queue_empty): + if pending_item is None: + try: + if send_items_iter is not None: + item = next(send_items_iter) + else: + item = self.buffer_send_queue.sync_get( + timeout=self.poll_interval + ) + pending_item = message_encoding.encode(item) + queue_empty = 0 + except (culsans.QueueEmpty, queue.Empty, StopIteration): + queue_empty += 1 + + if pending_item is not None: + try: + with pipe_lock: + if pipe_item is not None: + time.sleep(self.poll_interval / 100) + raise queue.Full + else: + pipe_item = pending_item + if send_items_iter is None: + self.buffer_send_queue.task_done() + pending_item = None + except (culsans.QueueFull, queue.Full): + pass + finally: + local_stop.set() + + def _receive_messages_task_thread( # noqa: C901 + self, + pipe: tuple[Connection, Connection], + receive_callback: Callable[[Any], Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ): + receive_connection: Connection = ( + pipe[0] if self.worker_index is not None else pipe[1] + ) + pending_item = None + received_item = None + queue_empty = 0 + + while not check_stop(pending_item is not None, queue_empty): + if pending_item is None: + try: + if receive_connection.poll(self.poll_interval): + item = receive_connection.recv() + pending_item = message_encoding.decode(item) + else: + raise queue.Empty + queue_empty = 0 + except (culsans.QueueEmpty, queue.Empty): + queue_empty += 1 + + if pending_item is not None or received_item is not None: + try: + if received_item is None: + received_item = ( + pending_item + if not receive_callback + else receive_callback(pending_item) + ) + + self.buffer_receive_queue.sync_put(received_item) + pending_item = None + received_item = None + except (culsans.QueueFull, queue.Full): + pass diff --git a/src/guidellm/utils/mixins.py b/src/guidellm/utils/mixins.py index c71067a4..b001ff2d 100644 --- a/src/guidellm/utils/mixins.py +++ b/src/guidellm/utils/mixins.py @@ -3,18 +3,41 @@ Provides reusable mixins for extracting structured metadata from objects, enabling consistent information exposure across different class hierarchies. - -Classes: - InfoMixin: Mixin providing standardized metadata extraction capabilities. """ +from __future__ import annotations + from typing import Any __all__ = ["InfoMixin"] +PYTHON_PRIMITIVES = (str, int, float, bool, list, tuple, dict) +"""Type alias for serialized object representations""" + + class InfoMixin: - """Mixin class providing standardized metadata extraction for introspection.""" + """ + Mixin class providing standardized metadata extraction for introspection. + + Enables consistent object metadata extraction patterns across different + class hierarchies for debugging, serialization, and runtime analysis. + Provides both instance and class-level methods for extracting structured + information from arbitrary objects with fallback handling for objects + without built-in info capabilities. + + Example: + :: + from guidellm.utils.mixins import InfoMixin + + class ConfiguredClass(InfoMixin): + def __init__(self, setting: str): + self.setting = setting + + obj = ConfiguredClass("value") + # Returns {'str': 'ConfiguredClass(...)', 'type': 'ConfiguredClass', ...} + print(obj.info) + """ @classmethod def extract_from_obj(cls, obj: Any) -> dict[str, Any]: @@ -23,10 +46,11 @@ def extract_from_obj(cls, obj: Any) -> dict[str, Any]: Attempts to use the object's own `info` method or property if available, otherwise constructs metadata from object attributes and type information. + Provides consistent metadata format across different object types. - :param obj: Object to extract metadata from. + :param obj: Object to extract metadata from :return: Dictionary containing object metadata including type, class, - module, and public attributes. + module, and public attributes """ if hasattr(obj, "info"): return obj.info() if callable(obj.info) else obj.info @@ -38,9 +62,7 @@ def extract_from_obj(cls, obj: Any) -> dict[str, Any]: "module": obj.__class__.__module__ if hasattr(obj, "__class__") else None, "attributes": ( { - key: val - if isinstance(val, (str, int, float, bool, list, dict)) - else str(val) + key: val if isinstance(val, PYTHON_PRIMITIVES) else repr(val) for key, val in obj.__dict__.items() if not key.startswith("_") } @@ -54,8 +76,12 @@ def create_info_dict(cls, obj: Any) -> dict[str, Any]: """ Create a structured info dictionary for the given object. - :param obj: Object to extract info from. - :return: Dictionary containing structured metadata about the object. + Builds standardized metadata dictionary containing object identification, + type information, and accessible attributes. Used internally by other + info extraction methods and available for direct metadata construction. + + :param obj: Object to extract info from + :return: Dictionary containing structured metadata about the object """ return { "str": str(obj), @@ -66,7 +92,7 @@ def create_info_dict(cls, obj: Any) -> dict[str, Any]: { key: val if isinstance(val, (str, int, float, bool, list, dict)) - else str(val) + else repr(val) for key, val in obj.__dict__.items() if not key.startswith("_") } @@ -80,6 +106,10 @@ def info(self) -> dict[str, Any]: """ Return structured metadata about this instance. - :return: Dictionary containing class name, module, and public attributes. + Provides consistent access to object metadata for debugging, serialization, + and introspection. Uses the create_info_dict method to generate standardized + metadata format including class information and public attributes. + + :return: Dictionary containing class name, module, and public attributes """ return self.create_info_dict(self) diff --git a/src/guidellm/utils/pydantic_utils.py b/src/guidellm/utils/pydantic_utils.py index 8d329eb6..f06614f8 100644 --- a/src/guidellm/utils/pydantic_utils.py +++ b/src/guidellm/utils/pydantic_utils.py @@ -3,18 +3,19 @@ Provides integration between Pydantic and the registry system, enabling polymorphic serialization and deserialization of Pydantic models using -a discriminator field and dynamic class registry. - -Classes: - ReloadableBaseModel: Base model with schema reloading capabilities. - PydanticClassRegistryMixin: Polymorphic Pydantic models with registry support. +a discriminator field and dynamic class registry. Includes base model classes +with standardized configurations and generic status breakdown models for +structured result organization. """ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Any, ClassVar, Generic, Optional, TypeVar +from typing import Any, ClassVar, Generic, TypeVar from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema +from typing_extensions import get_args, get_origin from guidellm.utils.registry import RegistryMixin @@ -28,103 +29,224 @@ BaseModelT = TypeVar("BaseModelT", bound=BaseModel) -T = TypeVar("T", bound=BaseModel) +RegisterClassT = TypeVar("RegisterClassT") +SuccessfulT = TypeVar("SuccessfulT") +ErroredT = TypeVar("ErroredT") +IncompleteT = TypeVar("IncompleteT") +TotalT = TypeVar("TotalT") class ReloadableBaseModel(BaseModel): - """Base Pydantic model with schema reloading capabilities.""" + """ + Base Pydantic model with schema reloading capabilities. + + Provides dynamic schema rebuilding functionality for models that need to + update their validation schemas at runtime, particularly useful when + working with registry-based polymorphic models where new types are + registered after initial class definition. + """ model_config = ConfigDict( extra="ignore", use_enum_values=True, - validate_assignment=True, from_attributes=True, arbitrary_types_allowed=True, ) @classmethod - def reload_schema(cls): + def reload_schema(cls, parents: bool = True) -> None: """ Reload the class schema with updated registry information. - :return: None + Forces a complete rebuild of the Pydantic model schema to incorporate + any changes made to associated registries or validation rules. + + :param parents: Whether to also rebuild schemas for any pydantic parent + types that reference this model. """ cls.model_rebuild(force=True) + if parents: + cls.reload_parent_schemas() + + @classmethod + def reload_parent_schemas(cls): + """ + Recursively reload schemas for all parent Pydantic models. + + Traverses the inheritance hierarchy to find all parent classes that + are Pydantic models and triggers schema rebuilding on each to ensure + that any changes in child models are reflected in parent schemas. + """ + potential_parents: set[BaseModel] = {BaseModel} + stack: list[BaseModel] = [BaseModel] + + while stack: + current = stack.pop() + for subclass in current.__subclasses__(): + if ( + issubclass(subclass, BaseModel) + and subclass is not cls + and subclass not in potential_parents + ): + potential_parents.add(subclass) + stack.append(subclass) + + for check in cls.__mro__: + if isinstance(check, type) and issubclass(check, BaseModel): + cls._reload_schemas_depending_on(check, potential_parents) + + @classmethod + def _reload_schemas_depending_on(cls, target: type[BaseModel], types: set[type]): + changed = True + while changed: + changed = False + for candidate in types: + if ( + isinstance(candidate, type) + and issubclass(candidate, BaseModel) + and any( + cls._uses_type(target, field_info.annotation) + for field_info in candidate.model_fields.values() + ) + ): + before = candidate.model_json_schema() + candidate.model_rebuild(force=True) + after = candidate.model_json_schema() + if before != after: + changed = True + + @classmethod + def _uses_type(cls, target: type, candidate: type) -> bool: + if target is candidate: + return True + + origin = get_origin(candidate) + + if origin is None: + return isinstance(candidate, type) and issubclass(candidate, target) + + if isinstance(origin, type) and ( + target is origin or issubclass(origin, target) + ): + return True + + for arg in get_args(candidate) or []: + if isinstance(arg, type) and cls._uses_type(target, arg): + return True + + return False + class StandardBaseModel(BaseModel): """ - A base class for Pydantic models throughout GuideLLM enabling standard - configuration and logging. + Base Pydantic model with standardized configuration for GuideLLM. + + Provides consistent validation behavior and configuration settings across + all Pydantic models in the application, including field validation, + attribute conversion, and default value handling. + + Example: + :: + class MyModel(StandardBaseModel): + name: str + value: int = 42 + + # Access default values + default_value = MyModel.get_default("value") # Returns 42 """ model_config = ConfigDict( extra="ignore", use_enum_values=True, - validate_assignment=True, from_attributes=True, ) @classmethod - def get_default(cls: type[T], field: str) -> Any: - """Get default values for model fields""" + def get_default(cls: type[BaseModel], field: str) -> Any: + """ + Get default value for a model field. + + :param field: Name of the field to get the default value for + :return: Default value of the specified field + :raises KeyError: If the field does not exist in the model + """ return cls.model_fields[field].default class StandardBaseDict(StandardBaseModel): + """ + Base Pydantic model allowing arbitrary additional fields. + + Extends StandardBaseModel to accept extra fields beyond those explicitly + defined in the model schema. Useful for flexible data structures that + need to accommodate varying or unknown field sets while maintaining + type safety for known fields. + """ + model_config = ConfigDict( extra="allow", use_enum_values=True, - validate_assignment=True, from_attributes=True, arbitrary_types_allowed=True, ) -SuccessfulT = TypeVar("SuccessfulT") -ErroredT = TypeVar("ErroredT") -IncompleteT = TypeVar("IncompleteT") -TotalT = TypeVar("TotalT") - - class StatusBreakdown(BaseModel, Generic[SuccessfulT, ErroredT, IncompleteT, TotalT]): """ - A base class for Pydantic models that are separated by statuses including - successful, incomplete, and errored. It additionally enables the inclusion - of total, which is intended as the combination of all statuses. - Total may or may not be used depending on if it duplicates information. + Generic model for organizing results by processing status. + + Provides structured categorization of results into successful, errored, + incomplete, and total status groups. Supports flexible typing for each + status category to accommodate different result types while maintaining + consistent organization patterns across the application. + + Example: + :: + from guidellm.utils import StatusBreakdown + + # Define a breakdown for request counts + breakdown = StatusBreakdown[int, int, int, int]( + successful=150, + errored=5, + incomplete=10, + total=165 + ) """ successful: SuccessfulT = Field( - description="The results with a successful status.", + description="Results or metrics for requests with successful completion status", default=None, # type: ignore[assignment] ) errored: ErroredT = Field( - description="The results with an errored status.", + description="Results or metrics for requests with error completion status", default=None, # type: ignore[assignment] ) incomplete: IncompleteT = Field( - description="The results with an incomplete status.", + description="Results or metrics for requests with incomplete processing status", default=None, # type: ignore[assignment] ) total: TotalT = Field( - description="The combination of all statuses.", + description="Aggregated results or metrics combining all status categories", default=None, # type: ignore[assignment] ) class PydanticClassRegistryMixin( - ReloadableBaseModel, ABC, RegistryMixin[BaseModelT], Generic[BaseModelT] + ReloadableBaseModel, RegistryMixin[type[BaseModelT]], ABC, Generic[BaseModelT] ): """ - Polymorphic Pydantic models with registry-based dynamic instantiation. + Polymorphic Pydantic model mixin enabling registry-based dynamic instantiation. Integrates Pydantic validation with the registry system to enable polymorphic serialization and deserialization based on a discriminator field. Automatically - instantiates the correct subclass during validation based on registry mappings. + instantiates the correct subclass during validation based on registry mappings, + providing a foundation for extensible plugin-style architectures. Example: :: + from speculators.utils import PydanticClassRegistryMixin + class BaseConfig(PydanticClassRegistryMixin["BaseConfig"]): schema_discriminator: ClassVar[str] = "config_type" config_type: str = Field(description="Configuration type identifier") @@ -133,28 +255,37 @@ class BaseConfig(PydanticClassRegistryMixin["BaseConfig"]): def __pydantic_schema_base_type__(cls) -> type["BaseConfig"]: return BaseConfig - @BaseConfig.register("type_a") - class ConfigA(BaseConfig): - config_type: str = "type_a" - value: str = Field(description="Configuration value") + @BaseConfig.register("database") + class DatabaseConfig(BaseConfig): + config_type: str = "database" + connection_string: str = Field(description="Database connection string") - # Dynamic instantiation - config = BaseConfig.model_validate({"config_type": "type_a", "value": "test"}) + # Dynamic instantiation based on discriminator + config = BaseConfig.model_validate({ + "config_type": "database", + "connection_string": "postgresql://localhost:5432/db" + }) + + :cvar schema_discriminator: Field name used for polymorphic type discrimination """ schema_discriminator: ClassVar[str] = "model_type" @classmethod def register_decorator( - cls, clazz: type[BaseModel], name: Optional[str] = None - ) -> type[BaseModel]: + cls, clazz: RegisterClassT, name: str | list[str] | None = None + ) -> RegisterClassT: """ - Register a Pydantic model class with type validation. + Register a Pydantic model class with type validation and schema reload. + + Validates that the class is a proper Pydantic BaseModel subclass before + registering it in the class registry. Automatically triggers schema + reload to incorporate the new type into polymorphic validation. - :param clazz: The Pydantic model class to register. - :param name: Optional registry name. Defaults to class name if None. - :return: The registered class. - :raises TypeError: If clazz is not a Pydantic BaseModel subclass. + :param clazz: Pydantic model class to register in the polymorphic hierarchy + :param name: Registry identifier for the class. Uses class name if None + :return: The registered class unchanged for decorator chaining + :raises TypeError: If clazz is not a Pydantic BaseModel subclass """ if not issubclass(clazz, BaseModel): raise TypeError( @@ -162,21 +293,25 @@ def register_decorator( "Pydantic BaseModel" ) - dec_clazz = super().register_decorator(clazz, name=name) + super().register_decorator(clazz, name=name) cls.reload_schema() - return dec_clazz + return clazz @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler ) -> CoreSchema: """ - Generate polymorphic validation schema for dynamic instantiation. + Generate polymorphic validation schema for dynamic type instantiation. - :param source_type: The type for schema generation. - :param handler: Core schema generation handler. - :return: Tagged union schema for polymorphic validation. + Creates a tagged union schema that enables Pydantic to automatically + instantiate the correct subclass based on the discriminator field value. + Falls back to base schema generation when no registry is available. + + :param source_type: Type being processed for schema generation + :param handler: Pydantic core schema generation handler + :return: Tagged union schema for polymorphic validation or base schema """ if source_type == cls.__pydantic_schema_base_type__(): if not cls.registry: @@ -197,9 +332,12 @@ def __get_pydantic_core_schema__( @abstractmethod def __pydantic_schema_base_type__(cls) -> type[BaseModelT]: """ - Define the base type for polymorphic validation. + Define the base type for polymorphic validation hierarchy. + + Must be implemented by subclasses to specify which type serves as the + root of the polymorphic hierarchy for schema generation and validation. - :return: The base class type for the polymorphic hierarchy. + :return: Base class type for the polymorphic model hierarchy """ ... @@ -208,22 +346,52 @@ def __pydantic_generate_base_schema__( cls, handler: GetCoreSchemaHandler ) -> CoreSchema: """ - Generate base schema for polymorphic models without registry. + Generate fallback schema for polymorphic models without registry. - :param handler: Core schema generation handler. - :return: Base CoreSchema accepting any valid input. + Provides a base schema that accepts any valid input when no registry + is available for polymorphic validation. Used as fallback during + schema generation when the registry has not been populated. + + :param handler: Pydantic core schema generation handler + :return: Base CoreSchema that accepts any valid input """ return core_schema.any_schema() @classmethod def auto_populate_registry(cls) -> bool: """ - Initialize registry and reload schema for validation readiness. + Initialize registry with auto-discovery and reload validation schema. + + Triggers automatic population of the class registry through the parent + RegistryMixin functionality and ensures the Pydantic validation schema + is updated to include all discovered types for polymorphic validation. - :return: True if registry was populated, False if already populated. - :raises ValueError: If called when registry_auto_discovery is False. + :return: True if registry was populated, False if already populated + :raises ValueError: If called when registry_auto_discovery is disabled """ populated = super().auto_populate_registry() cls.reload_schema() return populated + + @classmethod + def registered_classes(cls) -> tuple[type[BaseModelT], ...]: + """ + Get all registered pydantic classes from the registry. + + Automatically triggers auto-discovery if registry_auto_discovery is enabled + to ensure all available implementations are included. + + :return: Tuple of all registered classes including auto-discovered ones + :raises ValueError: If called before any objects have been registered + """ + if cls.registry_auto_discovery: + cls.auto_populate_registry() + + if cls.registry is None: + raise ValueError( + "ClassRegistryMixin.registered_classes() must be called after " + "registering classes with ClassRegistryMixin.register()." + ) + + return tuple(cls.registry.values()) diff --git a/src/guidellm/utils/registry.py b/src/guidellm/utils/registry.py index 3a93c787..b9e3faf5 100644 --- a/src/guidellm/utils/registry.py +++ b/src/guidellm/utils/registry.py @@ -3,32 +3,34 @@ Provides a flexible object registration system with optional auto-discovery capabilities through decorators and module imports. Enables dynamic discovery -and instantiation of implementations based on configuration parameters. - -Classes: - RegistryMixin: Generic mixin for creating object registries with decorators - and optional auto-discovery capabilities. - -Type Variables: - RegistryObjT: Generic registry object type. +and instantiation of implementations based on configuration parameters, supporting +both manual registration and automatic package-based discovery for extensible +plugin architectures. """ -from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union +from __future__ import annotations + +from typing import Callable, ClassVar, Generic, TypeVar, cast from guidellm.utils.auto_importer import AutoImporterMixin -__all__ = ["RegistryMixin"] +__all__ = ["RegisterT", "RegistryMixin", "RegistryObjT"] -RegistryObjT = TypeVar("RegistryObjT", bound=Any) +RegistryObjT = TypeVar("RegistryObjT") +"""Generic type variable for objects managed by the registry system.""" +RegisterT = TypeVar("RegisterT") +"""Generic type variable for the args and return values within the registry.""" class RegistryMixin(Generic[RegistryObjT], AutoImporterMixin): """ Generic mixin for creating object registries with optional auto-discovery. - Enables classes to maintain separate registries of objects that can be - dynamically discovered and instantiated through decorators and module imports. + Enables classes to maintain separate registries of objects that can be dynamically + discovered and instantiated through decorators and module imports. Supports both + manual registration via decorators and automatic discovery through package scanning + for extensible plugin architectures. Example: :: @@ -54,47 +56,50 @@ class TokenProposal(RegistryMixin): # Automatically imports and registers decorated objects proposals = TokenProposal.registered_objects() + + :cvar registry: Dictionary mapping names to registered objects + :cvar registry_auto_discovery: Enable automatic package-based discovery + :cvar registry_populated: Track whether auto-discovery has completed """ - registry: ClassVar[Optional[dict[str, RegistryObjT]]] = None + registry: ClassVar[dict[str, RegistryObjT] | None] = None registry_auto_discovery: ClassVar[bool] = False registry_populated: ClassVar[bool] = False @classmethod def register( - cls, name: Optional[Union[str, list[str]]] = None - ) -> Callable[[RegistryObjT], RegistryObjT]: + cls, name: str | list[str] | None = None + ) -> Callable[[RegisterT], RegisterT]: """ - Decorator that registers an object with the registry. + Decorator for registering objects with the registry. :param name: Optional name(s) to register the object under. - If None, the object name is used as the registry key. - :return: A decorator function that registers the decorated object. - :raises ValueError: If name is provided but is not a string or list of strings. + If None, uses the object's __name__ attribute + :return: Decorator function that registers the decorated object + :raises ValueError: If name is not a string, list of strings, or None """ - if name is not None and not isinstance(name, (str, list)): - raise ValueError( - "RegistryMixin.register() name must be a string, list of strings, " - f"or None. Got {name}." - ) - return lambda obj: cls.register_decorator(obj, name=name) + def _decorator(obj: RegisterT) -> RegisterT: + cls.register_decorator(obj, name=name) + return obj + + return _decorator @classmethod def register_decorator( - cls, obj: RegistryObjT, name: Optional[Union[str, list[str]]] = None - ) -> RegistryObjT: + cls, obj: RegisterT, name: str | list[str] | None = None + ) -> RegisterT: """ - Direct decorator that registers an object with the registry. + Register an object directly with the registry. - :param obj: The object to register. + :param obj: The object to register :param name: Optional name(s) to register the object under. - If None, the object name is used as the registry key. - :return: The registered object. - :raises ValueError: If the object is already registered or if name is invalid. + If None, uses the object's __name__ attribute + :return: The registered object + :raises ValueError: If the object is already registered or name is invalid """ - if not name: + if name is None: name = obj.__name__ elif not isinstance(name, (str, list)): raise ValueError( @@ -121,20 +126,20 @@ def register_decorator( "registered." ) - cls.registry[register_name.lower()] = obj + cls.registry[register_name] = cast("RegistryObjT", obj) return obj @classmethod def auto_populate_registry(cls) -> bool: """ - Import and register all modules from the specified auto_package. + Import and register all modules from the auto_package. Automatically called by registered_objects when registry_auto_discovery is True - to ensure all available implementations are discovered before returning results. + to ensure all available implementations are discovered. - :return: True if the registry was populated, False if already populated. - :raises ValueError: If called when registry_auto_discovery is False. + :return: True if registry was populated, False if already populated + :raises ValueError: If called when registry_auto_discovery is False """ if not cls.registry_auto_discovery: raise ValueError( @@ -159,8 +164,8 @@ def registered_objects(cls) -> tuple[RegistryObjT, ...]: Automatically triggers auto-discovery if registry_auto_discovery is enabled to ensure all available implementations are included. - :return: Tuple of all registered objects including auto-discovered ones. - :raises ValueError: If called before any objects have been registered. + :return: Tuple of all registered objects including auto-discovered ones + :raises ValueError: If called before any objects have been registered """ if cls.registry_auto_discovery: cls.auto_populate_registry() @@ -177,6 +182,7 @@ def registered_objects(cls) -> tuple[RegistryObjT, ...]: def is_registered(cls, name: str) -> bool: """ Check if an object is registered under the given name. + It matches first by exact name, then by str.lower(). :param name: The name to check for registration. :return: True if the object is registered, False otherwise. @@ -184,12 +190,15 @@ def is_registered(cls, name: str) -> bool: if cls.registry is None: return False - return name.lower() in cls.registry + return name in cls.registry or name.lower() in [ + key.lower() for key in cls.registry + ] @classmethod - def get_registered_object(cls, name: str) -> Optional[RegistryObjT]: + def get_registered_object(cls, name: str) -> RegistryObjT | None: """ - Get a registered object by its name. + Get a registered object by its name. It matches first by exact name, + then by str.lower(). :param name: The name of the registered object. :return: The registered object if found, None otherwise. @@ -197,4 +206,9 @@ def get_registered_object(cls, name: str) -> Optional[RegistryObjT]: if cls.registry is None: return None - return cls.registry.get(name.lower()) + if name in cls.registry: + return cls.registry[name] + + lower_key_map = {key.lower(): key for key in cls.registry} + + return cls.registry.get(lower_key_map.get(name.lower())) diff --git a/src/guidellm/utils/singleton.py b/src/guidellm/utils/singleton.py index 48f039cf..3ec10f79 100644 --- a/src/guidellm/utils/singleton.py +++ b/src/guidellm/utils/singleton.py @@ -3,15 +3,13 @@ Provides singleton mixins for creating classes that maintain a single instance throughout the application lifecycle, with support for both basic and thread-safe -implementations. - -Classes: - SingletonMixin: Basic singleton implementation using class variables. - ThreadSafeSingletonMixin: Thread-safe singleton using locking mechanisms. +implementations. These mixins integrate with the scheduler and other system components +to ensure consistent state management and prevent duplicate resource allocation. """ +from __future__ import annotations + import threading -from typing import ClassVar __all__ = ["SingletonMixin", "ThreadSafeSingletonMixin"] @@ -22,29 +20,49 @@ class SingletonMixin: Implements the singleton pattern using class variables to control instance creation. Subclasses must call super().__init__() for proper initialization - state management. + state management. Suitable for single-threaded environments or when external + synchronization is provided. + + Example: + :: + class ConfigManager(SingletonMixin): + def __init__(self, config_path: str): + super().__init__() + if not self.initialized: + self.config = load_config(config_path) + + manager1 = ConfigManager("config.json") + manager2 = ConfigManager("config.json") + assert manager1 is manager2 """ - singleton_instance: ClassVar["SingletonMixin"] = None - - def __new__(cls, *args, **kwargs): + def __new__(cls, *args, **kwargs): # noqa: ARG004 """ Create or return the singleton instance. - :param args: Positional arguments passed to the constructor. - :param kwargs: Keyword arguments passed to the constructor. - :return: The singleton instance of the class. + :param args: Positional arguments passed to the constructor + :param kwargs: Keyword arguments passed to the constructor + :return: The singleton instance of the class """ - if cls.singleton_instance is None: - cls.singleton_instance = super().__new__(cls, *args, **kwargs) - cls.singleton_instance.initialized = False - return cls.singleton_instance + # Use class-specific attribute name to avoid inheritance issues + attr_name = f"_singleton_instance_{cls.__name__}" + + if not hasattr(cls, attr_name) or getattr(cls, attr_name) is None: + instance = super().__new__(cls) + setattr(cls, attr_name, instance) + instance._singleton_initialized = False + return getattr(cls, attr_name) def __init__(self): """Initialize the singleton instance exactly once.""" - if self.initialized: + if hasattr(self, "_singleton_initialized") and self._singleton_initialized: return - self.initialized = True + self._singleton_initialized = True + + @property + def initialized(self): + """Return True if the singleton has been initialized.""" + return getattr(self, "_singleton_initialized", False) class ThreadSafeSingletonMixin(SingletonMixin): @@ -52,27 +70,61 @@ class ThreadSafeSingletonMixin(SingletonMixin): Thread-safe singleton mixin with locking mechanisms. Extends SingletonMixin with thread safety using locks to prevent race - conditions during instance creation in multi-threaded environments. + conditions during instance creation in multi-threaded environments. Essential + for scheduler components and other shared resources accessed concurrently. + + Example: + :: + class SchedulerResource(ThreadSafeSingletonMixin): + def __init__(self): + super().__init__() + if not self.initialized: + self.resource_pool = initialize_resources() """ - singleton_lock: ClassVar[threading.Lock] = threading.Lock() - - def __new__(cls, *args, **kwargs): + def __new__(cls, *args, **kwargs): # noqa: ARG004 """ Create or return the singleton instance with thread safety. - :param args: Positional arguments passed to the constructor. - :param kwargs: Keyword arguments passed to the constructor. - :return: The singleton instance of the class. + :param args: Positional arguments passed to the constructor + :param kwargs: Keyword arguments passed to the constructor + :return: The singleton instance of the class """ - with cls.singleton_lock: - if cls.singleton_instance is None: - cls.singleton_instance = super().__new__(cls, *args, **kwargs) - cls.singleton_instance.initialized = False - return cls.singleton_instance + # Use class-specific lock and instance names to avoid inheritance issues + lock_attr_name = f"_singleton_lock_{cls.__name__}" + instance_attr_name = f"_singleton_instance_{cls.__name__}" + + with getattr(cls, lock_attr_name): + instance_exists = ( + hasattr(cls, instance_attr_name) + and getattr(cls, instance_attr_name) is not None + ) + if not instance_exists: + instance = super(SingletonMixin, cls).__new__(cls) + setattr(cls, instance_attr_name, instance) + instance._singleton_initialized = False + instance._init_lock = threading.Lock() + return getattr(cls, instance_attr_name) + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + lock_attr_name = f"_singleton_lock_{cls.__name__}" + setattr(cls, lock_attr_name, threading.Lock()) def __init__(self): - """Initialize the singleton instance with thread-local lock.""" - if not self.initialized: - self.thread_lock = threading.Lock() - super().__init__() + """Initialize the singleton instance with thread-safe initialization.""" + with self._init_lock: + if hasattr(self, "_singleton_initialized") and self._singleton_initialized: + return + self._singleton_initialized = True + + @property + def thread_lock(self): + """Return the thread lock for this singleton instance.""" + return getattr(self, "_init_lock", None) + + @classmethod + def get_singleton_lock(cls): + """Get the class-specific singleton creation lock.""" + lock_attr_name = f"_singleton_lock_{cls.__name__}" + return getattr(cls, lock_attr_name, None) diff --git a/src/guidellm/utils/statistics.py b/src/guidellm/utils/statistics.py index defbd93e..c820de9d 100644 --- a/src/guidellm/utils/statistics.py +++ b/src/guidellm/utils/statistics.py @@ -1,7 +1,19 @@ +""" +Statistical analysis utilities for distribution calculations and running metrics. + +Provides comprehensive statistical computation tools for analyzing numerical +distributions, percentiles, and streaming data. Includes specialized support for +request timing analysis, concurrency measurement, and rate calculations. Integrates +with Pydantic for serializable statistical models and supports both weighted and +unweighted distributions with cumulative distribution function (CDF) generation. +""" + +from __future__ import annotations + import math import time as timer from collections import defaultdict -from typing import Any, Literal, Optional +from typing import Any, Literal import numpy as np from pydantic import Field, computed_field @@ -19,7 +31,11 @@ class Percentiles(StandardBaseModel): """ - A pydantic model representing the standard percentiles of a distribution. + Standard percentiles model for statistical distribution analysis. + + Provides complete percentile coverage from 0.1th to 99.9th percentiles for + statistical distribution characterization. Used as a component within + DistributionSummary to provide detailed distribution shape analysis. """ p001: float = Field( @@ -59,8 +75,25 @@ class Percentiles(StandardBaseModel): class DistributionSummary(StandardBaseModel): """ - A pydantic model representing a statistical summary for a given - distribution of numerical values. + Comprehensive statistical summary for numerical value distributions. + + Calculates and stores complete statistical metrics including central tendency, + dispersion, extremes, and percentiles for any numerical distribution. Supports + both weighted and unweighted data with optional cumulative distribution function + generation. Primary statistical analysis tool for request timing, performance + metrics, and benchmark result characterization. + + Example: + :: + # Create from simple values + summary = DistributionSummary.from_values([1.0, 2.0, 3.0, 4.0, 5.0]) + print(f"Mean: {summary.mean}, P95: {summary.percentiles.p95}") + + # Create from request timings for concurrency analysis + requests = [(0.0, 1.0), (0.5, 2.0), (1.0, 2.5)] + concurrency = DistributionSummary.from_request_times( + requests, "concurrency" + ) """ mean: float = Field( @@ -93,7 +126,7 @@ class DistributionSummary(StandardBaseModel): percentiles: Percentiles = Field( description="The percentiles of the distribution.", ) - cumulative_distribution_function: Optional[list[tuple[float, float]]] = Field( + cumulative_distribution_function: list[tuple[float, float]] | None = Field( description="The cumulative distribution function (CDF) of the distribution.", default=None, ) @@ -102,22 +135,19 @@ class DistributionSummary(StandardBaseModel): def from_distribution_function( distribution: list[tuple[float, float]], include_cdf: bool = False, - ) -> "DistributionSummary": - """ - Create a statistical summary for a given distribution of weighted numerical - values or a probability distribution function (PDF). - 1. If the distribution is a PDF, it is expected to be a list of tuples - where each tuple contains (value, probability). The sum of the - probabilities should be 1. If it is not, it will be normalized. - 2. If the distribution is a values distribution function, it is expected - to be a list of tuples where each tuple contains (value, weight). - The weights are normalized to a probability distribution function. - - :param distribution: A list of tuples representing the distribution. - Each tuple contains (value, weight) or (value, probability). - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output DistributionSummary. - :return: An instance of DistributionSummary with calculated values. + ) -> DistributionSummary: + """ + Create statistical summary from weighted distribution or probability function. + + Converts weighted numerical values or probability distribution function (PDF) + into comprehensive statistical summary. Normalizes weights to probabilities + and calculates all statistical metrics including percentiles. + + :param distribution: List of (value, weight) or (value, probability) tuples + representing the distribution + :param include_cdf: Whether to include cumulative distribution function + in the output + :return: DistributionSummary instance with calculated statistical metrics """ values, weights = zip(*distribution) if distribution else ([], []) values = np.array(values) # type: ignore[assignment] @@ -190,20 +220,23 @@ def from_distribution_function( @staticmethod def from_values( values: list[float], - weights: Optional[list[float]] = None, + weights: list[float] | None = None, include_cdf: bool = False, - ) -> "DistributionSummary": + ) -> DistributionSummary: """ - Create a statistical summary for a given distribution of numerical values. - This is a wrapper around from_distribution_function to handle the optional case - of including weights for the values. If weights are not provided, they are - automatically set to 1.0 for each value, so each value is equally weighted. + Create statistical summary from numerical values with optional weights. + + Wrapper around from_distribution_function for simple value lists. If weights + are not provided, all values are equally weighted. Enables statistical + analysis of any numerical dataset. - :param values: A list of numerical values representing the distribution. - :param weights: A list of weights for each value in the distribution. - If not provided, all values are equally weighted. - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output DistributionSummary. + :param values: Numerical values representing the distribution + :param weights: Optional weights for each value. If not provided, all values + are equally weighted + :param include_cdf: Whether to include cumulative distribution function in + the output DistributionSummary + :return: DistributionSummary instance with calculated statistical metrics + :raises ValueError: If values and weights lists have different lengths """ if weights is None: weights = [1.0] * len(values) @@ -224,22 +257,21 @@ def from_request_times( distribution_type: Literal["concurrency", "rate"], include_cdf: bool = False, epsilon: float = 1e-6, - ) -> "DistributionSummary": - """ - Create a statistical summary for a given distribution of request times. - Specifically, this is used to measure concurrency or rate of requests - given an input list containing the start and end time of each request. - This will first convert the request times into a distribution function - and then calculate the statistics with from_distribution_function. - - :param requests: A list of tuples representing the start and end times of - each request. Example: [(start_1, end_1), (start_2, end_2), ...] - :param distribution_type: The type of distribution to calculate. - Either "concurrency" or "rate". - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output DistributionSummary. - :param epsilon: The epsilon value for merging close events. - :return: An instance of DistributionSummary with calculated values. + ) -> DistributionSummary: + """ + Create statistical summary from request timing data. + + Analyzes request start/end times to calculate concurrency or rate + distributions. Converts timing events into statistical metrics for + performance analysis and load characterization. + + :param requests: List of (start_time, end_time) tuples for each request + :param distribution_type: Type of analysis - "concurrency" for simultaneous + requests or "rate" for completion rates + :param include_cdf: Whether to include cumulative distribution function + :param epsilon: Threshold for merging close timing events + :return: DistributionSummary with timing-based statistical metrics + :raises ValueError: If distribution_type is not "concurrency" or "rate" """ if distribution_type == "concurrency": # convert to delta changes based on when requests were running @@ -309,34 +341,28 @@ def from_iterable_request_times( requests: list[tuple[float, float]], first_iter_times: list[float], iter_counts: list[int], - first_iter_counts: Optional[list[int]] = None, + first_iter_counts: list[int] | None = None, include_cdf: bool = False, epsilon: float = 1e-6, - ) -> "DistributionSummary": - """ - Create a statistical summary for a given distribution of request times - for a request with iterable responses between the start and end. - For example, this is used to measure auto regressive requests where - a request is started and at some later point, iterative responses are - received. This will convert the request times and iterable values into - a distribution function and then calculate the statistics with - from_distribution_function. - - :param requests: A list of tuples representing the start and end times of - each request. Example: [(start_1, end_1), (start_2, end_2), ...] - :param first_iter_times: A list of times when the first iteration of - each request was received. Must be the same length as requests. - :param iter_counts: A list of the total number of iterations for each - request that occurred starting at the first iteration and ending - at the request end time. Must be the same length as requests. - :param first_iter_counts: A list of the number of iterations to log - for the first iteration of each request. For example, when calculating - total number of tokens processed, this is set to the prompt tokens number. - If not provided, defaults to 1 for each request. - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output DistributionSummary. - :param epsilon: The epsilon value for merging close events. - :return: An instance of DistributionSummary with calculated values. + ) -> DistributionSummary: + """ + Create statistical summary from iterative request timing data. + + Analyzes autoregressive or streaming requests with multiple iterations + between start and end times. Calculates rate distributions based on + iteration timing patterns for LLM token generation analysis. + + :param requests: List of (start_time, end_time) tuples for each request + :param first_iter_times: Times when first iteration was received for + each request + :param iter_counts: Total iteration counts for each request from first + iteration to end + :param first_iter_counts: Iteration counts for first iteration (defaults + to 1 for each request) + :param include_cdf: Whether to include cumulative distribution function + :param epsilon: Threshold for merging close timing events + :return: DistributionSummary with iteration rate statistical metrics + :raises ValueError: If input lists have mismatched lengths """ if first_iter_counts is None: @@ -415,36 +441,45 @@ class StatusDistributionSummary( ] ): """ - A pydantic model representing a statistical summary for a given - distribution of numerical values grouped by status. - Specifically used to represent the total, successful, incomplete, - and errored values for a benchmark or other statistical summary. + Status-grouped statistical summary for request processing analysis. + + Provides comprehensive statistical analysis grouped by request status (total, + successful, incomplete, errored). Enables performance analysis across different + request outcomes for benchmarking and monitoring applications. Each status + category maintains complete DistributionSummary metrics. + + Example: + :: + status_summary = StatusDistributionSummary.from_values( + value_types=["successful", "error", "successful"], + values=[1.5, 10.0, 2.1] + ) + print(f"Success mean: {status_summary.successful.mean}") + print(f"Error rate: {status_summary.errored.count}") """ @staticmethod def from_values( value_types: list[Literal["successful", "incomplete", "error"]], values: list[float], - weights: Optional[list[float]] = None, + weights: list[float] | None = None, include_cdf: bool = False, - ) -> "StatusDistributionSummary": - """ - Create a statistical summary by status for a given distribution of numerical - values. This is used to measure the distribution of values for different - statuses (e.g., successful, incomplete, error) and calculate the statistics - for each status. Weights are optional to weight the probability distribution - for each value by. If not provided, all values are equally weighted. - - :param value_types: A list of status types for each value in the distribution. - Must be one of 'successful', 'incomplete', or 'error'. - :param values: A list of numerical values representing the distribution. - Must be the same length as value_types. - :param weights: A list of weights for each value in the distribution. - If not provided, all values are equally weighted (set to 1). - Must be the same length as value_types. - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output StatusDistributionSummary. - :return: An instance of StatusDistributionSummary with calculated values. + ) -> StatusDistributionSummary: + """ + Create status-grouped statistical summary from values and status types. + + Groups numerical values by request status and calculates complete + statistical summaries for each category. Enables performance analysis + across different request outcomes. + + :param value_types: Status type for each value ("successful", "incomplete", + or "error") + :param values: Numerical values representing the distribution + :param weights: Optional weights for each value (defaults to equal weighting) + :param include_cdf: Whether to include cumulative distribution functions + :return: StatusDistributionSummary with statistics grouped by status + :raises ValueError: If input lists have mismatched lengths or invalid + status types """ if any( type_ not in {"successful", "incomplete", "error"} for type_ in value_types @@ -529,25 +564,22 @@ def from_request_times( distribution_type: Literal["concurrency", "rate"], include_cdf: bool = False, epsilon: float = 1e-6, - ) -> "StatusDistributionSummary": - """ - Create a statistical summary by status for given distribution of request times. - This is used to measure the distribution of request times for different statuses - (e.g., successful, incomplete, error) for concurrency and rates. - This will call into DistributionSummary.from_request_times to calculate - the statistics for each status. - - :param request_types: List of status types for each request in the distribution. - Must be one of 'successful', 'incomplete', or 'error'. - :param requests: A list of tuples representing the start and end times of - each request. Example: [(start_1, end_1), (start_2, end_2), ...]. - Must be the same length as request_types. - :param distribution_type: The type of distribution to calculate. - Either "concurrency" or "rate". - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output StatusDistributionSummary. - :param epsilon: The epsilon value for merging close events. - :return: An instance of StatusDistributionSummary with calculated values. + ) -> StatusDistributionSummary: + """ + Create status-grouped statistical summary from request timing data. + + Analyzes request timings grouped by status to calculate concurrency or + rate distributions for each outcome category. Enables comparative + performance analysis across successful, incomplete, and errored requests. + + :param request_types: Status type for each request ("successful", + "incomplete", or "error") + :param requests: List of (start_time, end_time) tuples for each request + :param distribution_type: Analysis type - "concurrency" or "rate" + :param include_cdf: Whether to include cumulative distribution functions + :param epsilon: Threshold for merging close timing events + :return: StatusDistributionSummary with timing statistics by status + :raises ValueError: If input lists have mismatched lengths or invalid types """ if distribution_type not in {"concurrency", "rate"}: raise ValueError( @@ -639,38 +671,31 @@ def from_iterable_request_times( request_types: list[Literal["successful", "incomplete", "error"]], requests: list[tuple[float, float]], first_iter_times: list[float], - iter_counts: Optional[list[int]] = None, - first_iter_counts: Optional[list[int]] = None, + iter_counts: list[int] | None = None, + first_iter_counts: list[int] | None = None, include_cdf: bool = False, epsilon: float = 1e-6, - ) -> "StatusDistributionSummary": - """ - Create a statistical summary by status for given distribution of request times - for a request with iterable responses between the start and end. - For example, this is used to measure auto regressive requests where - a request is started and at some later point, iterative responses are - received. This will call into DistributionSummary.from_iterable_request_times - to calculate the statistics for each status. - - :param request_types: List of status types for each request in the distribution. - Must be one of 'successful', 'incomplete', or 'error'. - :param requests: A list of tuples representing the start and end times of - each request. Example: [(start_1, end_1), (start_2, end_2), ...]. - Must be the same length as request_types. - :param first_iter_times: A list of times when the first iteration of - each request was received. Must be the same length as requests. - :param iter_counts: A list of the total number of iterations for each - request that occurred starting at the first iteration and ending - at the request end time. Must be the same length as requests. - If not provided, defaults to 1 for each request. - :param first_iter_counts: A list of the number of iterations to log - for the first iteration of each request. For example, when calculating - total number of tokens processed, this is set to the prompt tokens number. - If not provided, defaults to 1 for each request. - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output StatusDistributionSummary. - :param epsilon: The epsilon value for merging close events. - :return: An instance of StatusDistributionSummary with calculated values. + ) -> StatusDistributionSummary: + """ + Create status-grouped statistical summary from iterative request timing data. + + Analyzes autoregressive request timings grouped by status to calculate + iteration rate distributions for each outcome category. Enables comparative + analysis of token generation or streaming response performance across + different request statuses. + + :param request_types: Status type for each request ("successful", + "incomplete", or "error") + :param requests: List of (start_time, end_time) tuples for each request + :param first_iter_times: Times when first iteration was received for + each request + :param iter_counts: Total iteration counts for each request (defaults to 1) + :param first_iter_counts: Iteration counts for first iteration (defaults + to 1) + :param include_cdf: Whether to include cumulative distribution functions + :param epsilon: Threshold for merging close timing events + :return: StatusDistributionSummary with iteration statistics by status + :raises ValueError: If input lists have mismatched lengths or invalid types """ if any( type_ not in {"successful", "incomplete", "error"} @@ -812,13 +837,19 @@ def from_iterable_request_times( class RunningStats(StandardBaseModel): """ - Create a running statistics object to track the mean, rate, and other - statistics of a stream of values. - 1. The start time is set to the time the object is created. - 2. The count is set to 0. - 3. The total is set to 0. - 4. The last value is set to 0. - 5. The mean is calculated as the total / count. + Real-time statistics tracking for streaming numerical data. + + Maintains mean, rate, and cumulative statistics for continuous data streams + without storing individual values. Optimized for memory efficiency in + long-running monitoring applications. Supports arithmetic operators for + convenient value addition and provides computed properties for derived metrics. + + Example: + :: + stats = RunningStats() + stats += 10.5 # Add value using operator + stats.update(20.0, count=3) # Add value with custom count + print(f"Mean: {stats.mean}, Rate: {stats.rate}") """ start_time: float = Field( @@ -866,10 +897,11 @@ def rate(self) -> float: def __add__(self, value: Any) -> float: """ - Enable the use of the + operator to add a value to the running statistics. + Add value using + operator and return current mean. - :param value: The value to add to the running statistics. - :return: The mean of the running statistics. + :param value: Numerical value to add to the running statistics + :return: Updated mean after adding the value + :raises ValueError: If value is not numeric (int or float) """ if not isinstance(value, (int, float)): raise ValueError( @@ -880,12 +912,13 @@ def __add__(self, value: Any) -> float: return self.mean - def __iadd__(self, value: Any) -> "RunningStats": + def __iadd__(self, value: Any) -> RunningStats: """ - Enable the use of the += operator to add a value to the running statistics. + Add value using += operator and return updated instance. - :param value: The value to add to the running statistics. - :return: The running statistics object. + :param value: Numerical value to add to the running statistics + :return: Self reference for method chaining + :raises ValueError: If value is not numeric (int or float) """ if not isinstance(value, (int, float)): raise ValueError( @@ -898,11 +931,10 @@ def __iadd__(self, value: Any) -> "RunningStats": def update(self, value: float, count: int = 1) -> None: """ - Update the running statistics with a new value. + Update running statistics with new value and count. - :param value: The new value to add to the running statistics. - :param count: The number of times to 'count' for the value. - If not provided, defaults to 1. + :param value: Numerical value to add to the running statistics + :param count: Number of occurrences to count for this value (defaults to 1) """ self.count += count self.total += value @@ -911,11 +943,17 @@ def update(self, value: float, count: int = 1) -> None: class TimeRunningStats(RunningStats): """ - Create a running statistics object to track the mean, rate, and other - statistics of a stream of time values. This is used to track time values - in milliseconds and seconds. + Specialized running statistics for time-based measurements. + + Extends RunningStats with time-specific computed properties for millisecond + conversions. Designed for tracking latency, duration, and timing metrics in + performance monitoring applications. - Adds time specific computed_fields such as measurements in milliseconds and seconds. + Example: + :: + time_stats = TimeRunningStats() + time_stats += 0.125 # Add 125ms in seconds + print(f"Mean: {time_stats.mean_ms}ms, Total: {time_stats.total_ms}ms") """ @computed_field # type: ignore[misc] diff --git a/src/guidellm/utils/synchronous.py b/src/guidellm/utils/synchronous.py new file mode 100644 index 00000000..aeb7d800 --- /dev/null +++ b/src/guidellm/utils/synchronous.py @@ -0,0 +1,159 @@ +""" +Async utilities for waiting on synchronization objects. + +This module provides async-compatible wrappers for threading and multiprocessing +synchronization primitives (Events and Barriers). These utilities enable async code +to wait for synchronization objects without blocking the event loop, essential for +coordinating between async and sync code or between processes in the guidellm system. +""" + +from __future__ import annotations + +import asyncio +import contextlib +from multiprocessing.synchronize import Barrier as ProcessingBarrier +from multiprocessing.synchronize import Event as ProcessingEvent +from threading import Barrier as ThreadingBarrier +from threading import Event as ThreadingEvent +from typing import Annotated, Union + +from typing_extensions import TypeAlias + +__all__ = [ + "SyncObjectTypesAlias", + "wait_for_sync_barrier", + "wait_for_sync_event", + "wait_for_sync_objects", +] + + +SyncObjectTypesAlias: TypeAlias = Annotated[ + Union[ThreadingEvent, ProcessingEvent, ThreadingBarrier, ProcessingBarrier], + "Type alias for threading and multiprocessing synchronization object types", +] + + +async def wait_for_sync_event( + event: ThreadingEvent | ProcessingEvent, + poll_interval: float, +) -> None: + """ + Asynchronously wait for a threading or multiprocessing Event to be set. + + This function polls the event at regular intervals without blocking the async + event loop, allowing other async tasks to continue executing while waiting. + + :param event: The Event object to wait for (threading or multiprocessing) + :param poll_interval: Time in seconds between polling checks + :raises asyncio.CancelledError: If the async task is cancelled + """ + stop = ThreadingEvent() + + def _watch(): + try: + while not stop.is_set(): + if event.wait(timeout=poll_interval): + return + except Exception as err: # noqa: BLE001 + if stop.is_set(): + return # Ignore error if we should have stopped + raise err + + try: + await asyncio.to_thread(_watch) + except asyncio.CancelledError: + stop.set() + raise + + +async def wait_for_sync_barrier( + barrier: ThreadingBarrier | ProcessingBarrier, + poll_interval: float, +) -> None: + """ + Asynchronously wait for a threading or multiprocessing Barrier to be reached. + + This function polls the barrier at regular intervals without blocking the async + event loop, allowing other async tasks to continue executing while waiting. + + :param barrier: The Barrier object to wait for (threading or multiprocessing) + :param poll_interval: Time in seconds between polling checks + :raises asyncio.CancelledError: If the async task is cancelled + """ + stop = ThreadingEvent() + barrier_broken = ThreadingEvent() + + def _wait_indefinite(): + try: + # wait forever, count on barrier broken event to exit + barrier.wait() + barrier_broken.set() + except Exception as err: + if stop.is_set(): + return # Ignore error if we should have stopped + raise err + + def _watch(): + while not barrier_broken.is_set(): + if stop.is_set(): + with contextlib.suppress(Exception): + if not barrier.broken: + barrier.abort() + break + + try: + await asyncio.gather( + asyncio.to_thread(_wait_indefinite), + asyncio.to_thread(_watch), + ) + except asyncio.CancelledError: + stop.set() + raise + + +async def wait_for_sync_objects( + objects: SyncObjectTypesAlias + | list[SyncObjectTypesAlias] + | dict[str, SyncObjectTypesAlias], + poll_interval: float = 0.1, +) -> int | str: + """ + Asynchronously wait for the first synchronization object to complete. + + This function waits for the first Event to be set or Barrier to be reached + from a collection of synchronization objects. It returns immediately when + any object completes and cancels waiting on the remaining objects. + + :param objects: Single sync object, list of objects, or dict mapping names + to objects + :param poll_interval: Time in seconds between polling checks for each object + :return: Index (for list/single) or key name (for dict) of the first + completed object + :raises asyncio.CancelledError: If the async task is cancelled + """ + if isinstance(objects, dict): + keys = list(objects.keys()) + objects = list(objects.values()) + elif isinstance(objects, list): + keys = list(range(len(objects))) + else: + keys = [0] + objects = [objects] + + tasks = [ + asyncio.create_task( + wait_for_sync_barrier(obj, poll_interval) + if isinstance(obj, (ThreadingBarrier, ProcessingBarrier)) + else wait_for_sync_event(obj, poll_interval) + ) + for obj in objects + ] + + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + # Cancel the remaining pending tasks + for pend in pending: + pend.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + return keys[tasks.index(list(done)[0])] diff --git a/src/guidellm/utils/text.py b/src/guidellm/utils/text.py index d14da3eb..fbbc6d91 100644 --- a/src/guidellm/utils/text.py +++ b/src/guidellm/utils/text.py @@ -1,19 +1,32 @@ +""" +Text processing utilities for content manipulation and formatting operations. + +Provides comprehensive text processing capabilities including cleaning, filtering, +splitting, loading from various sources, and formatting utilities. Supports loading +text from URLs, compressed files, package resources, and local files with automatic +encoding detection. Includes specialized formatting for display values and text +wrapping operations for consistent presentation across the system. +""" + +from __future__ import annotations + import gzip import re import textwrap from importlib.resources import as_file, files # type: ignore[attr-defined] from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import ftfy import httpx from loguru import logger from guidellm import data as package_data -from guidellm.config import settings +from guidellm.settings import settings from guidellm.utils.console import Colors __all__ = [ + "MAX_PATH_LENGTH", "EndlessTextCreator", "clean_text", "filter_text", @@ -24,17 +37,32 @@ "split_text_list_by_length", ] -MAX_PATH_LENGTH = 4096 +MAX_PATH_LENGTH: int = 4096 def format_value_display( value: float, label: str, units: str = "", - total_characters: Optional[int] = None, - digits_places: Optional[int] = None, - decimal_places: Optional[int] = None, + total_characters: int | None = None, + digits_places: int | None = None, + decimal_places: int | None = None, ) -> str: + """ + Format a numeric value with units and label for consistent display output. + + Creates standardized display strings for metrics and measurements with + configurable precision, width, and color formatting. Supports both + fixed-width and variable-width output for tabular displays. + + :param value: Numeric value to format and display + :param label: Descriptive label for the value + :param units: Units string to append after the value + :param total_characters: Total width for right-aligned output formatting + :param digits_places: Total number of digits for numeric formatting + :param decimal_places: Number of decimal places for numeric precision + :return: Formatted string with value, units, and colored label + """ if decimal_places is None and digits_places is None: formatted_number = f"{value}:.0f" elif digits_places is None: @@ -57,19 +85,24 @@ def format_value_display( def split_text_list_by_length( text_list: list[Any], - max_characters: Union[int, list[int]], + max_characters: int | list[int], pad_horizontal: bool = True, pad_vertical: bool = True, ) -> list[list[str]]: """ - Split a list of strings into a list of strings, - each with a maximum length of max_characters - - :param text_list: the list of strings to split - :param max_characters: the maximum length of each string - :param pad_horizontal: whether to pad the strings horizontally, defaults to True - :param pad_vertical: whether to pad the strings vertically, defaults to True - :return: a list of strings + Split text strings into wrapped lines with specified maximum character limits. + + Processes each string in the input list by wrapping text to fit within character + limits, with optional padding for consistent formatting in tabular displays. + Supports different character limits per string and uniform padding across results. + + :param text_list: List of strings to process and wrap + :param max_characters: Maximum characters per line, either single value or + per-string limits + :param pad_horizontal: Right-align lines within their character limits + :param pad_vertical: Pad shorter results to match the longest wrapped result + :return: List of wrapped line lists, one per input string + :raises ValueError: If max_characters list length doesn't match text_list length """ if not isinstance(max_characters, list): max_characters = [max_characters] * len(text_list) @@ -105,16 +138,21 @@ def split_text_list_by_length( def filter_text( text: str, - filter_start: Optional[Union[str, int]] = None, - filter_end: Optional[Union[str, int]] = None, + filter_start: str | int | None = None, + filter_end: str | int | None = None, ) -> str: """ - Filter text by start and end strings or indices + Extract text substring using start and end markers or indices. + + Filters text content by locating string markers or using numeric indices + to extract specific portions. Supports flexible filtering for content + extraction and preprocessing operations. - :param text: the text to filter - :param filter_start: the start string or index to filter from - :param filter_end: the end string or index to filter to - :return: the filtered text + :param text: Source text to filter and extract from + :param filter_start: Starting marker string or index position + :param filter_end: Ending marker string or index position + :return: Filtered text substring between specified boundaries + :raises ValueError: If filter indices are invalid or markers not found """ filter_start_index = -1 filter_end_index = -1 @@ -142,10 +180,29 @@ def filter_text( def clean_text(text: str) -> str: + """ + Normalize text by fixing encoding issues and standardizing whitespace. + + Applies Unicode normalization and whitespace standardization for consistent + text processing. Removes excessive whitespace and fixes common encoding problems. + + :param text: Raw text string to clean and normalize + :return: Cleaned text with normalized encoding and whitespace + """ return re.sub(r"\s+", " ", ftfy.fix_text(text)).strip() def split_text(text: str, split_punctuation: bool = False) -> list[str]: + """ + Split text into tokens with optional punctuation separation. + + Tokenizes text into words and optionally separates punctuation marks + for detailed text analysis and processing operations. + + :param text: Text string to tokenize and split + :param split_punctuation: Separate punctuation marks as individual tokens + :return: List of text tokens + """ text = clean_text(text) if split_punctuation: @@ -154,16 +211,20 @@ def split_text(text: str, split_punctuation: bool = False) -> list[str]: return text.split() -def load_text(data: Union[str, Path], encoding: Optional[str] = None) -> str: +def load_text(data: str | Path, encoding: str | None = None) -> str: """ - Load an HTML file from a path or URL - - :param data: the path or URL to load the HTML file from - :type data: Union[str, Path] - :param encoding: the encoding to use when reading the file - :type encoding: str - :return: the HTML content - :rtype: str + Load text content from various sources including URLs, files, and package data. + + Supports loading from HTTP/FTP URLs, local files, compressed archives, package + resources, and raw text strings. Automatically detects source type and applies + appropriate loading strategy with encoding support. + + :param data: Source location or raw text - URL, file path, package resource + identifier, or text content + :param encoding: Character encoding for file reading operations + :return: Loaded text content as string + :raises FileNotFoundError: If local file path does not exist + :raises httpx.HTTPStatusError: If URL request fails """ logger.debug("Loading text: {}", data) @@ -209,29 +270,62 @@ def load_text(data: Union[str, Path], encoding: Optional[str] = None) -> str: def is_puncutation(text: str) -> bool: """ - Check if the text is a punctuation + Check if a single character is a punctuation mark. - :param text: the text to check - :type text: str - :return: True if the text is a punctuation, False otherwise - :rtype: bool + Identifies punctuation characters by excluding alphanumeric characters + and whitespace from single-character strings. + + :param text: Single character string to test + :return: True if the character is punctuation, False otherwise """ return len(text) == 1 and not text.isalnum() and not text.isspace() class EndlessTextCreator: + """ + Infinite text generator for load testing and content creation operations. + + Provides deterministic text generation by cycling through preprocessed word + tokens from source content. Supports filtering and punctuation handling for + realistic text patterns in benchmarking scenarios. + + Example: + :: + creator = EndlessTextCreator("path/to/source.txt") + generated = creator.create_text(start=0, length=100) + more_text = creator.create_text(start=50, length=200) + """ + def __init__( self, - data: Union[str, Path], - filter_start: Optional[Union[str, int]] = None, - filter_end: Optional[Union[str, int]] = None, + data: str | Path, + filter_start: str | int | None = None, + filter_end: str | int | None = None, ): + """ + Initialize text creator with source content and optional filtering. + + :param data: Source text location or content - file path, URL, or raw text + :param filter_start: Starting marker or index for content filtering + :param filter_end: Ending marker or index for content filtering + """ self.data = data self.text = load_text(data) self.filtered_text = filter_text(self.text, filter_start, filter_end) self.words = split_text(self.filtered_text, split_punctuation=True) def create_text(self, start: int, length: int) -> str: + """ + Generate text by cycling through word tokens from the specified position. + + Creates deterministic text sequences by selecting consecutive tokens from + the preprocessed word list, wrapping around when reaching the end. + Maintains proper spacing and punctuation formatting. + + :param start: Starting position in the token sequence + :param length: Number of tokens to include in generated text + :return: Generated text string with proper spacing and punctuation + """ text = "" for counter in range(length): @@ -244,3 +338,9 @@ def create_text(self, start: int, length: int) -> str: text += add_word return text + + +from faker import Faker + +fake = Faker() +fake.text() diff --git a/src/guidellm/utils/threading.py b/src/guidellm/utils/threading.py deleted file mode 100644 index 37dbea0a..00000000 --- a/src/guidellm/utils/threading.py +++ /dev/null @@ -1,149 +0,0 @@ -import asyncio -import contextlib -import functools -import time -from collections.abc import Generator, Iterable, Iterator -from multiprocessing.synchronize import Barrier as ProcessingBarrier -from multiprocessing.synchronize import Event as ProcessingEvent -from threading import Barrier as ThreadingBarrier -from threading import BrokenBarrierError, Thread -from threading import Event as ThreadingEvent -from typing import Any, Callable, Literal, Optional, Union - -__all__ = ["synchronous_to_exitable_async"] - - -def _start_barrier_monitor_thread( - barrier: Optional[Union[ThreadingBarrier, ProcessingBarrier]], - barrier_event: ThreadingEvent, -): - if barrier is None: - return - - def _watch() -> None: - try: - barrier.wait() - except BrokenBarrierError: - pass - finally: - barrier_event.set() - - Thread(target=_watch, daemon=True).start() - - -def _check_event_set( - events: list[tuple[str, Union[ThreadingEvent, ProcessingEvent]]], -) -> Optional[str]: - for name, event in events: - if event.is_set(): - return name - return None - - -def _run_worker( - events_list: list[tuple[str, Union[ThreadingEvent, ProcessingEvent]]], - exit_barrier: Optional[Union[ThreadingBarrier, ProcessingBarrier]], - synchronous: Optional[Union[Iterator, Iterable, Generator, Callable]], - poll_interval: float, - args: tuple, - kwargs: dict, -) -> tuple[str, Any]: - finish_reason: str = "completed" - last_val: Any = None - - try: - barrier_event = list(filter(lambda x: x[0] == "barrier", events_list))[0][1] - _start_barrier_monitor_thread(exit_barrier, barrier_event) - - if isinstance(synchronous, Iterable): - synchronous = iter(synchronous) - - while True: - if (check_event := _check_event_set(events_list)) is not None: - finish_reason = check_event - break - - if isinstance(synchronous, (Iterator, Generator)): - try: - last_val = next(synchronous) - except StopIteration: - break - elif isinstance(synchronous, Callable): - last_val = synchronous(*args, **kwargs) - break - - time.sleep(poll_interval) - - if ( - finish_reason == "completed" - and (check_event := _check_event_set(events_list)) is not None - ): - # Final check for any exit signals - finish_reason = check_event - except Exception as err: # noqa: BLE001 - finish_reason = "internal_error" - last_val = err - finally: - if exit_barrier is not None: - with contextlib.suppress(BrokenBarrierError, RuntimeError): - exit_barrier.abort() - - return finish_reason, last_val - - -async def synchronous_to_exitable_async( - synchronous: Optional[Union[Iterator, Iterable, Generator, Callable]], - exit_events: Optional[dict[str, Union[ThreadingEvent, ProcessingEvent]]] = None, - exit_barrier: Optional[Union[ThreadingBarrier, ProcessingBarrier]] = None, - poll_interval: float = 0.1, - *args, - **kwargs, -) -> tuple[Union[Literal["completed", "canceled", "barrier"], str], Any]: - """ - Run a sync callable or iterable inside an async context with exit controls. - Supports cooperative termination via exit events and an optional barrier. - - :param synchronous: Callable (invoked once) or iterable/iterator (next()). If - None, only watch exit events (poll mode). - :param exit_events: Optional mapping of name -> Event objects to signal exit. - 'canceled', 'barrier', and 'internal_error' are reserved keywords. - :param exit_barrier: Optional barrier to coordinate shutdown; when it trips or is - aborted, the worker exits with reason "barrier". On exit, this function aborts - the barrier to release any waiters. - :param poll_interval: Sleep duration (seconds) used only in poll mode. - :param args: Positional arguments passed to the callable (if provided). - :param kwargs: Keyword arguments passed to the callable (if provided). - :return: (exit_reason, last_item). exit_reason is "completed", "canceled", - "barrier", or a key from exit_events. last_item is the last yielded value for - an iterator or the return value for a callable. - :raises asyncio.CancelledError: If the async task is canceled. - """ - events_map = exit_events or {} - - canceled_event = ThreadingEvent() - barrier_event = ThreadingEvent() - events_list = [ - ("canceled", canceled_event), - ("barrier", barrier_event), - *list(events_map.items()), - ] - worker = functools.partial( - _run_worker, - events_list, - exit_barrier, - synchronous, - poll_interval, - args, - kwargs, - ) - - try: - return await asyncio.to_thread(worker) - except asyncio.CancelledError: - if exit_barrier is not None: - with contextlib.suppress(BrokenBarrierError, RuntimeError): - exit_barrier.abort() - canceled_event.set() - raise - except Exception as err: # noqa: BLE001 - print(f"******EXCEPTION in synchronous_to_exitable_async: {err}") diff --git a/src/guidellm/utils/typing.py b/src/guidellm/utils/typing.py new file mode 100644 index 00000000..59358221 --- /dev/null +++ b/src/guidellm/utils/typing.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated, Literal, Union, get_args, get_origin + +if TYPE_CHECKING: + from collections.abc import Iterator + +# Backwards compatibility for Python <3.10 +try: + from types import UnionType # type: ignore[attr-defined] +except ImportError: + UnionType = Union + +# Backwards compatibility for Python <3.12 +try: + from typing import TypeAliasType # type: ignore[attr-defined] +except ImportError: + from typing_extensions import TypeAliasType + + +__all__ = ["get_literal_vals"] + + +def get_literal_vals(alias) -> frozenset[str]: + """Extract all literal values from a (possibly nested) type alias.""" + + def resolve(alias) -> Iterator[str]: + origin = get_origin(alias) + + # Base case: Literal types + if origin is Literal: + for literal_val in get_args(alias): + yield str(literal_val) + # Unwrap Annotated type + elif origin is Annotated: + yield from resolve(get_args(alias)[0]) + # Unwrap TypeAliasTypes + elif isinstance(alias, TypeAliasType): + yield from resolve(alias.__value__) + # Iterate over unions + elif origin in (Union, UnionType): + for arg in get_args(alias): + yield from resolve(arg) + # Fallback + else: + yield str(alias) + + return frozenset(resolve(alias)) diff --git a/tests/e2e/test_common_use_cases.py b/tests/e2e/test_common_use_cases.py new file mode 100644 index 00000000..ab9b9430 --- /dev/null +++ b/tests/e2e/test_common_use_cases.py @@ -0,0 +1,547 @@ +# Property-based E2E tests following Mark Kurtz's specifications +# +# Test Categories: +# - SMOKE: 5 curated use cases (20s each, couple minutes total) +# - SANITY: Property-based cartesian product (20s each, couple hours total) +# - REGRESSION: Curated long-running tests (few minutes each, couple hours total) +# +# Uses hypothesis for systematic test case generation instead of manual configuration + +from pathlib import Path +from typing import Optional + +import pytest +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st +from hypothesis.strategies import composite + +from tests.e2e.utils import ( + GuidellmClient, + assert_no_python_exceptions, + assert_successful_requests_fields, + cleanup_report_file, + load_benchmark_report, +) +from tests.e2e.vllm_sim_server import VllmSimServer + +# Backend performance profiles as specified by Mark Kurtz +BACKEND_PROFILES = { + "fast": {"ttft": 100, "itl": 10}, # TTFT <100ms, ITL <10ms + "medium": {"ttft": 500, "itl": 25}, # TTFT <500ms, ITL <25ms + "slow": {"ttft": 2000, "itl": 100}, # TTFT <2s, ITL <100ms +} + + +# Server fixture factory +def create_server_fixture(profile_name: str, port: int = 8000): + """Create session-scoped server fixture for a backend profile.""" + profile = BACKEND_PROFILES[profile_name] + + @pytest.fixture + def server(): + server = VllmSimServer( + mode="random", + time_to_first_token=profile["ttft"], + inter_token_latency=profile["itl"], + ) + with server: + yield server + + return server + + +# Create server fixtures +fast_server = create_server_fixture("fast") +medium_server = create_server_fixture("medium") +slow_server = create_server_fixture("slow") + +SERVER_FIXTURES = { + "fast": fast_server, + "medium": medium_server, + "slow": slow_server, +} + + +def run_benchmark_test( + server, + strategy: str, + rate: Optional[int], + data_config: str, + max_seconds: Optional[int] = None, + max_requests: Optional[int] = None, + warmup_percent: Optional[int] = None, + cooldown_percent: Optional[int] = None, + timeout_multiplier: float = 1.5, +): + """Simplified benchmark test runner.""" + + # Generate unique report path + test_id = f"{strategy}_{rate}_{max_seconds}s_{max_requests}r" + report_path = Path(f"tests/e2e/property_{test_id}.json") + cleanup_report_file(report_path) + + # Create client + client = GuidellmClient(target=server.get_url(), output_path=report_path) + + # Build command arguments + additional_args = "" + if warmup_percent: + additional_args += f" --warmup-percent {warmup_percent}" + if cooldown_percent: + additional_args += f" --cooldown-percent {cooldown_percent}" + + # Calculate timeout with more generous buffer for high-latency servers + timeout_base = max_seconds or 30 + # Increased buffer from 30s to 60s for high-latency servers + timeout = int((timeout_base + 60) * timeout_multiplier) + + if strategy == "sweep": + timeout = timeout * 10 + + # Start benchmark + benchmark_args = { + "rate_type": strategy, + "rate": rate, + "data": data_config, + "additional_args": additional_args, + } + + if max_seconds: + benchmark_args["max_seconds"] = max_seconds + if max_requests: + benchmark_args["max_requests"] = max_requests + + client.start_benchmark(**benchmark_args) + client.wait_for_completion(timeout=timeout) + + # Validate results - allow application bugs to fail tests + assert_no_python_exceptions(client.stderr) + + report = load_benchmark_report(report_path) + benchmark = report["benchmarks"][0] + + # Basic validation + assert "requests" in benchmark + assert "successful" in benchmark["requests"] + assert len(benchmark["requests"]["successful"]) > 0 + + # Cleanup + cleanup_report_file(report_path) + + return benchmark + + +# ============================================================================= +# SMOKE TESTS +# ============================================================================= + + +@pytest.mark.smoke +@pytest.mark.timeout(90) +def test_interactive_chat_use_case(fast_server): + """ + Interactive chat style use case: + - data: emulated 512x512 + - backend: fast (TTFT <100ms, ITL <10ms) + - strategy: constant (changed from sweep due to baseline issues) + - constraints: max_seconds=60, max_requests=1000 + - aggregation: warmup=10%, cooldown=10% + """ + + benchmark = run_benchmark_test( + server=fast_server, + strategy="constant", # Changed from sweep to avoid baseline issues + rate=5, # constant rate (reduced for 512x512 tokens) + data_config="prompt_tokens=512,output_tokens=512", + max_seconds=15, # Normal timeout for constant strategy + max_requests=25, # Reduced for quick smoke test + # Removed warmup/cooldown to avoid interaction issues with 512x512 tokens + ) + + # Validate it's a proper interactive chat benchmark + assert len(benchmark["requests"]["successful"]) > 0 + + +@pytest.mark.smoke +@pytest.mark.timeout(60) +def test_rag_throughput_use_case(fast_server): + """ + RAG style use case: + - data: emulated 2048x128 + - backend: fast (changed from medium due to server simulator issues) + - strategy: throughput + - constraints: max_seconds=60, max_requests=500 + - aggregation: None + """ + + benchmark = run_benchmark_test( + server=fast_server, + strategy="throughput", + rate=10, # Normal rate for fast server + data_config="prompt_tokens=512,output_tokens=128", + max_seconds=15, # Normal timeout for fast server + max_requests=30, # Normal count for smoke test + ) + + assert len(benchmark["requests"]["successful"]) > 0 + + +@pytest.mark.smoke +@pytest.mark.timeout(60) +def test_rag_constant_rate_use_case(fast_server): + """ + RAG style with constant rate: + - data: emulated 2048x128 + - backend: fast (changed from medium due to server simulator issues) + - strategy: constant at 10 RPS + - constraints: max_seconds=60, max_requests=500 + """ + + benchmark = run_benchmark_test( + server=fast_server, + strategy="constant", + rate=5, # Normal rate for fast server + data_config="prompt_tokens=512,output_tokens=128", + max_seconds=15, # Normal timeout for fast server + max_requests=30, # Normal count for smoke test + ) + + assert len(benchmark["requests"]["successful"]) > 0 + + +@pytest.mark.smoke +@pytest.mark.timeout(60) +def test_code_generation_use_case(fast_server): + """ + Code generation style use case: + - data: emulated 512x2048 + - backend: fast (changed from medium due to server simulator issues) + - strategy: concurrent at 50 + - constraints: max_seconds=120 + """ + + benchmark = run_benchmark_test( + server=fast_server, + strategy="concurrent", + rate=5, # Normal rate for fast server + data_config="prompt_tokens=512,output_tokens=512", + max_seconds=15, # Normal timeout for fast server + max_requests=10, # Small count for smoke test + ) + + assert len(benchmark["requests"]["successful"]) > 0 + + +@pytest.mark.smoke +@pytest.mark.timeout(60) +def test_fast_perf_stress_use_case(fast_server): + """ + Fast performance stress test: + - data: emulated 64x64 + - backend: fast (TTFT <50ms, ITL <5ms) - using fast server as closest + - strategy: constant at 50 + - aggregation: warmup=5% + """ + + benchmark = run_benchmark_test( + server=fast_server, + strategy="constant", + rate=5, # Reduced rate for quick test + data_config="prompt_tokens=64,output_tokens=64", + max_seconds=10, # Reduced for quick smoke test + max_requests=25, # Reduced for quick smoke test + warmup_percent=5, + ) + + assert len(benchmark["requests"]["successful"]) > 0 + + +@pytest.mark.smoke +@pytest.mark.timeout(60) +def test_synchronous_fast_use_case(fast_server): + """ + Synchronous strategy test with fast backend: + - data: emulated 512x512 (interactive chat size) + - backend: fast (TTFT <100ms, ITL <10ms) + - strategy: synchronous + - constraints: max_seconds=15, max_requests=30 + """ + + benchmark = run_benchmark_test( + server=fast_server, + strategy="synchronous", + rate=None, # synchronous doesn't use rate + data_config="prompt_tokens=512,output_tokens=512", + max_seconds=15, # Short for smoke test + max_requests=30, # Small count for smoke test + ) + + assert len(benchmark["requests"]["successful"]) > 0 + + +@pytest.mark.smoke +@pytest.mark.timeout(60) +def test_synchronous_alternative_use_case(fast_server): + """ + Synchronous strategy test with alternative data: + - data: emulated 512x256 (different from other fast server tests) + - backend: fast (changed from medium due to server simulator issues) + - strategy: synchronous + - constraints: max_seconds=15, max_requests=20 + """ + + benchmark = run_benchmark_test( + server=fast_server, + strategy="synchronous", + rate=None, # synchronous doesn't use rate + data_config="prompt_tokens=512,output_tokens=256", + max_seconds=15, # Normal timeout for fast server + max_requests=10, # Small count for smoke test + ) + + assert len(benchmark["requests"]["successful"]) > 0 + + +@pytest.mark.smoke +@pytest.mark.timeout(90) +def test_sweep_smoke_use_case(fast_server): + """ + Sweep strategy smoke test: + - data: emulated 64x64 (small tokens for fast sweep) + - backend: fast (TTFT <100ms, ITL <10ms) + - strategy: sweep (runs 10 sub-benchmarks) + - constraints: max_seconds=8, max_requests=20 (per sub-benchmark) + - Higher timeout due to 10 sub-benchmarks + """ + + benchmark = run_benchmark_test( + server=fast_server, + strategy="sweep", + rate=10, # Sweep max rate + data_config="prompt_tokens=64,output_tokens=64", + max_seconds=8, # Short per sub-benchmark (8s * 10 = ~80s total) + max_requests=20, # Small count per sub-benchmark + timeout_multiplier=2.0, # Higher multiplier for sweep overhead + ) + + assert len(benchmark["requests"]["successful"]) > 0 + + +# ============================================================================= +# SANITY TESTS - Property-based cartesian product +# ============================================================================= + + +# Hypothesis strategies for test case generation +@composite +def backend_strategy(draw): + """Generate backend profile configurations.""" + return draw(st.sampled_from(["fast", "medium", "slow"])) + + +@composite +def data_strategy(draw): + """Generate data configurations based on Mark's input sizes.""" + sizes = [ + (64, 64), # Fast perf + (512, 128), # Short prompt, short output + (512, 512), # Interactive chat + (512, 2048), # Code generation + (2048, 128), # RAG + (2048, 2048), # Offline throughput + ] + prompt_tokens, output_tokens = draw(st.sampled_from(sizes)) + return f"prompt_tokens={prompt_tokens},output_tokens={output_tokens}" + + +@composite +def strategy_rate_strategy(draw): + """Generate strategy and rate combinations.""" + strategy = draw( + st.sampled_from( + ["synchronous", "sweep", "constant", "concurrent", "throughput"] + ) + ) + + if strategy == "synchronous": + rate = None # synchronous doesn't use rate + elif strategy == "sweep": + rate = draw(st.integers(min_value=5, max_value=20)) + elif strategy in ["constant", "concurrent"]: + rate = draw(st.sampled_from([1, 5, 10, 25, 50])) + else: # throughput + rate = draw(st.integers(min_value=5, max_value=50)) + + return strategy, rate + + +@composite +def constraints_strategy(draw): + """Generate constraint configurations.""" + # For sanity tests, keep them short (20s max) + max_seconds = draw(st.integers(min_value=10, max_value=20)) + max_requests = draw(st.sampled_from([25, 50, 100])) + return max_seconds, max_requests + + +@composite +def aggregation_strategy(draw): + """Generate aggregation configurations.""" + use_aggregation = draw(st.booleans()) + if not use_aggregation: + return None, None + + warmup = draw(st.integers(min_value=5, max_value=20)) + cooldown = draw(st.integers(min_value=5, max_value=20)) + return warmup, cooldown + + +@pytest.mark.sanity +@pytest.mark.timeout(3600) +@given( + backend=backend_strategy(), + data_config=data_strategy(), + strategy_rate=strategy_rate_strategy(), + constraints=constraints_strategy(), + aggregation=aggregation_strategy(), +) +@settings( + max_examples=20, # Limit examples for reasonable test time + deadline=None, # Disable deadline for E2E tests + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_sanity_property_based_benchmark( + backend, data_config, strategy_rate, constraints, aggregation +): + """ + Property-based sanity tests covering cartesian product of configurations. + Each test runs for up to 20 seconds with systematic parameter combinations. + """ + strategy, rate = strategy_rate + max_seconds, max_requests = constraints + warmup_percent, cooldown_percent = aggregation + + profile = BACKEND_PROFILES[backend] + + server = VllmSimServer( + mode="random", + time_to_first_token=profile["ttft"], + inter_token_latency=profile["itl"], + ) + with server: + benchmark = run_benchmark_test( + server=server, + strategy=strategy, + rate=rate, + data_config=data_config, + max_seconds=max_seconds, + max_requests=max_requests, + warmup_percent=warmup_percent, + cooldown_percent=cooldown_percent, + timeout_multiplier=1.2, + ) + + # Property-based assertions + assert "requests" in benchmark + assert "successful" in benchmark["requests"] + assert len(benchmark["requests"]["successful"]) > 0 + assert "failed" in benchmark["requests"] + + # Validate metrics structure + assert "metrics" in benchmark + metrics = benchmark["metrics"] + assert "request_rate" in metrics + assert "error_rate" in metrics + + +# ============================================================================= +# REGRESSION TESTS - Curated long-running tests +# ============================================================================= + + +@pytest.mark.regression +@pytest.mark.timeout(600) +def test_regression_high_load_code_generation(medium_server): + """ + Long-running code generation stress test. + - High concurrent load (100) + - Long duration (120s) + - Large outputs (2048 tokens) + """ + + benchmark = run_benchmark_test( + server=medium_server, + strategy="concurrent", + rate=100, + data_config="prompt_tokens=512,output_tokens=2048", + max_seconds=120, + max_requests=1000, + timeout_multiplier=2.0, + ) + + # Validate high-load performance + successful_requests = benchmark["requests"]["successful"] + assert len(successful_requests) >= 50, ( + f"Too few successful requests: {len(successful_requests)}" + ) + + if successful_requests: + assert_successful_requests_fields(successful_requests) + + +@pytest.mark.regression +@pytest.mark.timeout(600) +def test_regression_offline_throughput_stress(slow_server): + """ + Long-running offline throughput test. + - Large inputs/outputs (2048x2048) + - Slow backend simulation + - High request volume (5000) + """ + + benchmark = run_benchmark_test( + server=slow_server, + strategy="throughput", + rate=50, + data_config="prompt_tokens=2048,output_tokens=2048", + max_requests=1000, # Reduced from 5000 for reasonable test time + timeout_multiplier=3.0, + ) + + # Validate throughput characteristics + successful_requests = benchmark["requests"]["successful"] + assert len(successful_requests) >= 100, ( + f"Too few successful requests: {len(successful_requests)}" + ) + + +@pytest.mark.regression +@pytest.mark.timeout(600) +def test_regression_sustained_high_rate_constant(fast_server): + """ + Long-running sustained high rate test. + - Fast backend with high constant rate + - Extended duration to test stability + """ + + benchmark = run_benchmark_test( + server=fast_server, + strategy="constant", + rate=500, + data_config="prompt_tokens=64,output_tokens=64", + max_seconds=180, + max_requests=2000, + warmup_percent=5, + timeout_multiplier=2.0, + ) + + # Validate sustained performance + successful_requests = benchmark["requests"]["successful"] + assert len(successful_requests) >= 200, ( + f"Too few successful requests: {len(successful_requests)}" + ) + + # Check rate sustainability + metrics = benchmark["metrics"] + request_rate = metrics.get("request_rate", 0) + assert request_rate > 100, f"Request rate too low: {request_rate}" diff --git a/tests/e2e/test_max_error_benchmark.py b/tests/e2e/test_max_error_benchmark.py index 6079b21c..221de87a 100644 --- a/tests/e2e/test_max_error_benchmark.py +++ b/tests/e2e/test_max_error_benchmark.py @@ -20,53 +20,54 @@ def server(): Pytest fixture to start and stop the server for the entire module using the TestServer class. """ - server = VllmSimServer(port=8000, model="databricks/dolly-v2-12b", mode="echo") - try: - server.start() + server = VllmSimServer( + mode="random", + time_to_first_token=1, # 1ms TTFT + inter_token_latency=1, # 1ms ITL + ) + with server: yield server # Yield the URL for tests to use - finally: - server.stop() # Teardown: Stop the server after tests are done +@pytest.mark.smoke @pytest.mark.timeout(30) def test_max_error_benchmark(server: VllmSimServer): """ Test that the max error rate constraint is properly triggered when server goes down. """ report_path = Path("tests/e2e/max_error_benchmarks.json") + cleanup_report_file(report_path) rate = 10 max_error_rate = 0.1 # Create and configure the guidellm client client = GuidellmClient(target=server.get_url(), output_path=report_path) - try: - # Start the benchmark - client.start_benchmark( - rate=rate, - max_seconds=25, - max_error_rate=max_error_rate, - ) + # Start the benchmark + client.start_benchmark( + rate=rate, + max_seconds=25, + max_error_rate=max_error_rate, + ) - # Wait for the benchmark to complete (server will be stopped after 10 seconds) - client.wait_for_completion(timeout=30, stop_server_after=10, server=server) + # Wait for the benchmark to complete (server will be stopped after 10 seconds) + client.wait_for_completion(timeout=30, stop_server_after=10, server=server) - # Assert no Python exceptions occurred - assert_no_python_exceptions(client.stderr) + # Assert no Python exceptions occurred + assert_no_python_exceptions(client.stderr) - # Load and validate the report - report = load_benchmark_report(report_path) - benchmark = report["benchmarks"][0] + # Load and validate the report + report = load_benchmark_report(report_path) + benchmark = report["benchmarks"][0] - # Check that the max error rate constraint was triggered - assert_constraint_triggered( - benchmark, - "max_error_rate", - { - "exceeded_error_rate": True, - "current_error_rate": lambda rate: rate >= max_error_rate, - }, - ) + # Check that the max error rate constraint was triggered + assert_constraint_triggered( + benchmark, + "max_error_rate", + { + "exceeded_error_rate": True, + "current_error_rate": lambda rate: rate >= max_error_rate, + }, + ) - finally: - cleanup_report_file(report_path) + cleanup_report_file(report_path) diff --git a/tests/e2e/test_successful_benchmark.py b/tests/e2e/test_successful_benchmark.py index 8f0181a3..bd6dec20 100644 --- a/tests/e2e/test_successful_benchmark.py +++ b/tests/e2e/test_successful_benchmark.py @@ -22,99 +22,90 @@ def server(): using the TestServer class. """ server = VllmSimServer( - port=8000, - model="databricks/dolly-v2-12b", - mode="echo", + mode="random", time_to_first_token=1, # 1ms TTFT inter_token_latency=1, # 1ms ITL ) - try: - server.start() + with server: yield server # Yield the URL for tests to use - finally: - server.stop() # Teardown: Stop the server after tests are done +@pytest.mark.smoke @pytest.mark.timeout(30) def test_max_seconds_benchmark(server: VllmSimServer): """ Test that the max seconds constraint is properly triggered. """ report_path = Path("tests/e2e/max_duration_benchmarks.json") + cleanup_report_file(report_path) rate = 10 # Create and configure the guidellm client client = GuidellmClient(target=server.get_url(), output_path=report_path) - try: - # Start the benchmark - client.start_benchmark( - rate=rate, - max_seconds=1, - ) + # Start the benchmark + client.start_benchmark( + rate=rate, + max_seconds=1, + ) - # Wait for the benchmark to complete - client.wait_for_completion(timeout=30) + # Wait for the benchmark to complete + client.wait_for_completion(timeout=30) - # Assert no Python exceptions occurred - assert_no_python_exceptions(client.stderr) + # Assert no Python exceptions occurred + assert_no_python_exceptions(client.stderr) - # Load and validate the report - report = load_benchmark_report(report_path) - benchmark = report["benchmarks"][0] + # Load and validate the report + report = load_benchmark_report(report_path) + benchmark = report["benchmarks"][0] - # Check that the max duration constraint was triggered - assert_constraint_triggered( - benchmark, "max_seconds", {"duration_exceeded": True} - ) + # Check that the max duration constraint was triggered + assert_constraint_triggered(benchmark, "max_seconds", {"duration_exceeded": True}) - # Validate successful requests have all expected fields - successful_requests = benchmark["requests"]["successful"] - assert_successful_requests_fields(successful_requests) + # Validate successful requests have all expected fields + successful_requests = benchmark["requests"]["successful"] + assert_successful_requests_fields(successful_requests) - finally: - cleanup_report_file(report_path) + cleanup_report_file(report_path) +@pytest.mark.smoke @pytest.mark.timeout(30) def test_max_requests_benchmark(server: VllmSimServer): """ Test that the max requests constraint is properly triggered. """ report_path = Path("tests/e2e/max_number_benchmarks.json") + cleanup_report_file(report_path) rate = 10 # Create and configure the guidellm client client = GuidellmClient(target=server.get_url(), output_path=report_path) - try: - # Start the benchmark - client.start_benchmark( - rate=rate, - max_requests=rate, - ) - - # Wait for the benchmark to complete - client.wait_for_completion(timeout=30) - - # Assert no Python exceptions occurred - assert_no_python_exceptions(client.stderr) - - # Load and validate the report - report = load_benchmark_report(report_path) - benchmark = report["benchmarks"][0] - - # Check that the max requests constraint was triggered - assert_constraint_triggered( - benchmark, "max_requests", {"processed_exceeded": True} - ) - - # Validate successful requests have all expected fields - successful_requests = benchmark["requests"]["successful"] - assert len(successful_requests) == rate, ( - f"Expected {rate} successful requests, got {len(successful_requests)}" - ) - assert_successful_requests_fields(successful_requests) - - finally: - cleanup_report_file(report_path) + # Start the benchmark + client.start_benchmark( + rate=rate, + max_requests=rate, + ) + + # Wait for the benchmark to complete + client.wait_for_completion(timeout=30) + + # Assert no Python exceptions occurred + assert_no_python_exceptions(client.stderr) + + # Load and validate the report + report = load_benchmark_report(report_path) + benchmark = report["benchmarks"][0] + + # Check that the max requests constraint was triggered + assert_constraint_triggered(benchmark, "max_requests", {"processed_exceeded": True}) + + # Validate successful requests have all expected fields + successful_requests = benchmark["requests"]["successful"] + assert len(successful_requests) == rate, ( + f"Expected {rate} successful requests, got {len(successful_requests)}" + ) + assert_successful_requests_fields(successful_requests) + + cleanup_report_file(report_path) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 9357949c..bf950df1 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -41,7 +41,7 @@ def __init__(self, target: str, output_path: Path): def start_benchmark( self, rate_type: str = "constant", - rate: int = 10, + rate: Optional[int] = 10, max_seconds: Optional[int] = None, max_requests: Optional[int] = None, max_error_rate: Optional[float] = None, @@ -65,12 +65,15 @@ def start_benchmark( # Build command components cmd_parts = [ - f"GUIDELLM__MAX_CONCURRENCY=10 GUIDELLM__MAX_WORKER_PROCESSES=10 {guidellm_exe} benchmark", + f"HF_HOME=/tmp/huggingface_cache {guidellm_exe} benchmark", f'--target "{self.target}"', f"--rate-type {rate_type}", - f"--rate {rate}", ] + # Only add rate parameter if it's not None (synchronous doesn't use rate) + if rate is not None: + cmd_parts.append(f"--rate {rate}") + if max_seconds is not None: cmd_parts.append(f"--max-seconds {max_seconds}") diff --git a/tests/e2e/vllm_sim_server.py b/tests/e2e/vllm_sim_server.py index 726dba40..41ae9165 100644 --- a/tests/e2e/vllm_sim_server.py +++ b/tests/e2e/vllm_sim_server.py @@ -16,8 +16,8 @@ class VllmSimServer: def __init__( self, - port: int, - model: str, + port: int = 8000, + model: str = "test-model", lora: Optional[list[str]] = None, mode: Optional[str] = None, echo: Optional[bool] = None, @@ -134,3 +134,10 @@ def get_url(self): Returns the base URL of the running server. """ return self.server_url + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stop() diff --git a/tests/unit/benchmark/test_benchmarker.py b/tests/unit/benchmark/test_benchmarker.py index df0c6c3a..5f690677 100644 --- a/tests/unit/benchmark/test_benchmarker.py +++ b/tests/unit/benchmark/test_benchmarker.py @@ -23,7 +23,6 @@ from guidellm.benchmark.profile import SynchronousProfile from guidellm.scheduler import ( BackendInterface, - MeasuredRequestTimingsT, NonDistributedEnvironment, RequestT, ResponseT, @@ -72,15 +71,6 @@ def test_response_t(): assert ResponseT.__constraints__ == () -@pytest.mark.smoke -def test_measured_request_timings_t(): - """Test that MeasuredRequestTimingsT is filled out correctly as a TypeVar.""" - assert isinstance(MeasuredRequestTimingsT, type(TypeVar("tmp"))) - assert MeasuredRequestTimingsT.__name__ == "MeasuredRequestTimingsT" - assert MeasuredRequestTimingsT.__bound__ is not None - assert MeasuredRequestTimingsT.__constraints__ == () - - class MockBenchmark: def __init__(self, **kwargs): for key, val in kwargs.items(): diff --git a/tests/unit/benchmark/test_objects.py b/tests/unit/benchmark/test_objects.py index fd74526a..d17f4bba 100644 --- a/tests/unit/benchmark/test_objects.py +++ b/tests/unit/benchmark/test_objects.py @@ -551,7 +551,7 @@ class TestBenchmark: worker_targeted_start_delay_avg=0.1, request_start_delay_avg=0.1, request_time_avg=0.1, - request_targeted_delay_avg=0.1, + request_targeted_start_delay_avg=0.1, ), "start_time": 1000.0, "end_time": 2000.0, @@ -677,7 +677,7 @@ def test_invalid_initialization_values(self): worker_targeted_start_delay_avg=0.1, request_start_delay_avg=0.1, request_time_avg=0.1, - request_targeted_delay_avg=0.1, + request_targeted_start_delay_avg=0.1, ), start_time=0, end_time=1, @@ -980,7 +980,7 @@ class TestGenerativeBenchmark: worker_targeted_start_delay_avg=0.1, request_start_delay_avg=0.1, request_time_avg=0.1, - request_targeted_delay_avg=0.1, + request_targeted_start_delay_avg=0.1, ), "start_time": 1000.0, "end_time": 2000.0, @@ -1099,7 +1099,7 @@ class TestGenerativeBenchmarksReport: worker_targeted_start_delay_avg=0.1, request_start_delay_avg=0.1, request_time_avg=0.1, - request_targeted_delay_avg=0.1, + request_targeted_start_delay_avg=0.1, ), start_time=10, end_time=20, @@ -1154,7 +1154,7 @@ class TestGenerativeBenchmarksReport: worker_targeted_start_delay_avg=0.1, request_start_delay_avg=0.1, request_time_avg=0.1, - request_targeted_delay_avg=0.1, + request_targeted_start_delay_avg=0.1, ), start_time=30, end_time=40, diff --git a/tests/unit/mock_backend.py b/tests/unit/mock_backend.py index 4e1476d3..5ac069a8 100644 --- a/tests/unit/mock_backend.py +++ b/tests/unit/mock_backend.py @@ -96,11 +96,9 @@ async def default_model(self) -> Optional[str]: async def resolve( self, request: GenerationRequest, - request_info: ScheduledRequestInfo[GenerationRequestTimings], + request_info: ScheduledRequestInfo, history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, - ) -> AsyncIterator[ - tuple[GenerationResponse, ScheduledRequestInfo[GenerationRequestTimings]] - ]: + ) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]: """ Process a generation request and yield progressive responses. diff --git a/tests/unit/mock_server/__init__.py b/tests/unit/mock_server/__init__.py new file mode 100644 index 00000000..e02d60bd --- /dev/null +++ b/tests/unit/mock_server/__init__.py @@ -0,0 +1 @@ +"""Unit tests for the GuideLLM mock server package.""" diff --git a/tests/unit/mock_server/test_server.py b/tests/unit/mock_server/test_server.py new file mode 100644 index 00000000..ed5c7727 --- /dev/null +++ b/tests/unit/mock_server/test_server.py @@ -0,0 +1,518 @@ +from __future__ import annotations + +import asyncio +import json +import multiprocessing + +import httpx +import pytest +import pytest_asyncio +from pydantic import ValidationError + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.server import MockServer + + +# Start server in a separate process +def _start_server_process(config: MockServerConfig): + server = MockServer(config) + server.run() + + +@pytest_asyncio.fixture(scope="class") +async def mock_server_instance(): + """Instance-level fixture that provides a running server for HTTP testing.""" + + config = MockServerConfig( + host="127.0.0.1", + port=8012, + model="test-model", + ttft_ms=10.0, + itl_ms=1.0, + request_latency=0.1, + ) + base_url = f"http://{config.host}:{config.port}" + server_process = multiprocessing.Process( + target=_start_server_process, args=(config,) + ) + server_process.start() + + # Wait for server to start up and be ready + async def wait_for_startup(): + poll_frequency = 1.0 + async with httpx.AsyncClient() as client: + while True: + try: + response = await client.get(f"{base_url}/health", timeout=1.0) + if response.status_code == 200: + break + except (httpx.RequestError, httpx.TimeoutException): + pass + await asyncio.sleep(poll_frequency) + poll_frequency = min(poll_frequency * 1.5, 2.0) + + timeout = 30.0 + try: + await asyncio.wait_for(wait_for_startup(), timeout) + except TimeoutError: + # Server failed to start within timeout + server_process.terminate() + server_process.kill() + server_process.join(timeout=5) + pytest.fail(f"Server failed to start within {timeout} seconds") + + yield base_url, config + + # Cleanup: terminate the server process + server_process.terminate() + server_process.kill() + server_process.join(timeout=5) + + +class TestMockServerConfig: + """Test suite for MockServerConfig class.""" + + @pytest.mark.smoke + def test_default_initialization(self): + """Test MockServerConfig initialization with default values.""" + config = MockServerConfig() + assert config.host == "127.0.0.1" + assert config.port == 8000 + assert config.workers == 1 + assert config.model == "llama-3.1-8b-instruct" + assert config.processor is None + assert config.request_latency == 3.0 + assert config.request_latency_std == 0.0 + assert config.ttft_ms == 150.0 + assert config.ttft_ms_std == 0.0 + assert config.itl_ms == 10.0 + assert config.itl_ms_std == 0.0 + assert config.output_tokens == 128 + assert config.output_tokens_std == 0.0 + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("kwargs", "expected_values"), + [ + ( + {"host": "127.0.0.1", "port": 9000, "model": "custom-model"}, + {"host": "127.0.0.1", "port": 9000, "model": "custom-model"}, + ), + ( + {"request_latency": 1.5, "ttft_ms": 100.0, "output_tokens": 256}, + {"request_latency": 1.5, "ttft_ms": 100.0, "output_tokens": 256}, + ), + ], + ) + def test_custom_initialization(self, kwargs, expected_values): + """Test MockServerConfig initialization with custom values.""" + config = MockServerConfig(**kwargs) + for key, expected_value in expected_values.items(): + assert getattr(config, key) == expected_value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("port", "not_int"), + ("request_latency", "not_float"), + ("output_tokens", "not_int"), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test MockServerConfig with invalid field values.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + MockServerConfig(**kwargs) + + +class TestMockServer: + """Test suite for MockServer class.""" + + @pytest.mark.smoke + def test_class_signatures(self): + """Test MockServer class signatures and attributes.""" + assert hasattr(MockServer, "__init__") + assert hasattr(MockServer, "run") + assert hasattr(MockServer, "_setup_middleware") + assert hasattr(MockServer, "_setup_routes") + assert hasattr(MockServer, "_setup_error_handlers") + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test MockServer initialization without required config.""" + with pytest.raises(TypeError): + MockServer() + + +class TestMockServerEndpoints: + """Test suite for MockServer HTTP endpoints with real server instances.""" + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_health_endpoint(self, mock_server_instance): + """Test the health check endpoint.""" + server_url, _ = mock_server_instance + + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/health", timeout=5.0) + assert response.status_code == 200 + + data = response.json() + assert "status" in data + assert data["status"] == "healthy" + assert "timestamp" in data + assert isinstance(data["timestamp"], (int, float)) + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_models_endpoint(self, mock_server_instance): + """Test the models listing endpoint.""" + server_url, _ = mock_server_instance + + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/v1/models", timeout=5.0) + assert response.status_code == 200 + + data = response.json() + assert "object" in data + assert data["object"] == "list" + assert "data" in data + assert isinstance(data["data"], list) + assert len(data["data"]) > 0 + + model = data["data"][0] + assert "id" in model + assert "object" in model + assert "created" in model + assert "owned_by" in model + assert model["object"] == "model" + assert model["owned_by"] == "guidellm-mock" + assert model["id"] == "test-model" + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("payload", "expected_fields"), + [ + ( + { + "model": "test-model", + "messages": [{"role": "user", "content": "Hello!"}], + "max_tokens": 10, + }, + ["choices", "usage", "model", "object"], + ), + ( + { + "model": "test-model", + "messages": [{"role": "user", "content": "Test"}], + "max_tokens": 5, + "temperature": 0.7, + }, + ["choices", "usage", "model", "object"], + ), + ], + ) + async def test_chat_completions_endpoint( + self, mock_server_instance, payload, expected_fields + ): + """Test the chat completions endpoint.""" + server_url, _ = mock_server_instance + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/v1/chat/completions", json=payload, timeout=10.0 + ) + assert response.status_code == 200 + + data = response.json() + for field in expected_fields: + assert field in data + + assert len(data["choices"]) > 0 + choice = data["choices"][0] + assert "message" in choice + assert "content" in choice["message"] + assert "role" in choice["message"] + assert choice["message"]["role"] == "assistant" + assert isinstance(choice["message"]["content"], str) + assert len(choice["message"]["content"]) > 0 + + # Verify usage information + assert "prompt_tokens" in data["usage"] + assert "completion_tokens" in data["usage"] + assert "total_tokens" in data["usage"] + assert data["usage"]["total_tokens"] == ( + data["usage"]["prompt_tokens"] + data["usage"]["completion_tokens"] + ) + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_streaming_chat_completions(self, mock_server_instance): + """Test streaming chat completions endpoint.""" + server_url, _ = mock_server_instance + + payload = { + "model": "test-model", + "messages": [{"role": "user", "content": "Hi!"}], + "max_tokens": 5, + "stream": True, + } + + async with ( + httpx.AsyncClient() as client, + client.stream( + "POST", + f"{server_url}/v1/chat/completions", + json=payload, + timeout=10.0, + ) as response, + ): + assert response.status_code == 200 + assert "text/event-stream" in response.headers.get("content-type", "") + + chunks = [] + async for line in response.aiter_lines(): + if line and line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + break + try: + chunk_data = json.loads(data_str) + chunks.append(chunk_data) + except json.JSONDecodeError: + continue + + assert len(chunks) > 0 + # Verify chunk structure + for chunk in chunks: + assert "choices" in chunk + assert len(chunk["choices"]) > 0 + assert "delta" in chunk["choices"][0] + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("payload", "expected_fields"), + [ + ( + { + "model": "test-model", + "prompt": "Hello", + "max_tokens": 10, + }, + ["choices", "usage", "model", "object"], + ), + ( + { + "model": "test-model", + "prompt": "Test prompt", + "max_tokens": 5, + "temperature": 0.8, + }, + ["choices", "usage", "model", "object"], + ), + ], + ) + @pytest.mark.asyncio + async def test_completions_endpoint( + self, mock_server_instance, payload, expected_fields + ): + """Test the legacy completions endpoint.""" + server_url, _ = mock_server_instance + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/v1/completions", json=payload, timeout=10.0 + ) + assert response.status_code == 200 + + data = response.json() + for field in expected_fields: + assert field in data + + assert len(data["choices"]) > 0 + choice = data["choices"][0] + assert "text" in choice + assert isinstance(choice["text"], str) + assert len(choice["text"]) > 0 + + # Verify usage information + assert "prompt_tokens" in data["usage"] + assert "completion_tokens" in data["usage"] + assert "total_tokens" in data["usage"] + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_streaming_completions(self, mock_server_instance): + """Test streaming completions endpoint.""" + server_url, _ = mock_server_instance + payload = { + "model": "test-model", + "prompt": "Hello", + "max_tokens": 5, + "stream": True, + } + + async with ( + httpx.AsyncClient() as client, + client.stream( + "POST", + f"{server_url}/v1/completions", + json=payload, + timeout=10.0, + ) as response, + ): + assert response.status_code == 200 + assert "text/event-stream" in response.headers.get("content-type", "") + + chunks = [] + async for line in response.aiter_lines(): + if line and line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + break + try: + chunk_data = json.loads(data_str) + chunks.append(chunk_data) + except json.JSONDecodeError: + continue + + assert len(chunks) > 0 + # Verify chunk structure + for chunk in chunks: + assert "choices" in chunk + assert len(chunk["choices"]) > 0 + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("payload", "expected_fields"), + [ + ( + {"text": "Hello world!"}, + ["tokens", "count"], + ), + ( + {"text": "This is a test sentence."}, + ["tokens", "count"], + ), + ], + ) + @pytest.mark.asyncio + async def test_tokenize_endpoint( + self, mock_server_instance, payload, expected_fields + ): + """Test the tokenize endpoint.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/tokenize", json=payload, timeout=5.0 + ) + assert response.status_code == 200 + + data = response.json() + for field in expected_fields: + assert field in data + + assert isinstance(data["tokens"], list) + assert isinstance(data["count"], int) + assert data["count"] == len(data["tokens"]) + assert len(data["tokens"]) > 0 + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("payload", "expected_fields"), + [ + ( + {"tokens": [123, 456, 789]}, + ["text"], + ), + ( + {"tokens": [100, 200]}, + ["text"], + ), + ], + ) + @pytest.mark.asyncio + async def test_detokenize_endpoint( + self, mock_server_instance, payload, expected_fields + ): + """Test the detokenize endpoint.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/detokenize", json=payload, timeout=5.0 + ) + assert response.status_code == 200 + + data = response.json() + for field in expected_fields: + assert field in data + + assert isinstance(data["text"], str) + assert len(data["text"]) > 0 + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_options_endpoint(self, mock_server_instance): + """Test the OPTIONS endpoint for CORS support.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.options( + f"{server_url}/v1/chat/completions", timeout=5.0 + ) + assert response.status_code == 204 + assert response.text == "" + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_cors_headers(self, mock_server_instance): + """Test CORS headers are properly set.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/health", timeout=5.0) + assert response.status_code == 200 + + # Check for CORS headers + assert response.headers.get("Access-Control-Allow-Origin") == "*" + methods_header = response.headers.get("Access-Control-Allow-Methods", "") + assert "GET, POST, OPTIONS" in methods_header + headers_header = response.headers.get("Access-Control-Allow-Headers", "") + assert "Content-Type, Authorization" in headers_header + assert response.headers.get("Server") == "guidellm-mock-server" + + @pytest.mark.sanity + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("endpoint", "method", "payload"), + [ + ("/v1/chat/completions", "POST", {"invalid": "payload"}), + ("/v1/completions", "POST", {"invalid": "payload"}), + ("/tokenize", "POST", {"invalid": "payload"}), + ("/detokenize", "POST", {"invalid": "payload"}), + ], + ) + async def test_invalid_request_handling( + self, mock_server_instance, endpoint, method, payload + ): + """Test handling of invalid requests.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + if method == "POST": + response = await client.post( + f"{server_url}{endpoint}", json=payload, timeout=5.0 + ) + else: + response = await client.get(f"{server_url}{endpoint}", timeout=5.0) + + # Should return an error response, not crash + assert response.status_code in [400, 422, 500] + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_nonexistent_endpoint(self, mock_server_instance): + """Test handling of requests to nonexistent endpoints.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/nonexistent", timeout=5.0) + assert response.status_code == 404 diff --git a/tests/unit/objects/test_pydantic.py b/tests/unit/objects/test_pydantic.py deleted file mode 100644 index 515d95ab..00000000 --- a/tests/unit/objects/test_pydantic.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest -from pydantic import computed_field - -from guidellm.utils.pydantic_utils import StandardBaseModel - - -class ExampleModel(StandardBaseModel): - name: str - age: int - - @computed_field # type: ignore[misc] - @property - def computed(self) -> str: - return self.name + " " + str(self.age) - - -@pytest.mark.smoke -def test_standard_base_model_initialization(): - example = ExampleModel(name="John Doe", age=30) - assert example.name == "John Doe" - assert example.age == 30 - assert example.computed == "John Doe 30" - - -@pytest.mark.smoke -def test_standard_base_model_invalid_initialization(): - with pytest.raises(ValueError): - ExampleModel(name="John Doe", age="thirty") # type: ignore[arg-type] - - -@pytest.mark.smoke -def test_standard_base_model_marshalling(): - example = ExampleModel(name="John Doe", age=30) - serialized = example.model_dump() - assert serialized["name"] == "John Doe" - assert serialized["age"] == 30 - assert serialized["computed"] == "John Doe 30" - - serialized["computed"] = "Jane Doe 40" - deserialized = ExampleModel.model_validate(serialized) - assert deserialized.name == "John Doe" - assert deserialized.age == 30 - assert deserialized.computed == "John Doe 30" diff --git a/tests/unit/objects/test_statistics.py b/tests/unit/objects/test_statistics.py deleted file mode 100644 index 855bfa5f..00000000 --- a/tests/unit/objects/test_statistics.py +++ /dev/null @@ -1,706 +0,0 @@ -import math -import time -from typing import Literal - -import numpy as np -import pytest - -from guidellm.utils import ( - DistributionSummary, - Percentiles, - RunningStats, - StatusDistributionSummary, - TimeRunningStats, -) - - -def create_default_percentiles() -> Percentiles: - return Percentiles( - p001=0.1, - p01=1.0, - p05=5.0, - p10=10.0, - p25=25.0, - p50=50.0, - p75=75.0, - p90=90.0, - p95=95.0, - p99=99.0, - p999=99.9, - ) - - -def create_default_distribution_summary() -> DistributionSummary: - return DistributionSummary( - mean=50.0, - median=50.0, - mode=50.0, - variance=835, - std_dev=math.sqrt(835), - min=0.0, - max=100.0, - count=1001, - total_sum=50050.0, - percentiles=create_default_percentiles(), - ) - - -@pytest.mark.smoke -def test_percentiles_initialization(): - percentiles = create_default_percentiles() - assert percentiles.p001 == 0.1 - assert percentiles.p01 == 1.0 - assert percentiles.p05 == 5.0 - assert percentiles.p10 == 10.0 - assert percentiles.p25 == 25.0 - assert percentiles.p50 == 50.0 - assert percentiles.p75 == 75.0 - assert percentiles.p90 == 90.0 - assert percentiles.p95 == 95.0 - assert percentiles.p99 == 99.0 - assert percentiles.p999 == 99.9 - - -@pytest.mark.smoke -def test_percentiles_invalid_initialization(): - test_kwargs = { - "p001": 0.1, - "p01": 1.0, - "p05": 5.0, - "p10": 10.0, - "p25": 25.0, - "p50": 50.0, - "p75": 75.0, - "p90": 90.0, - "p95": 95.0, - "p99": 99.0, - "p999": 99.9, - } - test_missing_keys = list(test_kwargs.keys()) - - for missing_key in test_missing_keys: - kwargs = {key: val for key, val in test_kwargs.items() if key != missing_key} - with pytest.raises(ValueError): - Percentiles(**kwargs) - - -@pytest.mark.smoke -def test_percentiles_marshalling(): - percentiles = create_default_percentiles() - serialized = percentiles.model_dump() - deserialized = Percentiles.model_validate(serialized) - - for key, value in vars(percentiles).items(): - assert getattr(deserialized, key) == value - - -@pytest.mark.smoke -def test_distribution_summary_initilaization(): - distribution_summary = create_default_distribution_summary() - assert distribution_summary.mean == 50.0 - assert distribution_summary.median == 50.0 - assert distribution_summary.mode == 50.0 - assert distribution_summary.variance == 835 - assert distribution_summary.std_dev == math.sqrt(835) - assert distribution_summary.min == 0.0 - assert distribution_summary.max == 100.0 - assert distribution_summary.count == 1001 - assert distribution_summary.total_sum == 50050.0 - assert distribution_summary.percentiles.p001 == 0.1 - assert distribution_summary.percentiles.p01 == 1.0 - assert distribution_summary.percentiles.p05 == 5.0 - assert distribution_summary.percentiles.p10 == 10.0 - assert distribution_summary.percentiles.p25 == 25.0 - assert distribution_summary.percentiles.p50 == 50.0 - assert distribution_summary.percentiles.p75 == 75.0 - assert distribution_summary.percentiles.p90 == 90.0 - assert distribution_summary.percentiles.p95 == 95.0 - assert distribution_summary.percentiles.p99 == 99.0 - assert distribution_summary.percentiles.p999 == 99.9 - - -@pytest.mark.smoke -def test_distribution_summary_invalid_initialization(): - test_kwargs = { - "mean": 50.0, - "median": 50.0, - "mode": 50.0, - "variance": 835, - "std_dev": math.sqrt(835), - "min": 0.0, - "max": 100.0, - "count": 1001, - "total_sum": 50050.0, - "percentiles": create_default_percentiles(), - } - test_missing_keys = list(test_kwargs.keys()) - for missing_key in test_missing_keys: - kwargs = {key: val for key, val in test_kwargs.items() if key != missing_key} - with pytest.raises(ValueError): - DistributionSummary(**kwargs) # type: ignore[arg-type] - - -@pytest.mark.smoke -def test_distribution_summary_marshalling(): - distribution_summary = create_default_distribution_summary() - serialized = distribution_summary.model_dump() - deserialized = DistributionSummary.model_validate(serialized) - - for key, value in vars(distribution_summary).items(): - assert getattr(deserialized, key) == value - - -@pytest.mark.smoke -def test_distribution_summary_from_distribution_function(): - values = [val / 10.0 for val in range(1001)] - distribution = [(val, 1.0) for val in values] - distribution_summary = DistributionSummary.from_distribution_function(distribution) - assert distribution_summary.mean == pytest.approx(np.mean(values)) - assert distribution_summary.median == pytest.approx(np.median(values)) - assert distribution_summary.mode == 0.0 - assert distribution_summary.variance == pytest.approx(np.var(values, ddof=0)) - assert distribution_summary.std_dev == pytest.approx(np.std(values, ddof=0)) - assert distribution_summary.min == min(values) - assert distribution_summary.max == max(values) - assert distribution_summary.count == len(values) - assert distribution_summary.total_sum == sum(values) - assert distribution_summary.percentiles.p001 == pytest.approx( - np.percentile(values, 0.1) - ) - assert distribution_summary.percentiles.p01 == pytest.approx( - np.percentile(values, 1.0) - ) - assert distribution_summary.percentiles.p05 == pytest.approx( - np.percentile(values, 5.0) - ) - assert distribution_summary.percentiles.p10 == pytest.approx( - np.percentile(values, 10.0) - ) - assert distribution_summary.percentiles.p25 == pytest.approx( - np.percentile(values, 25.0) - ) - assert distribution_summary.percentiles.p50 == pytest.approx( - np.percentile(values, 50.0) - ) - assert distribution_summary.percentiles.p75 == pytest.approx( - np.percentile(values, 75.0) - ) - assert distribution_summary.percentiles.p90 == pytest.approx( - np.percentile(values, 90.0) - ) - assert distribution_summary.percentiles.p95 == pytest.approx( - np.percentile(values, 95.0) - ) - assert distribution_summary.percentiles.p99 == pytest.approx( - np.percentile(values, 99.0) - ) - assert distribution_summary.percentiles.p999 == pytest.approx( - np.percentile(values, 99.9) - ) - assert distribution_summary.cumulative_distribution_function is None - - distribution_summary_cdf = DistributionSummary.from_distribution_function( - distribution, include_cdf=True - ) - assert distribution_summary_cdf.cumulative_distribution_function is not None - assert len(distribution_summary_cdf.cumulative_distribution_function) == len(values) - - -def test_distribution_summary_from_values(): - values = [val / 10 for val in range(1001)] - distribution_summary = DistributionSummary.from_values(values) - assert distribution_summary.mean == pytest.approx(np.mean(values)) - assert distribution_summary.median == pytest.approx(np.median(values)) - assert distribution_summary.mode == 0.0 - assert distribution_summary.variance == pytest.approx(np.var(values, ddof=0)) - assert distribution_summary.std_dev == pytest.approx(np.std(values, ddof=0)) - assert distribution_summary.min == min(values) - assert distribution_summary.max == max(values) - assert distribution_summary.count == len(values) - assert distribution_summary.total_sum == sum(values) - assert distribution_summary.percentiles.p001 == pytest.approx( - np.percentile(values, 0.1) - ) - assert distribution_summary.percentiles.p01 == pytest.approx( - np.percentile(values, 1.0) - ) - assert distribution_summary.percentiles.p05 == pytest.approx( - np.percentile(values, 5.0) - ) - assert distribution_summary.percentiles.p10 == pytest.approx( - np.percentile(values, 10.0) - ) - assert distribution_summary.percentiles.p25 == pytest.approx( - np.percentile(values, 25.0) - ) - assert distribution_summary.percentiles.p50 == pytest.approx( - np.percentile(values, 50.0) - ) - assert distribution_summary.percentiles.p75 == pytest.approx( - np.percentile(values, 75.0) - ) - assert distribution_summary.percentiles.p90 == pytest.approx( - np.percentile(values, 90.0) - ) - assert distribution_summary.percentiles.p95 == pytest.approx( - np.percentile(values, 95.0) - ) - assert distribution_summary.percentiles.p99 == pytest.approx( - np.percentile(values, 99.0) - ) - assert distribution_summary.percentiles.p999 == pytest.approx( - np.percentile(values, 99.9) - ) - assert distribution_summary.cumulative_distribution_function is None - - distribution_summary_weights = DistributionSummary.from_values( - values, weights=[2] * len(values) - ) - assert distribution_summary_weights.mean == pytest.approx(np.mean(values)) - assert distribution_summary_weights.median == pytest.approx(np.median(values)) - assert distribution_summary_weights.mode == 0.0 - assert distribution_summary_weights.variance == pytest.approx( - np.var(values, ddof=0) - ) - assert distribution_summary_weights.std_dev == pytest.approx(np.std(values, ddof=0)) - assert distribution_summary_weights.min == min(values) - assert distribution_summary_weights.max == max(values) - assert distribution_summary_weights.count == len(values) - assert distribution_summary_weights.total_sum == sum(values) - assert distribution_summary_weights.cumulative_distribution_function is None - - distribution_summary_cdf = DistributionSummary.from_values(values, include_cdf=True) - assert distribution_summary_cdf.cumulative_distribution_function is not None - assert len(distribution_summary_cdf.cumulative_distribution_function) == len(values) - - -def test_distribution_summary_from_request_times_concurrency(): - # create consistent timestamped values matching a rate of 10 per second - requests = [(val / 10, val / 10 + 1) for val in range(10001)] - distribution_summary = DistributionSummary.from_request_times( - requests, distribution_type="concurrency" - ) - assert distribution_summary.mean == pytest.approx(10.0, abs=0.01) - assert distribution_summary.median == pytest.approx(10.0) - assert distribution_summary.mode == 10.0 - assert distribution_summary.variance == pytest.approx(0, abs=0.1) - assert distribution_summary.std_dev == pytest.approx(0, abs=0.3) - assert distribution_summary.min == pytest.approx(1) - assert distribution_summary.max == pytest.approx(10.0) - assert distribution_summary.count == 10 - assert distribution_summary.total_sum == pytest.approx(55.0) - assert distribution_summary.percentiles.p001 == pytest.approx(10, abs=5) - assert distribution_summary.percentiles.p01 == pytest.approx(10) - assert distribution_summary.percentiles.p05 == pytest.approx(10) - assert distribution_summary.percentiles.p10 == pytest.approx(10) - assert distribution_summary.percentiles.p25 == pytest.approx(10) - assert distribution_summary.percentiles.p50 == pytest.approx(10) - assert distribution_summary.percentiles.p75 == pytest.approx(10) - assert distribution_summary.percentiles.p90 == pytest.approx(10) - assert distribution_summary.percentiles.p95 == pytest.approx(10) - assert distribution_summary.percentiles.p99 == pytest.approx(10) - assert distribution_summary.percentiles.p999 == pytest.approx(10) - assert distribution_summary.cumulative_distribution_function is None - - distribution_summary_cdf = DistributionSummary.from_request_times( - requests, distribution_type="concurrency", include_cdf=True - ) - assert distribution_summary_cdf.cumulative_distribution_function is not None - assert len(distribution_summary_cdf.cumulative_distribution_function) == 10 - - -def test_distribution_summary_from_request_times_rate(): - # create consistent timestamped values matching a rate of 10 per second - requests = [(val / 10, val / 10 + 1) for val in range(10001)] - distribution_summary = DistributionSummary.from_request_times( - requests, distribution_type="rate" - ) - assert distribution_summary.mean == pytest.approx(10.0, abs=0.01) - assert distribution_summary.median == pytest.approx(10.0) - assert distribution_summary.mode == pytest.approx(10.0) - assert distribution_summary.variance == pytest.approx(0, abs=0.1) - assert distribution_summary.std_dev == pytest.approx(0, abs=0.3) - assert distribution_summary.min == pytest.approx(1.0) - assert distribution_summary.max == pytest.approx(10.0) - assert distribution_summary.count == 12 - assert distribution_summary.total_sum == pytest.approx(111.0) - assert distribution_summary.percentiles.p001 == pytest.approx(10.0, abs=0.5) - assert distribution_summary.percentiles.p01 == pytest.approx(10.0) - assert distribution_summary.percentiles.p05 == pytest.approx(10.0) - assert distribution_summary.percentiles.p10 == pytest.approx(10.0) - assert distribution_summary.percentiles.p25 == pytest.approx(10.0) - assert distribution_summary.percentiles.p50 == pytest.approx(10.0) - assert distribution_summary.percentiles.p75 == pytest.approx(10.0) - assert distribution_summary.percentiles.p90 == pytest.approx(10.0) - assert distribution_summary.percentiles.p95 == pytest.approx(10.0) - assert distribution_summary.percentiles.p99 == pytest.approx(10.0) - assert distribution_summary.percentiles.p999 == pytest.approx(10.0) - assert distribution_summary.cumulative_distribution_function is None - - distribution_summary_cdf = DistributionSummary.from_request_times( - requests, distribution_type="rate", include_cdf=True - ) - assert distribution_summary_cdf.cumulative_distribution_function is not None - assert len(distribution_summary_cdf.cumulative_distribution_function) == 12 - - -def test_distribution_summary_from_iterable_request_times(): - # create consistent timestamped values matching a rate of 10 per second - requests = [(val / 10, val / 10 + 1) for val in range(10001)] - # create 9 iterations for each request with first iter at start + 0.1 - # and spaced at 0.1 seconds apart - first_iter_times = [val / 10 + 0.1 for val in range(10001)] - iter_counts = [9 for _ in range(10001)] - first_iter_counts = [1 for _ in range(10001)] - - distribution_summary = DistributionSummary.from_iterable_request_times( - requests, first_iter_times, iter_counts, first_iter_counts - ) - assert distribution_summary.mean == pytest.approx(90.0, abs=0.1) - assert distribution_summary.median == pytest.approx(80.0) - assert distribution_summary.mode == pytest.approx(80.0) - assert distribution_summary.variance == pytest.approx(704.463, abs=0.001) - assert distribution_summary.std_dev == pytest.approx(26.541, abs=0.001) - assert distribution_summary.min == pytest.approx(0.0) - assert distribution_summary.max == pytest.approx(160.0) - assert distribution_summary.count == 44 - assert distribution_summary.total_sum == pytest.approx(3538.85, abs=0.01) - assert distribution_summary.percentiles.p001 == pytest.approx(80.0) - assert distribution_summary.percentiles.p01 == pytest.approx(80.0) - assert distribution_summary.percentiles.p05 == pytest.approx(80.0) - assert distribution_summary.percentiles.p10 == pytest.approx(80.0) - assert distribution_summary.percentiles.p25 == pytest.approx(80.0) - assert distribution_summary.percentiles.p50 == pytest.approx(80.0) - assert distribution_summary.percentiles.p75 == pytest.approx(80.0) - assert distribution_summary.percentiles.p90 == pytest.approx(160.0) - assert distribution_summary.percentiles.p95 == pytest.approx(160.0) - assert distribution_summary.percentiles.p99 == pytest.approx(160.0) - assert distribution_summary.percentiles.p999 == pytest.approx(160.0) - assert distribution_summary.cumulative_distribution_function is None - - distribution_summary_cdf = DistributionSummary.from_iterable_request_times( - requests, first_iter_times, iter_counts, first_iter_counts, include_cdf=True - ) - assert distribution_summary_cdf.cumulative_distribution_function is not None - assert len(distribution_summary_cdf.cumulative_distribution_function) == 44 - - -def test_status_distribution_summary_initialization(): - status_distribution_summary = StatusDistributionSummary( - total=create_default_distribution_summary(), - successful=create_default_distribution_summary(), - incomplete=create_default_distribution_summary(), - errored=create_default_distribution_summary(), - ) - assert status_distribution_summary.total.mean == 50.0 - assert status_distribution_summary.successful.mean == 50.0 - assert status_distribution_summary.incomplete.mean == 50.0 - assert status_distribution_summary.errored.mean == 50.0 - - -def test_status_distribution_summary_marshalling(): - status_distribution_summary = StatusDistributionSummary( - total=create_default_distribution_summary(), - successful=create_default_distribution_summary(), - incomplete=create_default_distribution_summary(), - errored=create_default_distribution_summary(), - ) - serialized = status_distribution_summary.model_dump() - deserialized = StatusDistributionSummary.model_validate(serialized) - - for key, value in vars(status_distribution_summary).items(): - for child_key, child_value in vars(value).items(): - assert getattr(getattr(deserialized, key), child_key) == child_value - - -def test_status_distribution_summary_from_values(): - value_types: list[Literal["successful", "incomplete", "error"]] = [ - "successful", - "incomplete", - "error", - ] * 1000 - values = [float(val % 3) for val in range(3000)] - status_distribution_summary = StatusDistributionSummary.from_values( - value_types, values - ) - assert status_distribution_summary.total.count == len(values) - assert status_distribution_summary.total.mean == pytest.approx(np.mean(values)) - assert status_distribution_summary.total.cumulative_distribution_function is None - assert status_distribution_summary.successful.mean == pytest.approx( - np.mean( - [val for ind, val in enumerate(values) if value_types[ind] == "successful"] - ) - ) - assert status_distribution_summary.successful.count == len( - [val for ind, val in enumerate(values) if value_types[ind] == "successful"] - ) - assert ( - status_distribution_summary.successful.cumulative_distribution_function is None - ) - assert status_distribution_summary.incomplete.mean == pytest.approx( - np.mean( - [val for ind, val in enumerate(values) if value_types[ind] == "incomplete"] - ) - ) - assert status_distribution_summary.incomplete.count == len( - [val for ind, val in enumerate(values) if value_types[ind] == "incomplete"] - ) - assert ( - status_distribution_summary.incomplete.cumulative_distribution_function is None - ) - assert status_distribution_summary.errored.mean == pytest.approx( - np.mean([val for ind, val in enumerate(values) if value_types[ind] == "error"]) - ) - assert status_distribution_summary.errored.count == len( - [val for ind, val in enumerate(values) if value_types[ind] == "error"] - ) - assert status_distribution_summary.errored.cumulative_distribution_function is None - - status_distribution_summary_cdf = StatusDistributionSummary.from_values( - value_types, values, include_cdf=True - ) - assert ( - status_distribution_summary_cdf.total.cumulative_distribution_function - is not None - ) - assert ( - status_distribution_summary_cdf.successful.cumulative_distribution_function - is not None - ) - assert ( - status_distribution_summary_cdf.incomplete.cumulative_distribution_function - is not None - ) - assert ( - status_distribution_summary_cdf.errored.cumulative_distribution_function - is not None - ) - - -def test_status_distribution_summary_from_request_times(): - request_types: list[Literal["successful", "incomplete", "error"]] = [ - "successful", - "incomplete", - "error", - ] * 1000 - requests = [((val % 3) / 10, (val % 3) / 10 + 1) for val in range(3000)] - status_distribution_summary = StatusDistributionSummary.from_request_times( - request_types, requests, distribution_type="concurrency" - ) - assert status_distribution_summary.total.mean == pytest.approx(2500.0, abs=0.01) - assert status_distribution_summary.total.cumulative_distribution_function is None - assert status_distribution_summary.successful.mean == pytest.approx( - 1000.0, abs=0.01 - ) - assert ( - status_distribution_summary.successful.cumulative_distribution_function is None - ) - assert status_distribution_summary.incomplete.mean == pytest.approx( - 1000.0, abs=0.01 - ) - assert ( - status_distribution_summary.incomplete.cumulative_distribution_function is None - ) - assert status_distribution_summary.errored.mean == pytest.approx(1000.0, abs=0.01) - assert status_distribution_summary.errored.cumulative_distribution_function is None - - status_distribution_summary_cdf = StatusDistributionSummary.from_request_times( - request_types, requests, distribution_type="concurrency", include_cdf=True - ) - assert ( - status_distribution_summary_cdf.total.cumulative_distribution_function - is not None - ) - assert ( - status_distribution_summary_cdf.successful.cumulative_distribution_function - is not None - ) - assert ( - status_distribution_summary_cdf.incomplete.cumulative_distribution_function - is not None - ) - assert ( - status_distribution_summary_cdf.errored.cumulative_distribution_function - is not None - ) - - -def test_status_distribution_summary_from_iterable_request_times(): - request_types: list[Literal["successful", "incomplete", "error"]] = [ - "successful", - "incomplete", - "error", - ] * 1000 - requests = [(val % 3 / 10, val % 3 / 10 + 1) for val in range(3000)] - first_iter_times = [val % 3 / 10 + 0.1 for val in range(3000)] - iter_counts = [9 for _ in range(3000)] - first_iter_counts = [1 for _ in range(3000)] - status_distribution_summary = StatusDistributionSummary.from_iterable_request_times( - request_types, - requests, - first_iter_times, - iter_counts, - first_iter_counts, - ) - assert status_distribution_summary.total.mean == pytest.approx(21666.66, abs=0.01) - assert status_distribution_summary.total.cumulative_distribution_function is None - assert status_distribution_summary.successful.mean == pytest.approx( - 8000.0, abs=0.01 - ) - assert ( - status_distribution_summary.successful.cumulative_distribution_function is None - ) - assert status_distribution_summary.incomplete.mean == pytest.approx( - 8000.0, abs=0.01 - ) - assert ( - status_distribution_summary.incomplete.cumulative_distribution_function is None - ) - assert status_distribution_summary.errored.mean == pytest.approx(8000.0, abs=0.01) - assert status_distribution_summary.errored.cumulative_distribution_function is None - - status_distribution_summary_cdf = ( - StatusDistributionSummary.from_iterable_request_times( - request_types, - requests, - first_iter_times, - iter_counts, - first_iter_counts, - include_cdf=True, - ) - ) - assert ( - status_distribution_summary_cdf.total.cumulative_distribution_function - is not None - ) - assert ( - status_distribution_summary_cdf.successful.cumulative_distribution_function - is not None - ) - assert ( - status_distribution_summary_cdf.incomplete.cumulative_distribution_function - is not None - ) - assert ( - status_distribution_summary_cdf.errored.cumulative_distribution_function - is not None - ) - - -def test_running_stats_initialization(): - running_stats = RunningStats() - assert running_stats.start_time == pytest.approx(time.time(), abs=0.01) - assert running_stats.count == 0 - assert running_stats.total == 0 - assert running_stats.last == 0 - assert running_stats.mean == 0 - assert running_stats.rate == 0 - - -def test_running_stats_marshalling(): - running_stats = RunningStats() - serialized = running_stats.model_dump() - deserialized = RunningStats.model_validate(serialized) - - for key, value in vars(running_stats).items(): - assert getattr(deserialized, key) == value - - -def test_running_stats_update(): - running_stats = RunningStats() - running_stats.update(1) - assert running_stats.count == 1 - assert running_stats.total == 1 - assert running_stats.last == 1 - assert running_stats.mean == 1 - time.sleep(1.0) - assert running_stats.rate == pytest.approx( - 1.0 / (time.time() - running_stats.start_time), abs=0.1 - ) - - running_stats.update(2) - assert running_stats.count == 2 - assert running_stats.total == 3 - assert running_stats.last == 2 - assert running_stats.mean == 1.5 - time.sleep(1) - assert running_stats.rate == pytest.approx( - 3 / (time.time() - running_stats.start_time), abs=0.1 - ) - - -def test_running_stats_add(): - running_stats = RunningStats() - mean = running_stats + 1 - assert mean == 1 - assert mean == running_stats.mean - assert running_stats.count == 1 - assert running_stats.total == 1 - assert running_stats.last == 1 - - -def test_running_stats_iadd(): - running_stats = RunningStats() - running_stats += 1 - assert running_stats.count == 1 - assert running_stats.total == 1 - assert running_stats.last == 1 - assert running_stats.mean == 1 - - -def test_time_running_stats_initialization(): - time_running_stats = TimeRunningStats() - assert time_running_stats.start_time == pytest.approx(time.time(), abs=0.01) - assert time_running_stats.count == 0 - assert time_running_stats.total == 0 - assert time_running_stats.last == 0 - assert time_running_stats.mean == 0 - assert time_running_stats.rate == 0 - assert time_running_stats.total_ms == 0 - assert time_running_stats.last_ms == 0 - assert time_running_stats.mean_ms == 0 - assert time_running_stats.rate_ms == 0 - - -def test_time_running_stats_marshalling(): - time_running_stats = TimeRunningStats() - serialized = time_running_stats.model_dump() - deserialized = TimeRunningStats.model_validate(serialized) - - for key, value in vars(time_running_stats).items(): - assert getattr(deserialized, key) == value - - -def test_time_running_stats_update(): - time_running_stats = TimeRunningStats() - time_running_stats.update(1) - assert time_running_stats.count == 1 - assert time_running_stats.total == 1 - assert time_running_stats.last == 1 - assert time_running_stats.mean == 1 - assert time_running_stats.total_ms == 1000 - assert time_running_stats.last_ms == 1000 - assert time_running_stats.mean_ms == 1000 - time.sleep(1.0) - assert time_running_stats.rate == pytest.approx( - 1.0 / (time.time() - time_running_stats.start_time), abs=0.1 - ) - assert time_running_stats.rate_ms == pytest.approx( - 1000 / (time.time() - time_running_stats.start_time), abs=0.1 - ) - - time_running_stats.update(2) - assert time_running_stats.count == 2 - assert time_running_stats.total == 3 - assert time_running_stats.last == 2 - assert time_running_stats.mean == 1.5 - assert time_running_stats.total_ms == 3000 - assert time_running_stats.last_ms == 2000 - assert time_running_stats.mean_ms == 1500 - time.sleep(1) - assert time_running_stats.rate == pytest.approx( - 3 / (time.time() - time_running_stats.start_time), abs=0.1 - ) - assert time_running_stats.rate_ms == pytest.approx( - 3000 / (time.time() - time_running_stats.start_time), abs=0.1 - ) diff --git a/tests/unit/presentation/test_injector.py b/tests/unit/presentation/test_injector.py index cdaa7619..9d97d021 100644 --- a/tests/unit/presentation/test_injector.py +++ b/tests/unit/presentation/test_injector.py @@ -3,8 +3,8 @@ import pytest from pydantic import BaseModel -from guidellm.config import settings from guidellm.presentation.injector import create_report, inject_data +from guidellm.settings import settings class ExampleModel(BaseModel): diff --git a/tests/unit/scheduler/test_objects.py b/tests/unit/scheduler/test_objects.py index d1be6e94..df794ff8 100644 --- a/tests/unit/scheduler/test_objects.py +++ b/tests/unit/scheduler/test_objects.py @@ -2,7 +2,6 @@ import inspect import typing -from abc import ABC from collections.abc import AsyncIterator from typing import Any, Optional, TypeVar, Union @@ -14,7 +13,6 @@ BackendInterface, BackendT, MeasuredRequestTimings, - MeasuredRequestTimingsT, MultiTurnRequestT, RequestSchedulerTimings, RequestT, @@ -43,14 +41,6 @@ def test_response_t(): assert ResponseT.__constraints__ == () -def test_request_timings_t(): - """Validate MeasuredRequestTimingsT is a TypeVar bound to MeasuredRequestTimings.""" - assert isinstance(MeasuredRequestTimingsT, TypeVar) - assert MeasuredRequestTimingsT.__name__ == "MeasuredRequestTimingsT" - assert MeasuredRequestTimingsT.__bound__ == MeasuredRequestTimings - assert MeasuredRequestTimingsT.__constraints__ == () - - def test_backend_t(): """Validate that BackendT is a TypeVar bound to BackendInterface.""" assert isinstance(BackendT, TypeVar) @@ -76,18 +66,10 @@ def test_multi_turn_request_t(): class TestBackendInterface: """Test the BackendInterface abstract base class.""" - @pytest.mark.smoke - def test_is_abstract_base_class(self): - """Test that BackendInterface is an ABC and cannot be instantiated directly.""" - assert issubclass(BackendInterface, ABC) - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - BackendInterface() - @pytest.mark.smoke def test_abstract_methods_defined(self): """Test that all expected abstract methods are defined.""" expected_methods = { - "info", "process_startup", "validate", "process_shutdown", @@ -96,6 +78,7 @@ def test_abstract_methods_defined(self): expected_properties = { "processes_limit", "requests_limit", + "info", } for method_name in expected_methods: @@ -112,50 +95,26 @@ def test_abstract_methods_defined(self): def test_generic_type_parameters(self): """Test that BackendInterface has the correct generic type parameters.""" orig_bases = BackendInterface.__orig_bases__ - abc_base = None + protocol_base = None generic_base = None for base in orig_bases: if hasattr(base, "__origin__"): if base.__origin__ is typing.Generic: generic_base = base - elif base.__name__ == "ABC": - abc_base = base + elif base.__name__ == "Protocol": + protocol_base = base - assert abc_base is not None, "Should inherit from ABC" + assert protocol_base is not None, "Should inherit from Protocol" assert generic_base is not None, "Should inherit from Generic" if hasattr(generic_base, "__args__"): type_params = generic_base.__args__ assert len(type_params) == 3, "Should have 3 type parameters" param_names = [param.__name__ for param in type_params] - expected_names = ["RequestT", "MeasuredRequestTimingsT", "ResponseT"] + expected_names = ["RequestT", "ResponseT"] assert param_names == expected_names - @pytest.mark.sanity - def test_invalid_implementation(self): - """Test that a concrete implementation must implement all abstract methods.""" - - class PartialBackend(BackendInterface): - @property - def processes_limit(self): - return 1 - - @property - def requests_limit(self): - return 10 - - def info(self): - return {} - - async def process_startup(self): - pass - - # Missing: validate, process_shutdown, resolve - - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - PartialBackend() - @pytest.mark.smoke def test_implementation_construction(self): """Test that a complete concrete implementation can be instantiated.""" @@ -169,6 +128,7 @@ def processes_limit(self) -> int | None: def requests_limit(self) -> int | None: return 100 + @property def info(self) -> dict[str, Any]: return {"model": "test", "version": "1.0"} @@ -184,11 +144,9 @@ async def process_shutdown(self) -> None: async def resolve( self, request: str, - request_info: ScheduledRequestInfo[MeasuredRequestTimings], + request_info: ScheduledRequestInfo, history: list[tuple[str, str]] | None = None, - ) -> AsyncIterator[ - tuple[str, ScheduledRequestInfo[MeasuredRequestTimings]] - ]: + ) -> AsyncIterator[tuple[str, ScheduledRequestInfo]]: yield f"Response to: {request}", request_info backend = ConcreteBackend() @@ -196,12 +154,12 @@ async def resolve( assert isinstance(backend, ConcreteBackend) assert backend.processes_limit == 4 assert backend.requests_limit == 100 - info = backend.info() + info = backend.info assert info == {"model": "test", "version": "1.0"} @pytest.mark.smoke @pytest.mark.asyncio - async def test_implementation_async_methods(self): + async def test_implementation_async_methods(self): # noqa: C901 """Test that async methods work correctly in concrete implementation.""" class AsyncBackend(BackendInterface[dict, MeasuredRequestTimings, dict]): @@ -218,6 +176,7 @@ def processes_limit(self) -> int | None: def requests_limit(self) -> int | None: return None # Unlimited + @property def info(self) -> dict[str, Any]: return {"backend": "async_test"} @@ -233,11 +192,9 @@ async def process_shutdown(self) -> None: async def resolve( self, request: dict, - request_info: ScheduledRequestInfo[MeasuredRequestTimings], + request_info: ScheduledRequestInfo, history: list[tuple[dict, dict]] | None = None, - ) -> AsyncIterator[ - tuple[dict, ScheduledRequestInfo[MeasuredRequestTimings]] - ]: + ) -> AsyncIterator[tuple[dict, ScheduledRequestInfo]]: response = {"result": request.get("input", ""), "status": "success"} yield response, request_info @@ -271,9 +228,14 @@ async def resolve( @pytest.mark.smoke def test_method_signatures(self): """Test that abstract methods have the expected signatures.""" - info_sig = inspect.signature(BackendInterface.info) - assert len(info_sig.parameters) == 1 - assert list(info_sig.parameters.keys()) == ["self"] + info_prop = BackendInterface.info + assert isinstance(info_prop, property) + + processes_limit_prop = BackendInterface.processes_limit + assert isinstance(processes_limit_prop, property) + + requests_limit_prop = BackendInterface.requests_limit + assert isinstance(requests_limit_prop, property) startup_sig = inspect.signature(BackendInterface.process_startup) assert len(startup_sig.parameters) == 1 # Only self @@ -302,6 +264,7 @@ class TestRequestSchedulerTimings: "targeted_start", "queued", "dequeued", + "scheduled_at", "resolve_start", "resolve_end", "finalized", @@ -314,6 +277,7 @@ class TestRequestSchedulerTimings: "targeted_start": None, "queued": None, "dequeued": None, + "scheduled_at": None, "resolve_start": None, "resolve_end": None, "finalized": None, @@ -322,12 +286,14 @@ class TestRequestSchedulerTimings: "targeted_start": 1000.0, "queued": 200.0, "dequeued": 800.0, + "scheduled_at": 900.0, "resolve_start": 1000.5, "resolve_end": 1100.0, "finalized": 1100.5, }, { "queued": 200.0, + "scheduled_at": 250.0, "resolve_start": 1000.5, "resolve_end": 1100.0, }, @@ -335,6 +301,7 @@ class TestRequestSchedulerTimings: "targeted_start": 0.0, "queued": 0.0, "dequeued": 0.0, + "scheduled_at": 0.0, "resolve_start": 0.0, "resolve_end": 0.0, "finalized": 0.0, @@ -388,6 +355,7 @@ def test_initialization(self, valid_instances): ("targeted_start", "invalid_string"), ("queued", "invalid_string"), ("dequeued", [1, 2, 3]), + ("scheduled_at", {"key": "value"}), ("resolve_start", {"key": "value"}), ("resolve_end", [1, 2, 3]), ("finalized", object()), diff --git a/tests/unit/scheduler/test_strategy.py b/tests/unit/scheduler/test_strategy.py index f06707e7..8cb91d82 100644 --- a/tests/unit/scheduler/test_strategy.py +++ b/tests/unit/scheduler/test_strategy.py @@ -5,7 +5,7 @@ import statistics import time from abc import ABC -from typing import TypeVar +from typing import Literal, TypeVar import pytest from pydantic import ValidationError @@ -234,7 +234,7 @@ def test_lifecycle( completion_time = time.time() + offset request_times.append(completion_time) - mock_request = ScheduledRequestInfo( + mock_request: ScheduledRequestInfo = ScheduledRequestInfo( request_id=f"test-{index}", status="completed", scheduler_node_id=0, @@ -565,7 +565,7 @@ def test_invalid_implementation(self): """Test that invalid implementations raise NotImplementedError.""" class InvalidStrategy(SchedulingStrategy): - type_: str = "strategy" + type_: Literal["strategy"] = "strategy" # type: ignore[assignment,annotation-unchecked] strategy = InvalidStrategy() with pytest.raises(NotImplementedError): @@ -576,7 +576,7 @@ def test_concrete_implementation(self): """Test that concrete implementations can be constructed.""" class TestStrategy(SchedulingStrategy): - type_: str = "strategy" + type_: Literal["strategy"] = "strategy" # type: ignore[assignment,annotation-unchecked] def create_request_timings( self, diff --git a/tests/unit/scheduler/test_worker.py b/tests/unit/scheduler/test_worker.py index e7eba9b2..a2ad99c3 100644 --- a/tests/unit/scheduler/test_worker.py +++ b/tests/unit/scheduler/test_worker.py @@ -1,36 +1,34 @@ from __future__ import annotations import asyncio -import contextlib import inspect -import math -import threading +import random import time -from collections import defaultdict +from dataclasses import dataclass from functools import wraps -from multiprocessing import Barrier, Event, Queue +from multiprocessing import Barrier, Event, Process from multiprocessing.synchronize import Barrier as ProcessingBarrier from multiprocessing.synchronize import Event as ProcessingEvent -from queue import Empty -from typing import Any, Callable, Generic, Literal -from unittest.mock import AsyncMock, patch +from typing import Any, Generic, Literal import pytest +import pytest_asyncio from guidellm.scheduler import ( BackendInterface, + ConstantRateRequestTimings, LastCompletionRequestTimings, MeasuredRequestTimings, + NoDelayRequestTimings, + PoissonRateRequestTimings, ScheduledRequestInfo, ScheduledRequestTimings, + SchedulerMessagingPydanticRegistry, WorkerProcess, ) -from guidellm.scheduler.strategy import ( - ConstantRateRequestTimings, - NoDelayRequestTimings, - PoissonRateRequestTimings, -) -from guidellm.utils import MsgpackEncoding, random +from guidellm.utils import InterProcessMessagingQueue + +STANDARD_NUM_REQUESTS: int = 200 def async_timeout(delay): @@ -44,6 +42,18 @@ async def new_func(*args, **kwargs): return decorator +@dataclass +class TimingsBounds: + exact: float | None = None + lower: float | None = None + upper: float | None = None + prev_request: Literal["greater", "greater_equal", "less", "less_equal"] | None = ( + None + ) + tolerance: float = 10e-4 + actual_tolerance: float = 10e-4 + + class MockRequestTimings(MeasuredRequestTimings): """Mock timing implementation for testing.""" @@ -53,11 +63,13 @@ class MockBackend(BackendInterface): def __init__( self, - delay: float = 0.01, + lifecycle_delay: float = 0.1, + resolve_delay: float = 0.0, should_fail: bool = False, request_error_rate: float = 0.0, ): - self.delay = delay + self.lifecycle_delay = lifecycle_delay + self.resolve_delay = resolve_delay self.should_fail = should_fail self.request_error_rate = request_error_rate self.process_startup_called = False @@ -73,100 +85,103 @@ def processes_limit(self) -> int | None: def requests_limit(self) -> int | None: return None + @property def info(self) -> dict[str, Any]: - return {"type": "mock", "delay": self.delay} + return { + "type": "mock", + "lifecycle_delay": self.lifecycle_delay, + "resolve_delay": self.resolve_delay, + } async def process_startup(self): - await asyncio.sleep(self.delay) + await asyncio.sleep(self.lifecycle_delay) self.process_startup_called = True async def validate(self): - await asyncio.sleep(self.delay) + await asyncio.sleep(self.lifecycle_delay) self.validate_called = True if self.should_fail: raise RuntimeError("Mock validation failed") async def process_shutdown(self): - await asyncio.sleep(0.1) + await asyncio.sleep(self.lifecycle_delay) self.process_shutdown_called = True async def resolve(self, request, request_info, request_history): self.resolve_called = True - await asyncio.sleep(self.delay) + await asyncio.sleep( + self.resolve_delay if not str(request).startswith("cancel") else 1000.0 + ) if self.should_fail: raise RuntimeError("Mock resolve failed") if self.request_error_rate > 0.0 and random.random() < self.request_error_rate: raise RuntimeError("Mock resolve failed") - yield f"response_for_{request}" + yield f"response_for_{request}", request_info class TestWorkerProcess: """Test suite for WorkerProcess class.""" - @pytest.fixture( + @pytest_asyncio.fixture( params=[ { - "local_rank": 0, - "local_world_size": 2, - "async_limit": 5, - "poll_intervals": 0.01, - }, - { - "local_rank": 1, - "local_world_size": 3, - "async_limit": 10, - "poll_intervals": 0.05, + "messaging": { + "serialization": "dict", + "encoding": None, + "max_buffer_receive_size": 2, + }, + "worker": { + "async_limit": 1, + }, }, { - "local_rank": 2, - "local_world_size": 4, - "async_limit": 1, - "poll_intervals": 0.1, + "messaging": { + "serialization": "dict", + "encoding": None, + "max_buffer_receive_size": 100, + }, + "worker": { + "async_limit": 1000, + }, }, ], - ids=["basic_config", "multi_worker", "single_async"], ) - def valid_instances(self, request): + async def valid_instances(self, request): """Fixture providing test data for WorkerProcess.""" constructor_args = request.param - backend = MockBackend() - request_timings = LastCompletionRequestTimings() - - instance = WorkerProcess( - startup_barrier=Barrier(constructor_args["local_world_size"]), - shutdown_event=Event(), - error_event=Event(), - requests_queue=Queue(), - updates_queue=Queue(), - backend=backend, - request_timings=request_timings, - **constructor_args, + main_messaging = InterProcessMessagingQueue( + **constructor_args["messaging"], poll_interval=0.01 ) - return instance, constructor_args - - @pytest.fixture - def worker_process(self): - """Create a WorkerProcess instance for testing.""" - backend = MockBackend() - request_timings = LastCompletionRequestTimings() - return WorkerProcess( - local_rank=0, - local_world_size=2, - async_limit=5, - startup_barrier=Barrier(2), - shutdown_event=Event(), - error_event=Event(), - requests_queue=Queue(), - updates_queue=Queue(), - backend=backend, - request_timings=request_timings, - poll_intervals=0.01, - ) + try: + instance = WorkerProcess( + messaging=main_messaging.create_worker_copy(0), + backend=MockBackend(), + request_timings=LastCompletionRequestTimings(), + **constructor_args["worker"], + startup_barrier=Barrier(2), + requests_generated_event=Event(), + constraint_reached_event=Event(), + shutdown_event=Event(), + error_event=Event(), + ) + await main_messaging.start( + pydantic_models=list( + SchedulerMessagingPydanticRegistry.registry.values() + ) + ) + yield instance, main_messaging, constructor_args + finally: + await main_messaging.stop() @pytest.mark.smoke - def test_class_signatures(self, worker_process: WorkerProcess): + def test_class_signatures( + self, + valid_instances: tuple[WorkerProcess, InterProcessMessagingQueue, dict], + ): """Test inheritance and type relationships.""" + worker_process, main_messaging, constructor_args = valid_instances + # Class assert isinstance(worker_process, Generic) assert issubclass(WorkerProcess, Generic) @@ -184,7 +199,7 @@ def test_class_signatures(self, worker_process: WorkerProcess): ) assert generic_base is not None type_args = getattr(generic_base, "__args__", ()) - assert len(type_args) == 3 # RequestT, MeasuredRequestTimingsT, ResponseT + assert len(type_args) == 2 # RequestT, ResponseT # Function signatures run_sig = inspect.signature(WorkerProcess.run) @@ -195,48 +210,60 @@ def test_class_signatures(self, worker_process: WorkerProcess): assert len(run_async_sig.parameters) == 1 assert "self" in run_async_sig.parameters - stop_processing_sig = inspect.signature(WorkerProcess.run_async_stop_processing) + stop_processing_sig = inspect.signature(WorkerProcess._stop_monitor) assert len(stop_processing_sig.parameters) == 1 assert "self" in stop_processing_sig.parameters - requests_processing_sig = inspect.signature( - WorkerProcess.run_async_requests_processing - ) + requests_processing_sig = inspect.signature(WorkerProcess._process_requests) assert len(requests_processing_sig.parameters) == 1 assert "self" in requests_processing_sig.parameters @pytest.mark.smoke - def test_initialization(self, valid_instances): + def test_initialization( + self, + valid_instances: tuple[WorkerProcess, InterProcessMessagingQueue, dict], + ): """Test basic initialization of WorkerProcess.""" - instance, constructor_args = valid_instances - - # worker info - assert instance.local_rank == constructor_args["local_rank"] - assert instance.local_world_size == constructor_args["local_world_size"] - assert instance.async_limit == constructor_args["async_limit"] + instance, main_messaging, constructor_args = valid_instances + + # messaging + assert instance.messaging is not None + assert isinstance(instance.messaging, InterProcessMessagingQueue) + assert instance.messaging is not main_messaging + assert instance.messaging.worker_index is not None + assert instance.messaging.worker_index == 0 + assert ( + instance.messaging.serialization + == constructor_args["messaging"]["serialization"] + ) + assert instance.messaging.encoding == constructor_args["messaging"]["encoding"] + assert ( + instance.messaging.max_buffer_receive_size + == constructor_args["messaging"]["max_buffer_receive_size"] + ) - # process synchronization + # worker + assert instance.async_limit == constructor_args["worker"]["async_limit"] + assert instance.startup_barrier is not None assert isinstance(instance.startup_barrier, ProcessingBarrier) + assert instance.shutdown_event is not None assert isinstance(instance.shutdown_event, ProcessingEvent) + assert instance.error_event is not None assert isinstance(instance.error_event, ProcessingEvent) - assert hasattr(instance.requests_queue, "put") - assert hasattr(instance.requests_queue, "get") - assert hasattr(instance.updates_queue, "put") - assert hasattr(instance.updates_queue, "get") - - # local synchronization - assert instance.pending_requests_queue is None - assert instance.pending_updates_queue is None - - # request processing + assert instance.requests_generated_event is not None + assert isinstance(instance.requests_generated_event, ProcessingEvent) + assert instance.constraint_reached_event is not None + assert isinstance(instance.constraint_reached_event, ProcessingEvent) + assert instance.backend is not None assert isinstance(instance.backend, MockBackend) - assert instance.poll_intervals == constructor_args["poll_intervals"] + assert instance.request_timings is not None assert isinstance(instance.request_timings, LastCompletionRequestTimings) - assert instance.startup_completed is False + assert not instance.startup_completed @pytest.mark.sanity def test_invalid_initialization(self): """Test that invalid initialization raises appropriate errors.""" + # Test with missing required parameters with pytest.raises(TypeError): WorkerProcess() @@ -247,36 +274,34 @@ def test_invalid_initialization(self): barrier = Barrier(2) shutdown_event = Event() error_event = Event() - requests_queue = Queue() - updates_queue = Queue() + requests_generated_event = Event() + constraint_reached_event = Event() + messaging = InterProcessMessagingQueue() # Test missing each required parameter one by one required_params = [ - "local_rank", - "local_world_size", + "messaging", + "backend", + "request_timings", "async_limit", "startup_barrier", + "requests_generated_event", + "constraint_reached_event", "shutdown_event", "error_event", - "requests_queue", - "updates_queue", - "backend", - "request_timings", ] for param_to_remove in required_params: kwargs = { - "local_rank": 0, - "local_world_size": 2, + "messaging": messaging, + "backend": backend, + "request_timings": request_timings, "async_limit": 5, "startup_barrier": barrier, + "requests_generated_event": requests_generated_event, + "constraint_reached_event": constraint_reached_event, "shutdown_event": shutdown_event, "error_event": error_event, - "requests_queue": requests_queue, - "updates_queue": updates_queue, - "backend": backend, - "request_timings": request_timings, - "poll_intervals": 0.01, } del kwargs[param_to_remove] @@ -284,755 +309,364 @@ def test_invalid_initialization(self): with pytest.raises(TypeError): WorkerProcess(**kwargs) - @pytest.mark.smoke - @patch("asyncio.run") - def test_run(self, mock_asyncio_run, worker_process: WorkerProcess): - """ - Test that run method functions as expected (calls run_async, handles errors) - """ - # Test successful execution - with patch.object( - worker_process, "run_async", new_callable=AsyncMock - ) as mock_run_async: - worker_process.run() - mock_asyncio_run.assert_called_once() - mock_run_async.assert_called_once() - - mock_asyncio_run.reset_mock() - - # Test exception during execution - test_exception = RuntimeError("Test error in run_async") - with patch.object( - worker_process, "run_async", new_callable=AsyncMock - ) as mock_run_async: - mock_asyncio_run.side_effect = test_exception - - with pytest.raises( - RuntimeError, match="Worker process 0 encountered an error" - ): - worker_process.run() - - assert worker_process.error_event.is_set() - @pytest.mark.smoke @pytest.mark.asyncio - @async_timeout(5.0) + # @async_timeout(15) @pytest.mark.parametrize( - ("stop_action", "req_action"), + ("num_requests", "num_canceled", "error_rate"), [ - ("complete_short", "complete_short"), - ("complete_long", "error"), - ("error", "complete_long"), - ("error", "error"), - ("complete_long", "cancel"), - ("cancel", "complete_long"), - ("cancel", "cancel"), + (20, 0, 0), + (STANDARD_NUM_REQUESTS, 20, 0.5), ], ) - async def test_run_async( # noqa: C901 + async def test_run_async_lifecycle( # noqa: C901, PLR0912 self, - worker_process: WorkerProcess, - stop_action: Literal["complete_short", "complete_long", "error", "cancel"], - req_action: Literal["complete_short", "complete_long", "error", "cancel"], + valid_instances: tuple[WorkerProcess, InterProcessMessagingQueue, dict], + num_requests: int, + num_canceled: int, + error_rate: float, ): - def make_task(action: str, state: dict): - loops = {"error": 1, "cancel": 2, "complete_short": 3, "complete_long": 50}[ - action - ] - - async def _run(self): - state.update(called=True, iterations=0) - try: - for _ in range(loops): - await asyncio.sleep(0.01) - state["iterations"] += 1 - if action == "error": - state["errored"] = True - raise RuntimeError(state["error_message"]) - if action == "cancel": - state["cancelled"] = True - raise asyncio.CancelledError(state["cancel_message"]) - if action == "complete_short": - state["completed_short"] = True - if action == "complete_long": - state["completed_long"] = True - except asyncio.CancelledError: - state["cancelled"] = True - raise - - return _run, loops - - def init_state(prefix): - return { - "called": False, - "iterations": 0, - "completed_short": False, - "completed_long": False, - "errored": False, - "cancelled": False, - "error_message": f"{prefix} processing error", - "cancel_message": f"{prefix} processing cancelled", - } + """Test the asynchronous request processing of WorkerProcess.""" + instance, main_messaging, constructor_args = valid_instances + instance.backend.request_error_rate = error_rate + instance_task = asyncio.create_task(instance.run_async()) + + try: + await asyncio.to_thread(instance.startup_barrier.wait) + start_time = time.time() + + # Send regular requests + requests_tracker = {} + for index in range(num_requests): + request = f"request_{index}" + request_info = ScheduledRequestInfo( + request_id=request, + scheduler_start_time=start_time, + scheduler_process_id=0, + ) + request_info.scheduler_timings.queued = time.time() + requests_tracker[request] = { + "sent": True, + "received_pending": 0, + "received_in_progress": 0, + "received_resolved": 0, + } + await main_messaging.put( + (request, request_info), + timeout=2.0, + ) - stop_state, req_state = init_state("Stop"), init_state("Requests") - stop_fn, stop_loops = make_task(stop_action, stop_state) - req_fn, req_loops = make_task(req_action, req_state) + # Process regular requests + error_count = 0 + for _ in range(num_requests * 3): + # Each request must have a pending, in_progress, and resolution + response, request, request_info = await main_messaging.get(timeout=2.0) + assert request is not None + assert request_info is not None + assert request_info.request_id is not None + assert request_info.status is not None + assert request_info.scheduler_node_id > -1 + assert request_info.scheduler_process_id > -1 + assert request_info.scheduler_start_time == start_time + assert request_info.scheduler_timings is not None + assert request_info.scheduler_timings.targeted_start is not None + assert request_info.scheduler_timings.targeted_start >= start_time + + if request_info.status == "pending": + requests_tracker[request]["received_pending"] += 1 + assert request_info.scheduler_timings.dequeued is not None + assert ( + request_info.scheduler_timings.dequeued + >= request_info.scheduler_timings.targeted_start + ) + elif request_info.status == "in_progress": + requests_tracker[request]["received_in_progress"] += 1 + assert request_info.scheduler_timings.scheduled_at is not None + assert ( + request_info.scheduler_timings.scheduled_at + >= request_info.scheduler_timings.dequeued + ) + assert request_info.scheduler_timings.resolve_start is not None + assert ( + request_info.scheduler_timings.resolve_start + >= request_info.scheduler_timings.scheduled_at + ) + elif request_info.status == "completed": + assert response == f"response_for_{request}" + requests_tracker[request]["received_resolved"] += 1 + assert request_info.scheduler_timings.resolve_end is not None + assert ( + request_info.scheduler_timings.resolve_end + > request_info.scheduler_timings.resolve_start + ) + elif request_info.status == "errored": + assert response is None + requests_tracker[request]["received_resolved"] += 1 + error_count += 1 + assert request_info.scheduler_timings.resolve_end is not None + assert ( + request_info.scheduler_timings.resolve_end + > request_info.scheduler_timings.resolve_start + ) + else: + raise ValueError(f"Unexpected status: {request_info.status}") - expected_exc = RuntimeError if "error" in {stop_action, req_action} else None - with ( - patch.object( - type(worker_process), "run_async_stop_processing", new=stop_fn - ), - patch.object( - type(worker_process), "run_async_requests_processing", new=req_fn - ), - ): - if expected_exc: - with pytest.raises(expected_exc): - await worker_process.run_async() - else: - await worker_process.run_async() - - assert stop_state["called"] - assert req_state["called"] - - # build unified expected outcome table - def is_long(a): - return a == "complete_long" - - def is_short(a): - return a in {"complete_short", "error", "cancel"} - - expectations = { - "stop": { - "errored": stop_action == "error", - "cancelled": stop_action == "cancel" - or (is_short(req_action) and is_long(stop_action)) - or (req_action == "error" and is_long(stop_action)), - }, - "req": { - "errored": req_action == "error", - "cancelled": req_action == "cancel" - or (is_short(stop_action) and is_long(req_action)) - or (stop_action == "error" and is_long(req_action)), - }, - } + # Ensure correct error rate + assert float(error_count) / num_requests == pytest.approx( + error_rate, rel=0.2 + ) - # assert final state matches expectations - for label, (state, action) in { - "stop": (stop_state, stop_action), - "req": (req_state, req_action), - }.items(): - if expectations[label]["errored"]: - assert state["errored"] - if expectations[label]["cancelled"]: - assert state["cancelled"] - if action.startswith("complete_") and not expectations[label]["cancelled"]: - key = ( - "completed_short" - if action == "complete_short" - else "completed_long" + # Ensure no extra statuses + with pytest.raises(asyncio.TimeoutError): + await main_messaging.get(timeout=0.5) + + # Send cancel requests + for index in range(num_canceled): + cancel_request = f"cancel_request_{index}" + cancel_info = ScheduledRequestInfo( + request_id=request, + scheduler_start_time=start_time, + scheduler_process_id=0, + ) + cancel_info.scheduler_timings.queued = time.time() + requests_tracker[cancel_request] = { + "sent": True, + "received_pending": 0, + "received_in_progress": 0, + "received_resolved": 0, + } + await main_messaging.put( + (cancel_request, cancel_info), + timeout=2.0, ) - assert state[key] - @pytest.mark.smoke - @pytest.mark.asyncio - @async_timeout(3.0) - @pytest.mark.parametrize( - "stop_action", - ["error_event", "shutdown_event", "cancel_event"], - ) - async def test_run_async_stop_processing( - self, worker_process: WorkerProcess, stop_action - ): - # ensure initial state - assert not worker_process.error_event.is_set() - assert not worker_process.shutdown_event.is_set() - - action = stop_action - early_check_delay = 0.01 - trigger_delay = 0.05 - - task = asyncio.create_task(worker_process.run_async_stop_processing()) - time_start = time.time() - await asyncio.sleep(early_check_delay) - assert not task.done(), "Task finished before any stop signal was triggered" - - async def trigger(): - await asyncio.sleep(trigger_delay - early_check_delay) - if action == "error_event": - worker_process.error_event.set() - elif action == "shutdown_event": - worker_process.shutdown_event.set() - elif action == "cancel_event": - task.cancel() - - trigger_task = asyncio.create_task(trigger()) - - if action == "error_event": - with pytest.raises(RuntimeError): - await asyncio.wait_for(task, timeout=1.0) - elif action in {"shutdown_event", "cancel_event"}: - with pytest.raises(asyncio.CancelledError): - await asyncio.wait_for(task, timeout=1.0) - else: - raise ValueError(f"Unknown stop action: {action}") - - await asyncio.gather(trigger_task, return_exceptions=True) - - # validate correct ending states - elapsed = time.time() - time_start - assert elapsed >= trigger_delay - 0.01, ( - "Task completed too early: " - f"elapsed={elapsed:.3f}s < trigger={trigger_delay:.3f}s" - ) - if action == "error_event": - assert worker_process.error_event.is_set() - assert not worker_process.shutdown_event.is_set() - elif action == "shutdown_event": - assert worker_process.shutdown_event.is_set() - assert not worker_process.error_event.is_set() - elif action == "cancel_event": - assert not worker_process.error_event.is_set() - assert not worker_process.shutdown_event.is_set() + # Receive expected updates for cancel up to async number + for _ in range(2 * min(num_canceled, instance.async_limit)): + # Each processing request (up to async limit) will have pending, in_progress + response, request, request_info = await main_messaging.get(timeout=2.0) + if request_info.status == "pending": + requests_tracker[request]["received_pending"] += 1 + elif request_info.status == "in_progress": + requests_tracker[request]["received_in_progress"] += 1 + error_count += 1 + else: + raise ValueError(f"Unexpected status: {request_info.status}") + + # Signal constraints reached to start canceling + instance.constraint_reached_event.set() + await asyncio.sleep(0) + + # Receive the remaining canceled updates + for _ in range(num_canceled): + # All cancel requests should resolve with canceled (no other statuses) + response, request, request_info = await main_messaging.get(timeout=2.0) + assert request is not None + assert request_info is not None + assert request_info.request_id is not None + assert request_info.status is not None + assert request_info.scheduler_node_id > -1 + assert request_info.scheduler_process_id > -1 + assert request_info.scheduler_start_time == start_time + assert request_info.scheduler_timings is not None + + if request_info.status == "cancelled": + requests_tracker[request]["received_resolved"] += 1 + assert request_info.scheduler_timings.resolve_end is not None + assert request_info.scheduler_timings.resolve_end > start_time + else: + raise ValueError(f"Unexpected status: {request_info.status}") + + # Ensure no extra statuses + with pytest.raises(asyncio.TimeoutError): + await main_messaging.get(timeout=0.5) + + # Signal requests stop now that all requests have been processed + instance.requests_generated_event.set() + await asyncio.sleep(0) + + # Validate all the requests are correct + for request_key in [f"request_{index}" for index in range(num_requests)]: + assert request_key in requests_tracker + assert requests_tracker[request_key]["sent"] + assert requests_tracker[request_key]["received_pending"] == 1 + assert requests_tracker[request_key]["received_resolved"] == 1 + if request_key.startswith("request"): + assert requests_tracker[request_key]["received_in_progress"] == 1 + finally: + # Shut down + instance.shutdown_event.set() + await asyncio.wait_for(instance_task, timeout=2.0) @pytest.mark.smoke @pytest.mark.asyncio - @async_timeout(10.0) + @async_timeout(15) @pytest.mark.parametrize( - ("request_timings_const", "async_limit"), + ("request_timings", "timing_bounds"), [ - (lambda: LastCompletionRequestTimings(), 1), - (lambda: PoissonRateRequestTimings(rate=10000), 2), - (lambda: ConstantRateRequestTimings(rate=10000), 3), - (lambda: NoDelayRequestTimings(), 4), + ( + LastCompletionRequestTimings(offset=0.1), + [ + TimingsBounds(lower=0.1, prev_request="greater_equal") + for _ in range(STANDARD_NUM_REQUESTS) + ], + ), + ( + NoDelayRequestTimings(offset=0.05), + [ + TimingsBounds(lower=0.05, upper=0.05, actual_tolerance=1.0) + for _ in range(STANDARD_NUM_REQUESTS) + ], + ), + ( + ConstantRateRequestTimings(rate=100, offset=0.2), + [ + TimingsBounds( + exact=0.2 + ind * 0.01, + lower=0.2, + prev_request="greater", + actual_tolerance=10e-2, + ) + for ind in range(STANDARD_NUM_REQUESTS) + ], + ), + ( + PoissonRateRequestTimings(rate=200, offset=0.01), + [ + TimingsBounds(lower=0.01, prev_request="greater") + for ind in range(STANDARD_NUM_REQUESTS) + ], + ), + ], + ids=[ + "LastCompletion", + "NoDelay", + "ConstantRate", + "PoissonRate", ], ) - async def test_run_async_requests_processing( # noqa: C901 + async def test_run_with_timings( # noqa: C901, PLR0912 self, - request_timings_const: Callable[[], ScheduledRequestTimings], - async_limit: int, + valid_instances: tuple[WorkerProcess, InterProcessMessagingQueue, dict], + request_timings: ScheduledRequestTimings, + timing_bounds: list[TimingsBounds], ): - startup_barrier = Barrier(2) - requests_queue = Queue() - updates_queue = Queue() - backend = MockBackend(delay=0.001) - worker_process = WorkerProcess( - local_rank=0, - local_world_size=1, - async_limit=async_limit, - startup_barrier=startup_barrier, - shutdown_event=Event(), - error_event=Event(), - requests_queue=requests_queue, - updates_queue=updates_queue, - backend=backend, - request_timings=request_timings_const(), - poll_intervals=0.01, - ) - - def _trip_barrier_later(): - time.sleep(0.02) - with contextlib.suppress(RuntimeError): - # barrier may be aborted (suppressed) during cancellation - worker_process.startup_barrier.wait(timeout=1.0) - - threading.Thread(target=_trip_barrier_later, daemon=True).start() - - run_task = asyncio.create_task(worker_process.run_async_requests_processing()) - await asyncio.sleep(0.05) # small delay to allow start up first - - # validate start up - assert worker_process.backend.process_startup_called - assert worker_process.backend.validate_called - assert worker_process.pending_requests_queue is not None - assert worker_process.pending_updates_queue is not None - assert worker_process.startup_completed - - # ensure full processing of requests - for index in range(20): - requests_queue.put( - MsgpackEncoding.encode( + instance, main_messaging, constructor_args = valid_instances + instance.request_timings = request_timings + num_requests = STANDARD_NUM_REQUESTS + assert len(timing_bounds) == num_requests + + # Start process + process = Process(target=instance.run) + process.start() + + try: + await asyncio.to_thread(instance.startup_barrier.wait) + start_time = time.time() + 0.1 + + # Send regular requests + requests_tracker = {} + for ind in range(num_requests): + request = f"request_{ind}" + requests_tracker[request] = { + "sent": True, + "target_start_time": -1, + "actual_start_time": -1, + "received_pending": 0, + "received_in_progress": 0, + "received_resolved": 0, + } + await main_messaging.put( ( - f"req-{index}", - ScheduledRequestInfo[MeasuredRequestTimings]( - request_id=f"req-{index}", - status="queued", - scheduler_node_id=0, - scheduler_process_id=0, - scheduler_start_time=time.time(), - ), - ) - ) - ) - - updates = [] - num_failures = 0 - max_wait_time = 5.0 - start_time = time.time() - while time.time() - start_time < max_wait_time: - try: - update_message = updates_queue.get_nowait() - updates.append(MsgpackEncoding.decode(update_message)) - num_failures = 0 - except Empty: - num_failures += 1 - if len(updates) >= 40: # We got all expected updates - break - await asyncio.sleep(0.05) - - # validate updates are correct for each request - assert len(updates) == 40 - per_request = defaultdict(dict) - for update in updates: - response, request, info = update - if info.status == "in_progress": - per_request[info.request_id]["start"] = (response, request, info) - per_request[info.request_id]["targeted_start"] = ( - info.scheduler_timings.targeted_start + request, + ScheduledRequestInfo(scheduler_start_time=start_time), + ), + timeout=2.0, ) - per_request[info.request_id]["resolve_start"] = ( - info.scheduler_timings.resolve_start - ) - elif info.status == "completed": - per_request[info.request_id]["complete"] = (response, request, info) - per_request[info.request_id]["resolve_end"] = ( - info.scheduler_timings.resolve_end - ) - assert len(per_request) == 20 - assert all( - "start" in parts and "complete" in parts for parts in per_request.values() - ) - # validate request times match expected - last_targeted_start = -1 * math.inf - for index in range(20): - targeted_start = per_request[f"req-{index}"]["targeted_start"] - resolve_start = per_request[f"req-{index}"]["resolve_start"] - resolve_end = per_request[f"req-{index}"]["resolve_end"] - assert targeted_start >= last_targeted_start - assert targeted_start < resolve_start - assert resolve_start == pytest.approx(targeted_start) - assert resolve_end == pytest.approx(resolve_start + backend.delay) - - # Validate concurrency limits are respected - events = [] - for req_id in per_request: - events.append((per_request[req_id]["resolve_start"], 1)) - events.append((per_request[req_id]["resolve_end"], -1)) - events.sort() - max_concurrent = concurrent = 0 - for _, delta in events: - concurrent += delta - max_concurrent = max(max_concurrent, concurrent) - assert max_concurrent <= async_limit - - # validate cancellation - backend.delay = 10 - # max concurrent for backend + 2 queued for backend - num_cancel_tasks = (async_limit + 2) * 2 - for index in range(20, 20 + num_cancel_tasks): - requests_queue.put( - MsgpackEncoding.encode( - ( - f"req-{index}", - ScheduledRequestInfo[MeasuredRequestTimings]( - request_id=f"req-{index}", - status="queued", - scheduler_node_id=0, - scheduler_process_id=0, - scheduler_start_time=time.time(), - ), + # Process regular requests + for _ in range(num_requests * 3): + # Each request must have pending, in_progress, and resolved statuses + response, request, request_info = await main_messaging.get(timeout=2.0) + if request_info.status == "pending": + requests_tracker[request]["received_pending"] += 1 + elif request_info.status == "in_progress": + requests_tracker[request]["received_in_progress"] += 1 + requests_tracker[request]["target_start_time"] = ( + request_info.scheduler_timings.targeted_start ) - ) - ) - await asyncio.sleep(0.5) - run_task.cancel() - await asyncio.gather(run_task, return_exceptions=True) - assert worker_process.backend.process_shutdown_called - assert worker_process.pending_requests_queue is None - assert worker_process.pending_updates_queue is None - - # validate canceled tasks - updates = [] - num_failures = 0 - while True: - try: - update_message = updates_queue.get_nowait() - updates.append(MsgpackEncoding.decode(update_message)) - except Empty: - num_failures += 1 - if num_failures > 3: - break - await asyncio.sleep(0.1) - # Ensure we get all updates we expected (async_limit for pending + 2 for queued) - assert len(updates) >= 2 * (async_limit + 2) - # Ensure we didn't process all requests on the queue and shutdown early - assert len(updates) < 2 * 2 * (async_limit + 2) - - @pytest.mark.smoke - @pytest.mark.parametrize( - ("request_timings_const", "async_limit", "request_error_rate"), - [ - (lambda: LastCompletionRequestTimings(), 1, 0.1), - (lambda: PoissonRateRequestTimings(rate=10000), 2, 0.2), - (lambda: ConstantRateRequestTimings(rate=10000), 3, 0.3), - (lambda: NoDelayRequestTimings(), 4, 0.4), - ], - ) - def test_run_lifecycle( - self, - request_timings_const: Callable[[], ScheduledRequestTimings], - async_limit: int, - request_error_rate: float, - ): - backend = MockBackend( - delay=0.01, - request_error_rate=request_error_rate, - ) - startup_barrier = Barrier(2) - shutdown_event = Event() - requests_queue = Queue() - updates_queue = Queue() - backend = MockBackend(delay=0.001) - worker_process = WorkerProcess( - local_rank=0, - local_world_size=1, - async_limit=async_limit, - startup_barrier=startup_barrier, - shutdown_event=shutdown_event, - error_event=Event(), - requests_queue=requests_queue, - updates_queue=updates_queue, - backend=backend, - request_timings=request_timings_const(), - poll_intervals=0.01, - ) - - def _background_thread(): - time.sleep(0.1) # delay for startup - startup_barrier.wait() - - for index in range(20): - requests_queue.put( - MsgpackEncoding.encode( - ( - f"req-{index}", - ScheduledRequestInfo[MeasuredRequestTimings]( - request_id=f"req-{index}", - status="queued", - scheduler_node_id=0, - scheduler_process_id=0, - scheduler_start_time=time.time(), - ), - ) + requests_tracker[request]["actual_start_time"] = ( + request_info.scheduler_timings.resolve_start ) + elif request_info.status == "completed": + assert response == f"response_for_{request}" + requests_tracker[request]["received_resolved"] += 1 + else: + raise ValueError(f"Unexpected status: {request_info.status}") + + # Ensure no extra statuses + with pytest.raises(asyncio.TimeoutError): + await main_messaging.get(timeout=0.1) + + # Trigger stopping for constraints and requests + instance.requests_generated_event.set() + instance.constraint_reached_event.set() + await asyncio.sleep(0) + + # Validate request values are correct + for ind in range(num_requests): + request = f"request_{ind}" + assert requests_tracker[request]["received_pending"] == 1 + assert requests_tracker[request]["received_in_progress"] == 1 + assert requests_tracker[request]["received_resolved"] == 1 + + bounds = timing_bounds[ind] + target_offset = ( + requests_tracker[request]["target_start_time"] - start_time ) - - time.sleep(0.5) # delay for processing - shutdown_event.set() - - threading.Thread(target=_background_thread).start() - worker_process.run() - - updates = [] - max_attempts = 50 - attempts = 0 - while attempts < max_attempts: - try: - update_message = updates_queue.get_nowait() - updates.append(MsgpackEncoding.decode(update_message)) - except Empty: - attempts += 1 - if len(updates) >= 40: # We got all expected updates - break - time.sleep(0.05) - - # Validate updates - assert len(updates) == 40 - per_request = defaultdict(dict) - for update in updates: - response, request, info = update - if info.status == "in_progress": - per_request[info.request_id]["start"] = (response, request, info) - per_request[info.request_id]["targeted_start"] = ( - info.scheduler_timings.targeted_start - ) - per_request[info.request_id]["resolve_start"] = ( - info.scheduler_timings.resolve_start + actual_offset = ( + requests_tracker[request]["actual_start_time"] - start_time ) - elif info.status == "completed": - per_request[info.request_id]["complete"] = (response, request, info) - per_request[info.request_id]["resolve_end"] = ( - info.scheduler_timings.resolve_end + prev_offset = ( + requests_tracker[f"request_{ind - 1}"]["target_start_time"] + - start_time + if ind > 0 + else None ) - assert len(per_request) == 20 - assert all( - "start" in parts and "complete" in parts for parts in per_request.values() - ) - - @pytest.mark.smoke - @pytest.mark.asyncio - @async_timeout(10.0) - async def test_initialize_requests_processing(self, valid_instances): - """Test _initialize_requests_processing method.""" - instance, _ = valid_instances - - await instance._initialize_requests_processing() - - # Verify backend methods were called - assert instance.backend.process_startup_called - assert instance.backend.validate_called - - # Verify queues are initialized - assert instance.pending_requests_queue is not None - assert instance.pending_updates_queue is not None - assert instance.requests_canceled is not None - assert instance.pull_requests_stopped is not None - assert instance.pull_task is not None - assert instance.push_task is not None - - @pytest.mark.smoke - @pytest.mark.asyncio - @async_timeout(5.0) - async def test_start_ready_requests_processing(self, valid_instances): - """Test _start_ready_requests_processing method.""" - instance, constructor_args = valid_instances - - def _trip_barrier_later(): - time.sleep(0.02) - with contextlib.suppress(RuntimeError): - instance.startup_barrier.wait(timeout=1.0) - - threading.Thread(target=_trip_barrier_later, daemon=True).start() - - await instance._start_ready_requests_processing() - assert instance.startup_completed is True - - @pytest.mark.smoke - @pytest.mark.asyncio - @async_timeout(5.0) - async def test_shutdown_requests_processing(self, valid_instances): - """Test _shutdown_requests_processing method.""" - instance, _ = valid_instances - - # Initialize first to have something to shutdown - await instance._initialize_requests_processing() - - # Now shutdown - await instance._shutdown_requests_processing() - - # Verify backend shutdown was called - assert instance.backend.process_shutdown_called - - # Verify state reset - assert instance.pending_requests_queue is None - assert instance.pending_updates_queue is None - assert instance.pull_task is None - assert instance.push_task is None - assert instance.requests_canceled is None - - @pytest.mark.sanity - @pytest.mark.asyncio - @async_timeout(3.0) - async def test_handle_request_update_status_transitions(self, valid_instances): - """Test _handle_request_update with different status transitions.""" - instance, _ = valid_instances - await instance._initialize_requests_processing() - - request = "test_request" - request_info = ScheduledRequestInfo[MeasuredRequestTimings]( - request_id="test-123", - status="queued", - scheduler_node_id=0, - scheduler_process_id=0, - scheduler_start_time=time.time(), - ) - - # Simulate that we've got this request from the queue (so task_done is expected) - await instance.pending_requests_queue.async_put((request, request_info)) - - # Test handling different status updates - but go through full flow - await instance._handle_request_update( - new_status="completed", - response="test_response", - request=request, - request_info=request_info, - ) - - @pytest.mark.smoke - def test_pull_requests_generator(self, valid_instances): - """Test _pull_requests_generator method.""" - instance, _ = valid_instances - - # Initialize necessary attributes that the generator needs - instance.requests_canceled = threading.Event() - instance.pull_requests_stopped = threading.Event() - # Create a minimal pending_requests_queue for the generator - import culsans - - instance.pending_requests_queue = culsans.Queue(maxsize=2) - - # Set the stop condition before creating the generator - instance.requests_canceled.set() - - # Initialize the generator - generator = instance._pull_requests_generator() - - # Test that generator can be created - assert generator is not None - # The generator should stop when requests_canceled is set - with pytest.raises(StopIteration): - next(generator) - - @pytest.mark.smoke - def test_push_updates_generator(self, valid_instances): - """Test _push_updates_generator method.""" - instance, _ = valid_instances - - # Initialize the generator - generator = instance._push_updates_generator() - - # Test that generator can be created - assert generator is not None - - @pytest.mark.sanity - @pytest.mark.asyncio - @async_timeout(3.0) - async def test_process_next_request_multi_turn_error(self, valid_instances): - """Test _process_next_request with multi-turn requests raises - NotImplementedError.""" - instance, _ = valid_instances - await instance._initialize_requests_processing() - - # Put a multi-turn request (tuple/list) in the queue - multi_turn_request = ["request1", "request2"] - request_info = ScheduledRequestInfo[MeasuredRequestTimings]( - request_id="test-123", - status="queued", - scheduler_node_id=0, - scheduler_process_id=0, - scheduler_start_time=time.time(), - ) - - await instance.pending_requests_queue.async_put( - (multi_turn_request, request_info) - ) - - # The NotImplementedError gets caught and converted to an errored status update - # So the method completes normally, but we can check that the error is set - await instance._process_next_request() - - # Check that the request_info.error contains the expected error message - assert "Multi-turn requests are not yet supported" in request_info.error - - @pytest.mark.sanity - @pytest.mark.asyncio - @async_timeout(3.0) - async def test_process_next_request_cancellation(self, valid_instances): - """Test _process_next_request handles cancellation properly.""" - instance, _ = valid_instances - await instance._initialize_requests_processing() - - request = "test_request" - request_info = ScheduledRequestInfo[MeasuredRequestTimings]( - request_id="test-123", - status="queued", - scheduler_node_id=0, - scheduler_process_id=0, - scheduler_start_time=time.time(), - ) - - await instance.pending_requests_queue.async_put((request, request_info)) - - # Create task and cancel it immediately - task = asyncio.create_task(instance._process_next_request()) - await asyncio.sleep(0.01) # Let it start - task.cancel() - - with pytest.raises(asyncio.CancelledError): - await task - - @pytest.mark.sanity - @pytest.mark.asyncio - @async_timeout(5.0) - async def test_cancel_pending_requests(self, valid_instances): - """Test _cancel_pending_requests method.""" - instance, _ = valid_instances - - # Create worker with larger queue buffer to avoid blocking - backend = MockBackend() - request_timings = LastCompletionRequestTimings() - worker_with_larger_buffer = WorkerProcess( - local_rank=0, - local_world_size=2, - async_limit=5, - startup_barrier=Barrier(2), - shutdown_event=Event(), - error_event=Event(), - requests_queue=Queue(), - updates_queue=Queue(), - backend=backend, - request_timings=request_timings, - poll_intervals=0.01, - max_requests_queue_buffer=10, # Larger buffer to avoid blocking - ) - - await worker_with_larger_buffer._initialize_requests_processing() - - # Add some requests to cancel - use smaller number to avoid queue size issues - for i in range(3): - request = f"test_request_{i}" - request_info = ScheduledRequestInfo[MeasuredRequestTimings]( - request_id=f"test-{i}", - status="queued", - scheduler_node_id=0, - scheduler_process_id=0, - scheduler_start_time=time.time(), - ) - await worker_with_larger_buffer.pending_requests_queue.async_put( - (request, request_info) + if bounds.exact is not None: + assert target_offset == pytest.approx( + bounds.exact, rel=bounds.tolerance + ) + assert target_offset == pytest.approx( + actual_offset, rel=bounds.actual_tolerance or bounds.tolerance + ) + if bounds.lower is not None: + assert target_offset >= bounds.lower - bounds.tolerance + assert actual_offset >= bounds.lower - ( + bounds.actual_tolerance or bounds.tolerance + ) + if bounds.upper is not None: + assert target_offset <= bounds.upper + bounds.tolerance + assert actual_offset <= bounds.upper + ( + bounds.actual_tolerance or bounds.tolerance + ) + if bounds.prev_request is not None and prev_offset is not None: + if bounds.prev_request == "greater": + assert target_offset > prev_offset - bounds.tolerance + elif bounds.prev_request == "greater_equal": + assert target_offset >= prev_offset - bounds.tolerance + elif bounds.prev_request == "less": + assert target_offset < prev_offset + bounds.tolerance + elif bounds.prev_request == "less_equal": + assert target_offset <= prev_offset + bounds.tolerance + finally: + # Trigger shutdown + instance.shutdown_event.set() + await asyncio.to_thread(process.join, timeout=2.0) + + if process.is_alive(): + process.terminate() + await asyncio.to_thread(process.join, timeout=2.0) + assert process.exitcode <= 0, ( + f"Process exited with error code: {process.exitcode}" ) - - # Set the stop flag - worker_with_larger_buffer.pull_requests_stopped.set() - - await worker_with_larger_buffer._cancel_pending_requests() - - # Verify queue is empty - assert worker_with_larger_buffer.pending_requests_queue.qsize() == 0 - - @pytest.mark.smoke - @pytest.mark.parametrize( - ("max_requests_queue_buffer", "poll_intervals"), - [ - (1, 0.01), - (5, 0.05), - (10, 0.1), - ], - ) - def test_initialization_with_optional_params( - self, max_requests_queue_buffer, poll_intervals - ): - """Test WorkerProcess initialization with optional parameters.""" - backend = MockBackend() - request_timings = LastCompletionRequestTimings() - - instance = WorkerProcess( - local_rank=0, - local_world_size=2, - async_limit=5, - startup_barrier=Barrier(2), - shutdown_event=Event(), - error_event=Event(), - requests_queue=Queue(), - updates_queue=Queue(), - backend=backend, - request_timings=request_timings, - poll_intervals=poll_intervals, - max_requests_queue_buffer=max_requests_queue_buffer, - ) - - assert instance.poll_intervals == poll_intervals - assert instance.max_requests_queue_buffer == max_requests_queue_buffer diff --git a/tests/unit/scheduler/test_worker_group.py b/tests/unit/scheduler/test_worker_group.py index f80a368d..e741334b 100644 --- a/tests/unit/scheduler/test_worker_group.py +++ b/tests/unit/scheduler/test_worker_group.py @@ -2,35 +2,33 @@ import asyncio import inspect -import math -import os -import queue -import threading import time -from collections import defaultdict from functools import wraps -from multiprocessing import get_context -from queue import Empty -from typing import Any, Generic +from multiprocessing.context import BaseContext +from multiprocessing.managers import BaseManager +from multiprocessing.process import BaseProcess +from multiprocessing.synchronize import Barrier, Event +from typing import Any, Generic, Literal -import culsans import pytest +from pydantic import Field from guidellm.scheduler import ( AsyncConstantStrategy, - AsyncPoissonStrategy, BackendInterface, ConcurrentStrategy, + MaxDurationConstraint, MaxNumberConstraint, MeasuredRequestTimings, ScheduledRequestInfo, + SchedulerMessagingPydanticRegistry, SchedulerState, SynchronousStrategy, ThroughputStrategy, WorkerProcessGroup, - worker_group, ) -from guidellm.utils import MsgpackEncoding +from guidellm.scheduler.worker_group import WorkerGroupState +from guidellm.utils import InterProcessMessaging def async_timeout(delay): @@ -44,99 +42,11 @@ async def new_func(*args, **kwargs): return decorator -class MockWorker: - """Picklable mock worker used to validate create_processes logic.""" - - @classmethod - def __class_getitem__(cls, item): - return cls - - def __init__( - self, - local_rank, - local_world_size, - async_limit, - startup_barrier, - shutdown_event, - error_event, - requests_queue, - updates_queue, - backend, - request_timings, - poll_intervals, - ): - self.local_rank = local_rank - self.local_world_size = local_world_size - self.async_limit = async_limit - self.startup_barrier = startup_barrier - self.shutdown_event = shutdown_event - self.error_event = error_event - self.requests_queue = requests_queue - self.updates_queue = updates_queue - self.backend = backend - self.request_timings = request_timings - self.poll_intervals = poll_intervals - - def run(self): - try: - # Access parameters to ensure they're usable and wait for barrier - shutdown_is_set = self.shutdown_event.is_set() - error_is_set = self.error_event.is_set() - backend_info = self.backend.info() - - self.startup_barrier.wait() - - # Publish diagnostics back to parent for assertions - payload = ( - "diag", - self.local_rank, - { - "child_pid": os.getpid(), - "local_rank": self.local_rank, - "local_world_size": self.local_world_size, - "async_limit": self.async_limit, - "backend_info": backend_info, - "shutdown_is_set": shutdown_is_set, - "error_is_set": error_is_set, - "passed_barrier": True, - "request_timings_type": type(self.request_timings).__name__, - }, - ) - self.updates_queue.put(payload) - except Exception as err: # noqa: BLE001 - try: - self.error_event.set() - self.updates_queue.put(("error", self.local_rank, repr(err))) - finally: - raise - - -class MockWorkerProcessor(MockWorker): - def run(self): - self.startup_barrier.wait() - - while not self.shutdown_event.is_set() and not self.error_event.is_set(): - try: - request_msg = self.requests_queue.get(timeout=0.1) - except queue.Empty: - continue - - request, request_info = MsgpackEncoding.decode(request_msg) - request_info.status = "in_progress" - self.updates_queue.put( - MsgpackEncoding.encode((None, request, request_info)) - ) - time.sleep(0.01) - request_info.status = "completed" - response = f"response_for_{request}" - self.updates_queue.put( - MsgpackEncoding.encode((response, request, request_info)) - ) - - class MockRequestTimings(MeasuredRequestTimings): """Mock timing implementation for testing.""" + timings_type: Literal["mock"] = Field(default="mock") + class MockBackend(BackendInterface): """Mock backend for testing worker group functionality.""" @@ -170,62 +80,79 @@ async def process_shutdown(self): pass async def resolve(self, request, request_info, request_history): - yield f"response_for_{request}" + request_info.request_timings = MockRequestTimings( + request_start=time.time(), request_end=time.time() + ) + yield f"response_for_{request}", request_info class TestWorkerProcessGroup: """Test suite for WorkerProcessGroup class.""" + def setup_method(self): + self._original_messaging_registry = ( + SchedulerMessagingPydanticRegistry.registry.copy() + if SchedulerMessagingPydanticRegistry.registry + else {} + ) + self._original_timings_registry = ( + MeasuredRequestTimings.registry.copy() + if MeasuredRequestTimings.registry + else {} + ) + MeasuredRequestTimings.register_decorator(MockRequestTimings, "mock") + SchedulerMessagingPydanticRegistry.register_decorator( + MockRequestTimings, "mock" + ) + + def teardown_method(self): + SchedulerMessagingPydanticRegistry.registry = self._original_messaging_registry + MeasuredRequestTimings.registry = self._original_timings_registry + MeasuredRequestTimings.model_rebuild(force=True) + ScheduledRequestInfo.model_rebuild(force=True) + @pytest.fixture( params=[ { - "requests": ["request1", "request2", "request3"], + "requests": None, + "cycle_requests": ["request1", "request2", "request3"], "strategy": SynchronousStrategy(), - "constraints": {"max_requests": MaxNumberConstraint(max_num=10)}, + "constraints": {"max_num": MaxNumberConstraint(max_num=10)}, }, { - "requests": ["req_a", "req_b"], + "requests": None, + "cycle_requests": ["req_a", "req_b"], "strategy": ConcurrentStrategy(streams=2), - "constraints": {}, + "constraints": {"max_num": MaxNumberConstraint(max_num=5)}, }, { - "requests": iter(["req_x", "req_y", "req_z"]), + "requests": ["req_x", "req_y", "req_z"], + "cycle_requests": None, "strategy": ThroughputStrategy(max_concurrency=5), - "constraints": {"max_num": MaxNumberConstraint(max_num=5)}, - "infinite_requests": False, + "constraints": {}, + }, + { + "requests": None, + "cycle_requests": ["req_8", "req_9", "req_10"], + "strategy": AsyncConstantStrategy(rate=20), + "constraints": {"max_duration": MaxDurationConstraint(max_duration=1)}, }, ], - ids=["basic_sync", "concurrent", "throughput_iterator"], + ids=["sync_max", "concurrent_max", "throughput_no_cycle", "constant_duration"], ) def valid_instances(self, request): """Fixture providing test data for WorkerProcessGroup.""" constructor_args = request.param.copy() - backend = MockBackend() - constructor_args["backend"] = backend - - instance = WorkerProcessGroup(**constructor_args) + instance = WorkerProcessGroup(**request.param, backend=MockBackend()) return instance, constructor_args - @pytest.fixture - def worker_process_group(self): - """Create a basic WorkerProcessGroup instance for testing.""" - backend = MockBackend() - requests = ["request1", "request2", "request3"] - strategy = SynchronousStrategy() - constraints = {"max_requests": MaxNumberConstraint(max_num=10)} - - return WorkerProcessGroup( - requests=requests, - backend=backend, - strategy=strategy, - constraints=constraints, - ) - @pytest.mark.smoke - def test_class_signatures(self, worker_process_group: WorkerProcessGroup): + def test_class_signatures(self, valid_instances): """Test inheritance and type relationships.""" + instance, _ = valid_instances + # Class - assert isinstance(worker_process_group, Generic) + assert isinstance(instance, Generic) assert issubclass(WorkerProcessGroup, Generic) # Generics @@ -241,7 +168,7 @@ def test_class_signatures(self, worker_process_group: WorkerProcessGroup): ) assert generic_base is not None type_args = getattr(generic_base, "__args__", ()) - assert len(type_args) == 3 + assert len(type_args) == 2 # Function signatures create_processes_sig = inspect.signature(WorkerProcessGroup.create_processes) @@ -269,651 +196,278 @@ def test_initialization(self, valid_instances): # Core attributes assert isinstance(instance.backend, MockBackend) assert instance.requests is constructor_args["requests"] + assert instance.cycle_requests is constructor_args["cycle_requests"] assert isinstance(instance.strategy, type(constructor_args["strategy"])) assert isinstance(instance.constraints, dict) assert instance.constraints == constructor_args["constraints"] - # Optional attributes - expected_infinite = constructor_args.get("infinite_requests", None) - assert instance.infinite_requests == expected_infinite - # Multiprocessing attributes (should be None initially) assert instance.mp_context is None + assert instance.mp_manager is None assert instance.processes is None # Synchronization primitives (should be None initially) assert instance.startup_barrier is None assert instance.shutdown_event is None assert instance.error_event is None + assert instance.requests_generated_event is None + assert instance.constraint_reached_event is None - # Queues (should be None initially) - assert instance.requests_queue is None - assert instance.updates_queue is None - assert instance.pending_updates_queue is None - assert instance.pending_requests_complete is None - assert instance.pending_updates_complete is None - - # Scheduler state and tasks (should be None initially) - assert instance.state_update_lock is None - assert instance.scheduler_state is None - assert instance.populate_requests_task is None - assert instance.populate_updates_task is None + # Scheduler state and messaging (should be None initially) + assert instance.state is None + assert instance.messaging is None @pytest.mark.sanity - def test_invalid_initialization_values(self): - """Test WorkerProcessGroup with invalid field values.""" - backend = MockBackend() - requests = ["req1"] - strategy = SynchronousStrategy() - constraints = {} - - # Test with None requests (will likely fail during create_processes) - group1 = WorkerProcessGroup( - requests=None, - backend=backend, - strategy=strategy, - constraints=constraints, - ) - assert group1.requests is None - - # Test with None backend (will likely fail during create_processes) - group2 = WorkerProcessGroup( - requests=requests, - backend=None, - strategy=strategy, - constraints=constraints, - ) - assert group2.backend is None - - # Test with None strategy (will likely fail during create_processes) - group3 = WorkerProcessGroup( - requests=requests, - backend=backend, - strategy=None, - constraints=constraints, - ) - assert group3.strategy is None - - # Test with None constraints (will likely fail during create_processes) - group4 = WorkerProcessGroup( - requests=requests, - backend=backend, - strategy=strategy, - constraints=None, - ) - assert group4.constraints is None - - @pytest.mark.smoke - @pytest.mark.asyncio @pytest.mark.parametrize( - ("strategy", "expected_num_procs", "expected_max_conc"), + ("requests", "cycle_requests", "expected_error"), [ - (SynchronousStrategy(), 1, 1), - (ConcurrentStrategy(streams=3), 3, 3), - (ThroughputStrategy(max_concurrency=6), 3, 6), - (AsyncConstantStrategy(rate=100.0), 3, 12), - (AsyncPoissonStrategy(rate=100.0), 3, 12), + (None, None, ValueError), + ([], iter([]), ValueError), # cycle_requests as Iterator + (None, iter(["req1"]), ValueError), # cycle_requests as Iterator ], + ids=["no_requests", "cycle_as_iterator_empty", "cycle_as_iterator_data"], ) - async def test_create_processes( - self, - monkeypatch, - strategy, - expected_num_procs, - expected_max_conc, + def test_invalid_initialization_values( + self, requests, cycle_requests, expected_error ): - # Patch required mock settings - monkeypatch.setattr( - worker_group.settings, "max_worker_processes", 3, raising=False - ) - monkeypatch.setattr(worker_group.settings, "max_concurrency", 12, raising=False) - monkeypatch.setattr( - worker_group.settings, "scheduler_poll_interval", 0.01, raising=False - ) - monkeypatch.setattr(worker_group, "WorkerProcess", MockWorker, raising=True) - - # Setup group to test - backend = MockBackend() - requests = [f"r{i}" for i in range(10)] - constraints = {"max_requests": MaxNumberConstraint(max_num=100)} - group = WorkerProcessGroup( - backend=backend, - requests=requests, - strategy=strategy, - constraints=constraints, - ) + """Test WorkerProcessGroup with invalid initialization values.""" + with pytest.raises(expected_error): + WorkerProcessGroup( + requests=requests, + cycle_requests=cycle_requests, + backend=MockBackend(), + strategy=SynchronousStrategy(), + constraints={}, + ) - # Run within a reasonable time limit - try: - await asyncio.wait_for(group.create_processes(), timeout=5.0) - except asyncio.TimeoutError: - pytest.fail("create_processes() timed out after 5 seconds") - - # Check expected attributes are created - assert group.mp_context is not None - assert hasattr(group.mp_context, "Barrier") - assert hasattr(group.mp_context, "Event") - assert hasattr(group.mp_context, "Queue") - assert group.processes is not None - assert len(group.processes) == expected_num_procs - - # Validate processes ran correctly - diags: dict[int, dict] = {} - for _ in range(expected_num_procs): - kind, rank, payload = group.updates_queue.get(timeout=3) - if kind == "error": - pytest.fail(f"Worker {rank} reported error: {payload}") - assert kind == "diag" - diags[rank] = payload - - # Verify returned processes state - main_pid = os.getpid() - assert len(diags) == expected_num_procs - for rank, payload in diags.items(): - assert payload["local_rank"] == rank - assert payload["local_world_size"] == expected_num_procs - assert payload["passed_barrier"] is True - assert payload["shutdown_is_set"] is False - assert payload["error_is_set"] is False - assert isinstance(payload["backend_info"], dict) - assert payload["child_pid"] != main_pid - per_proc = math.ceil(expected_max_conc / expected_num_procs) - expected_last = expected_max_conc - per_proc * (expected_num_procs - 1) - for rank, payload in diags.items(): - exp_limit = per_proc if rank < expected_num_procs - 1 else expected_last - assert payload["async_limit"] == exp_limit - - exceptions = await group.shutdown() - assert len(exceptions) == 0, f"Shutdown encountered exceptions: {exceptions}" + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test WorkerProcessGroup initialization without required fields.""" + with pytest.raises(TypeError): + WorkerProcessGroup() @pytest.mark.smoke + @async_timeout(10) @pytest.mark.asyncio - async def test_start(self, monkeypatch): - # Patch required mock settings - monkeypatch.setattr( - worker_group.settings, "max_worker_processes", 1, raising=False - ) - monkeypatch.setattr(worker_group.settings, "max_concurrency", 1, raising=False) - monkeypatch.setattr( - worker_group.settings, "scheduler_poll_interval", 0.01, raising=False - ) - monkeypatch.setattr(worker_group, "WorkerProcess", MockWorker, raising=True) - - # Setup group and mimic create_processes - backend = MockBackend() - requests = [f"r{i}" for i in range(5)] # to few requests, test new iter logic - group = WorkerProcessGroup( - backend=backend, - requests=requests, - strategy=SynchronousStrategy(), - constraints={"max_num": MaxNumberConstraint(max_num=10)}, - ) - group.mp_context = get_context("fork") - group.startup_barrier = group.mp_context.Barrier(2) - group.shutdown_event = group.mp_context.Event() - group.error_event = group.mp_context.Event() - group.requests_queue = group.mp_context.Queue() - group.updates_queue = group.mp_context.Queue() - group.pending_updates_queue = culsans.Queue() - group.pending_updates_complete = threading.Event() - group.processes = [None] - - # Validate function runs and returns at start_time - start_time = time.time() + 0.2 - await asyncio.wait_for(group.start(start_time), timeout=3.0) - end_time = time.time() - assert end_time == pytest.approx(start_time, abs=0.01) - - # Validate instance state - assert group.state_update_lock is not None - assert hasattr(group.state_update_lock, "acquire") - assert group.scheduler_state is not None - assert group.scheduler_state.num_processes == 1 - assert group.scheduler_state.start_time == start_time - assert isinstance(group.populate_requests_task, asyncio.Task) - assert isinstance(group.populate_updates_task, asyncio.Task) - - # Pull the queued requests - await asyncio.sleep(0.1) - sent_requests = [] - while True: - await asyncio.sleep(0) - try: - req = group.requests_queue.get(timeout=1.0) - sent_requests.append(req) - except Empty: - break - assert len(sent_requests) == 10 - - # Enqueue lifecycle updates - for req in requests + requests: - group.updates_queue.put( - MsgpackEncoding.encode( - ( - None, - req, - ScheduledRequestInfo[MockRequestTimings]( - request_id=str(req), - status="in_progress", - scheduler_node_id=0, - scheduler_process_id=0, - scheduler_start_time=start_time, - ), - ) + async def test_lifecycle(self, valid_instances: tuple[WorkerProcessGroup, dict]): + """Test the lifecycle methods of WorkerProcessGroup.""" + instance, constructor_args = valid_instances + assert instance.requests or instance.cycle_requests + assert instance.backend + assert instance.strategy + assert instance.constraints is not None + + # Validate create_processes works and sets correct state + await instance.create_processes() + assert instance.mp_context is not None + assert isinstance(instance.mp_context, BaseContext) + assert instance.mp_manager is not None + assert isinstance(instance.mp_manager, BaseManager) + assert instance.processes is not None + assert isinstance(instance.processes, list) + assert len(instance.processes) > 0 + assert all(isinstance(proc, BaseProcess) for proc in instance.processes) + assert all(proc.is_alive() for proc in instance.processes) + assert instance.startup_barrier is not None + assert isinstance(instance.startup_barrier, Barrier) + assert instance.requests_generated_event is not None + assert isinstance(instance.requests_generated_event, Event) + assert instance.constraint_reached_event is not None + assert isinstance(instance.constraint_reached_event, Event) + assert instance.shutdown_event is not None + assert isinstance(instance.shutdown_event, Event) + assert instance.error_event is not None + assert isinstance(instance.error_event, Event) + assert instance.messaging is not None + assert isinstance(instance.messaging, InterProcessMessaging) + assert instance.messaging.worker_index is None + + # Validate start works and sets correct state + start_time = time.time() + 0.1 + await instance.start(start_time=start_time) + assert instance.state is not None + assert isinstance(instance.state, WorkerGroupState) + assert not instance.requests_generated_event.is_set() + assert not instance.constraint_reached_event.is_set() + assert not instance.shutdown_event.is_set() + assert not instance.error_event.is_set() + + # Test iter updates + requests_tracker = {} + + async for ( + response, + request, + request_info, + scheduler_state, + ) in instance.request_updates(): + # Validate returned request + assert request is not None + + # Validate returned request info and response + assert request_info is not None + assert isinstance(request_info, ScheduledRequestInfo) + assert request_info.request_id is not None + assert request_info.status is not None + if request_info.request_id not in requests_tracker: + requests_tracker[request_info.request_id] = { + "received_pending": 0, + "received_in_progress": 0, + "received_resolved": 0, + "received_cancelled": 0, + } + assert request_info.scheduler_node_id > -1 + assert request_info.scheduler_process_id > -1 + assert request_info.scheduler_start_time == start_time + assert request_info.scheduler_timings is not None + if request_info.status == "pending": + requests_tracker[request_info.request_id]["received_pending"] += 1 + assert request_info.scheduler_timings.dequeued is not None + assert request_info.scheduler_timings.targeted_start is not None + assert request_info.scheduler_timings.targeted_start >= start_time + elif request_info.status == "in_progress": + requests_tracker[request_info.request_id]["received_in_progress"] += 1 + assert request_info.scheduler_timings.scheduled_at is not None + assert ( + request_info.scheduler_timings.scheduled_at + >= request_info.scheduler_timings.dequeued ) - ) - group.updates_queue.put( - MsgpackEncoding.encode( - ( - None, - req, - ScheduledRequestInfo[MockRequestTimings]( - request_id=str(req), - status="completed", - scheduler_node_id=0, - scheduler_process_id=0, - scheduler_start_time=start_time, - ), - ) + assert request_info.scheduler_timings.resolve_start is not None + assert ( + request_info.scheduler_timings.resolve_start + >= request_info.scheduler_timings.scheduled_at ) - ) - await asyncio.sleep(0) - - # Drain 3 updates per request (queued, started, completed) - await asyncio.sleep(0.1) - updates = [] - for _ in range(3 * 10): - try: - update = await asyncio.wait_for( - group.pending_updates_queue.async_get(), timeout=1.0 + elif request_info.status == "completed": + requests_tracker[request_info.request_id]["received_resolved"] += 1 + assert response is not None + assert request_info.scheduler_timings.resolve_end is not None + assert ( + request_info.scheduler_timings.resolve_end + > request_info.scheduler_timings.resolve_start ) - updates.append(update) - except asyncio.TimeoutError: - break - assert len(updates) == 3 * 10 - - # Ensure tasks finish - if not group.populate_requests_task.done(): - await asyncio.wait_for(group.populate_requests_task, timeout=1.0) - if not group.populate_updates_task.done(): - await asyncio.wait_for(group.populate_updates_task, timeout=1.0) - - # Clean up resources - group.processes = None - exceptions = await group.shutdown() - assert len(exceptions) == 0, f"Shutdown encountered exceptions: {exceptions}" - - @pytest.mark.smoke - @pytest.mark.asyncio - @async_timeout(3.0) - async def test_error_handling_basic(self, monkeypatch): - """Test basic error handling patterns.""" - self._setup_test_environment(monkeypatch) - - backend = MockBackend() - requests = ["req1"] - # Create group directly without using helper (which calls start automatically) - group = WorkerProcessGroup( - requests=requests, - backend=backend, - strategy=SynchronousStrategy(), - constraints={}, - ) - - # Test that error_event can be accessed when not initialized - # First save the existing error_event - original_error_event = group.error_event - - # Temporarily set to None to test this state - group.error_event = None - assert group.error_event is None - - # Restore it for the start test - group.error_event = original_error_event - - # Test basic group state validation - with pytest.raises( - RuntimeError, match="create_processes.*must be called before start" - ): - await group.start(time.time()) - - @pytest.mark.smoke - @pytest.mark.asyncio - @async_timeout(10.0) - async def test_shutdown_event_stops_tasks(self, monkeypatch): - """Test that setting shutdown event stops background tasks.""" - self._setup_test_environment(monkeypatch) - - # Setup group - backend = MockBackend() - requests = [f"req_{i}" for i in range(5)] - group = self._create_test_group(backend, requests) - - # Start and verify tasks - start_time = time.time() + 0.1 - await group.start(start_time) - - # Simulate some processing - self._process_test_requests(group, start_time, count=2) - await asyncio.sleep(0.05) - - # Set shutdown event and verify tasks stop - group.shutdown_event.set() - await asyncio.sleep(0.1) # Allow propagation - - assert group.pending_requests_complete.is_set() - assert group.populate_requests_task.done() - - # Clean up - await group.shutdown() - - def _setup_test_environment(self, monkeypatch): - """Helper to setup test environment with mocked settings.""" - monkeypatch.setattr( - worker_group.settings, "max_worker_processes", 1, raising=False - ) - monkeypatch.setattr(worker_group.settings, "max_concurrency", 1, raising=False) - monkeypatch.setattr( - worker_group.settings, "scheduler_poll_interval", 0.01, raising=False - ) - monkeypatch.setattr(worker_group, "WorkerProcess", MockWorker, raising=True) - - def _create_test_group(self, backend, requests): - """Helper to create a test group with mocked multiprocessing components.""" - group = WorkerProcessGroup( - requests=requests, - backend=backend, - strategy=SynchronousStrategy(), - constraints={}, - ) - group.mp_context = get_context("fork") - group.startup_barrier = group.mp_context.Barrier(2) - group.shutdown_event = group.mp_context.Event() - group.error_event = group.mp_context.Event() - group.requests_queue = group.mp_context.Queue(maxsize=1) - group.updates_queue = group.mp_context.Queue() - group.pending_updates_queue = culsans.Queue() - group.pending_updates_complete = threading.Event() - # Create mock process objects instead of None - mock_process = type( - "MockProcess", - (), - {"join": lambda self, timeout=None: None, "exitcode": 0, "pid": 12345}, - )() - group.processes = [mock_process] - return group - - def _process_test_requests(self, group, start_time, count=1): - """Helper to process test requests and generate updates.""" - for _ in range(count): - try: - req, req_info = MsgpackEncoding.decode( - group.requests_queue.get(timeout=0.1) + assert request_info.request_timings is not None + assert isinstance(request_info.request_timings, MockRequestTimings) + assert request_info.request_timings.request_start is not None + assert ( + request_info.request_timings.request_start + >= request_info.scheduler_timings.targeted_start ) - # Simulate in_progress update - group.updates_queue.put( - MsgpackEncoding.encode( - ( - None, - req, - ScheduledRequestInfo[MockRequestTimings]( - request_id=str(req), - status="in_progress", - scheduler_node_id=0, - scheduler_process_id=0, - scheduler_start_time=start_time, - ), - ) - ) + assert request_info.request_timings.request_end is not None + assert ( + request_info.request_timings.request_end + >= request_info.request_timings.request_start ) - # Simulate completed update - group.updates_queue.put( - MsgpackEncoding.encode( - ( - None, - req, - ScheduledRequestInfo[MockRequestTimings]( - request_id=str(req), - status="completed", - scheduler_node_id=0, - scheduler_process_id=0, - scheduler_start_time=start_time, - ), - ) - ) + elif request_info.status in ("errored", "cancelled"): + assert response is None + requests_tracker[request_info.request_id]["received_resolved"] += 1 + assert request_info.scheduler_timings.resolve_end is not None + assert ( + request_info.scheduler_timings.resolve_end + > request_info.scheduler_start_time ) - except Empty: - break - - @pytest.mark.smoke - @pytest.mark.asyncio - async def test_request_updates(self, monkeypatch): - """Test the request_updates async iterator functionality.""" - # Configure settings for controlled testing - monkeypatch.setattr( - worker_group.settings, "max_worker_processes", 1, raising=False - ) - monkeypatch.setattr(worker_group.settings, "max_concurrency", 1, raising=False) - monkeypatch.setattr( - worker_group.settings, "scheduler_poll_interval", 0.01, raising=False - ) - monkeypatch.setattr( - worker_group, "WorkerProcess", MockWorkerProcessor, raising=True - ) - - # Setup group - backend = MockBackend() - requests = [f"req_{index}" for index in range(20)] - group = WorkerProcessGroup( - backend=backend, - requests=requests, - strategy=SynchronousStrategy(), - constraints={"max_num": MaxNumberConstraint(max_num=10)}, - ) - - # Mimic create_processes to set required state - await group.create_processes() - await group.start(time.time() + 0.05) - - # Collect all updates from request_updates iterator - received_updates = defaultdict(list) - received_responses = [] - count = 0 - async for resp, req, req_info, state in group.request_updates(): - assert isinstance(req_info, ScheduledRequestInfo) - assert isinstance(state, SchedulerState) - received_updates[req].append(req_info.status) - if resp is not None: - received_responses.append(resp) - count += 1 - - # Check we have all expected updates (10 requests) - assert len(received_updates) == 10 - for index, (req, statuses, resp) in enumerate( - zip(received_updates.keys(), received_updates.values(), received_responses) - ): - assert req == f"req_{index}" - assert resp == f"response_for_req_{index}" - assert statuses == ["queued", "in_progress", "completed"] - - @pytest.mark.smoke - @pytest.mark.asyncio - @async_timeout(10.0) - async def test_shutdown_basic(self): - """Test basic shutdown functionality.""" - backend = MockBackend() - requests = ["req1", "req2"] - group = WorkerProcessGroup( - requests=requests, - backend=backend, - strategy=SynchronousStrategy(), - constraints={}, + if request_info.status == "cancelled": + requests_tracker[request_info.request_id]["received_cancelled"] += 1 + + # Validate state structure + assert scheduler_state is not None + assert isinstance(scheduler_state, SchedulerState) + assert scheduler_state.node_id > -1 + assert scheduler_state.start_time == start_time + assert scheduler_state.end_time is not None + if constructor_args.get("constraints"): + assert scheduler_state.remaining_fraction is not None + assert scheduler_state.remaining_fraction >= 0.0 + assert scheduler_state.remaining_fraction <= 1.0 + if constructor_args.get("constraints", {}).get("max_num") is not None: + assert scheduler_state.remaining_requests is not None + assert scheduler_state.remaining_requests >= 0 + assert ( + scheduler_state.remaining_requests + <= constructor_args["constraints"]["max_num"].max_num + ) + if constructor_args.get("constraints", {}).get("max_duration") is not None: + assert scheduler_state.remaining_duration is not None + assert scheduler_state.remaining_duration >= 0.0 + assert ( + scheduler_state.remaining_duration + <= constructor_args["constraints"]["max_duration"].max_duration + ) + assert scheduler_state.created_requests >= 0 + assert scheduler_state.queued_requests >= 0 + assert scheduler_state.pending_requests >= 0 + assert scheduler_state.processing_requests >= 0 + assert scheduler_state.processed_requests >= 0 + assert scheduler_state.successful_requests >= 0 + assert scheduler_state.errored_requests >= 0 + assert scheduler_state.cancelled_requests >= 0 + + # Validate correctness of all updates + for _, counts in requests_tracker.items(): + assert counts["received_cancelled"] in (0, 1) + if counts["received_cancelled"] == 0: + assert counts["received_pending"] == 1 + assert counts["received_in_progress"] >= 1 + assert counts["received_resolved"] == 1 + assert scheduler_state is not None # last yielded state + assert scheduler_state.end_time > scheduler_state.start_time + assert scheduler_state.end_queuing_time is not None + assert scheduler_state.end_queuing_constraints is not None + assert scheduler_state.end_processing_time is not None + assert scheduler_state.end_processing_time >= scheduler_state.start_time + assert scheduler_state.end_processing_constraints is not None + assert scheduler_state.scheduler_constraints is not None + assert scheduler_state.created_requests == len(requests_tracker) + assert scheduler_state.queued_requests == 0 + assert scheduler_state.pending_requests == 0 + assert scheduler_state.processing_requests == 0 + assert scheduler_state.processed_requests == len(requests_tracker) + assert scheduler_state.successful_requests >= 0 + assert scheduler_state.errored_requests >= 0 + assert scheduler_state.cancelled_requests >= 0 + assert ( + scheduler_state.successful_requests + + scheduler_state.errored_requests + + scheduler_state.cancelled_requests + == len(requests_tracker) ) - - # Test shutdown with empty state - should return no exceptions - exceptions = await group.shutdown() + if constructor_args.get("constraints"): + assert list(scheduler_state.scheduler_constraints.keys()) == list( + constructor_args["constraints"].keys() + ) + assert scheduler_state.remaining_fraction == 0.0 + if "max_num" in constructor_args["constraints"]: + assert "max_num" in scheduler_state.end_queuing_constraints + assert "max_num" in scheduler_state.end_processing_constraints + max_num = constructor_args["constraints"]["max_num"].max_num + assert scheduler_state.created_requests == max_num + assert scheduler_state.successful_requests == max_num + assert scheduler_state.errored_requests == 0 + assert scheduler_state.cancelled_requests == 0 + if "max_duration" in constructor_args["constraints"]: + assert "max_duration" in scheduler_state.end_queuing_constraints + assert "max_duration" in scheduler_state.end_processing_constraints + assert scheduler_state.remaining_duration == 0.0 + else: + assert "requests_exhausted" in scheduler_state.scheduler_constraints + assert "requests_exhausted" in scheduler_state.end_queuing_constraints + assert "requests_exhausted" in scheduler_state.end_processing_constraints + assert scheduler_state.remaining_fraction is None + assert scheduler_state.remaining_requests is None + assert scheduler_state.remaining_duration is None + + # Test shutdown + exceptions = await instance.shutdown() + + # Check valid shutdown behavior + assert isinstance(exceptions, list) assert len(exceptions) == 0 - assert group.processes is None - assert group.mp_context is None - assert group.shutdown_event is None - - @pytest.mark.sanity - @pytest.mark.asyncio - @async_timeout(5.0) - async def test_start_without_create_processes(self): - """Test that start() raises error when create_processes() not called.""" - backend = MockBackend() - requests = ["req1", "req2"] - group = WorkerProcessGroup( - requests=requests, - backend=backend, - strategy=SynchronousStrategy(), - constraints={}, - ) - - with pytest.raises( - RuntimeError, - match="create_processes\\(\\) must be called before start\\(\\)", - ): - await group.start(time.time()) - - @pytest.mark.sanity - @pytest.mark.asyncio - @async_timeout(5.0) - async def test_create_processes_invalid_limits(self, monkeypatch): - """Test create_processes with invalid process and concurrency limits.""" - # Test zero processes limit - monkeypatch.setattr( - worker_group.settings, "max_worker_processes", 0, raising=False - ) - monkeypatch.setattr(worker_group.settings, "max_concurrency", 1, raising=False) - - backend = MockBackend() - requests = ["req1"] - group = WorkerProcessGroup( - requests=requests, - backend=backend, - strategy=SynchronousStrategy(), - constraints={}, - ) - - with pytest.raises(RuntimeError, match="num_processes resolved to 0"): - await group.create_processes() - - # Test zero concurrency limit - monkeypatch.setattr( - worker_group.settings, "max_worker_processes", 1, raising=False - ) - monkeypatch.setattr(worker_group.settings, "max_concurrency", 0, raising=False) - - group2 = WorkerProcessGroup( - requests=requests, - backend=backend, - strategy=SynchronousStrategy(), - constraints={}, - ) - - with pytest.raises(RuntimeError, match="max_concurrency resolved to 0"): - await group2.create_processes() - - @pytest.mark.smoke - @pytest.mark.asyncio - @async_timeout(10.0) - async def test_request_updates_error_handling(self, monkeypatch): - """Test request_updates handles error events correctly.""" - # Use the helper method that creates mocked multiprocessing components - self._setup_test_environment(monkeypatch) - - backend = MockBackend() - requests = ["req1"] - group = self._create_test_group(backend, requests) - - # Start the group with mocked components - start_time = time.time() + 0.1 - await group.start(start_time) - - # Set error event to simulate error - group.error_event.set() - - # Test that request_updates raises RuntimeError when error event is set - with pytest.raises( - RuntimeError, match="error_event is set in WorkerProcessGroup" - ): - async for _ in group.request_updates(): - pass - - # Clean up - await group.shutdown() - - @pytest.mark.smoke - def test_valid_instances_fixture(self): - """Test the valid_instances fixture provides correct data.""" - backend = MockBackend() - requests = ["request1", "request2", "request3"] - strategy = SynchronousStrategy() - constraints = {"max_requests": MaxNumberConstraint(max_num=10)} - - instance = WorkerProcessGroup( - requests=requests, - backend=backend, - strategy=strategy, - constraints=constraints, - ) - - assert isinstance(instance, WorkerProcessGroup) - assert instance.requests is requests - assert instance.backend is backend - assert instance.strategy is strategy - assert instance.constraints is constraints - - @pytest.mark.smoke - @pytest.mark.parametrize( - "infinite_requests", - [ - None, - True, - False, - ], - ) - def test_initialization_infinite_requests(self, infinite_requests): - """Test initialization with different infinite_requests values.""" - backend = MockBackend() - requests = ["req1", "req2"] - strategy = SynchronousStrategy() - constraints = {} - - group = WorkerProcessGroup( - requests=requests, - backend=backend, - strategy=strategy, - constraints=constraints, - infinite_requests=infinite_requests, - ) - - assert group.infinite_requests == infinite_requests - - @pytest.mark.sanity - @pytest.mark.parametrize( - "missing_param", - [ - "requests", - "backend", - "strategy", - "constraints", - ], - ) - def test_invalid_initialization_missing_params(self, missing_param): - """Test invalid initialization with missing required parameters.""" - # Create complete valid parameters - params = { - "requests": ["req1"], - "backend": MockBackend(), - "strategy": SynchronousStrategy(), - "constraints": {}, - } - - # Remove the specified parameter - del params[missing_param] - - with pytest.raises(TypeError): - WorkerProcessGroup(**params) + assert instance.messaging is None + assert instance.state is None + assert instance.processes is None + assert instance.startup_barrier is None + assert instance.requests_generated_event is None + assert instance.constraint_reached_event is None + assert instance.shutdown_event is None + assert instance.error_event is None + assert instance.mp_manager is None + assert instance.mp_context is None diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py index 53e8b664..792c9770 100644 --- a/tests/unit/test_logger.py +++ b/tests/unit/test_logger.py @@ -3,7 +3,7 @@ import pytest from guidellm import configure_logger, logger -from guidellm.config import LoggingSettings +from guidellm.settings import LoggingSettings @pytest.fixture(autouse=True) diff --git a/tests/unit/test_config.py b/tests/unit/test_settings.py similarity index 99% rename from tests/unit/test_config.py rename to tests/unit/test_settings.py index f5d9415c..42c8901d 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_settings.py @@ -1,6 +1,6 @@ import pytest -from guidellm.config import ( +from guidellm.settings import ( DatasetSettings, Environment, LoggingSettings, diff --git a/tests/unit/utils/test_auto_importer.py b/tests/unit/utils/test_auto_importer.py index daadbd5e..cc71bce3 100644 --- a/tests/unit/utils/test_auto_importer.py +++ b/tests/unit/utils/test_auto_importer.py @@ -2,6 +2,8 @@ Unit tests for the auto_importer module. """ +from __future__ import annotations + from unittest import mock import pytest @@ -9,49 +11,77 @@ from guidellm.utils import AutoImporterMixin -class MockHelper: - """Helper class to create consistent mock objects for testing.""" - - @staticmethod - def create_mock_package(name: str, path: str): - """Create a mock package with required attributes.""" - package = mock.MagicMock() - package.__name__ = name - package.__path__ = [path] - return package +class TestAutoImporterMixin: + """Test suite for AutoImporterMixin functionality.""" - @staticmethod - def create_mock_module(name: str): - """Create a mock module with required attributes.""" - module = mock.MagicMock() - module.__name__ = name - return module + @pytest.fixture( + params=[ + { + "auto_package": "test.package", + "auto_ignore_modules": None, + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module1", "test.package.module2"], + }, + { + "auto_package": ("test.package1", "test.package2"), + "auto_ignore_modules": None, + "modules": [ + ("test.package1.moduleA", False), + ("test.package2.moduleB", False), + ], + "expected_imports": ["test.package1.moduleA", "test.package2.moduleB"], + }, + { + "auto_package": "test.package", + "auto_ignore_modules": ("test.package.module1",), + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module2"], + }, + ], + ids=["single_package", "multiple_packages", "ignored_modules"], + ) + def valid_instances(self, request): + """Fixture providing test data for AutoImporterMixin subclasses.""" + config = request.param + class TestClass(AutoImporterMixin): + auto_package = config["auto_package"] + auto_ignore_modules = config["auto_ignore_modules"] -class TestAutoImporterMixin: - """Test suite for AutoImporterMixin functionality.""" + return TestClass, config @pytest.mark.smoke - def test_mixin_initialization(self): - """Test that AutoImporterMixin initializes with correct default values.""" + def test_class_signatures(self): + """Test AutoImporterMixin class signatures and attributes.""" + assert hasattr(AutoImporterMixin, "auto_package") + assert hasattr(AutoImporterMixin, "auto_ignore_modules") + assert hasattr(AutoImporterMixin, "auto_imported_modules") + assert hasattr(AutoImporterMixin, "auto_import_package_modules") + assert callable(AutoImporterMixin.auto_import_package_modules) + + # Test default class variables assert AutoImporterMixin.auto_package is None assert AutoImporterMixin.auto_ignore_modules is None assert AutoImporterMixin.auto_imported_modules is None @pytest.mark.smoke - def test_subclass_attributes(self): - """Test that subclass can set auto_package attribute.""" + def test_initialization(self, valid_instances): + """Test AutoImporterMixin subclass initialization.""" + test_class, config = valid_instances + assert issubclass(test_class, AutoImporterMixin) + assert test_class.auto_package == config["auto_package"] + assert test_class.auto_ignore_modules == config["auto_ignore_modules"] + assert test_class.auto_imported_modules is None - class TestClass(AutoImporterMixin): - auto_package = "test.package" - - assert TestClass.auto_package == "test.package" - assert TestClass.auto_ignore_modules is None - assert TestClass.auto_imported_modules is None - - @pytest.mark.smoke - def test_missing_package_raises_error(self): - """Test that missing auto_package raises ValueError.""" + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test AutoImporterMixin with missing auto_package.""" class TestClass(AutoImporterMixin): pass @@ -62,121 +92,70 @@ class TestClass(AutoImporterMixin): @pytest.mark.smoke @mock.patch("importlib.import_module") @mock.patch("pkgutil.walk_packages") - def test_single_package_import(self, mock_walk, mock_import): - """Test importing modules from a single package.""" - - class TestClass(AutoImporterMixin): - auto_package = "test.package" - - # Setup mocks - mock_package = MockHelper.create_mock_package("test.package", "test/package") - mock_module1 = MockHelper.create_mock_module("test.package.module1") - mock_module2 = MockHelper.create_mock_module("test.package.module2") - - mock_import.side_effect = lambda name: { - "test.package": mock_package, - "test.package.module1": mock_module1, - "test.package.module2": mock_module2, - }[name] - - mock_walk.return_value = [ - (None, "test.package.module1", False), - (None, "test.package.module2", False), - ] - - # Execute - TestClass.auto_import_package_modules() - - # Verify - assert TestClass.auto_imported_modules == [ - "test.package.module1", - "test.package.module2", - ] - mock_import.assert_any_call("test.package") - mock_import.assert_any_call("test.package.module1") - mock_import.assert_any_call("test.package.module2") - - @pytest.mark.sanity - @mock.patch("importlib.import_module") - @mock.patch("pkgutil.walk_packages") - def test_multiple_package_import(self, mock_walk, mock_import): - """Test importing modules from multiple packages.""" - - class TestClass(AutoImporterMixin): - auto_package = ("test.package1", "test.package2") - - # Setup mocks - packages = { - "test.package1": MockHelper.create_mock_package( - "test.package1", "test/package1" - ), - "test.package2": MockHelper.create_mock_package( - "test.package2", "test/package2" - ), - } - modules = { - "test.package1.moduleA": MockHelper.create_mock_module( - "test.package1.moduleA" - ), - "test.package2.moduleB": MockHelper.create_mock_module( - "test.package2.moduleB" - ), - } - - mock_import.side_effect = lambda name: {**packages, **modules}[name] + def test_auto_import_package_modules(self, mock_walk, mock_import, valid_instances): + """Test auto_import_package_modules core functionality.""" + test_class, config = valid_instances + + # Setup mocks based on config + packages = {} + modules = {} + + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + pkg_path = pkg.replace(".", "/") + packages[pkg] = MockHelper.create_mock_package(pkg, pkg_path) + else: + pkg = config["auto_package"] + packages[pkg] = MockHelper.create_mock_package(pkg, pkg.replace(".", "/")) + + for module_name, is_pkg in config["modules"]: + if not is_pkg: + modules[module_name] = MockHelper.create_mock_module(module_name) + + mock_import.side_effect = lambda name: {**packages, **modules}.get( + name, mock.MagicMock() + ) def walk_side_effect(path, prefix): - if prefix == "test.package1.": - return [(None, "test.package1.moduleA", False)] - elif prefix == "test.package2.": - return [(None, "test.package2.moduleB", False)] - return [] + return [ + (None, module_name, is_pkg) + for module_name, is_pkg in config["modules"] + if module_name.startswith(prefix) + ] mock_walk.side_effect = walk_side_effect # Execute - TestClass.auto_import_package_modules() + test_class.auto_import_package_modules() # Verify - assert TestClass.auto_imported_modules == [ - "test.package1.moduleA", - "test.package2.moduleB", - ] + assert test_class.auto_imported_modules == config["expected_imports"] + + # Verify package imports + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + mock_import.assert_any_call(pkg) + else: + mock_import.assert_any_call(config["auto_package"]) + + # Verify expected module imports + for expected_module in config["expected_imports"]: + mock_import.assert_any_call(expected_module) @pytest.mark.sanity @mock.patch("importlib.import_module") @mock.patch("pkgutil.walk_packages") - def test_ignore_modules(self, mock_walk, mock_import): - """Test that modules in auto_ignore_modules are skipped.""" + def test_auto_import_package_modules_invalid(self, mock_walk, mock_import): + """Test auto_import_package_modules with invalid configurations.""" class TestClass(AutoImporterMixin): auto_package = "test.package" - auto_ignore_modules = ("test.package.module1",) - # Setup mocks - mock_package = MockHelper.create_mock_package("test.package", "test/package") - mock_module2 = MockHelper.create_mock_module("test.package.module2") - - mock_import.side_effect = lambda name: { - "test.package": mock_package, - "test.package.module2": mock_module2, - }.get(name, mock.MagicMock()) - - mock_walk.return_value = [ - (None, "test.package.module1", False), - (None, "test.package.module2", False), - ] - - # Execute - TestClass.auto_import_package_modules() + # Test import error handling + mock_import.side_effect = ImportError("Module not found") - # Verify - assert TestClass.auto_imported_modules == ["test.package.module2"] - mock_import.assert_any_call("test.package") - mock_import.assert_any_call("test.package.module2") - # module1 should not be imported - with pytest.raises(AssertionError): - mock_import.assert_any_call("test.package.module1") + with pytest.raises(ImportError): + TestClass.auto_import_package_modules() @pytest.mark.sanity @mock.patch("importlib.import_module") @@ -269,3 +248,22 @@ class TestClass(AutoImporterMixin): # Verify assert TestClass.auto_imported_modules == ["test.package.module"] assert mock_import.call_count == 2 # Package + module (not duplicate) + + +class MockHelper: + """Helper class to create consistent mock objects for testing.""" + + @staticmethod + def create_mock_package(name: str, path: str): + """Create a mock package with required attributes.""" + package = mock.MagicMock() + package.__name__ = name + package.__path__ = [path] + return package + + @staticmethod + def create_mock_module(name: str): + """Create a mock module with required attributes.""" + module = mock.MagicMock() + module.__name__ = name + return module diff --git a/tests/unit/utils/test_encoding.py b/tests/unit/utils/test_encoding.py index 404a8671..da1f63ee 100644 --- a/tests/unit/utils/test_encoding.py +++ b/tests/unit/utils/test_encoding.py @@ -1,222 +1,556 @@ +from __future__ import annotations + +import uuid from typing import Any, Generic, TypeVar import pytest from pydantic import BaseModel, Field -from guidellm.utils.encoding import MsgpackEncoding +from guidellm.backend.objects import ( + GenerationRequest, + GenerationResponse, +) +from guidellm.scheduler.objects import RequestSchedulerTimings, ScheduledRequestInfo +from guidellm.utils.encoding import Encoder, MessageEncoding, Serializer + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing.""" + + name: str = Field(description="Name field for testing") + value: int = Field(description="Value field for testing") + + +class SampleModelSubclass(SampleModel): + """Subclass of SampleModel for testing.""" + extra_field: str -class SimpleModel(BaseModel): - name: str - value: int +SampleModelT = TypeVar("SampleModelT", bound=SampleModel) -class NestedModel(BaseModel): - simple: SimpleModel - items: list[str] - metadata: dict[str, Any] +class ComplexModel(BaseModel, Generic[SampleModelT]): + """Complex Pydantic model for testing.""" -T = TypeVar("T") + items: list[str] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + nested: SampleModelT | None = Field(default=None) -class GenericModel(BaseModel, Generic[T]): - data: T - count: int +class GenricModelWrapper(Generic[SampleModelT]): + """Simulates a layered generic type.""" + def method(self, **kwargs) -> ComplexModel[SampleModelT]: + return ComplexModel[SampleModelT](**kwargs) -class ComplexModel(BaseModel): - id: str = Field(description="Unique identifier") - nested: NestedModel - numbers: list[int] - mapping: dict[str, SimpleModel] +class TestMessageEncoding: + """Test suite for MessageEncoding class.""" + + @pytest.fixture( + params=[ + {"serialization": None, "encoding": None}, + {"serialization": "dict", "encoding": None}, + {"serialization": "sequence", "encoding": None}, + {"serialization": None, "encoding": "msgpack"}, + {"serialization": "dict", "encoding": "msgpack"}, + {"serialization": "sequence", "encoding": "msgpack"}, + {"serialization": None, "encoding": "msgspec"}, + {"serialization": "dict", "encoding": "msgspec"}, + {"serialization": "sequence", "encoding": "msgspec"}, + {"serialization": None, "encoding": ["msgspec", "msgpack"]}, + {"serialization": "dict", "encoding": ["msgspec", "msgpack"]}, + ], + ids=[ + "no_serialization_no_encoding", + "dict_serialization_no_encoding", + "str_serialization_no_encoding", + "no_serialization_msgpack", + "dict_serialization_msgpack", + "str_serialization_msgpack", + "no_serialization_msgspec", + "dict_serialization_msgspec", + "str_serialization_msgspec", + "no_serialization_encoding_list", + "dict_serialization_encoding_list", + ], + ) + def valid_instances(self, request): + """Fixture providing test data for MessageEncoding.""" + constructor_args = request.param + try: + instance = MessageEncoding(**constructor_args) + return instance, constructor_args + except ImportError: + pytest.skip("Required encoding library not available") + + @pytest.mark.smoke + def test_class_signatures(self): + """Test MessageEncoding inheritance and type relationships.""" + assert issubclass(MessageEncoding, Generic) + assert hasattr(MessageEncoding, "DEFAULT_ENCODING_PREFERENCE") + assert isinstance(MessageEncoding.DEFAULT_ENCODING_PREFERENCE, list) + assert MessageEncoding.DEFAULT_ENCODING_PREFERENCE == ["msgspec", "msgpack"] + + # Check classmethods + assert hasattr(MessageEncoding, "encode_message") + assert callable(MessageEncoding.encode_message) + assert hasattr(MessageEncoding, "decode_message") + assert callable(MessageEncoding.decode_message) + + # Check instance methods + assert hasattr(MessageEncoding, "__init__") + assert hasattr(MessageEncoding, "register_pydantic") + assert hasattr(MessageEncoding, "encode") + assert hasattr(MessageEncoding, "decode") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test MessageEncoding initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, MessageEncoding) + assert hasattr(instance, "serializer") + assert isinstance(instance.serializer, Serializer) + assert instance.serializer.serialization == constructor_args["serialization"] + assert hasattr(instance, "encoder") + assert isinstance(instance.encoder, Encoder) + + expected_encoding = constructor_args["encoding"] + if isinstance(expected_encoding, list): + assert instance.encoder.encoding in expected_encoding + else: + assert instance.encoder.encoding == expected_encoding -class TestMsgpackEncoding: @pytest.mark.smoke @pytest.mark.parametrize( - "primitive_data", + "obj", [ - # Basic primitives - 42, - 3.14, - True, - False, None, - "hello world", - "", - [], - [1, 2, 3], - {}, - {"key": "value"}, - # Nested collections - [1, [2, 3], {"nested": True}], - {"outer": {"inner": [1, 2, 3]}}, - # Mixed types - [1, "string", 3.14, True, None], - {"int": 42, "str": "hello", "float": 3.14, "bool": True, "null": None}, + 0, + 0.0, + "0.1.2.3", + [0, 0.0, "0.1.2.3", None], + (0, 0.0, "0.1.2.3", None), + {"key1": 0, "key2": 0.0, "key3": "0.1.2.3", "key4": None}, ], ) - def test_encode_decode_primitives(self, primitive_data): - """Test encoding and decoding of Python primitives and collections.""" - encoded = MsgpackEncoding.encode(primitive_data) - assert isinstance(encoded, bytes) + def test_encode_decode_python(self, valid_instances, obj: Any): + """Test MessageEncoding encode/decode with comprehensive data types.""" + instance, constructor_args = valid_instances - decoded = MsgpackEncoding.decode(encoded) - assert decoded == primitive_data - assert isinstance(decoded, type(primitive_data)) + message = instance.encode(obj) + decoded = instance.decode(message) + + if isinstance(obj, tuple): + assert list(decoded) == list(obj) + else: + assert decoded == obj @pytest.mark.smoke @pytest.mark.parametrize( - ("tuple_data", "expected_list"), + "obj", [ - ((), []), - ((1, 2, 3), [1, 2, 3]), - ((1, (2, 3), {"tuple_dict": True}), [1, [2, 3], {"tuple_dict": True}]), + SampleModel(name="sample", value=123), + ComplexModel( + items=["item1", "item2"], + metadata={"key": "value"}, + nested=SampleModel(name="sample", value=123), + ), + ( + SampleModel(name="sample", value=123), + None, + ComplexModel( + items=["item1", "item2"], + metadata={"key": "value"}, + nested=SampleModel(name="sample", value=123), + ), + ), + { + "key1": SampleModel(name="sample", value=123), + "key2": None, + "key3": ComplexModel( + items=["item1", "item2"], + metadata={"key": "value"}, + nested=SampleModel(name="sample", value=123), + ), + }, ], ) - def test_encode_decode_tuples(self, tuple_data, expected_list): - encoded = MsgpackEncoding.encode(tuple_data) - assert isinstance(encoded, bytes) + def test_encode_decode_pydantic(self, valid_instances, obj: Any): + """Test MessageEncoding encode/decode with Pydantic models.""" + instance, constructor_args = valid_instances + + if ( + constructor_args["serialization"] is None + and constructor_args["encoding"] is not None + ): + # msgpack/msgspec don't support Pydantic models natively + pytest.skip("Skipping unsupported Pydantic serialization/encoding combo") + + # Register Pydantic models for proper serialization + instance.register_pydantic(SampleModel) + instance.register_pydantic(ComplexModel) - decoded = MsgpackEncoding.decode(encoded) - assert decoded == expected_list - assert isinstance(decoded, list) + message = instance.encode(obj) + decoded = instance.decode(message) + + if isinstance(obj, tuple): + assert list(decoded) == list(obj) + else: + assert decoded == obj @pytest.mark.smoke @pytest.mark.parametrize( - "model_data", + "obj", [ - SimpleModel(name="test", value=42), - NestedModel( - simple=SimpleModel(name="nested", value=100), - items=["a", "b", "c"], - metadata={"key": "value", "number": 123}, + ( + None, + GenerationRequest(content="test content"), + ScheduledRequestInfo( + scheduler_timings=RequestSchedulerTimings( + targeted_start=1.0, + queued=0.1, + dequeued=0.2, + scheduled_at=0.3, + resolve_start=1.1, + resolve_end=1.5, + finalized=1.6, + ) + ), ), - ComplexModel( - id="test-123", - nested=NestedModel( - simple=SimpleModel(name="complex", value=999), - items=["x", "y"], - metadata={"complex": True}, + ( + GenerationResponse( + request_id=str(uuid.uuid4()), + request_args={}, + value="test response", + request_prompt_tokens=2, + request_output_tokens=3, + response_prompt_tokens=4, + response_output_tokens=6, + ), + GenerationRequest(content="test content"), + ScheduledRequestInfo( + scheduler_timings=RequestSchedulerTimings( + targeted_start=1.0, + queued=0.1, + dequeued=0.2, + scheduled_at=0.3, + resolve_start=1.1, + resolve_end=1.5, + finalized=1.6, + ) ), - numbers=[1, 2, 3, 4, 5], - mapping={ - "first": SimpleModel(name="first", value=1), - "second": SimpleModel(name="second", value=2), - }, ), ], ) - def test_encode_decode_pydantic_models(self, model_data): - """Test encoding and decoding of Pydantic models.""" - encoded = MsgpackEncoding.encode(model_data) - assert isinstance(encoded, bytes) + def test_encode_decode_generative(self, valid_instances, obj: Any): + """Test MessageEncoding encode/decode with generative models.""" + instance, constructor_args = valid_instances + + if ( + constructor_args["serialization"] is None + and constructor_args["encoding"] is not None + ): + # msgpack/msgspec don't support Pydantic models natively + pytest.skip("Skipping unsupported Pydantic serialization/encoding combo") + + instance.register_pydantic(GenerationRequest) + instance.register_pydantic(GenerationResponse) + instance.register_pydantic(ScheduledRequestInfo) - decoded = MsgpackEncoding.decode(encoded) - assert decoded == model_data - assert isinstance(decoded, type(model_data)) - assert decoded.model_dump() == model_data.model_dump() + message = instance.encode(obj) + decoded = instance.decode(message) + + assert list(decoded) == list(obj) @pytest.mark.smoke @pytest.mark.parametrize( - ("generic_model", "expected_type"), + "serialization", [ - (GenericModel[str](data="hello", count=1), str), - (GenericModel[int](data=42, count=2), int), - (GenericModel[list[str]](data=["a", "b"], count=3), list), + None, + "dict", + "sequence", ], ) - def test_encode_decode_generic_models(self, generic_model, expected_type): - """Test encoding and decoding of generic Pydantic models.""" - encoded = MsgpackEncoding.encode(generic_model) - assert isinstance(encoded, bytes) - - decoded = MsgpackEncoding.decode(encoded) - assert decoded == generic_model - assert decoded.data == generic_model.data - assert decoded.count == generic_model.count - assert isinstance(decoded.data, expected_type) - - @pytest.mark.smoke @pytest.mark.parametrize( - "mixed_data", + "encoding", + [None, "msgpack", "msgspec"], + ) + @pytest.mark.parametrize( + "obj", [ - [SimpleModel(name="item1", value=1), SimpleModel(name="item2", value=2)], - {"model": SimpleModel(name="dict_value", value=42), "primitive": "string"}, + "0.1.2.3", + [0, 0.0, "0.1.2.3", None, SampleModel(name="sample", value=123)], { - "models": [ - SimpleModel(name="item1", value=1), - SimpleModel(name="item2", value=2), - ], - "data": {"nested": {"deep": SimpleModel(name="deep", value=999)}}, + "key1": 0, + "key2": 0.0, + "key3": "0.1.2.3", + "key4": None, + "key5": ComplexModel( + items=["item1", "item2"], + metadata={"key": "value"}, + nested=SampleModel(name="sample", value=123), + ), }, - [ - { - "id": "test", - "model": NestedModel( - simple=SimpleModel(name="nested_in_list", value=456), - items=["nested", "list"], - metadata={"in_list": True}, - ), - "primitives": [1, 2, 3], - } - ], ], ) - def test_encode_decode_mixed_collections(self, mixed_data): - encoded = MsgpackEncoding.encode(mixed_data) - assert isinstance(encoded, bytes) + def test_encode_decode_message(self, serialization, encoding, obj): + """Test MessageEncoding.encode_message and decode_message class methods.""" + if encoding is not None and serialization is None and obj != "0.1.2.3": + pytest.skip("Skipping unsupported serialization/encoding combo") + + try: + serializer = Serializer(serialization) if serialization else None + encoder = Encoder(encoding) if encoding else None - decoded = MsgpackEncoding.decode(encoded) - assert decoded == mixed_data - assert isinstance(decoded, type(mixed_data)) + message = MessageEncoding.encode_message(obj, serializer, encoder) + decoded = MessageEncoding.decode_message(message, serializer, encoder) + + if isinstance(obj, tuple): + assert list(decoded) == list(obj) + else: + assert decoded == obj + except ImportError: + pytest.skip("Required encoding library not available") @pytest.mark.smoke - def test_round_trip_consistency(self): - original_data = { - "simple": SimpleModel(name="test", value=42), - "nested": NestedModel( - simple=SimpleModel(name="nested", value=100), - items=["a", "b", "c"], - metadata={"key": "value"}, - ), - "primitives": [1, 2, 3, "string", True, None], - "list_data": [1, 2, SimpleModel(name="list", value=999)], - } + def test_register_pydantic(self): + """Test MessageEncoding.register_pydantic functionality.""" + instance = MessageEncoding(serialization="dict", encoding=None) + assert len(instance.serializer.pydantic_registry) == 0 + instance.register_pydantic(SampleModel) + assert len(instance.serializer.pydantic_registry) == 1 + assert ( + instance.serializer.pydantic_registry.values().__iter__().__next__() + is SampleModel + ) - current_data = original_data - for _ in range(3): - encoded = MsgpackEncoding.encode(current_data) - current_data = MsgpackEncoding.decode(encoded) + @pytest.mark.sanity + def test_invalid_initialization_values(self): + """Test invalid initialization (unsupported encoding).""" + inst = MessageEncoding(serialization="dict", encoding=["invalid_encoding"]) # type: ignore[arg-type] + assert inst.encoder.encoding is None + with pytest.raises(ImportError): + MessageEncoding(serialization="dict", encoding="invalid") # type: ignore[arg-type] - assert current_data == original_data + +class TestEncoder: + """Test suite for Encoder class.""" + + @pytest.fixture( + params=[ + None, + "msgpack", + "msgspec", + ["msgspec", "msgpack"], + ["msgpack", "msgspec"], + ], + ids=[ + "none", + "msgpack", + "msgspec", + "list_pref_msgspec_first", + "list_pref_msgpack_first", + ], + ) + def valid_instances(self, request): + args = request.param + try: + inst = Encoder(args) + except ImportError: + pytest.skip("Encoding backend missing") + return inst, args @pytest.mark.smoke - def test_empty_collections(self): - test_cases = [[], {}] + def test_class_signatures(self): + assert hasattr(Encoder, "encode") + assert hasattr(Encoder, "decode") + assert hasattr(Encoder, "_resolve_encoding") - for empty_collection in test_cases: - encoded = MsgpackEncoding.encode(empty_collection) - decoded = MsgpackEncoding.decode(encoded) - assert decoded == empty_collection - assert isinstance(decoded, type(empty_collection)) + @pytest.mark.smoke + def test_initialization(self, valid_instances): + inst, args = valid_instances + assert isinstance(inst, Encoder) + if isinstance(args, list): + assert inst.encoding in args or inst.encoding is None + else: + assert inst.encoding == args + + @pytest.mark.sanity + def test_invalid_initialization_values(self): + with pytest.raises(ImportError): + Encoder("invalid") # type: ignore[arg-type] @pytest.mark.smoke - def test_pydantic_constants(self): - """Test that the Pydantic-related constants are properly defined.""" - assert MsgpackEncoding.PYDANTIC_TAG == "__pydantic__" - assert MsgpackEncoding.PYDANTIC_DATA == "data" - assert MsgpackEncoding.PYDANTIC_ARGS == "args" + @pytest.mark.parametrize("obj", [None, 0, 1.2, "text", [1, 2], {"a": 1}]) + def test_encode_decode(self, valid_instances, obj): + inst, _ = valid_instances + msg = inst.encode(obj) + out = inst.decode(msg) + assert out == obj + + +class TestSerializer: + """Test suite for Serializer class.""" + + @pytest.fixture(params=[None, "dict", "sequence"], ids=["none", "dict", "sequence"]) + def valid_instances(self, request): + inst = Serializer(request.param) + return inst, request.param + + @pytest.mark.smoke + def test_class_signatures(self): + assert hasattr(Serializer, "serialize") + assert hasattr(Serializer, "deserialize") + assert hasattr(Serializer, "register_pydantic") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + inst, mode = valid_instances + assert isinstance(inst, Serializer) + assert inst.serialization == mode + + @pytest.mark.smoke + def test_register_pydantic(self, valid_instances): + inst, _ = valid_instances + assert len(inst.pydantic_registry) == 0 + inst.register_pydantic(SampleModel) + assert len(inst.pydantic_registry) == 1 + + @pytest.mark.smoke + @pytest.mark.parametrize( + "obj", + [ + 1, + "str_val", + [1, 2, 3], + SampleModel(name="x", value=1), + {"k": SampleModel(name="y", value=2)}, + ], + ) + def test_serialize_deserialize(self, valid_instances, obj): + inst, mode = valid_instances + inst.register_pydantic(SampleModel) + msg = inst.serialize(obj) + out = inst.deserialize(msg) + if isinstance(obj, list): + assert list(out) == obj + else: + assert out == obj + + @pytest.mark.regression + def test_sequence_mapping_roundtrip(self): + inst = Serializer("sequence") + inst.register_pydantic(SampleModel) + data = { + "a": SampleModel(name="a", value=1), + "b": SampleModel(name="b", value=2), + } + msg = inst.serialize(data) + out = inst.deserialize(msg) + assert out == data + + @pytest.mark.sanity + def test_to_from_dict_variations(self): + inst = Serializer("dict") + inst.register_pydantic(SampleModel) + model = SampleModel(name="n", value=3) + lst = [model, 5] + mp = {"k1": model, "k2": 9} + assert inst.from_dict(inst.to_dict(model)) == model + assert inst.from_dict(inst.to_dict(lst)) == lst + assert inst.from_dict(inst.to_dict(mp)) == mp + + @pytest.mark.sanity + @pytest.mark.parametrize( + "collection", + [ + [SampleModel(name="x", value=1), 2, 3], + (SampleModel(name="y", value=2), None), + ], + ) + def test_to_from_sequence_collections(self, collection): + inst = Serializer("sequence") + inst.register_pydantic(SampleModel) + seq = inst.to_sequence(collection) + out = inst.from_sequence(seq) + assert len(out) == len(collection) + assert all(a == b for a, b in zip(out, list(collection))) @pytest.mark.sanity - def test_encode_invalid_data(self): - """Test encoding behavior with edge cases.""" + def test_to_from_sequence_mapping(self): + inst = Serializer("sequence") + inst.register_pydantic(SampleModel) + data = {"k": SampleModel(name="z", value=7), "j": 1} + seq = inst.to_sequence(data) + out = inst.from_sequence(seq) + assert out == data - class CustomClass: - def __init__(self, value): - self.value = value + @pytest.mark.sanity + def test_sequence_multiple_root_raises(self): + inst = Serializer("sequence") + part1 = inst.pack_next_sequence("python", inst.to_sequence_python(1), None) + part2 = inst.pack_next_sequence("python", inst.to_sequence_python(2), None) + with pytest.raises(ValueError): + inst.from_sequence(part1 + part2) # type: ignore[operator] + + @pytest.mark.sanity + def test_pack_next_sequence_type_mismatch(self): + inst = Serializer("sequence") + first_payload = inst.to_sequence_python(1) + first = inst.pack_next_sequence("python", first_payload, None) + bad_payload: Any = ( + first_payload.decode() if isinstance(first_payload, bytes) else b"1" + ) + with pytest.raises(ValueError): + inst.pack_next_sequence("python", bad_payload, first) - custom_obj = CustomClass(42) - primitive = MsgpackEncoding.to_primitive(custom_obj) - assert primitive is custom_obj + @pytest.mark.sanity + def test_unpack_invalid(self): + inst = Serializer("sequence") + with pytest.raises(ValueError): + inst.unpack_next_sequence("X|3|abc") + with pytest.raises(ValueError): + inst.unpack_next_sequence("p?bad") + + @pytest.mark.sanity + def test_dynamic_import_load_pydantic(self, monkeypatch): + inst = Serializer("dict") + inst.pydantic_registry.clear() + sample = SampleModel(name="dyn", value=5) + dumped = inst.to_dict(sample) + inst.pydantic_registry.clear() + restored = inst.from_dict(dumped) + assert restored == sample + + @pytest.mark.sanity + def test_generic_model(self): + inst = Serializer("dict") + inst.register_pydantic(ComplexModel[SampleModelSubclass]) + nested = ComplexModel[SampleModelSubclass]( + items=["i1", "i2"], + metadata={"m": 1}, + nested=SampleModelSubclass(name="nested", value=10, extra_field="extra"), + ) + dumped = inst.to_dict(nested) + restored = inst.from_dict(dumped) + assert restored == nested + + @pytest.mark.sanity + @pytest.mark.xfail( + reason="A generic object returned by a generic method loses its type args" + ) + def test_generic_emitted_type(self): + generic_instance = GenricModelWrapper[SampleModelSubclass]() + + inst = Serializer("dict") + inst.register_pydantic(ComplexModel[SampleModelSubclass]) + nested = generic_instance.method( + items=["i1", "i2"], + metadata={"m": 1}, + nested=SampleModelSubclass(name="nested", value=10, extra_field="extra"), + ) + dumped = inst.to_dict(nested) + restored = inst.from_dict(dumped) + assert restored == nested diff --git a/tests/unit/utils/test_functions.py b/tests/unit/utils/test_functions.py new file mode 100644 index 00000000..3b353759 --- /dev/null +++ b/tests/unit/utils/test_functions.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from datetime import datetime + +import pytest + +from guidellm.utils.functions import ( + all_defined, + safe_add, + safe_divide, + safe_format_timestamp, + safe_getattr, + safe_multiply, +) + + +class TestAllDefined: + """Test suite for all_defined function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("values", "expected"), + [ + ((1, 2, 3), True), + (("test", "hello"), True), + ((0, False, ""), True), + ((1, None, 3), False), + ((None,), False), + ((None, None), False), + ((), True), + ], + ) + def test_invocation(self, values, expected): + """Test all_defined with valid inputs.""" + result = all_defined(*values) + assert result == expected + + @pytest.mark.sanity + def test_mixed_types(self): + """Test all_defined with mixed data types.""" + result = all_defined(1, "test", [], {}, 0.0, False) + assert result is True + + result = all_defined(1, "test", None, {}) + assert result is False + + +class TestSafeGetattr: + """Test suite for safe_getattr function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("obj", "attr", "default", "expected"), + [ + (None, "any_attr", "default_val", "default_val"), + (None, "any_attr", None, None), + ("test_string", "nonexistent", "default_val", "default_val"), + ], + ) + def test_invocation(self, obj, attr, default, expected): + """Test safe_getattr with valid inputs.""" + result = safe_getattr(obj, attr, default) + assert result == expected + + @pytest.mark.smoke + def test_with_object(self): + """Test safe_getattr with actual object attributes.""" + + class TestObj: + test_attr = "test_value" + + obj = TestObj() + result = safe_getattr(obj, "test_attr", "default") + assert result == "test_value" + + result = safe_getattr(obj, "missing_attr", "default") + assert result == "default" + + # Test with method attribute + result = safe_getattr("test_string", "upper", None) + assert callable(result) + assert result() == "TEST_STRING" + + +class TestSafeDivide: + """Test suite for safe_divide function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("numerator", "denominator", "num_default", "den_default", "expected"), + [ + (10, 2, 0.0, 1.0, 5.0), + (None, 2, 6.0, 1.0, 3.0), + (10, None, 0.0, 5.0, 2.0), + (None, None, 8.0, 4.0, 2.0), + (10, 0, 0.0, 1.0, 10 / 1e-10), + ], + ) + def test_invocation( + self, numerator, denominator, num_default, den_default, expected + ): + """Test safe_divide with valid inputs.""" + result = safe_divide(numerator, denominator, num_default, den_default) + assert result == pytest.approx(expected, rel=1e-6) + + @pytest.mark.sanity + def test_zero_division_protection(self): + """Test safe_divide protection against zero division.""" + result = safe_divide(10, 0) + assert result == 10 / 1e-10 + + result = safe_divide(5, None, den_default=0) + assert result == 5 / 1e-10 + + +class TestSafeMultiply: + """Test suite for safe_multiply function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("values", "default", "expected"), + [ + ((2, 3, 4), 1.0, 24.0), + ((2, None, 4), 1.0, 8.0), + ((None, None), 5.0, 5.0), + ((), 3.0, 3.0), + ((2, 3, None, 5), 2.0, 60.0), + ], + ) + def test_invocation(self, values, default, expected): + """Test safe_multiply with valid inputs.""" + result = safe_multiply(*values, default=default) + assert result == expected + + @pytest.mark.sanity + def test_with_zero(self): + """Test safe_multiply with zero values.""" + result = safe_multiply(2, 0, 3, default=1.0) + assert result == 0.0 + + result = safe_multiply(None, 0, None, default=5.0) + assert result == 0.0 + + +class TestSafeAdd: + """Test suite for safe_add function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("values", "signs", "default", "expected"), + [ + ((1, 2, 3), None, 0.0, 6.0), + ((1, None, 3), None, 5.0, 9.0), + ((10, 5), [1, -1], 0.0, 5.0), + ((None, None), [1, -1], 2.0, 0.0), + ((), None, 3.0, 3.0), + ((1, 2, 3), [1, 1, -1], 0.0, 0.0), + ], + ) + def test_invocation(self, values, signs, default, expected): + """Test safe_add with valid inputs.""" + result = safe_add(*values, signs=signs, default=default) + assert result == expected + + @pytest.mark.sanity + def test_invalid_signs_length(self): + """Test safe_add with invalid signs length.""" + with pytest.raises( + ValueError, match="Length of signs must match length of values" + ): + safe_add(1, 2, 3, signs=[1, -1]) + + @pytest.mark.sanity + def test_single_value(self): + """Test safe_add with single value.""" + result = safe_add(5, default=1.0) + assert result == 5.0 + + result = safe_add(None, default=3.0) + assert result == 3.0 + + +class TestSafeFormatTimestamp: + """Test suite for safe_format_timestamp function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("timestamp", "format_", "default", "expected"), + [ + (1609459200.0, "%Y-%m-%d", "N/A", "2020-12-31"), + (1609459200.0, "%H:%M:%S", "N/A", "19:00:00"), + (None, "%H:%M:%S", "N/A", "N/A"), + (-1, "%H:%M:%S", "N/A", "N/A"), + (2**32, "%H:%M:%S", "N/A", "N/A"), + ], + ) + def test_invocation(self, timestamp, format_, default, expected): + """Test safe_format_timestamp with valid inputs.""" + result = safe_format_timestamp(timestamp, format_, default) + assert result == expected + + @pytest.mark.sanity + def test_edge_cases(self): + """Test safe_format_timestamp with edge case timestamps.""" + result = safe_format_timestamp(0.0, "%Y", "N/A") + assert result == "1969" + + result = safe_format_timestamp(1.0, "%Y", "N/A") + assert result == "1969" + + result = safe_format_timestamp(2**31 - 1, "%Y", "N/A") + expected_year = datetime.fromtimestamp(2**31 - 1).strftime("%Y") + assert result == expected_year + + @pytest.mark.sanity + def test_invalid_timestamp_ranges(self): + """Test safe_format_timestamp with invalid timestamp ranges.""" + result = safe_format_timestamp(2**31 + 1, "%Y", "ERROR") + assert result == "ERROR" + + result = safe_format_timestamp(-1000, "%Y", "ERROR") + assert result == "ERROR" diff --git a/tests/unit/utils/test_messaging.py b/tests/unit/utils/test_messaging.py new file mode 100644 index 00000000..d6627e88 --- /dev/null +++ b/tests/unit/utils/test_messaging.py @@ -0,0 +1,974 @@ +from __future__ import annotations + +import asyncio +import multiprocessing +import threading +from functools import wraps +from typing import Any, TypeVar + +import culsans +import pytest +from pydantic import BaseModel + +from guidellm.backend import ( + GenerationRequest, + GenerationResponse, +) +from guidellm.scheduler import ScheduledRequestInfo +from guidellm.utils import ( + InterProcessMessaging, + InterProcessMessagingManagerQueue, + InterProcessMessagingPipe, + InterProcessMessagingQueue, +) +from guidellm.utils.messaging import ReceiveMessageT, SendMessageT + + +def async_timeout(delay: float): + """Decorator to add timeout to async test functions.""" + + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +class MockMessage(BaseModel): + content: str + num: int + + +class MockProcessTarget: + """Mock process target for testing.""" + + def __init__( + self, + messaging: InterProcessMessaging, + num_messages: int, + worker_index: int = 0, + ): + self.messaging = messaging + self.num_messages = num_messages + self.worker_index = worker_index + + def run(self): + loop = asyncio.new_event_loop() + + try: + asyncio.set_event_loop(loop) + asyncio.run(asyncio.wait_for(self._async_runner(), timeout=10.0)) + except RuntimeError: + pass + finally: + loop.close() + + async def _async_runner(self): + await self.messaging.start( + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo, + ], + ) + + try: + for _ in range(self.num_messages): + obj = await self.messaging.get(timeout=2.0) + await self.messaging.put(obj, timeout=2.0) + finally: + await self.messaging.stop() + + +@pytest.fixture( + params=[ + {"ctx_name": "fork"}, + {"ctx_name": "spawn"}, + ], + ids=["fork_ctx", "spawn_ctx"], +) +def multiprocessing_contexts(request): + context = multiprocessing.get_context(request.param["ctx_name"]) + manager = context.Manager() + try: + yield manager, context + finally: + manager.shutdown() + + +def test_send_message_type(): + """Test that SendMessageT is filled out correctly as a TypeVar.""" + assert isinstance(SendMessageT, type(TypeVar("test"))) + assert SendMessageT.__name__ == "SendMessageT" + assert SendMessageT.__bound__ is Any + assert SendMessageT.__constraints__ == () + + +def test_receive_message_type(): + """Test that ReceiveMessageT is filled out correctly as a TypeVar.""" + assert isinstance(ReceiveMessageT, type(TypeVar("test"))) + assert ReceiveMessageT.__name__ == "ReceiveMessageT" + assert ReceiveMessageT.__bound__ is Any + assert ReceiveMessageT.__constraints__ == () + + +class TestInterProcessMessaging: + """Test suite for InterProcessMessaging abstract base class.""" + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InterProcessMessaging abstract class signatures.""" + assert hasattr(InterProcessMessaging, "__init__") + assert hasattr(InterProcessMessaging, "create_worker_copy") + assert hasattr(InterProcessMessaging, "create_send_messages_threads") + assert hasattr(InterProcessMessaging, "create_receive_messages_threads") + assert hasattr(InterProcessMessaging, "start") + assert hasattr(InterProcessMessaging, "stop") + assert hasattr(InterProcessMessaging, "get") + assert hasattr(InterProcessMessaging, "put") + + # Check abstract methods + assert getattr( + InterProcessMessaging.create_worker_copy, "__isabstractmethod__", False + ) + assert getattr( + InterProcessMessaging.create_send_messages_threads, + "__isabstractmethod__", + False, + ) + assert getattr( + InterProcessMessaging.create_receive_messages_threads, + "__isabstractmethod__", + False, + ) + + @pytest.mark.smoke + def test_cannot_instantiate_directly(self): + """Test InterProcessMessaging cannot be instantiated directly.""" + with pytest.raises(TypeError): + InterProcessMessaging() + + +class TestInterProcessMessagingQueue: + """Test suite for InterProcessMessagingQueue.""" + + @pytest.fixture( + params=[ + { + "serialization": "dict", + "encoding": None, + "max_pending_size": None, + "max_done_size": None, + "worker_index": None, + }, + { + "serialization": "sequence", + "encoding": None, + "max_pending_size": 10, + "max_buffer_send_size": 2, + "max_done_size": 5, + "max_buffer_receive_size": 3, + "worker_index": None, + }, + { + "serialization": None, + "encoding": None, + "max_pending_size": None, + "max_done_size": None, + "worker_index": None, + }, + ], + ) + def valid_instances(self, multiprocessing_contexts, request): + """Fixture providing test data for InterProcessMessagingQueue.""" + constructor_args = request.param + manager, context = multiprocessing_contexts + instance = InterProcessMessagingQueue( + **constructor_args, poll_interval=0.01, mp_context=context + ) + + return instance, constructor_args, manager, context + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InterProcessMessagingQueue inheritance and signatures.""" + assert issubclass(InterProcessMessagingQueue, InterProcessMessaging) + assert hasattr(InterProcessMessagingQueue, "__init__") + assert hasattr(InterProcessMessagingQueue, "create_worker_copy") + assert hasattr(InterProcessMessagingQueue, "create_send_messages_threads") + assert hasattr(InterProcessMessagingQueue, "create_receive_messages_threads") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InterProcessMessagingQueue initialization.""" + instance, constructor_args, _, _ = valid_instances + + assert isinstance(instance, InterProcessMessagingQueue) + assert instance.worker_index == constructor_args["worker_index"] + assert instance.max_pending_size == constructor_args["max_pending_size"] + assert instance.max_done_size == constructor_args["max_done_size"] + assert hasattr(instance, "pending_queue") + assert hasattr(instance, "done_queue") + assert instance.running is False + + @pytest.mark.smoke + def test_create_worker_copy(self, valid_instances): + """Test InterProcessMessagingQueue.create_worker_copy.""" + instance, _, _, _ = valid_instances + worker_index = 42 + + worker_copy = instance.create_worker_copy(worker_index) + + assert isinstance(worker_copy, InterProcessMessagingQueue) + assert worker_copy.worker_index == worker_index + assert worker_copy.pending_queue is instance.pending_queue + assert worker_copy.done_queue is instance.done_queue + assert worker_copy.max_pending_size == instance.max_pending_size + assert worker_copy.max_done_size == instance.max_done_size + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "stop_events_lambda", + [ + list, + lambda: [threading.Event()], + lambda: [multiprocessing.Event()], + lambda: [threading.Event(), multiprocessing.Event()], + ], + ) + @async_timeout(5.0) + async def test_start_stop_lifecycle(self, valid_instances, stop_events_lambda): + """Test InterProcessMessagingQueue start/stop lifecycle.""" + instance, _, _, _ = valid_instances + stop_events = stop_events_lambda() + + # Initially not running + assert instance.running is False + assert instance.send_stopped_event is None + assert instance.receive_stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + # Start should work + await instance.start( + send_stop_criteria=stop_events, receive_stop_criteria=stop_events + ) + assert instance.running is True + assert instance.send_stopped_event is not None + assert isinstance(instance.send_stopped_event, threading.Event) + assert instance.receive_stopped_event is not None + assert isinstance(instance.receive_stopped_event, threading.Event) + assert instance.shutdown_event is not None + assert isinstance(instance.shutdown_event, threading.Event) + assert instance.buffer_send_queue is not None + assert isinstance(instance.buffer_send_queue, culsans.Queue) + assert instance.buffer_receive_queue is not None + assert isinstance(instance.buffer_receive_queue, culsans.Queue) + assert instance.send_task is not None + assert isinstance(instance.send_task, asyncio.Task) + assert instance.receive_task is not None + assert isinstance(instance.receive_task, asyncio.Task) + + # Stop should work + if stop_events: + for event in stop_events: + event.set() + + await asyncio.sleep(0.1) + assert instance.send_stopped_event.is_set() + assert instance.receive_stopped_event.is_set() + assert instance.send_task.done() + assert instance.receive_task.done() + + await instance.stop() + assert instance.running is False + assert instance.send_stopped_event is None + assert instance.receive_stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + 123451, + "asdfghjkl", + [None, 123, 45.67, "string", {"key": "value"}, [1, 2, 3]], + {"key": "value", "another_key": 123.456, "yet_another_key": [1, 2, 3]}, + MockMessage(content="hello", num=42), + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ( + GenerationResponse(request_id="id", request_args={}), + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get(self, valid_instances, test_obj): + instance, constructor_args, manager, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + # Handle case where ScheduledRequestInfo is not pickleable + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + process_target = MockProcessTarget( + instance.create_worker_copy(0), num_messages=5 + ) + process = context.Process(target=process_target.run) + process.start() + + # Local startup and wait + await instance.start( + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo, + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5): + await instance.put(test_obj, timeout=2.0) + + for _ in range(5): + val = await instance.get(timeout=2.0) + if not isinstance(test_obj, tuple): + assert val == test_obj + else: + assert list(val) == list(test_obj) + finally: + # Clean up + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ( + GenerationResponse(request_id="id", request_args={}), + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get_iter(self, valid_instances, test_obj): + instance, constructor_args, manager, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + # Handle case where ScheduledRequestInfo is not pickleable + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + process_target = MockProcessTarget( + instance.create_worker_copy(0), num_messages=5 + ) + process = context.Process(target=process_target.run) + process.start() + + def _received_callback(msg): + if not isinstance(test_obj, tuple): + assert msg == test_obj + else: + assert list(msg) == list(test_obj) + return "changed_obj" + + # Local startup and wait + await instance.start( + send_items=[test_obj for _ in range(5)], + receive_callback=_received_callback, + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo, + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5): + val = await instance.get(timeout=2.0) + assert val == "changed_obj" + finally: + # Clean up + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() + + +class TestInterProcessMessagingManagerQueue: + """Test suite for InterProcessMessagingManagerQueue.""" + + @pytest.fixture( + params=[ + { + "serialization": "dict", + "encoding": None, + "max_pending_size": None, + "max_done_size": None, + "worker_index": None, + }, + { + "serialization": "sequence", + "encoding": None, + "max_pending_size": 10, + "max_buffer_send_size": 2, + "max_done_size": 5, + "max_buffer_receive_size": 3, + "worker_index": None, + }, + { + "serialization": None, + "encoding": None, + "max_pending_size": None, + "max_done_size": None, + "worker_index": None, + }, + ], + ) + def valid_instances(self, multiprocessing_contexts, request): + """Fixture providing test data for InterProcessMessagingManagerQueue.""" + constructor_args = request.param + manager, context = multiprocessing_contexts + instance = InterProcessMessagingManagerQueue( + **constructor_args, manager=manager, poll_interval=0.01 + ) + return instance, constructor_args, manager, context + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InterProcessMessagingManagerQueue inheritance and signatures.""" + assert issubclass(InterProcessMessagingManagerQueue, InterProcessMessaging) + assert issubclass(InterProcessMessagingManagerQueue, InterProcessMessagingQueue) + assert hasattr(InterProcessMessagingManagerQueue, "__init__") + assert hasattr(InterProcessMessagingManagerQueue, "create_worker_copy") + assert hasattr(InterProcessMessagingManagerQueue, "_send_messages_task_thread") + assert hasattr( + InterProcessMessagingManagerQueue, "_receive_messages_task_thread" + ) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InterProcessMessagingManagerQueue initialization.""" + instance, constructor_args, _, _ = valid_instances + + assert isinstance(instance, InterProcessMessagingManagerQueue) + assert instance.worker_index == constructor_args["worker_index"] + assert instance.max_pending_size == constructor_args["max_pending_size"] + assert instance.max_done_size == constructor_args["max_done_size"] + assert hasattr(instance, "pending_queue") + assert hasattr(instance, "done_queue") + assert instance.running is False + + @pytest.mark.smoke + def test_create_worker_copy(self, valid_instances): + """Test InterProcessMessagingQueue.create_worker_copy.""" + instance, _, _, _ = valid_instances + worker_index = 42 + + worker_copy = instance.create_worker_copy(worker_index) + + assert isinstance(worker_copy, InterProcessMessagingManagerQueue) + assert worker_copy.worker_index == worker_index + assert worker_copy.pending_queue is instance.pending_queue + assert worker_copy.done_queue is instance.done_queue + assert worker_copy.max_pending_size == instance.max_pending_size + assert worker_copy.max_done_size == instance.max_done_size + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "stop_events_lambda", + [ + list, + lambda: [threading.Event()], + lambda: [multiprocessing.Event()], + lambda: [threading.Event(), multiprocessing.Event()], + ], + ) + @async_timeout(5.0) + async def test_start_stop_lifecycle(self, valid_instances, stop_events_lambda): + """Test InterProcessMessagingQueue start/stop lifecycle.""" + instance, _, _, _ = valid_instances + stop_events = stop_events_lambda() + + # Initially not running + assert instance.running is False + assert instance.send_stopped_event is None + assert instance.receive_stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + # Start should work + await instance.start( + send_stop_criteria=stop_events, receive_stop_criteria=stop_events + ) + assert instance.running is True + assert instance.send_stopped_event is not None + assert isinstance(instance.send_stopped_event, threading.Event) + assert instance.receive_stopped_event is not None + assert isinstance(instance.receive_stopped_event, threading.Event) + assert instance.shutdown_event is not None + assert isinstance(instance.shutdown_event, threading.Event) + assert instance.buffer_send_queue is not None + assert isinstance(instance.buffer_send_queue, culsans.Queue) + assert instance.buffer_receive_queue is not None + assert isinstance(instance.buffer_receive_queue, culsans.Queue) + assert instance.send_task is not None + assert isinstance(instance.send_task, asyncio.Task) + assert instance.receive_task is not None + assert isinstance(instance.receive_task, asyncio.Task) + + # Stop should work + if stop_events: + for event in stop_events: + event.set() + + await asyncio.sleep(0.1) + assert instance.send_stopped_event.is_set() + assert instance.receive_stopped_event.is_set() + assert instance.send_task.done() + assert instance.receive_task.done() + + await instance.stop() + assert instance.running is False + assert instance.send_stopped_event is None + assert instance.receive_stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + 123451, + "asdfghjkl", + [None, 123, 45.67, "string", {"key": "value"}, [1, 2, 3]], + {"key": "value", "another_key": 123.456, "yet_another_key": [1, 2, 3]}, + MockMessage(content="hello", num=42), + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get(self, valid_instances, test_obj): + instance, constructor_args, _, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + # Handle case where ScheduledRequestInfo is not pickleable + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + process_target = MockProcessTarget( + instance.create_worker_copy(0), num_messages=5 + ) + process = context.Process(target=process_target.run) + process.start() + + # Local startup and wait + await instance.start( + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo, + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5): + await instance.put(test_obj, timeout=2.0) + + for _ in range(5): + val = await instance.get(timeout=2.0) + if not isinstance(test_obj, tuple): + assert val == test_obj + else: + assert list(val) == list(test_obj) + finally: + # Clean up + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ( + GenerationResponse(request_id="id", request_args={}), + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get_iter(self, valid_instances, test_obj): + instance, constructor_args, _, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + # Handle case where ScheduledRequestInfo is not pickleable + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + process_target = MockProcessTarget( + instance.create_worker_copy(0), num_messages=5 + ) + process = context.Process(target=process_target.run) + process.start() + + def _received_callback(msg): + if not isinstance(test_obj, tuple): + assert msg == test_obj + else: + assert list(msg) == list(test_obj) + return "changed_obj" + + # Local startup and wait + await instance.start( + send_items=[test_obj for _ in range(5)], + receive_callback=_received_callback, + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo, + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5): + val = await instance.get(timeout=2.0) + assert val == "changed_obj" + finally: + # Clean up + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() + + +class TestInterProcessMessagingPipe: + """Test suite for InterProcessMessagingPipe.""" + + @pytest.fixture( + params=[ + { + "num_workers": 2, + "serialization": "dict", + "encoding": None, + "max_pending_size": None, + "max_done_size": None, + "worker_index": None, + }, + { + "num_workers": 1, + "serialization": "sequence", + "encoding": None, + "max_pending_size": 10, + "max_buffer_send_size": 2, + "max_done_size": 5, + "max_buffer_receive_size": 3, + "worker_index": None, + }, + { + "num_workers": 1, + "serialization": None, + "encoding": None, + "max_pending_size": None, + "max_done_size": None, + "worker_index": None, + }, + ], + ) + def valid_instances(self, multiprocessing_contexts, request): + """Fixture providing test data for InterProcessMessagingPipe.""" + constructor_args = request.param + manager, context = multiprocessing_contexts + instance = InterProcessMessagingPipe(**constructor_args, poll_interval=0.01) + return instance, constructor_args, manager, context + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InterProcessMessagingPipe inheritance and signatures.""" + assert issubclass(InterProcessMessagingPipe, InterProcessMessaging) + assert hasattr(InterProcessMessagingPipe, "__init__") + assert hasattr(InterProcessMessagingPipe, "create_worker_copy") + assert hasattr(InterProcessMessagingPipe, "_send_messages_task_thread") + assert hasattr(InterProcessMessagingPipe, "_receive_messages_task_thread") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InterProcessMessagingPipe initialization.""" + instance, constructor_args, _, _ = valid_instances + + assert isinstance(instance, InterProcessMessagingPipe) + assert instance.worker_index == constructor_args["worker_index"] + assert instance.max_pending_size == constructor_args["max_pending_size"] + assert instance.max_done_size == constructor_args["max_done_size"] + assert instance.num_workers == constructor_args["num_workers"] + assert hasattr(instance, "pipes") + assert len(instance.pipes) == constructor_args["num_workers"] + assert len(instance.pipes) == constructor_args["num_workers"] + assert instance.running is False + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("kwargs", "expected_error"), + [ + ({"invalid_param": "value"}, TypeError), + ({"num_workers": 1, "unknown_arg": "test"}, TypeError), + ], + ) + def test_invalid_initialization_values(self, kwargs, expected_error): + """Test InterProcessMessagingPipe with invalid field values.""" + with pytest.raises(expected_error): + InterProcessMessagingPipe(**kwargs) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test InterProcessMessagingPipe initialization without required field.""" + with pytest.raises(TypeError): + InterProcessMessagingPipe() + + @pytest.mark.smoke + def test_create_worker_copy(self, valid_instances): + """Test InterProcessMessagingPipe.create_worker_copy.""" + instance, _, _, _ = valid_instances + worker_index = 0 + + worker_copy = instance.create_worker_copy(worker_index) + + assert isinstance(worker_copy, InterProcessMessagingPipe) + assert worker_copy.worker_index == worker_index + assert worker_copy.pipes[0] is instance.pipes[worker_index] + assert worker_copy.max_pending_size == instance.max_pending_size + assert worker_copy.max_done_size == instance.max_done_size + assert worker_copy.num_workers == instance.num_workers + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(5.0) + async def test_start_stop_lifecycle(self, valid_instances): + """Test InterProcessMessagingPipe start/stop lifecycle.""" + instance, _, _, _ = valid_instances + stop_events = [] + + # Initially not running + assert instance.running is False + assert instance.send_stopped_event is None + assert instance.receive_stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + # Start should work + await instance.start( + send_stop_criteria=stop_events, receive_stop_criteria=stop_events + ) + assert instance.running is True + assert instance.send_stopped_event is not None + assert isinstance(instance.send_stopped_event, threading.Event) + assert instance.receive_stopped_event is not None + assert isinstance(instance.receive_stopped_event, threading.Event) + assert instance.shutdown_event is not None + assert isinstance(instance.shutdown_event, threading.Event) + assert instance.buffer_send_queue is not None + assert isinstance(instance.buffer_send_queue, culsans.Queue) + assert instance.buffer_receive_queue is not None + assert isinstance(instance.buffer_receive_queue, culsans.Queue) + assert instance.send_task is not None + assert isinstance(instance.send_task, asyncio.Task) + assert instance.receive_task is not None + assert isinstance(instance.receive_task, asyncio.Task) + + # Stop should work + await instance.stop() + assert instance.running is False + assert instance.send_stopped_event is None + assert instance.receive_stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + 123451, + "asdfghjkl", + [None, 123, 45.67, "string", {"key": "value"}, [1, 2, 3]], + {"key": "value", "another_key": 123.456, "yet_another_key": [1, 2, 3]}, + MockMessage(content="hello", num=42), + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ( + GenerationResponse(request_id="id", request_args={}), + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get(self, valid_instances, test_obj): + instance, constructor_args, manager, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + processes = [] + for index in range(constructor_args["num_workers"]): + process_target = MockProcessTarget( + instance.create_worker_copy(index), num_messages=5 + ) + process = context.Process(target=process_target.run) + processes.append(process) + process.start() + + # Local startup and wait + await instance.start( + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo, + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5 * constructor_args["num_workers"]): + await instance.put(test_obj, timeout=2.0) + + for _ in range(5 * constructor_args["num_workers"]): + val = await instance.get(timeout=2.0) + if not isinstance(test_obj, tuple): + assert val == test_obj + else: + assert list(val) == list(test_obj) + finally: + # Clean up + for process in processes: + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() diff --git a/tests/unit/utils/test_mixins.py b/tests/unit/utils/test_mixins.py new file mode 100644 index 00000000..cd8990de --- /dev/null +++ b/tests/unit/utils/test_mixins.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +import pytest + +from guidellm.utils.mixins import InfoMixin + + +class TestInfoMixin: + """Test suite for InfoMixin.""" + + @pytest.fixture( + params=[ + {"attr_one": "test_value", "attr_two": 42}, + {"attr_one": "hello_world", "attr_two": 100, "attr_three": [1, 2, 3]}, + ], + ids=["basic_attributes", "extended_attributes"], + ) + def valid_instances(self, request): + """Fixture providing test data for InfoMixin.""" + constructor_args = request.param + + class TestClass(InfoMixin): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + instance = TestClass(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InfoMixin class signatures and methods.""" + assert hasattr(InfoMixin, "extract_from_obj") + assert callable(InfoMixin.extract_from_obj) + assert hasattr(InfoMixin, "create_info_dict") + assert callable(InfoMixin.create_info_dict) + assert hasattr(InfoMixin, "info") + assert isinstance(InfoMixin.info, property) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InfoMixin initialization through inheritance.""" + instance, constructor_args = valid_instances + assert isinstance(instance, InfoMixin) + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.smoke + def test_info_property(self, valid_instances): + """Test InfoMixin.info property.""" + instance, constructor_args = valid_instances + result = instance.info + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert "class" in result + assert "module" in result + assert "attributes" in result + assert result["type"] == "TestClass" + assert result["class"] == "TestClass" + assert isinstance(result["attributes"], dict) + for key, value in constructor_args.items(): + assert key in result["attributes"] + assert result["attributes"][key] == value + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("obj_data", "expected_attributes"), + [ + ({"name": "test", "value": 42}, {"name": "test", "value": 42}), + ({"data": [1, 2, 3], "flag": True}, {"data": [1, 2, 3], "flag": True}), + ({"nested": {"key": "value"}}, {"nested": {"key": "value"}}), + ], + ) + def test_create_info_dict(self, obj_data, expected_attributes): + """Test InfoMixin.create_info_dict class method.""" + + class SimpleObject: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + obj = SimpleObject(**obj_data) + result = InfoMixin.create_info_dict(obj) + + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert "class" in result + assert "module" in result + assert "attributes" in result + assert result["type"] == "SimpleObject" + assert result["class"] == "SimpleObject" + assert result["attributes"] == expected_attributes + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("obj_data", "expected_attributes"), + [ + ({"name": "test", "value": 42}, {"name": "test", "value": 42}), + ({"data": [1, 2, 3], "flag": True}, {"data": [1, 2, 3], "flag": True}), + ], + ) + def test_extract_from_obj_without_info(self, obj_data, expected_attributes): + """Test InfoMixin.extract_from_obj with objects without info method.""" + + class SimpleObject: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + obj = SimpleObject(**obj_data) + result = InfoMixin.extract_from_obj(obj) + + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert "class" in result + assert "module" in result + assert "attributes" in result + assert result["type"] == "SimpleObject" + assert result["class"] == "SimpleObject" + assert result["attributes"] == expected_attributes + + @pytest.mark.smoke + def test_extract_from_obj_with_info_method(self): + """Test InfoMixin.extract_from_obj with objects that have info method.""" + + class ObjectWithInfoMethod: + def info(self): + return {"custom": "info_method", "type": "custom_type"} + + obj = ObjectWithInfoMethod() + result = InfoMixin.extract_from_obj(obj) + + assert result == {"custom": "info_method", "type": "custom_type"} + + @pytest.mark.smoke + def test_extract_from_obj_with_info_property(self): + """Test InfoMixin.extract_from_obj with objects that have info property.""" + + class ObjectWithInfoProperty: + @property + def info(self): + return {"custom": "info_property", "type": "custom_type"} + + obj = ObjectWithInfoProperty() + result = InfoMixin.extract_from_obj(obj) + + assert result == {"custom": "info_property", "type": "custom_type"} + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("obj_type", "obj_value"), + [ + (str, "test_string"), + (int, 42), + (float, 3.14), + (list, [1, 2, 3]), + (dict, {"key": "value"}), + ], + ) + def test_extract_from_obj_builtin_types(self, obj_type, obj_value): + """Test InfoMixin.extract_from_obj with built-in types.""" + result = InfoMixin.extract_from_obj(obj_value) + + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert result["type"] == obj_type.__name__ + assert result["str"] == str(obj_value) + + @pytest.mark.sanity + def test_extract_from_obj_without_dict(self): + """Test InfoMixin.extract_from_obj with objects without __dict__.""" + obj = 42 + result = InfoMixin.extract_from_obj(obj) + + assert isinstance(result, dict) + assert "attributes" in result + assert result["attributes"] == {} + assert result["type"] == "int" + assert result["str"] == "42" + + @pytest.mark.sanity + def test_extract_from_obj_with_private_attributes(self): + """Test InfoMixin.extract_from_obj filters private attributes.""" + + class ObjectWithPrivate: + def __init__(self): + self.public_attr = "public" + self._private_attr = "private" + self.__very_private = "very_private" + + obj = ObjectWithPrivate() + result = InfoMixin.extract_from_obj(obj) + + assert "public_attr" in result["attributes"] + assert result["attributes"]["public_attr"] == "public" + assert "_private_attr" not in result["attributes"] + assert "__very_private" not in result["attributes"] + + @pytest.mark.sanity + def test_extract_from_obj_complex_attributes(self): + """Test InfoMixin.extract_from_obj with complex attribute types.""" + + class ComplexObject: + def __init__(self): + self.simple_str = "test" + self.simple_int = 42 + self.simple_list = [1, 2, 3] + self.simple_dict = {"key": "value"} + self.complex_object = object() + + obj = ComplexObject() + result = InfoMixin.extract_from_obj(obj) + + attributes = result["attributes"] + assert attributes["simple_str"] == "test" + assert attributes["simple_int"] == 42 + assert attributes["simple_list"] == [1, 2, 3] + assert attributes["simple_dict"] == {"key": "value"} + assert isinstance(attributes["complex_object"], str) + + @pytest.mark.regression + def test_create_info_dict_consistency(self, valid_instances): + """Test InfoMixin.create_info_dict produces consistent results.""" + instance, _ = valid_instances + + result1 = InfoMixin.create_info_dict(instance) + result2 = InfoMixin.create_info_dict(instance) + + assert result1 == result2 + assert result1 is not result2 + + @pytest.mark.regression + def test_info_property_uses_create_info_dict(self, valid_instances): + """Test InfoMixin.info property uses create_info_dict method.""" + instance, _ = valid_instances + + info_result = instance.info + create_result = InfoMixin.create_info_dict(instance) + + assert info_result == create_result diff --git a/tests/unit/utils/test_pydantic_utils.py b/tests/unit/utils/test_pydantic_utils.py index 8f8d1eeb..726b5ddf 100644 --- a/tests/unit/utils/test_pydantic_utils.py +++ b/tests/unit/utils/test_pydantic_utils.py @@ -1,245 +1,1002 @@ """ -Unit tests for the pydantic_utils module in the Speculators library. +Unit tests for the pydantic_utils module. """ -from typing import ClassVar +from __future__ import annotations + +from typing import ClassVar, TypeVar from unittest import mock import pytest -from pydantic import BaseModel - -from guidellm.utils import PydanticClassRegistryMixin, ReloadableBaseModel - -# ===== ReloadableBaseModel Tests ===== - - -@pytest.mark.smoke -def test_reloadable_base_model_initialization(): - class TestModel(ReloadableBaseModel): - name: str - - model = TestModel(name="test") - assert model.name == "test" +from pydantic import BaseModel, Field, ValidationError + +from guidellm.utils import ( + PydanticClassRegistryMixin, + ReloadableBaseModel, + StandardBaseDict, + StandardBaseModel, + StatusBreakdown, +) +from guidellm.utils.pydantic_utils import ( + BaseModelT, + ErroredT, + IncompleteT, + RegisterClassT, + SuccessfulT, + TotalT, +) @pytest.mark.smoke -def test_reloadable_base_model_reload_schema(): - class TestModel(ReloadableBaseModel): - name: str - - model = TestModel(name="test") - assert model.name == "test" - - # Mock the model_rebuild method to simulate schema reload - with mock.patch.object(TestModel, "model_rebuild") as mock_rebuild: - TestModel.reload_schema() - mock_rebuild.assert_called_once() - - -# ===== PydanticClassRegistryMixin Tests ===== +def test_base_model_t(): + """Test that BaseModelT is configured correctly as a TypeVar.""" + assert isinstance(BaseModelT, type(TypeVar("test"))) + assert BaseModelT.__name__ == "BaseModelT" + assert BaseModelT.__bound__ is BaseModel + assert BaseModelT.__constraints__ == () @pytest.mark.smoke -def test_pydantic_class_registry_subclass_init(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - return cls - - assert TestBaseModel.registry is None - assert TestBaseModel.schema_discriminator == "test_type" +def test_register_class_t(): + """Test that RegisterClassT is configured correctly as a TypeVar.""" + assert isinstance(RegisterClassT, type(TypeVar("test"))) + assert RegisterClassT.__name__ == "RegisterClassT" + assert RegisterClassT.__bound__ is None + assert RegisterClassT.__constraints__ == () @pytest.mark.smoke -def test_pydantic_class_registry_subclass_missing_base_type(): - class InvalidBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - with pytest.raises(TypeError): - InvalidBaseModel(test_type="test") # type: ignore[abstract] - - -@pytest.mark.sanity -def test_pydantic_class_registry_decorator(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - @TestBaseModel.register() - class TestSubModel(TestBaseModel): - test_type: str = "TestSubModel" - value: str - - assert TestBaseModel.registry is not None - assert "TestSubModel" in TestBaseModel.registry - assert TestBaseModel.registry["TestSubModel"] is TestSubModel - - -@pytest.mark.sanity -def test_pydantic_class_registry_decorator_with_name(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - @TestBaseModel.register("custom_name") - class TestSubModel(TestBaseModel): - test_type: str = "custom_name" - value: str - - assert TestBaseModel.registry is not None - assert "custom_name" in TestBaseModel.registry - assert TestBaseModel.registry["custom_name"] is TestSubModel +def test_successful_t(): + """Test that SuccessfulT is configured correctly as a TypeVar.""" + assert isinstance(SuccessfulT, type(TypeVar("test"))) + assert SuccessfulT.__name__ == "SuccessfulT" + assert SuccessfulT.__bound__ is None + assert SuccessfulT.__constraints__ == () @pytest.mark.smoke -def test_pydantic_class_registry_decorator_invalid_type(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - class RegularClass: - pass - - with pytest.raises(TypeError) as exc_info: - TestBaseModel.register_decorator(RegularClass) # type: ignore[arg-type] - - assert "not a subclass of Pydantic BaseModel" in str(exc_info.value) +def test_errored_t(): + """Test that ErroredT is configured correctly as a TypeVar.""" + assert isinstance(ErroredT, type(TypeVar("test"))) + assert ErroredT.__name__ == "ErroredT" + assert ErroredT.__bound__ is None + assert ErroredT.__constraints__ == () @pytest.mark.smoke -def test_pydantic_class_registry_subclass_marshalling(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - @TestBaseModel.register("test_sub") - class TestSubModel(TestBaseModel): - test_type: str = "test_sub" - value: str - - TestBaseModel.reload_schema() - - # Test direct construction of subclass - sub_instance = TestSubModel(value="test_value") - assert isinstance(sub_instance, TestSubModel) - assert sub_instance.test_type == "test_sub" - assert sub_instance.value == "test_value" - - # Test serialization with model_dump - dump_data = sub_instance.model_dump() - assert isinstance(dump_data, dict) - assert dump_data["test_type"] == "test_sub" - assert dump_data["value"] == "test_value" - - # Test deserialization via model_validate - recreated = TestSubModel.model_validate(dump_data) - assert isinstance(recreated, TestSubModel) - assert recreated.test_type == "test_sub" - assert recreated.value == "test_value" - - # Test polymorphic deserialization via base class - recreated = TestBaseModel.model_validate(dump_data) # type: ignore[assignment] - assert isinstance(recreated, TestSubModel) - assert recreated.test_type == "test_sub" - assert recreated.value == "test_value" +def test_incomplete_t(): + """Test that IncompleteT is configured correctly as a TypeVar.""" + assert isinstance(IncompleteT, type(TypeVar("test"))) + assert IncompleteT.__name__ == "IncompleteT" + assert IncompleteT.__bound__ is None + assert IncompleteT.__constraints__ == () @pytest.mark.smoke -def test_pydantic_class_registry_parent_class_marshalling(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - @classmethod - def __pydantic_generate_base_schema__(cls, handler): - return handler(cls) - - @TestBaseModel.register("sub_a") - class TestSubModelA(TestBaseModel): - test_type: str = "sub_a" - value_a: str - - @TestBaseModel.register("sub_b") - class TestSubModelB(TestBaseModel): - test_type: str = "sub_b" - value_b: int - - class ContainerModel(BaseModel): - name: str - model: TestBaseModel - models: list[TestBaseModel] - - sub_a = TestSubModelA(value_a="test") - sub_b = TestSubModelB(value_b=123) - - container = ContainerModel(name="container", model=sub_a, models=[sub_a, sub_b]) - assert isinstance(container.model, TestSubModelA) - assert container.model.test_type == "sub_a" - assert container.model.value_a == "test" - assert isinstance(container.models[0], TestSubModelA) - assert isinstance(container.models[1], TestSubModelB) - assert container.models[0].test_type == "sub_a" - assert container.models[1].test_type == "sub_b" - assert container.models[0].value_a == "test" - assert container.models[1].value_b == 123 - - # Test serialization with model_dump - dump_data = container.model_dump() - assert isinstance(dump_data, dict) - assert dump_data["name"] == "container" - assert dump_data["model"]["test_type"] == "sub_a" - assert dump_data["model"]["value_a"] == "test" - assert len(dump_data["models"]) == 2 - assert dump_data["models"][0]["test_type"] == "sub_a" - assert dump_data["models"][0]["value_a"] == "test" - assert dump_data["models"][1]["test_type"] == "sub_b" - assert dump_data["models"][1]["value_b"] == 123 - - # Test deserialization via model_validate - recreated = ContainerModel.model_validate(dump_data) - assert isinstance(recreated, ContainerModel) - assert recreated.name == "container" - assert isinstance(recreated.model, TestSubModelA) - assert recreated.model.test_type == "sub_a" - assert recreated.model.value_a == "test" - assert len(recreated.models) == 2 - assert isinstance(recreated.models[0], TestSubModelA) - assert isinstance(recreated.models[1], TestSubModelB) - assert recreated.models[0].test_type == "sub_a" - assert recreated.models[1].test_type == "sub_b" - assert recreated.models[0].value_a == "test" - assert recreated.models[1].value_b == 123 +def test_total_t(): + """Test that TotalT is configured correctly as a TypeVar.""" + assert isinstance(TotalT, type(TypeVar("test"))) + assert TotalT.__name__ == "TotalT" + assert TotalT.__bound__ is None + assert TotalT.__constraints__ == () + + +class TestReloadableBaseModel: + """Test suite for ReloadableBaseModel.""" + + @pytest.fixture( + params=[ + {"name": "test_value"}, + {"name": "hello_world"}, + {"name": "another_test"}, + ], + ids=["basic_string", "multi_word", "underscore"], + ) + def valid_instances(self, request) -> tuple[ReloadableBaseModel, dict[str, str]]: + """Fixture providing test data for ReloadableBaseModel.""" + + class TestModel(ReloadableBaseModel): + name: str + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ReloadableBaseModel inheritance and class variables.""" + assert issubclass(ReloadableBaseModel, BaseModel) + assert hasattr(ReloadableBaseModel, "model_config") + assert hasattr(ReloadableBaseModel, "reload_schema") + + # Check model configuration + config = ReloadableBaseModel.model_config + assert config["extra"] == "ignore" + assert config["use_enum_values"] is True + assert config["from_attributes"] is True + assert config["arbitrary_types_allowed"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test ReloadableBaseModel initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ReloadableBaseModel) + assert instance.name == constructor_args["name"] # type: ignore[attr-defined] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("name", None), + ("name", 123), + ("name", []), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test ReloadableBaseModel with invalid field values.""" + + class TestModel(ReloadableBaseModel): + name: str + + data = {field: value} + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test ReloadableBaseModel initialization without required field.""" + + class TestModel(ReloadableBaseModel): + name: str + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_reload_schema(self): + """Test ReloadableBaseModel.reload_schema method.""" + + class TestModel(ReloadableBaseModel): + name: str + + # Mock the model_rebuild method to simulate schema reload + with mock.patch.object(TestModel, "model_rebuild") as mock_rebuild: + TestModel.reload_schema() + mock_rebuild.assert_called_once_with(force=True) + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test ReloadableBaseModel serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["name"] == constructor_args["name"] + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.name == constructor_args["name"] + + +class TestStandardBaseModel: + """Test suite for StandardBaseModel.""" + + @pytest.fixture( + params=[ + {"field_str": "test_value", "field_int": 42}, + {"field_str": "hello_world", "field_int": 100}, + {"field_str": "another_test", "field_int": 0}, + ], + ids=["basic_values", "positive_values", "zero_value"], + ) + def valid_instances( + self, request + ) -> tuple[StandardBaseModel, dict[str, int | str]]: + """Fixture providing test data for StandardBaseModel.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StandardBaseModel inheritance and class variables.""" + assert issubclass(StandardBaseModel, BaseModel) + assert hasattr(StandardBaseModel, "model_config") + assert hasattr(StandardBaseModel, "get_default") + + # Check model configuration + config = StandardBaseModel.model_config + assert config["extra"] == "ignore" + assert config["use_enum_values"] is True + assert config["from_attributes"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StandardBaseModel initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StandardBaseModel) + assert instance.field_str == constructor_args["field_str"] # type: ignore[attr-defined] + assert instance.field_int == constructor_args["field_int"] # type: ignore[attr-defined] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("field_str", None), + ("field_str", 123), + ("field_int", "not_int"), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test StandardBaseModel with invalid field values.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + data = {field: value} + if field == "field_str": + data["field_int"] = 42 + else: + data["field_str"] = "test" + + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test StandardBaseModel initialization without required field.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_get_default(self): + """Test StandardBaseModel.get_default method.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=42, description="Test integer field") + + default_value = TestModel.get_default("field_int") + assert default_value == 42 + + @pytest.mark.sanity + def test_get_default_invalid(self): + """Test StandardBaseModel.get_default with invalid field.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + + with pytest.raises(KeyError): + TestModel.get_default("nonexistent_field") + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StandardBaseModel serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["field_str"] == constructor_args["field_str"] + assert data_dict["field_int"] == constructor_args["field_int"] + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.field_str == constructor_args["field_str"] + assert recreated.field_int == constructor_args["field_int"] + + +class TestStandardBaseDict: + """Test suite for StandardBaseDict.""" + + @pytest.fixture( + params=[ + {"field_str": "test_value", "extra_field": "extra_value"}, + {"field_str": "hello_world", "another_extra": 123}, + {"field_str": "another_test", "complex_extra": {"nested": "value"}}, + ], + ids=["string_extra", "int_extra", "dict_extra"], + ) + def valid_instances( + self, request + ) -> tuple[StandardBaseDict, dict[str, str | int | dict[str, str]]]: + """Fixture providing test data for StandardBaseDict.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StandardBaseDict inheritance and class variables.""" + assert issubclass(StandardBaseDict, StandardBaseModel) + assert hasattr(StandardBaseDict, "model_config") + + # Check model configuration + config = StandardBaseDict.model_config + assert config["extra"] == "allow" + assert config["use_enum_values"] is True + assert config["from_attributes"] is True + assert config["arbitrary_types_allowed"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StandardBaseDict initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StandardBaseDict) + assert instance.field_str == constructor_args["field_str"] # type: ignore[attr-defined] + + # Check extra fields are preserved + for key, value in constructor_args.items(): + if key != "field_str": + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("field_str", None), + ("field_str", 123), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test StandardBaseDict with invalid field values.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + data = {field: value} + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test StandardBaseDict initialization without required field.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StandardBaseDict serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["field_str"] == constructor_args["field_str"] + + # Check extra fields are in the serialized data + for key, value in constructor_args.items(): + if key != "field_str": + assert key in data_dict + assert data_dict[key] == value + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.field_str == constructor_args["field_str"] + + # Check extra fields are preserved after deserialization + for key, value in constructor_args.items(): + if key != "field_str": + assert hasattr(recreated, key) + assert getattr(recreated, key) == value + + +class TestStatusBreakdown: + """Test suite for StatusBreakdown.""" + + @pytest.fixture( + params=[ + {"successful": 100, "errored": 5, "incomplete": 10, "total": 115}, + { + "successful": "success_data", + "errored": "error_data", + "incomplete": "incomplete_data", + "total": "total_data", + }, + { + "successful": [1, 2, 3], + "errored": [4, 5], + "incomplete": [6], + "total": [1, 2, 3, 4, 5, 6], + }, + ], + ids=["int_values", "string_values", "list_values"], + ) + def valid_instances(self, request) -> tuple[StatusBreakdown, dict]: + """Fixture providing test data for StatusBreakdown.""" + constructor_args = request.param + instance = StatusBreakdown(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StatusBreakdown inheritance and type relationships.""" + assert issubclass(StatusBreakdown, BaseModel) + # Check if Generic is in the MRO (method resolution order) + assert any(cls.__name__ == "Generic" for cls in StatusBreakdown.__mro__) + assert "successful" in StatusBreakdown.model_fields + assert "errored" in StatusBreakdown.model_fields + assert "incomplete" in StatusBreakdown.model_fields + assert "total" in StatusBreakdown.model_fields + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StatusBreakdown initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StatusBreakdown) + assert instance.successful == constructor_args["successful"] + assert instance.errored == constructor_args["errored"] + assert instance.incomplete == constructor_args["incomplete"] + assert instance.total == constructor_args["total"] + + @pytest.mark.smoke + def test_initialization_defaults(self): + """Test StatusBreakdown initialization with default values.""" + instance: StatusBreakdown = StatusBreakdown() + assert instance.successful is None + assert instance.errored is None + assert instance.incomplete is None + assert instance.total is None + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StatusBreakdown serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["successful"] == constructor_args["successful"] + assert data_dict["errored"] == constructor_args["errored"] + assert data_dict["incomplete"] == constructor_args["incomplete"] + assert data_dict["total"] == constructor_args["total"] + + recreated: StatusBreakdown = StatusBreakdown.model_validate(data_dict) + assert isinstance(recreated, StatusBreakdown) + assert recreated.successful == constructor_args["successful"] + assert recreated.errored == constructor_args["errored"] + assert recreated.incomplete == constructor_args["incomplete"] + assert recreated.total == constructor_args["total"] + + +class TestPydanticClassRegistryMixin: + """Test suite for PydanticClassRegistryMixin.""" + + @pytest.fixture( + params=[ + {"test_type": "test_sub", "value": "test_value"}, + {"test_type": "test_sub", "value": "hello_world"}, + ], + ids=["basic_value", "multi_word"], + ) + def valid_instances( + self, request + ) -> tuple[PydanticClassRegistryMixin, dict, type, type]: + """Fixture providing test data for PydanticClassRegistryMixin.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + TestBaseModel.reload_schema() + + constructor_args = request.param + instance = TestSubModel(value=constructor_args["value"]) + return instance, constructor_args, TestBaseModel, TestSubModel + + @pytest.mark.smoke + def test_class_signatures(self): + """Test PydanticClassRegistryMixin inheritance and class variables.""" + assert issubclass(PydanticClassRegistryMixin, ReloadableBaseModel) + assert hasattr(PydanticClassRegistryMixin, "schema_discriminator") + assert PydanticClassRegistryMixin.schema_discriminator == "model_type" + assert hasattr(PydanticClassRegistryMixin, "register_decorator") + assert hasattr(PydanticClassRegistryMixin, "__get_pydantic_core_schema__") + assert hasattr(PydanticClassRegistryMixin, "__pydantic_generate_base_schema__") + assert hasattr(PydanticClassRegistryMixin, "auto_populate_registry") + assert hasattr(PydanticClassRegistryMixin, "registered_classes") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test PydanticClassRegistryMixin initialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + assert isinstance(instance, sub_class) + assert isinstance(instance, base_class) + assert instance.test_type == constructor_args["test_type"] + assert instance.value == constructor_args["value"] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("test_type", None), + ("test_type", 123), + ("value", None), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test PydanticClassRegistryMixin with invalid field values.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + data = {field: value} + if field == "test_type": + data["value"] = "test" + else: + data["test_type"] = "test_sub" + + with pytest.raises(ValidationError): + TestSubModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test PydanticClassRegistryMixin initialization without required field.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + with pytest.raises(ValidationError): + TestSubModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_register_decorator(self): + """Test PydanticClassRegistryMixin.register_decorator method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register() + class TestSubModel(TestBaseModel): + test_type: str = "TestSubModel" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "TestSubModel" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["TestSubModel"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_with_name(self): + """Test PydanticClassRegistryMixin.register_decorator with custom name.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("custom_name") + class TestSubModel(TestBaseModel): + test_type: str = "custom_name" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "custom_name" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["custom_name"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_invalid_type(self): + """Test PydanticClassRegistryMixin.register_decorator with invalid type.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + class RegularClass: + pass + + with pytest.raises(TypeError) as exc_info: + TestBaseModel.register_decorator(RegularClass) # type: ignore[arg-type] + + assert "not a subclass of Pydantic BaseModel" in str(exc_info.value) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test PydanticClassRegistryMixin.auto_populate_registry method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = True + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + with ( + mock.patch.object(TestBaseModel, "reload_schema") as mock_reload, + mock.patch( + "guidellm.utils.registry.RegistryMixin.auto_populate_registry", + return_value=True, + ), + ): + result = TestBaseModel.auto_populate_registry() + assert result is True + mock_reload.assert_called_once() + + @pytest.mark.smoke + def test_registered_classes(self): + """Test PydanticClassRegistryMixin.registered_classes method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = False + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub_a") + class TestSubModelA(TestBaseModel): + test_type: str = "test_sub_a" + value_a: str + + @TestBaseModel.register("test_sub_b") + class TestSubModelB(TestBaseModel): + test_type: str = "test_sub_b" + value_b: int + + # Test normal case with registered classes + registered = TestBaseModel.registered_classes() + assert isinstance(registered, tuple) + assert len(registered) == 2 + assert TestSubModelA in registered + assert TestSubModelB in registered + + @pytest.mark.sanity + def test_registered_classes_with_auto_discovery(self): + """Test PydanticClassRegistryMixin.registered_classes with auto discovery.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = True + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + with mock.patch.object( + TestBaseModel, "auto_populate_registry" + ) as mock_auto_populate: + # Mock the registry to simulate registered classes + TestBaseModel.registry = {"test_class": type("TestClass", (), {})} + mock_auto_populate.return_value = False + + registered = TestBaseModel.registered_classes() + mock_auto_populate.assert_called_once() + assert isinstance(registered, tuple) + assert len(registered) == 1 + + @pytest.mark.sanity + def test_registered_classes_no_registry(self): + """Test PydanticClassRegistryMixin.registered_classes with no registry.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + # Ensure registry is None + TestBaseModel.registry = None + + with pytest.raises(ValueError) as exc_info: + TestBaseModel.registered_classes() + + assert "must be called after registering classes" in str(exc_info.value) + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test PydanticClassRegistryMixin serialization and deserialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + + # Test serialization with model_dump + dump_data = instance.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["test_type"] == constructor_args["test_type"] + assert dump_data["value"] == constructor_args["value"] + + # Test deserialization via subclass + recreated = sub_class.model_validate(dump_data) + assert isinstance(recreated, sub_class) + assert recreated.test_type == constructor_args["test_type"] + assert recreated.value == constructor_args["value"] + + # Test polymorphic deserialization via base class + recreated_base = base_class.model_validate(dump_data) # type: ignore[assignment] + assert isinstance(recreated_base, sub_class) + assert recreated_base.test_type == constructor_args["test_type"] + assert recreated_base.value == constructor_args["value"] + + @pytest.mark.regression + def test_polymorphic_container_marshalling(self): + """Test PydanticClassRegistryMixin in container models.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @classmethod + def __pydantic_generate_base_schema__(cls, handler): + return handler(cls) + + @TestBaseModel.register("sub_a") + class TestSubModelA(TestBaseModel): + test_type: str = "sub_a" + value_a: str + + @TestBaseModel.register("sub_b") + class TestSubModelB(TestBaseModel): + test_type: str = "sub_b" + value_b: int + + class ContainerModel(BaseModel): + name: str + model: TestBaseModel + models: list[TestBaseModel] + + sub_a = TestSubModelA(value_a="test") + sub_b = TestSubModelB(value_b=123) + + container = ContainerModel(name="container", model=sub_a, models=[sub_a, sub_b]) + + # Verify container construction + assert isinstance(container.model, TestSubModelA) + assert container.model.test_type == "sub_a" + assert container.model.value_a == "test" + assert len(container.models) == 2 + assert isinstance(container.models[0], TestSubModelA) + assert isinstance(container.models[1], TestSubModelB) + + # Test serialization + dump_data = container.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["name"] == "container" + assert dump_data["model"]["test_type"] == "sub_a" + assert dump_data["model"]["value_a"] == "test" + assert len(dump_data["models"]) == 2 + assert dump_data["models"][0]["test_type"] == "sub_a" + assert dump_data["models"][1]["test_type"] == "sub_b" + + # Test deserialization + recreated = ContainerModel.model_validate(dump_data) + assert isinstance(recreated, ContainerModel) + assert recreated.name == "container" + assert isinstance(recreated.model, TestSubModelA) + assert len(recreated.models) == 2 + assert isinstance(recreated.models[0], TestSubModelA) + assert isinstance(recreated.models[1], TestSubModelB) + + @pytest.mark.smoke + def test_register_preserves_pydantic_metadata(self): # noqa: C901 + """Test that registered Pydantic classes retain docs, types, and methods.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "model_type" + model_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + + return TestBaseModel + + @TestBaseModel.register("documented_model") + class DocumentedModel(TestBaseModel): + """This is a documented Pydantic model with methods and type hints.""" + + model_type: str = "documented_model" + value: int = Field(description="An integer value for the model") + + def get_value(self) -> int: + """Get the stored value. + + :return: The stored integer value + """ + return self.value + + def set_value(self, new_value: int) -> None: + """Set a new value. + + :param new_value: The new integer value to set + """ + self.value = new_value + + @classmethod + def from_string(cls, value_str: str) -> DocumentedModel: + """Create instance from string. + + :param value_str: String representation of value + :return: New DocumentedModel instance + """ + return cls(value=int(value_str)) + + @staticmethod + def validate_value(value: int) -> bool: + """Validate that a value is positive. + + :param value: Value to validate + :return: True if positive, False otherwise + """ + return value > 0 + + def model_post_init(self, __context) -> None: + """Post-initialization processing. + + :param __context: Validation context + """ + if self.value < 0: + raise ValueError("Value must be non-negative") + + # Check that the class was registered + assert TestBaseModel.is_registered("documented_model") + registered_class = TestBaseModel.get_registered_object("documented_model") + assert registered_class is DocumentedModel + + # Check that the class retains its documentation + assert registered_class.__doc__ is not None + assert "documented Pydantic model with methods" in registered_class.__doc__ + + # Check that methods retain their documentation + assert registered_class.get_value.__doc__ is not None + assert "Get the stored value" in registered_class.get_value.__doc__ + assert registered_class.set_value.__doc__ is not None + assert "Set a new value" in registered_class.set_value.__doc__ + assert registered_class.from_string.__doc__ is not None + assert "Create instance from string" in registered_class.from_string.__doc__ + assert registered_class.validate_value.__doc__ is not None + assert ( + "Validate that a value is positive" + in registered_class.validate_value.__doc__ + ) + assert registered_class.model_post_init.__doc__ is not None + assert ( + "Post-initialization processing" in registered_class.model_post_init.__doc__ + ) + + # Check that methods are callable and work correctly + instance = DocumentedModel(value=42) + assert isinstance(instance, DocumentedModel) + assert instance.get_value() == 42 + instance.set_value(100) + assert instance.get_value() == 100 + assert instance.model_type == "documented_model" + + # Check class methods work + instance2 = DocumentedModel.from_string("123") + assert instance2.get_value() == 123 + assert instance2.model_type == "documented_model" + + # Check static methods work + assert DocumentedModel.validate_value(10) is True + assert DocumentedModel.validate_value(-5) is False + + # Check that Pydantic functionality is preserved + data_dict = instance.model_dump() + assert data_dict["value"] == 100 + assert data_dict["model_type"] == "documented_model" + + recreated = DocumentedModel.model_validate(data_dict) + assert isinstance(recreated, DocumentedModel) + assert recreated.value == 100 + assert recreated.model_type == "documented_model" + + # Test field validation + with pytest.raises(ValidationError): + DocumentedModel(value="not_an_int") + + # Test post_init validation + with pytest.raises(ValueError, match="Value must be non-negative"): + DocumentedModel(value=-10) + + # Check that Pydantic field metadata is preserved + value_field = DocumentedModel.model_fields["value"] + assert value_field.description == "An integer value for the model" + + # Check that type annotations are preserved (if accessible) + import inspect + + if hasattr(inspect, "get_annotations"): + # Python 3.10+ + try: + annotations = inspect.get_annotations(DocumentedModel.get_value) + return_ann = annotations.get("return") + assert return_ann is int or return_ann == "int" + except (AttributeError, NameError): + # Fallback for older Python or missing annotations + pass + + # Check that the class name is preserved + assert DocumentedModel.__name__ == "DocumentedModel" + assert DocumentedModel.__qualname__.endswith("DocumentedModel") + + # Verify that the class is still properly integrated with the registry system + all_registered = TestBaseModel.registered_classes() + assert DocumentedModel in all_registered + + # Test that the registered class is the same as the original + assert registered_class is DocumentedModel diff --git a/tests/unit/utils/test_registry.py b/tests/unit/utils/test_registry.py index d4c337d1..eed126d3 100644 --- a/tests/unit/utils/test_registry.py +++ b/tests/unit/utils/test_registry.py @@ -2,212 +2,303 @@ Unit tests for the registry module. """ +from __future__ import annotations + +import inspect +from typing import TypeVar from unittest import mock import pytest -from guidellm.utils.registry import RegistryMixin +from guidellm.utils import RegistryMixin +from guidellm.utils.registry import RegisterT, RegistryObjT + + +def test_registry_obj_type(): + """Test that RegistryObjT is configured correctly as a TypeVar.""" + assert isinstance(RegistryObjT, type(TypeVar("test"))) + assert RegistryObjT.__name__ == "RegistryObjT" + assert RegistryObjT.__bound__ is None + assert RegistryObjT.__constraints__ == () + +def test_registered_type(): + """Test that RegisterT is configured correctly as a TypeVar.""" + assert isinstance(RegisterT, type(TypeVar("test"))) + assert RegisterT.__name__ == "RegisterT" + assert RegisterT.__bound__ is None + assert RegisterT.__constraints__ == () -class TestBasicRegistration: - """Test suite for basic registry functionality.""" + +class TestRegistryMixin: + """Test suite for RegistryMixin class.""" + + @pytest.fixture( + params=[ + {"registry_auto_discovery": False, "auto_package": None}, + {"registry_auto_discovery": True, "auto_package": "test.package"}, + ], + ids=["manual_registry", "auto_discovery"], + ) + def valid_instances(self, request): + """Fixture providing test data for RegistryMixin subclasses.""" + config = request.param + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = config["registry_auto_discovery"] + if config["auto_package"]: + auto_package = config["auto_package"] + + return TestRegistryClass, config @pytest.mark.smoke - def test_registry_initialization(self): - """Test that RegistryMixin initializes with correct defaults.""" + def test_class_signatures(self): + """Test RegistryMixin inheritance and exposed methods.""" + assert hasattr(RegistryMixin, "registry") + assert hasattr(RegistryMixin, "registry_auto_discovery") + assert hasattr(RegistryMixin, "registry_populated") + assert hasattr(RegistryMixin, "register") + assert hasattr(RegistryMixin, "register_decorator") + assert hasattr(RegistryMixin, "auto_populate_registry") + assert hasattr(RegistryMixin, "registered_objects") + assert hasattr(RegistryMixin, "is_registered") + assert hasattr(RegistryMixin, "get_registered_object") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test RegistryMixin initialization.""" + registry_class, config = valid_instances + + assert registry_class.registry is None + assert ( + registry_class.registry_auto_discovery == config["registry_auto_discovery"] + ) + assert registry_class.registry_populated is False + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test RegistryMixin with missing auto_package when auto_discovery enabled.""" class TestRegistryClass(RegistryMixin): - pass + registry_auto_discovery = True - assert TestRegistryClass.registry is None - assert TestRegistryClass.registry_auto_discovery is False - assert TestRegistryClass.registry_populated is False + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestRegistryClass.auto_import_package_modules() @pytest.mark.smoke @pytest.mark.parametrize( - ("register_name", "expected_key"), + ("name", "expected_key"), [ ("custom_name", "custom_name"), - ("CamelCase", "camelcase"), - ("UPPERCASE", "uppercase"), - ("snake_case", "snake_case"), + (["name1", "name2"], ["name1", "name2"]), + (None, "TestClass"), ], ) - def test_register_with_name(self, register_name, expected_key): - """Test registering objects with explicit names.""" - - class TestRegistryClass(RegistryMixin): - pass + def test_register(self, valid_instances, name, expected_key): + """Test register method with various name configurations.""" + registry_class, _ = valid_instances - @TestRegistryClass.register(register_name) + @registry_class.register(name) class TestClass: pass - assert TestRegistryClass.registry is not None - assert expected_key in TestRegistryClass.registry - assert TestRegistryClass.registry[expected_key] is TestClass + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass - @pytest.mark.smoke - def test_register_without_name(self): - """Test registering objects without explicit names.""" + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_invalid(self, valid_instances, invalid_name): + """Test register method with invalid name types.""" + registry_class, _ = valid_instances - class TestRegistryClass(RegistryMixin): - pass + # The register method returns a decorator, so we need to apply it to test + # validation + decorator = registry_class.register(invalid_name) - @TestRegistryClass.register() class TestClass: pass - assert TestRegistryClass.registry is not None - assert "testclass" in TestRegistryClass.registry - assert TestRegistryClass.registry["testclass"] is TestClass + with pytest.raises( + ValueError, match="name must be a string or an iterable of strings" + ): + decorator(TestClass) @pytest.mark.smoke - def test_register_decorator_direct(self): - """Test direct usage of register_decorator.""" - - class TestRegistryClass(RegistryMixin): - pass + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, "TestClass"), + ], + ) + def test_register_decorator(self, valid_instances, name, expected_key): + """Test register_decorator method with various name configurations.""" + registry_class, _ = valid_instances - @TestRegistryClass.register_decorator class TestClass: pass - assert TestRegistryClass.registry is not None - assert "testclass" in TestRegistryClass.registry - assert TestRegistryClass.registry["testclass"] is TestClass + registry_class.register_decorator(TestClass, name=name) - @pytest.mark.smoke - def test_register_multiple_names(self): - """Test registering an object with multiple names.""" + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass - class TestRegistryClass(RegistryMixin): - pass + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_decorator_invalid(self, valid_instances, invalid_name): + """Test register_decorator with invalid name types.""" + registry_class, _ = valid_instances - @TestRegistryClass.register(["name1", "name2", "Name3"]) class TestClass: pass - assert TestRegistryClass.registry is not None - assert "name1" in TestRegistryClass.registry - assert "name2" in TestRegistryClass.registry - assert "name3" in TestRegistryClass.registry - assert all( - TestRegistryClass.registry[key] is TestClass - for key in ["name1", "name2", "name3"] - ) + with pytest.raises( + ValueError, match="name must be a string or an iterable of strings" + ): + registry_class.register_decorator(TestClass, name=invalid_name) @pytest.mark.smoke - def test_registered_objects(self): - """Test retrieving all registered objects.""" + def test_auto_populate_registry(self): + """Test auto_populate_registry method with valid configuration.""" - class TestRegistryClass(RegistryMixin): - pass + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test.package" - @TestRegistryClass.register() - class TestClass1: - pass + with mock.patch.object( + TestAutoRegistry, "auto_import_package_modules" + ) as mock_import: + result = TestAutoRegistry.auto_populate_registry() + assert result is True + mock_import.assert_called_once() + assert TestAutoRegistry.registry_populated is True - @TestRegistryClass.register("custom_name") - class TestClass2: - pass + # Second call should return False + result = TestAutoRegistry.auto_populate_registry() + assert result is False + mock_import.assert_called_once() - registered = TestRegistryClass.registered_objects() - assert isinstance(registered, tuple) - assert len(registered) == 2 - assert TestClass1 in registered - assert TestClass2 in registered + @pytest.mark.sanity + def test_auto_populate_registry_invalid(self): + """Test auto_populate_registry when auto-discovery is disabled.""" + class TestDisabledRegistry(RegistryMixin): + registry_auto_discovery = False -class TestRegistrationValidation: - """Test suite for registration validation and error handling.""" + with pytest.raises(ValueError, match="registry_auto_discovery is set to False"): + TestDisabledRegistry.auto_populate_registry() - @pytest.mark.sanity - @pytest.mark.parametrize( - "invalid_name", [123, 42.5, True, {"key": "value"}, object()] - ) - def test_register_invalid_name_type(self, invalid_name): - """Test that invalid name types raise ValueError.""" + @pytest.mark.smoke + def test_registered_objects(self, valid_instances): + """Test registered_objects method with manual registration.""" + registry_class, config = valid_instances - class TestRegistryClass(RegistryMixin): + @registry_class.register("class1") + class TestClass1: pass - with pytest.raises(ValueError, match="name must be a string, list of strings"): - TestRegistryClass.register(invalid_name) - - @pytest.mark.sanity - def test_register_decorator_invalid_object(self): - """Test that register_decorator validates object has __name__ attribute.""" - - class TestRegistryClass(RegistryMixin): + @registry_class.register("class2") + class TestClass2: pass - with pytest.raises(AttributeError): - TestRegistryClass.register_decorator("not_a_class") + if config["registry_auto_discovery"]: + with mock.patch.object(registry_class, "auto_import_package_modules"): + objects = registry_class.registered_objects() + else: + objects = registry_class.registered_objects() + + assert isinstance(objects, tuple) + assert len(objects) == 2 + assert TestClass1 in objects + assert TestClass2 in objects @pytest.mark.sanity - @pytest.mark.parametrize("invalid_name", [123, 42.5, True, {"key": "value"}]) - def test_register_decorator_invalid_name_type(self, invalid_name): - """Test that invalid name types in register_decorator raise ValueError.""" + def test_registered_objects_invalid(self): + """Test registered_objects when no objects are registered.""" class TestRegistryClass(RegistryMixin): pass - class TestClass: - pass - with pytest.raises( - ValueError, match="name must be a string or an iterable of strings" + ValueError, match="must be called after registering objects" ): - TestRegistryClass.register_decorator(TestClass, name=invalid_name) - - @pytest.mark.sanity - def test_register_decorator_invalid_list_element(self): - """Test that invalid elements in name list raise ValueError.""" + TestRegistryClass.registered_objects() - class TestRegistryClass(RegistryMixin): - pass + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "check_name", "expected"), + [ + ("test_name", "test_name", True), + ("TestName", "testname", True), + ("UPPERCASE", "uppercase", True), + ("test_name", "nonexistent", False), + ], + ) + def test_is_registered(self, valid_instances, register_name, check_name, expected): + """Test is_registered with various name combinations.""" + registry_class, _ = valid_instances + @registry_class.register(register_name) class TestClass: pass - with pytest.raises( - ValueError, match="name must be a string or a list of strings" - ): - TestRegistryClass.register_decorator(TestClass, name=["valid", 123]) + result = registry_class.is_registered(check_name) + assert result == expected - @pytest.mark.sanity - def test_register_duplicate_name(self): - """Test that duplicate names raise ValueError.""" - - class TestRegistryClass(RegistryMixin): - pass + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "lookup_name"), + [ + ("test_name", "test_name"), + ("TestName", "testname"), + ("UPPERCASE", "uppercase"), + ], + ) + def test_get_registered_object(self, valid_instances, register_name, lookup_name): + """Test get_registered_object with valid names.""" + registry_class, _ = valid_instances - @TestRegistryClass.register("test_name") - class TestClass1: + @registry_class.register(register_name) + class TestClass: pass - with pytest.raises(ValueError, match="already registered"): - - @TestRegistryClass.register("test_name") - class TestClass2: - pass + result = registry_class.get_registered_object(lookup_name) + assert result is TestClass @pytest.mark.sanity - def test_registered_objects_empty_registry(self): - """Test that registered_objects raises error when no objects registered.""" + @pytest.mark.parametrize( + "lookup_name", + ["nonexistent", "wrong_name", "DIFFERENT_CASE"], + ) + def test_get_registered_object_invalid(self, valid_instances, lookup_name): + """Test get_registered_object with invalid names.""" + registry_class, _ = valid_instances - class TestRegistryClass(RegistryMixin): + @registry_class.register("valid_name") + class TestClass: pass - with pytest.raises( - ValueError, match="must be called after registering objects" - ): - TestRegistryClass.registered_objects() - - -class TestRegistryIsolation: - """Test suite for registry isolation between different classes.""" + result = registry_class.get_registered_object(lookup_name) + assert result is None @pytest.mark.regression def test_multiple_registries_isolation(self): @@ -230,45 +321,10 @@ class TestClass2: assert Registry1.registry is not None assert Registry2.registry is not None assert Registry1.registry != Registry2.registry - assert "testclass1" in Registry1.registry - assert "testclass2" in Registry2.registry - assert "testclass1" not in Registry2.registry - assert "testclass2" not in Registry1.registry - - @pytest.mark.regression - def test_inheritance_registry_sharing(self): - """Test that inherited registry classes share the same registry.""" - - class BaseRegistry(RegistryMixin): - pass - - class ChildRegistry(BaseRegistry): - pass - - @BaseRegistry.register() - class BaseClass: - pass - - @ChildRegistry.register() - class ChildClass: - pass - - # Child classes share the same registry as their parent - assert BaseRegistry.registry is ChildRegistry.registry - - # Both classes can see all registered objects - base_objects = BaseRegistry.registered_objects() - child_objects = ChildRegistry.registered_objects() - - assert len(base_objects) == 2 - assert len(child_objects) == 2 - assert base_objects == child_objects - assert BaseClass in base_objects - assert ChildClass in base_objects - - -class TestAutoDiscovery: - """Test suite for auto-discovery functionality.""" + assert "TestClass1" in Registry1.registry + assert "TestClass2" in Registry2.registry + assert "TestClass1" not in Registry2.registry + assert "TestClass2" not in Registry1.registry @pytest.mark.smoke def test_auto_discovery_initialization(self): @@ -284,54 +340,112 @@ class TestAutoRegistry(RegistryMixin): assert TestAutoRegistry.registry_auto_discovery is True @pytest.mark.smoke - def test_auto_populate_registry(self): - """Test auto population mechanism.""" + def test_auto_discovery_registered_objects(self): + """Test automatic population during registered_objects call.""" class TestAutoRegistry(RegistryMixin): registry_auto_discovery = True auto_package = "test_package.modules" with mock.patch.object( - TestAutoRegistry, "auto_import_package_modules" - ) as mock_import: - result = TestAutoRegistry.auto_populate_registry() - assert result is True - mock_import.assert_called_once() - assert TestAutoRegistry.registry_populated is True + TestAutoRegistry, "auto_populate_registry" + ) as mock_populate: + TestAutoRegistry.registry = {"class1": "obj1", "class2": "obj2"} + objects = TestAutoRegistry.registered_objects() + mock_populate.assert_called_once() + assert objects == ("obj1", "obj2") - result = TestAutoRegistry.auto_populate_registry() - assert result is False - mock_import.assert_called_once() + @pytest.mark.sanity + def test_register_duplicate_registration(self, valid_instances): + """Test register method with duplicate names.""" + registry_class, _ = valid_instances + + @registry_class.register("duplicate_name") + class TestClass1: + pass + + with pytest.raises(ValueError, match="already registered"): + + @registry_class.register("duplicate_name") + class TestClass2: + pass @pytest.mark.sanity - def test_auto_populate_registry_disabled(self): - """Test that auto population fails when disabled.""" + def test_register_decorator_duplicate_registration(self, valid_instances): + """Test register_decorator with duplicate names.""" + registry_class, _ = valid_instances - class TestDisabledAutoRegistry(RegistryMixin): - auto_package = "test_package.modules" + class TestClass1: + pass - with pytest.raises(ValueError, match="registry_auto_discovery is set to False"): - TestDisabledAutoRegistry.auto_populate_registry() + class TestClass2: + pass + + registry_class.register_decorator(TestClass1, name="duplicate_name") + with pytest.raises(ValueError, match="already registered"): + registry_class.register_decorator(TestClass2, name="duplicate_name") @pytest.mark.sanity - def test_auto_registered_objects(self): - """Test automatic population during registered_objects call.""" + def test_register_decorator_invalid_list_element(self, valid_instances): + """Test register_decorator with invalid elements in name list.""" + registry_class, _ = valid_instances - class TestAutoRegistry(RegistryMixin): - registry_auto_discovery = True - auto_package = "test_package.modules" + class TestClass: + pass - with mock.patch.object( - TestAutoRegistry, "auto_populate_registry" - ) as mock_populate: - TestAutoRegistry.registry = {"class1": "obj1", "class2": "obj2"} - objects = TestAutoRegistry.registered_objects() - mock_populate.assert_called_once() - assert objects == ("obj1", "obj2") + with pytest.raises( + ValueError, match="name must be a string or a list of strings" + ): + registry_class.register_decorator(TestClass, name=["valid", 123]) + + @pytest.mark.sanity + def test_register_decorator_invalid_object(self, valid_instances): + """Test register_decorator with object lacking __name__ attribute.""" + registry_class, _ = valid_instances + + with pytest.raises(AttributeError): + registry_class.register_decorator("not_a_class") + @pytest.mark.sanity + def test_register_decorator_empty_string_name(self, valid_instances): + """Test register_decorator with empty string name.""" + registry_class, _ = valid_instances -class TestAutoDiscoveryIntegration: - """Test suite for comprehensive auto-discovery integration scenarios.""" + class TestClass: + pass + + registry_class.register_decorator(TestClass, name="") + assert "" in registry_class.registry + assert registry_class.registry[""] is TestClass + + @pytest.mark.sanity + def test_register_decorator_none_in_list(self, valid_instances): + """Test register_decorator with None in name list.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + with pytest.raises( + ValueError, match="name must be a string or a list of strings" + ): + registry_class.register_decorator(TestClass, name=["valid", None]) + + @pytest.mark.smoke + def test_is_registered_empty_registry(self, valid_instances): + """Test is_registered with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.is_registered("any_name") + assert result is False + + @pytest.mark.smoke + def test_get_registered_object_empty_registry(self, valid_instances): + """Test get_registered_object with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.get_registered_object("any_name") + assert result is None @pytest.mark.regression def test_auto_registry_integration(self): @@ -377,37 +491,103 @@ def walk_packages(package_path, package_name): assert len(objects) == 1 assert TestAutoRegistry.registry_populated is True assert TestAutoRegistry.registry is not None - assert "module1class" in TestAutoRegistry.registry - - @pytest.mark.regression - def test_auto_registry_multiple_packages(self): - """Test auto-discovery with multiple packages.""" + assert "Module1Class" in TestAutoRegistry.registry - class TestMultiPackageRegistry(RegistryMixin): - registry_auto_discovery = True - auto_package = ("package1", "package2") + @pytest.mark.smoke + def test_register_preserves_class_metadata(self): + """Test that registered classes retain docs, types, and methods.""" - with mock.patch.object( - TestMultiPackageRegistry, "auto_import_package_modules" - ) as mock_import: - TestMultiPackageRegistry.registry = {} - TestMultiPackageRegistry.registered_objects() - mock_import.assert_called_once() - assert TestMultiPackageRegistry.registry_populated is True + class TestRegistry(RegistryMixin): + pass - @pytest.mark.regression - def test_auto_registry_import_error(self): - """Test handling of import errors during auto-discovery.""" + @TestRegistry.register("documented_class") + class DocumentedClass: + """This is a documented class with methods and type hints.""" + + def __init__(self, value: int) -> None: + """Initialize with a value. + + :param value: An integer value + """ + self.value = value + + def get_value(self) -> int: + """Get the stored value. + + :return: The stored integer value + """ + return self.value + + def set_value(self, new_value: int) -> None: + """Set a new value. + + :param new_value: The new integer value to set + """ + self.value = new_value + + @classmethod + def from_string(cls, value_str: str) -> DocumentedClass: + """Create instance from string. + + :param value_str: String representation of value + :return: New DocumentedClass instance + """ + return cls(int(value_str)) + + @staticmethod + def validate_value(value: int) -> bool: + """Validate that a value is positive. + + :param value: Value to validate + :return: True if positive, False otherwise + """ + return value > 0 + + # Check that the class was registered + assert TestRegistry.is_registered("documented_class") + registered_class = TestRegistry.get_registered_object("documented_class") + assert registered_class is DocumentedClass + + # Check that the class retains its documentation + assert registered_class.__doc__ is not None + assert "documented class with methods" in registered_class.__doc__ + assert registered_class.__init__.__doc__ is not None + assert "Initialize with a value" in registered_class.__init__.__doc__ + assert registered_class.get_value.__doc__ is not None + assert "Get the stored value" in registered_class.get_value.__doc__ + assert registered_class.set_value.__doc__ is not None + assert "Set a new value" in registered_class.set_value.__doc__ + assert registered_class.from_string.__doc__ is not None + assert "Create instance from string" in registered_class.from_string.__doc__ + assert registered_class.validate_value.__doc__ is not None + assert ( + "Validate that a value is positive" + in registered_class.validate_value.__doc__ + ) - class TestErrorRegistry(RegistryMixin): - registry_auto_discovery = True - auto_package = "nonexistent.package" + # Check that methods are callable and work correctly + instance = registered_class(42) + assert instance.get_value() == 42 + instance.set_value(100) + assert instance.get_value() == 100 + instance2 = registered_class.from_string("123") + assert instance2.get_value() == 123 + assert registered_class.validate_value(10) is True + assert registered_class.validate_value(-5) is False + + # Check that type annotations are preserved (if accessible) + if hasattr(inspect, "get_annotations"): + # Python 3.10+ + try: + annotations = inspect.get_annotations(registered_class.__init__) + assert "value" in annotations + assert annotations["value"] is int + return_ann = annotations.get("return") + assert return_ann is None or return_ann is type(None) + except (AttributeError, NameError): + # Fallback for older Python or missing annotations + pass - with mock.patch.object( - TestErrorRegistry, - "auto_import_package_modules", - side_effect=ValueError("auto_package must be set"), - ) as mock_import: - with pytest.raises(ValueError, match="auto_package must be set"): - TestErrorRegistry.auto_populate_registry() - mock_import.assert_called_once() + # Check that the class name is preserved + assert registered_class.__name__ == "DocumentedClass" + assert registered_class.__qualname__.endswith("DocumentedClass") diff --git a/tests/unit/utils/test_singleton.py b/tests/unit/utils/test_singleton.py new file mode 100644 index 00000000..ee01ead1 --- /dev/null +++ b/tests/unit/utils/test_singleton.py @@ -0,0 +1,371 @@ +from __future__ import annotations + +import threading +import time + +import pytest + +from guidellm.utils.singleton import SingletonMixin, ThreadSafeSingletonMixin + + +class TestSingletonMixin: + """Test suite for SingletonMixin class.""" + + @pytest.fixture( + params=[ + {"init_value": "test_value"}, + {"init_value": "another_value"}, + ], + ids=["basic_singleton", "different_value"], + ) + def valid_instances(self, request): + """Provide parameterized test configurations for singleton testing.""" + config = request.param + + class TestSingleton(SingletonMixin): + def __init__(self): + # Check if we need to initialize before calling super().__init__() + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = config["init_value"] + + return TestSingleton, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test SingletonMixin inheritance and exposed attributes.""" + assert hasattr(SingletonMixin, "__new__") + assert hasattr(SingletonMixin, "__init__") + assert hasattr(SingletonMixin, "initialized") + assert isinstance(SingletonMixin.initialized, property) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test SingletonMixin initialization.""" + singleton_class, config = valid_instances + + # Create first instance + instance1 = singleton_class() + + assert isinstance(instance1, singleton_class) + assert instance1.initialized is True + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + + # Check that the class has the singleton instance stored + instance_attr = f"_singleton_instance_{singleton_class.__name__}" + assert hasattr(singleton_class, instance_attr) + assert getattr(singleton_class, instance_attr) is instance1 + + @pytest.mark.smoke + def test_singleton_behavior(self, valid_instances): + """Test that multiple instantiations return the same instance.""" + singleton_class, config = valid_instances + + # Create multiple instances + instance1 = singleton_class() + instance2 = singleton_class() + instance3 = singleton_class() + + # All should be the same instance + assert instance1 is instance2 + assert instance2 is instance3 + assert instance1 is instance3 + + # Value should remain from first initialization + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + assert instance2.value == config["init_value"] + assert instance3.value == config["init_value"] + + @pytest.mark.sanity + def test_initialization_called_once(self, valid_instances): + """Test that __init__ is only called once despite multiple instantiations.""" + singleton_class, config = valid_instances + + class TestSingletonWithCounter(SingletonMixin): + init_count = 0 + + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + TestSingletonWithCounter.init_count += 1 + self.value = config["init_value"] + + # Create multiple instances + instance1 = TestSingletonWithCounter() + instance2 = TestSingletonWithCounter() + + assert TestSingletonWithCounter.init_count == 1 + assert instance1 is instance2 + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + + @pytest.mark.regression + def test_multiple_singleton_classes_isolation(self): + """Test that different singleton classes maintain separate instances.""" + + class Singleton1(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value1" + + class Singleton2(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value2" + + instance1a = Singleton1() + instance2a = Singleton2() + instance1b = Singleton1() + instance2b = Singleton2() + + # Each class has its own singleton instance + assert instance1a is instance1b + assert instance2a is instance2b + assert instance1a is not instance2a + + # Each maintains its own value + assert hasattr(instance1a, "value") + assert hasattr(instance2a, "value") + assert instance1a.value == "value1" + assert instance2a.value == "value2" + + @pytest.mark.regression + def test_inheritance_singleton_sharing(self): + """Test that inherited singleton classes share the same singleton_instance.""" + + class BaseSingleton(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "base_value" + + class ChildSingleton(BaseSingleton): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.extra = "extra_value" + + # Child classes now have separate singleton instances + base_instance = BaseSingleton() + child_instance = ChildSingleton() + + # They should be different instances now (fixed inheritance behavior) + assert base_instance is not child_instance + assert hasattr(base_instance, "value") + assert base_instance.value == "base_value" + assert hasattr(child_instance, "value") + assert child_instance.value == "base_value" + assert hasattr(child_instance, "extra") + assert child_instance.extra == "extra_value" + + @pytest.mark.sanity + def test_without_super_init_call(self): + """Test singleton behavior when subclass doesn't call super().__init__().""" + + class BadSingleton(SingletonMixin): + def __init__(self): + # Not calling super().__init__() + self.value = "bad_value" + + instance1 = BadSingleton() + instance2 = BadSingleton() + + assert instance1 is instance2 + assert hasattr(instance1, "initialized") + assert instance1.initialized is False + + +class TestThreadSafeSingletonMixin: + """Test suite for ThreadSafeSingletonMixin class.""" + + @pytest.fixture( + params=[ + {"init_value": "thread_safe_value"}, + {"init_value": "concurrent_value"}, + ], + ids=["basic_thread_safe", "concurrent_test"], + ) + def valid_instances(self, request): + """Fixture providing test data for ThreadSafeSingletonMixin subclasses.""" + config = request.param + + class TestThreadSafeSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = config["init_value"] + + return TestThreadSafeSingleton, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ThreadSafeSingletonMixin inheritance and exposed attributes.""" + assert issubclass(ThreadSafeSingletonMixin, SingletonMixin) + assert hasattr(ThreadSafeSingletonMixin, "get_singleton_lock") + assert hasattr(ThreadSafeSingletonMixin, "__new__") + assert hasattr(ThreadSafeSingletonMixin, "__init__") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test ThreadSafeSingletonMixin initialization.""" + singleton_class, config = valid_instances + + instance = singleton_class() + + assert isinstance(instance, singleton_class) + assert instance.initialized is True + assert hasattr(instance, "value") + assert instance.value == config["init_value"] + assert hasattr(instance, "thread_lock") + lock_type = type(threading.Lock()) + assert isinstance(instance.thread_lock, lock_type) + + @pytest.mark.smoke + def test_singleton_behavior(self, valid_instances): + """Test multiple instantiations return same instance with thread safety.""" + singleton_class, config = valid_instances + + instance1 = singleton_class() + instance2 = singleton_class() + + assert instance1 is instance2 + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + assert hasattr(instance1, "thread_lock") + + @pytest.mark.regression + def test_thread_safety_concurrent_creation(self, valid_instances): + """Test thread safety during concurrent instance creation.""" + singleton_class, config = valid_instances + + instances = [] + exceptions = [] + creation_count = 0 + lock = threading.Lock() + + class ThreadSafeTestSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + nonlocal creation_count + with lock: + creation_count += 1 + + time.sleep(0.01) + self.value = config["init_value"] + + def create_instance(): + try: + instance = ThreadSafeTestSingleton() + instances.append(instance) + except (TypeError, ValueError, AttributeError) as exc: + exceptions.append(exc) + + threads = [] + for _ in range(10): + thread = threading.Thread(target=create_instance) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + assert len(exceptions) == 0, f"Exceptions occurred: {exceptions}" + + assert len(instances) == 10 + for instance in instances: + assert instance is instances[0] + + assert creation_count == 1 + assert all(instance.value == config["init_value"] for instance in instances) + + @pytest.mark.sanity + def test_thread_lock_creation(self, valid_instances): + """Test that thread_lock is created during initialization.""" + singleton_class, config = valid_instances + + instance1 = singleton_class() + instance2 = singleton_class() + + assert hasattr(instance1, "thread_lock") + lock_type = type(threading.Lock()) + assert isinstance(instance1.thread_lock, lock_type) + assert instance1.thread_lock is instance2.thread_lock + + @pytest.mark.regression + def test_multiple_thread_safe_classes_isolation(self): + """Test thread-safe singleton classes behavior with separate locks.""" + + class ThreadSafeSingleton1(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value1" + + class ThreadSafeSingleton2(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value2" + + instance1 = ThreadSafeSingleton1() + instance2 = ThreadSafeSingleton2() + + lock1 = ThreadSafeSingleton1.get_singleton_lock() + lock2 = ThreadSafeSingleton2.get_singleton_lock() + + assert lock1 is not None + assert lock2 is not None + assert lock1 is not lock2 + + assert instance1 is not instance2 + assert hasattr(instance1, "value") + assert hasattr(instance2, "value") + assert instance1.value == "value1" + assert instance2.value == "value2" + + @pytest.mark.sanity + def test_inheritance_with_thread_safety(self): + """Test inheritance behavior with thread-safe singletons.""" + + class BaseThreadSafeSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "base_value" + + class ChildThreadSafeSingleton(BaseThreadSafeSingleton): + def __init__(self): + super().__init__() + + base_instance = BaseThreadSafeSingleton() + child_instance = ChildThreadSafeSingleton() + + base_lock = BaseThreadSafeSingleton.get_singleton_lock() + child_lock = ChildThreadSafeSingleton.get_singleton_lock() + + assert base_lock is not None + assert child_lock is not None + assert base_lock is not child_lock + + assert base_instance is not child_instance + assert hasattr(base_instance, "value") + assert base_instance.value == "base_value" + assert hasattr(base_instance, "thread_lock") diff --git a/tests/unit/utils/test_synchronous.py b/tests/unit/utils/test_synchronous.py new file mode 100644 index 00000000..4a3b1893 --- /dev/null +++ b/tests/unit/utils/test_synchronous.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import asyncio +import multiprocessing +import threading +from functools import wraps +from multiprocessing.synchronize import Barrier as ProcessingBarrier +from multiprocessing.synchronize import Event as ProcessingEvent +from typing import Union + +import pytest + +from guidellm.utils.synchronous import ( + SyncObjectTypesAlias, + wait_for_sync_barrier, + wait_for_sync_event, + wait_for_sync_objects, +) + + +def async_timeout(delay: float): + """Decorator to add timeout to async functions.""" + + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +def test_sync_object_types_alias(): + """Test that SyncObjectTypesAlias is defined correctly as a type alias.""" + assert hasattr(SyncObjectTypesAlias, "__origin__") + if hasattr(SyncObjectTypesAlias, "__args__"): + actual_type = SyncObjectTypesAlias.__args__[0] + assert hasattr(actual_type, "__origin__") + assert actual_type.__origin__ is Union + union_args = actual_type.__args__ + assert threading.Event in union_args + assert ProcessingEvent in union_args + assert threading.Barrier in union_args + assert ProcessingBarrier in union_args + + +class TestWaitForSyncEvent: + """Test suite for wait_for_sync_event function.""" + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "event_type", + [threading.Event, multiprocessing.Event], + ids=["threading", "multiprocessing"], + ) + @async_timeout(2.0) + async def test_invocation(self, event_type): + """Test wait_for_sync_event with valid events that get set.""" + event: threading.Event | ProcessingEvent = event_type() + + async def set_event(): + await asyncio.sleep(0.01) + event.set() + + asyncio.create_task(set_event()) + await wait_for_sync_event(event, poll_interval=0.001) + assert event.is_set() + + @pytest.mark.sanity + @pytest.mark.asyncio + @pytest.mark.parametrize( + "event_type", + [threading.Event, multiprocessing.Event], + ids=["threading", "multiprocessing"], + ) + @async_timeout(2.0) + async def test_cancellation_stops_waiting(self, event_type): + """Test that cancelling the task stops waiting for the event.""" + event: threading.Event | ProcessingEvent = event_type() + + async def waiter(): + await wait_for_sync_event(event, poll_interval=0.001) + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.02) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + +class TestWaitForSyncBarrier: + """Test suite for wait_for_sync_barrier function.""" + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "barrier_type", + [threading.Barrier, multiprocessing.Barrier], + ids=["threading", "multiprocessing"], + ) + @async_timeout(5.0) + async def test_invocation(self, barrier_type): + """Test wait_for_sync_barrier with barrier that gets reached.""" + barrier: threading.Barrier | ProcessingBarrier = barrier_type(2) + + async def reach_barrier(): + await asyncio.sleep(0.01) + print("waiting for barrier from reach_barrier") + await asyncio.to_thread(barrier.wait) + + task = asyncio.create_task(reach_barrier()) + await wait_for_sync_barrier(barrier, poll_interval=0.01) + await task + + @pytest.mark.sanity + @pytest.mark.asyncio + @pytest.mark.parametrize( + "barrier_type", + [threading.Barrier, multiprocessing.Barrier], + ids=["threading", "multiprocessing"], + ) + @async_timeout(2.0) + async def test_cancellation_stops_waiting(self, barrier_type): + """Test that cancelling the task stops waiting for the barrier.""" + barrier: threading.Barrier | ProcessingBarrier = barrier_type(2) + + async def waiter(): + await wait_for_sync_barrier(barrier, 0.01) + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + +class TestWaitForSyncObjects: + """Test suite for wait_for_sync_objects function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("objects_types", "expected_result"), + [ + (threading.Event, 0), + (multiprocessing.Event, 0), + (threading.Barrier, 0), + (multiprocessing.Barrier, 0), + ([threading.Event, multiprocessing.Barrier], 1), + ([multiprocessing.Event, threading.Barrier], 0), + ( + [ + threading.Event, + multiprocessing.Event, + threading.Barrier, + multiprocessing.Barrier, + ], + 2, + ), + ( + { + "multiprocessing.Event": multiprocessing.Event, + "threading.Barrier": threading.Barrier, + }, + "threading.Barrier", + ), + ( + { + "threading.Event": threading.Event, + "multiprocessing.Barrier": multiprocessing.Barrier, + }, + "threading.Event", + ), + ( + { + "multiprocessing.Event": multiprocessing.Event, + "threading.Event": threading.Event, + "multiprocessing.Barrier": multiprocessing.Barrier, + "threading.Barrier": threading.Barrier, + }, + "threading.Event", + ), + ], + ids=[ + "threading_event", + "multiprocessing_event", + "threading_barrier", + "multiprocessing_barrier", + "mixed_list_event_barrier_1", + "mixed_list_event_barrier_2", + "mixed_list_all", + "mixed_dict_event_barrier_1", + "mixed_dict_event_barrier_2", + "mixed_dict_all", + ], + ) + @pytest.mark.asyncio + @async_timeout(2.0) + async def test_invocation(self, objects_types, expected_result): + """Test wait_for_sync_objects with various object configurations.""" + if isinstance(objects_types, list): + objects = [ + obj() + if obj not in (threading.Barrier, multiprocessing.Barrier) + else obj(2) + for obj in objects_types + ] + elif isinstance(objects_types, dict): + objects = { + key: ( + obj() + if obj not in (threading.Barrier, multiprocessing.Barrier) + else obj(2) + ) + for key, obj in objects_types.items() + } + else: + objects = [ + objects_types() + if objects_types not in (threading.Barrier, multiprocessing.Barrier) + else objects_types(2) + ] + + async def set_target(): + await asyncio.sleep(0.01) + obj = objects[expected_result] + if isinstance(obj, (threading.Event, ProcessingEvent)): + obj.set() + else: + await asyncio.to_thread(obj.wait) + + task = asyncio.create_task(set_target()) + result = await wait_for_sync_objects(objects, poll_interval=0.001) + await task + + assert result == expected_result diff --git a/tests/unit/utils/test_text.py b/tests/unit/utils/test_text.py new file mode 100644 index 00000000..50f18ce3 --- /dev/null +++ b/tests/unit/utils/test_text.py @@ -0,0 +1,531 @@ +from __future__ import annotations + +import gzip +import tempfile +from pathlib import Path +from unittest.mock import Mock, patch + +import httpx +import pytest + +from guidellm.utils.text import ( + MAX_PATH_LENGTH, + EndlessTextCreator, + clean_text, + filter_text, + format_value_display, + is_puncutation, + load_text, + split_text, + split_text_list_by_length, +) + + +def test_max_path_length(): + """Test that MAX_PATH_LENGTH is correctly defined.""" + assert isinstance(MAX_PATH_LENGTH, int) + assert MAX_PATH_LENGTH == 4096 + + +class TestFormatValueDisplay: + """Test suite for format_value_display.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ( + "value", + "label", + "units", + "total_characters", + "digits_places", + "decimal_places", + "expected", + ), + [ + (42.0, "test", "", None, None, None, "42 [info]test[/info]"), + (42.5, "test", "ms", None, None, 1, "42.5ms [info]test[/info]"), + (42.123, "test", "", None, 5, 2, " 42.12 [info]test[/info]"), + ( + 42.0, + "test", + "ms", + 30, + None, + 0, + " 42ms [info]test[/info]", + ), + ], + ) + def test_invocation( + self, + value, + label, + units, + total_characters, + digits_places, + decimal_places, + expected, + ): + """Test format_value_display with various parameters.""" + result = format_value_display( + value=value, + label=label, + units=units, + total_characters=total_characters, + digits_places=digits_places, + decimal_places=decimal_places, + ) + assert label in result + assert units in result + value_check = ( + str(int(value)) + if decimal_places == 0 + else ( + f"{value:.{decimal_places}f}" + if decimal_places is not None + else str(value) + ) + ) + assert value_check in result or str(value) in result + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("value", "label"), + [ + (None, "test"), + (42.0, None), + ("not_number", "test"), + ], + ) + def test_invocation_with_none_values(self, value, label): + """Test format_value_display with None/invalid inputs still works.""" + result = format_value_display(value, label) + assert isinstance(result, str) + if label is not None: + assert str(label) in result + if value is not None: + assert str(value) in result + + +class TestSplitTextListByLength: + """Test suite for split_text_list_by_length.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ( + "text_list", + "max_characters", + "pad_horizontal", + "pad_vertical", + "expected_structure", + ), + [ + ( + ["hello world", "test"], + 5, + False, + False, + [["hello", "world"], ["test"]], + ), + ( + ["short", "longer text"], + [5, 10], + True, + True, + [[" short"], ["longer", "text"]], + ), + ( + ["a", "b", "c"], + 10, + True, + True, + [[" a"], [" b"], [" c"]], + ), + ], + ) + def test_invocation( + self, + text_list, + max_characters, + pad_horizontal, + pad_vertical, + expected_structure, + ): + """Test split_text_list_by_length with various parameters.""" + result = split_text_list_by_length( + text_list, max_characters, pad_horizontal, pad_vertical + ) + assert len(result) == len(text_list) + if pad_vertical: + max_lines = max(len(lines) for lines in result) + assert all(len(lines) == max_lines for lines in result) + + @pytest.mark.sanity + def test_invalid_max_characters_length(self): + """Test split_text_list_by_length with mismatched max_characters length.""" + error_msg = "max_characters must be a list of the same length" + with pytest.raises(ValueError, match=error_msg): + split_text_list_by_length(["a", "b"], [5, 10, 15]) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("text_list", "max_characters"), + [ + (None, 5), + (["test"], None), + (["test"], []), + ], + ) + def test_invalid_invocation(self, text_list, max_characters): + """Test split_text_list_by_length with invalid inputs.""" + with pytest.raises((TypeError, ValueError)): + split_text_list_by_length(text_list, max_characters) + + +class TestFilterText: + """Test suite for filter_text.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "filter_start", "filter_end", "expected"), + [ + ("hello world test", "world", None, "world test"), + ("hello world test", None, "world", "hello "), + ("hello world test", "hello", "test", "hello world "), + ("hello world test", 6, 11, "world test"), + ("hello world test", 0, 5, "hello"), + ("hello world test", None, None, "hello world test"), + ], + ) + def test_invocation(self, text, filter_start, filter_end, expected): + """Test filter_text with various start and end markers.""" + result = filter_text(text, filter_start, filter_end) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("text", "filter_start", "filter_end"), + [ + ("hello", "notfound", None), + ("hello", None, "notfound"), + ("hello", "invalid_type", None), + ("hello", None, "invalid_type"), + ], + ) + def test_invalid_invocation(self, text, filter_start, filter_end): + """Test filter_text with invalid markers.""" + with pytest.raises((ValueError, TypeError)): + filter_text(text, filter_start, filter_end) + + +class TestCleanText: + """Test suite for clean_text.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "expected"), + [ + ("hello world", "hello world"), + (" hello\n\nworld ", "hello world"), + ("hello\tworld\r\ntest", "hello world test"), + ("", ""), + (" ", ""), + ], + ) + def test_invocation(self, text, expected): + """Test clean_text with various whitespace scenarios.""" + result = clean_text(text) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + "text", + [ + None, + 123, + ], + ) + def test_invalid_invocation(self, text): + """Test clean_text with invalid inputs.""" + with pytest.raises((TypeError, AttributeError)): + clean_text(text) + + +class TestSplitText: + """Test suite for split_text.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "split_punctuation", "expected"), + [ + ("hello world", False, ["hello", "world"]), + ("hello, world!", True, ["hello", ",", "world", "!"]), + ("test.example", False, ["test.example"]), + ("test.example", True, ["test", ".", "example"]), + ("", False, []), + ], + ) + def test_invocation(self, text, split_punctuation, expected): + """Test split_text with various punctuation options.""" + result = split_text(text, split_punctuation) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + "text", + [ + None, + 123, + ], + ) + def test_invalid_invocation(self, text): + """Test split_text with invalid inputs.""" + with pytest.raises((TypeError, AttributeError)): + split_text(text) + + +class TestLoadText: + """Test suite for load_text.""" + + @pytest.mark.smoke + def test_empty_data(self): + """Test load_text with empty data.""" + result = load_text("") + assert result == "" + + @pytest.mark.smoke + def test_raw_text(self): + """Test load_text with raw text that's not a file.""" + long_text = "a" * (MAX_PATH_LENGTH + 1) + result = load_text(long_text) + assert result == long_text + + @pytest.mark.smoke + def test_local_file(self): + """Test load_text with local file.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as tmp: + test_content = "test file content" + tmp.write(test_content) + tmp.flush() + + result = load_text(tmp.name) + assert result == test_content + + Path(tmp.name).unlink() + + @pytest.mark.smoke + def test_gzipped_file(self): + """Test load_text with gzipped file.""" + with tempfile.NamedTemporaryFile(delete=False, suffix=".gz") as tmp: + test_content = "test gzipped content" + with gzip.open(tmp.name, "wt") as gzf: + gzf.write(test_content) + + result = load_text(tmp.name) + assert result == test_content + + Path(tmp.name).unlink() + + @pytest.mark.smoke + @patch("httpx.Client") + def test_url_loading(self, mock_client): + """Test load_text with HTTP URL.""" + mock_response = Mock() + mock_response.text = "url content" + mock_client.return_value.__enter__.return_value.get.return_value = mock_response + + result = load_text("http://example.com/test.txt") + assert result == "url content" + + @pytest.mark.smoke + @patch("guidellm.utils.text.files") + @patch("guidellm.utils.text.as_file") + def test_package_data_loading(self, mock_as_file, mock_files): + """Test load_text with package data.""" + mock_resource = Mock() + mock_files.return_value.joinpath.return_value = mock_resource + + mock_file = Mock() + mock_file.read.return_value = "package data content" + mock_as_file.return_value.__enter__.return_value = mock_file + + with patch("gzip.open") as mock_gzip: + mock_gzip.return_value.__enter__.return_value = mock_file + result = load_text("data:test.txt") + assert result == "package data content" + + @pytest.mark.sanity + def test_nonexistent_file(self): + """Test load_text with nonexistent file returns the path as raw text.""" + result = load_text("/nonexistent/path/file.txt") + assert result == "/nonexistent/path/file.txt" + + @pytest.mark.sanity + @patch("httpx.Client") + def test_url_error(self, mock_client): + """Test load_text with HTTP error.""" + mock_client.return_value.__enter__.return_value.get.side_effect = ( + httpx.HTTPStatusError("HTTP error", request=None, response=None) + ) + + with pytest.raises(httpx.HTTPStatusError): + load_text("http://example.com/error.txt") + + +class TestIsPuncutation: + """Test suite for is_puncutation.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "expected"), + [ + (".", True), + (",", True), + ("!", True), + ("?", True), + (";", True), + ("a", False), + ("1", False), + (" ", False), + ("ab", False), + ("", False), + ], + ) + def test_invocation(self, text, expected): + """Test is_puncutation with various characters.""" + result = is_puncutation(text) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + "text", + [ + None, + 123, + ], + ) + def test_invalid_invocation(self, text): + """Test is_puncutation with invalid inputs.""" + with pytest.raises((TypeError, AttributeError)): + is_puncutation(text) + + +class TestEndlessTextCreator: + """Test suite for EndlessTextCreator.""" + + @pytest.fixture( + params=[ + { + "data": "hello world test", + "filter_start": None, + "filter_end": None, + }, + { + "data": "hello world test", + "filter_start": "world", + "filter_end": None, + }, + {"data": "one two three four", "filter_start": 0, "filter_end": 9}, + ], + ids=["no_filter", "string_filter", "index_filter"], + ) + def valid_instances(self, request): + """Fixture providing test data for EndlessTextCreator.""" + constructor_args = request.param + instance = EndlessTextCreator(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test EndlessTextCreator signatures and methods.""" + assert hasattr(EndlessTextCreator, "__init__") + assert hasattr(EndlessTextCreator, "create_text") + instance = EndlessTextCreator("test") + assert hasattr(instance, "data") + assert hasattr(instance, "text") + assert hasattr(instance, "filtered_text") + assert hasattr(instance, "words") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test EndlessTextCreator initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, EndlessTextCreator) + assert instance.data == constructor_args["data"] + assert isinstance(instance.text, str) + assert isinstance(instance.filtered_text, str) + assert isinstance(instance.words, list) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("data", "filter_start", "filter_end"), + [ + ("test", "notfound", None), + ], + ) + def test_invalid_initialization_values(self, data, filter_start, filter_end): + """Test EndlessTextCreator with invalid initialization values.""" + with pytest.raises((TypeError, ValueError)): + EndlessTextCreator(data, filter_start, filter_end) + + @pytest.mark.smoke + def test_initialization_with_none(self): + """Test EndlessTextCreator handles None data gracefully.""" + instance = EndlessTextCreator(None) + assert isinstance(instance, EndlessTextCreator) + assert instance.data is None + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("start", "length", "expected_length"), + [ + (0, 5, 5), + (2, 3, 3), + (0, 0, 0), + ], + ) + def test_create_text(self, valid_instances, start, length, expected_length): + """Test EndlessTextCreator.create_text.""" + instance, constructor_args = valid_instances + result = instance.create_text(start, length) + assert isinstance(result, str) + if length > 0 and instance.words: + assert len(result) > 0 + + @pytest.mark.smoke + def test_create_text_cycling(self): + """Test EndlessTextCreator.create_text cycling behavior.""" + instance = EndlessTextCreator("one two three") + result1 = instance.create_text(0, 3) + result2 = instance.create_text(3, 3) + assert isinstance(result1, str) + assert isinstance(result2, str) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("start", "length"), + [ + ("invalid", 5), + (0, "invalid"), + ], + ) + def test_create_text_invalid(self, valid_instances, start, length): + """Test EndlessTextCreator.create_text with invalid inputs.""" + instance, constructor_args = valid_instances + with pytest.raises((TypeError, ValueError)): + instance.create_text(start, length) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("start", "length", "min_length"), + [ + (-1, 5, 0), + (0, -1, 0), + ], + ) + def test_create_text_edge_cases(self, valid_instances, start, length, min_length): + """Test EndlessTextCreator.create_text with edge cases.""" + instance, constructor_args = valid_instances + result = instance.create_text(start, length) + assert isinstance(result, str) + assert len(result) >= min_length diff --git a/tests/unit/utils/test_threading.py b/tests/unit/utils/test_threading.py deleted file mode 100644 index 887bf82c..00000000 --- a/tests/unit/utils/test_threading.py +++ /dev/null @@ -1,141 +0,0 @@ -import asyncio -import threading -from collections.abc import Iterator - -import pytest - -from guidellm.utils.threading import synchronous_to_exitable_async - - -def _infinite_counter() -> Iterator[int]: - i = 0 - while True: - i += 1 - yield i - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_callable_completed_returns_value(): - async def run(): - def add(a: int, b: int) -> int: - return a + b - - reason, value = await synchronous_to_exitable_async(add, None, None, 0.01, 2, 3) - return reason, value - - reason, value = await run() - assert reason == "completed" - assert value == 5 - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_iterable_completed_returns_last_item(): - items = ["a", "b", "c"] - reason, value = await synchronous_to_exitable_async(items, None, None, 0.005) - assert reason == "completed" - assert value == "c" - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_iterator_exits_on_custom_event(): - stop_event = threading.Event() - - async def trigger_event(): - await asyncio.sleep(0.02) - stop_event.set() - - task = asyncio.create_task( - synchronous_to_exitable_async( - _infinite_counter(), - exit_events={"stop": stop_event}, - exit_barrier=None, - poll_interval=0.005, - ) - ) - trigger = asyncio.create_task(trigger_event()) - reason, value = await task - await trigger - - assert reason == "stop" - assert isinstance(value, int) - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_barrier_triggers_exit(): - barrier = threading.Barrier(2) - - waiter = threading.Thread(target=barrier.wait, daemon=True) - waiter.start() - - reason, _ = await synchronous_to_exitable_async( - _infinite_counter(), - exit_events=None, - exit_barrier=barrier, - poll_interval=0.005, - ) - - assert reason == "barrier" - - -@pytest.mark.sanity -@pytest.mark.asyncio -async def test_cancellation_sets_canceled_and_aborts_barrier(): - barrier = threading.Barrier(2) - - async def runner(): - return await synchronous_to_exitable_async( - _infinite_counter(), - exit_events=None, - exit_barrier=barrier, - poll_interval=0.01, - ) - - task = asyncio.create_task(runner()) - await asyncio.sleep(0.02) - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - - for _ in range(50): - if barrier.broken: - break - await asyncio.sleep(0.01) - assert barrier.broken is True - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_callable_internal_error_propagates_in_tuple(): - def boom(): - raise ValueError("boom!") - - reason, err = await synchronous_to_exitable_async(boom, None, None, 0.001) - assert reason == "internal_error" - assert isinstance(err, ValueError) - assert str(err) == "boom!" - - -@pytest.mark.smoke -@pytest.mark.asyncio -async def test_poll_mode_only_exits_on_custom_event(): - stop_event = threading.Event() - - async def trigger(): - await asyncio.sleep(0.02) - stop_event.set() - - trigger_task = asyncio.create_task(trigger()) - reason, last = await synchronous_to_exitable_async( - None, - exit_events={"stop": stop_event}, - exit_barrier=None, - poll_interval=0.005, - ) - await trigger_task - - assert reason == "stop" - assert last is None diff --git a/tests/unit/utils/test_typing.py b/tests/unit/utils/test_typing.py new file mode 100644 index 00000000..fafa8765 --- /dev/null +++ b/tests/unit/utils/test_typing.py @@ -0,0 +1,123 @@ +""" +Test suite for the typing utilities module. +""" + +from typing import Annotated, Literal, Union + +import pytest +from typing_extensions import TypeAlias + +from guidellm.utils.typing import get_literal_vals + +# Local type definitions to avoid imports from other modules +LocalProfileType = Literal["synchronous", "async", "concurrent", "throughput", "sweep"] +LocalStrategyType = Annotated[ + Literal["synchronous", "concurrent", "throughput", "constant", "poisson"], + "Valid strategy type identifiers for scheduling request patterns", +] +StrategyProfileType: TypeAlias = Union[LocalStrategyType, LocalProfileType] + + +class TestGetLiteralVals: + """Test cases for the get_literal_vals function.""" + + @pytest.mark.sanity + def test_profile_type(self): + """ + Test extracting values from ProfileType. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(LocalProfileType) + expected = frozenset( + {"synchronous", "async", "concurrent", "throughput", "sweep"} + ) + assert result == expected + + @pytest.mark.sanity + def test_strategy_type(self): + """ + Test extracting values from StrategyType. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(LocalStrategyType) + expected = frozenset( + {"synchronous", "concurrent", "throughput", "constant", "poisson"} + ) + assert result == expected + + @pytest.mark.smoke + def test_inline_union_type(self): + """ + Test extracting values from inline union of ProfileType | StrategyType. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(Union[LocalProfileType, LocalStrategyType]) + expected = frozenset( + { + "synchronous", + "async", + "concurrent", + "throughput", + "constant", + "poisson", + "sweep", + } + ) + assert result == expected + + @pytest.mark.smoke + def test_type_alias(self): + """ + Test extracting values from type alias union. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(StrategyProfileType) + expected = frozenset( + { + "synchronous", + "async", + "concurrent", + "throughput", + "constant", + "poisson", + "sweep", + } + ) + assert result == expected + + @pytest.mark.sanity + def test_single_literal(self): + """ + Test extracting values from single Literal type. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(Literal["test"]) + expected = frozenset({"test"}) + assert result == expected + + @pytest.mark.sanity + def test_multi_literal(self): + """ + Test extracting values from multi-value Literal type. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(Literal["test", "test2"]) + expected = frozenset({"test", "test2"}) + assert result == expected + + @pytest.mark.smoke + def test_literal_union(self): + """ + Test extracting values from union of Literal types. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(Union[Literal["test", "test2"], Literal["test3"]]) + expected = frozenset({"test", "test2", "test3"}) + assert result == expected