From 6b6ed9886da8fad98917711652682bf247ea06f6 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 19 Sep 2025 03:55:12 +0000 Subject: [PATCH 1/3] refactors for the utility modules Signed-off-by: Mark Kurtz --- src/guidellm/utils/__init__.py | 96 ++- src/guidellm/utils/auto_importer.py | 98 +++ src/guidellm/utils/cli.py | 2 +- src/guidellm/utils/console.py | 183 ++++ src/guidellm/utils/dict.py | 23 - src/guidellm/utils/encoding.py | 787 +++++++++++++++++ src/guidellm/utils/functions.py | 133 +++ src/guidellm/utils/messaging.py | 1029 +++++++++++++++++++++++ src/guidellm/utils/mixins.py | 115 +++ src/guidellm/utils/pydantic_utils.py | 401 +++++++++ src/guidellm/utils/registry.py | 214 +++++ src/guidellm/utils/singleton.py | 130 +++ src/guidellm/utils/statistics.py | 990 ++++++++++++++++++++++ src/guidellm/utils/synchronous.py | 161 ++++ src/guidellm/utils/text.py | 199 ++++- tests/unit/utils/dict.py | 71 -- tests/unit/utils/test_auto_importer.py | 269 ++++++ tests/unit/utils/test_encoding.py | 556 ++++++++++++ tests/unit/utils/test_functions.py | 222 +++++ tests/unit/utils/test_messaging.py | 974 +++++++++++++++++++++ tests/unit/utils/test_mixins.py | 245 ++++++ tests/unit/utils/test_pydantic_utils.py | 1002 ++++++++++++++++++++++ tests/unit/utils/test_registry.py | 593 +++++++++++++ tests/unit/utils/test_singleton.py | 371 ++++++++ tests/unit/utils/test_synchronous.py | 238 ++++++ tests/unit/utils/test_text.py | 531 ++++++++++++ tests/unit/utils/text.py | 13 - 27 files changed, 9492 insertions(+), 154 deletions(-) create mode 100644 src/guidellm/utils/auto_importer.py create mode 100644 src/guidellm/utils/console.py delete mode 100644 src/guidellm/utils/dict.py create mode 100644 src/guidellm/utils/encoding.py create mode 100644 src/guidellm/utils/functions.py create mode 100644 src/guidellm/utils/messaging.py create mode 100644 src/guidellm/utils/mixins.py create mode 100644 src/guidellm/utils/pydantic_utils.py create mode 100644 src/guidellm/utils/registry.py create mode 100644 src/guidellm/utils/singleton.py create mode 100644 src/guidellm/utils/statistics.py create mode 100644 src/guidellm/utils/synchronous.py delete mode 100644 tests/unit/utils/dict.py create mode 100644 tests/unit/utils/test_auto_importer.py create mode 100644 tests/unit/utils/test_encoding.py create mode 100644 tests/unit/utils/test_functions.py create mode 100644 tests/unit/utils/test_messaging.py create mode 100644 tests/unit/utils/test_mixins.py create mode 100644 tests/unit/utils/test_pydantic_utils.py create mode 100644 tests/unit/utils/test_registry.py create mode 100644 tests/unit/utils/test_singleton.py create mode 100644 tests/unit/utils/test_synchronous.py create mode 100644 tests/unit/utils/test_text.py delete mode 100644 tests/unit/utils/text.py diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index 02f2427f..83a276b2 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -1,6 +1,21 @@ -from .colors import Colors +from .auto_importer import AutoImporterMixin +from .console import Colors, Console, ConsoleUpdateStep, StatusIcons, StatusStyles from .default_group import DefaultGroupHandler -from .dict import recursive_key_update +from .encoding import ( + Encoder, + EncodingTypesAlias, + MessageEncoding, + SerializationTypesAlias, + Serializer, +) +from .functions import ( + all_defined, + safe_add, + safe_divide, + safe_format_timestamp, + safe_getattr, + safe_multiply, +) from .hf_datasets import ( SUPPORTED_TYPES, save_dataset_to_file, @@ -8,32 +23,103 @@ from .hf_transformers import ( check_load_processor, ) +from .messaging import ( + InterProcessMessaging, + InterProcessMessagingManagerQueue, + InterProcessMessagingPipe, + InterProcessMessagingQueue, + SendMessageT, +) +from .mixins import InfoMixin +from .pydantic_utils import ( + PydanticClassRegistryMixin, + ReloadableBaseModel, + StandardBaseDict, + StandardBaseModel, + StatusBreakdown, +) from .random import IntegerRangeSampler +from .registry import RegistryMixin, RegistryObjT +from .singleton import SingletonMixin, ThreadSafeSingletonMixin +from .statistics import ( + DistributionSummary, + Percentiles, + RunningStats, + StatusDistributionSummary, + TimeRunningStats, +) +from .synchronous import ( + wait_for_sync_barrier, + wait_for_sync_event, + wait_for_sync_objects, +) from .text import ( EndlessTextCreator, - camelize_str, clean_text, filter_text, + format_value_display, is_puncutation, load_text, split_text, split_text_list_by_length, ) +from .typing import get_literal_vals __all__ = [ "SUPPORTED_TYPES", + "AutoImporterMixin", + "Colors", "Colors", + "Console", + "ConsoleUpdateStep", "DefaultGroupHandler", + "DistributionSummary", + "Encoder", + "EncodingTypesAlias", "EndlessTextCreator", + "InfoMixin", "IntegerRangeSampler", - "camelize_str", + "InterProcessMessaging", + "InterProcessMessagingManagerQueue", + "InterProcessMessagingPipe", + "InterProcessMessagingQueue", + "MessageEncoding", + "MessageEncoding", + "Percentiles", + "PydanticClassRegistryMixin", + "RegistryMixin", + "RegistryObjT", + "ReloadableBaseModel", + "RunningStats", + "SendMessageT", + "SerializationTypesAlias", + "Serializer", + "SingletonMixin", + "StandardBaseDict", + "StandardBaseModel", + "StatusBreakdown", + "StatusDistributionSummary", + "StatusIcons", + "StatusStyles", + "ThreadSafeSingletonMixin", + "TimeRunningStats", + "all_defined", "check_load_processor", "clean_text", "filter_text", + "format_value_display", + "get_literal_vals", "is_puncutation", "load_text", - "recursive_key_update", + "safe_add", + "safe_divide", + "safe_format_timestamp", + "safe_getattr", + "safe_multiply", "save_dataset_to_file", "split_text", "split_text_list_by_length", + "wait_for_sync_barrier", + "wait_for_sync_event", + "wait_for_sync_objects", ] diff --git a/src/guidellm/utils/auto_importer.py b/src/guidellm/utils/auto_importer.py new file mode 100644 index 00000000..5b939014 --- /dev/null +++ b/src/guidellm/utils/auto_importer.py @@ -0,0 +1,98 @@ +""" +Automatic module importing utilities for dynamic class discovery. + +This module provides a mixin class for automatic module importing within a package, +enabling dynamic discovery of classes and implementations without explicit imports. +It is particularly useful for auto-registering classes in a registry pattern where +subclasses need to be discoverable at runtime. + +The AutoImporterMixin can be combined with registration mechanisms to create +extensible systems where new implementations are automatically discovered and +registered when they are placed in the correct package structure. +""" + +from __future__ import annotations + +import importlib +import pkgutil +import sys +from typing import ClassVar + +__all__ = ["AutoImporterMixin"] + + +class AutoImporterMixin: + """ + Mixin class for automatic module importing within packages. + + This mixin enables dynamic discovery of classes and implementations without + explicit imports by automatically importing all modules within specified + packages. It is designed for use with class registration mechanisms to enable + automatic discovery and registration of classes when they are placed in the + correct package structure. + + Example: + :: + from guidellm.utils import AutoImporterMixin + + class MyRegistry(AutoImporterMixin): + auto_package = "my_package.implementations" + + MyRegistry.auto_import_package_modules() + + :cvar auto_package: Package name or tuple of package names to import modules from + :cvar auto_ignore_modules: Module names to ignore during import + :cvar auto_imported_modules: List tracking which modules have been imported + """ + + auto_package: ClassVar[str | tuple[str, ...] | None] = None + auto_ignore_modules: ClassVar[tuple[str, ...] | None] = None + auto_imported_modules: ClassVar[list[str] | None] = None + + @classmethod + def auto_import_package_modules(cls) -> None: + """ + Automatically import all modules within the specified package(s). + + Scans the package(s) defined in the `auto_package` class variable and imports + all modules found, tracking them in `auto_imported_modules`. Skips packages + (directories) and any modules listed in `auto_ignore_modules`. + + :raises ValueError: If the `auto_package` class variable is not set + """ + if cls.auto_package is None: + raise ValueError( + "The class variable 'auto_package' must be set to the package name to " + "import modules from." + ) + + cls.auto_imported_modules = [] + packages = ( + cls.auto_package + if isinstance(cls.auto_package, tuple) + else (cls.auto_package,) + ) + + for package_name in packages: + package = importlib.import_module(package_name) + + for _, module_name, is_pkg in pkgutil.walk_packages( + package.__path__, package.__name__ + "." + ): + if ( + is_pkg + or ( + cls.auto_ignore_modules is not None + and module_name in cls.auto_ignore_modules + ) + or module_name in cls.auto_imported_modules + ): + # Skip packages and ignored modules + continue + + if module_name in sys.modules: + # Avoid circular imports + cls.auto_imported_modules.append(module_name) + else: + importlib.import_module(module_name) + cls.auto_imported_modules.append(module_name) diff --git a/src/guidellm/utils/cli.py b/src/guidellm/utils/cli.py index 69cf15d3..4d83526a 100644 --- a/src/guidellm/utils/cli.py +++ b/src/guidellm/utils/cli.py @@ -35,7 +35,7 @@ def __init__(self, *types: click.ParamType): self.types = types self.name = "".join(t.name for t in types) - def convert(self, value, param, ctx): # noqa: RET503 + def convert(self, value, param, ctx): fails = [] for t in self.types: try: diff --git a/src/guidellm/utils/console.py b/src/guidellm/utils/console.py new file mode 100644 index 00000000..c8cd6825 --- /dev/null +++ b/src/guidellm/utils/console.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any, Literal + +from rich.console import Console as RichConsole +from rich.padding import Padding +from rich.status import Status +from rich.text import Text + +__all__ = [ + "Colors", + "Console", + "ConsoleUpdateStep", + "StatusIcons", + "StatusStyles", +] + + +class Colors: + # Core states + info: str = "light_steel_blue" + progress: str = "dark_slate_gray1" + success: str = "chartreuse1" + warning: str = "#FDB516" + error: str = "orange_red1" + + # Branding + primary: str = "#30A2FF" + secondary: str = "#FDB516" + tertiary: str = "#008080" + + +StatusIcons: Mapping[str, str] = { + "debug": "…", + "info": "ℹ", + "warning": "⚠", + "error": "✖", + "critical": "‼", + "notset": "⟳", + "success": "✔", +} + +StatusStyles: Mapping[str, str] = { + "debug": "dim", + "info": f"bold {Colors.info}", + "warning": f"bold {Colors.warning}", + "error": f"bold {Colors.error}", + "critical": "bold red reverse", + "notset": f"bold {Colors.progress}", + "success": f"bold {Colors.success}", +} + + +@dataclass +class ConsoleUpdateStep: + console: Console + title: str + details: Any | None = None + status_level: Literal[ + "debug", + "info", + "warning", + "error", + "critical", + "notset", + "success", + ] = "info" + spinner: str = "dots" + _status: Status | None = None + + def __enter__(self): + if self.console.quiet: + return self + + self._status = self.console.status( + f"[{StatusStyles.get(self.status_level, 'bold')}]{self.title}[/]", + spinner=self.spinner, + ) + self._status.__enter__() + return self + + def update( + self, + title: str, + status_level: Literal[ + "debug", + "info", + "warning", + "error", + "critical", + "notset", + "success", + ] + | None = None, + ): + self.title = title + if status_level is not None: + self.status_level = status_level + if self._status: + self._status.update( + status=f"[{StatusStyles.get(self.status_level, 'bold')}]{title}[/]" + ) + + def finish( + self, + title: str, + details: Any | None = None, + status_level: Literal[ + "debug", + "info", + "warning", + "error", + "critical", + "notset", + "success", + ] = "info", + ): + self.title = title + self.status_level = status_level + if self._status: + self._status.stop() + self.console.print_update(title, details, status_level) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._status: + return self._status.__exit__(exc_type, exc_val, exc_tb) + return False + + +class Console(RichConsole): + def print_update( + self, + title: str, + details: str | None = None, + status: Literal[ + "debug", + "info", + "warning", + "error", + "critical", + "notset", + "success", + ] = "info", + ) -> None: + icon = StatusIcons.get(status, "•") + style = StatusStyles.get(status, "bold") + line = Text.assemble(f"{icon} ", (title, style)) + self.print(line) + self.print_update_details(details) + + def print_update_details(self, details: Any | None): + if details: + block = Padding( + Text.from_markup(str(details)), + (0, 0, 0, 2), + style=StatusStyles.get("debug"), + ) + self.print(block) + + def print_update_step( + self, + title: str, + status: Literal[ + "debug", + "info", + "warning", + "error", + "critical", + "notset", + "success", + ] = "info", + details: Any | None = None, + spinner: str = "dots", + ) -> ConsoleUpdateStep: + return ConsoleUpdateStep( + console=self, + title=title, + details=details, + status_level=status, + spinner=spinner, + ) diff --git a/src/guidellm/utils/dict.py b/src/guidellm/utils/dict.py deleted file mode 100644 index 5b4579c9..00000000 --- a/src/guidellm/utils/dict.py +++ /dev/null @@ -1,23 +0,0 @@ -def recursive_key_update(d, key_update_func): - if not isinstance(d, dict) and not isinstance(d, list): - return d - - if isinstance(d, list): - for item in d: - recursive_key_update(item, key_update_func) - return d - - updated_key_pairs = [] - for key, _ in d.items(): - updated_key = key_update_func(key) - if key != updated_key: - updated_key_pairs.append((key, updated_key)) - - for key_pair in updated_key_pairs: - old_key, updated_key = key_pair - d[updated_key] = d[old_key] - del d[old_key] - - for _, value in d.items(): - recursive_key_update(value, key_update_func) - return d diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py new file mode 100644 index 00000000..ccd26982 --- /dev/null +++ b/src/guidellm/utils/encoding.py @@ -0,0 +1,787 @@ +""" +Message encoding utilities for multiprocess communication with Pydantic model support. + +Provides binary serialization and deserialization of Python objects using various +serialization formats and encoding packages to enable performance configurations +for distributed scheduler operations. Supports configurable two-stage processing +pipeline: object serialization (to dict/sequence) followed by binary encoding +(msgpack/msgspec) with specialized Pydantic model handling for type preservation. +""" + +from __future__ import annotations + +import json +from collections.abc import Mapping +from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar + +try: + import msgpack + from msgpack import Packer, Unpacker + + HAS_MSGPACK = True +except ImportError: + msgpack = Packer = Unpacker = None + HAS_MSGPACK = False + +try: + from msgspec.msgpack import Decoder as MsgspecDecoder + from msgspec.msgpack import Encoder as MsgspecEncoder + + HAS_MSGSPEC = True +except ImportError: + MsgspecDecoder = MsgspecEncoder = None + HAS_MSGSPEC = False + +try: + import orjson + + HAS_ORJSON = True +except ImportError: + orjson = None + HAS_ORJSON = False + +from pydantic import BaseModel +from typing_extensions import TypeAlias + +__all__ = [ + "Encoder", + "EncodingTypesAlias", + "MessageEncoding", + "MsgT", + "ObjT", + "SerializationTypesAlias", + "Serializer", +] + +ObjT = TypeVar("ObjT") +MsgT = TypeVar("MsgT") + +SerializationTypesAlias: TypeAlias = Annotated[ + Optional[Literal["dict", "sequence"]], + "Type alias for available serialization strategies", +] +EncodingTypesAlias: TypeAlias = Annotated[ + Optional[Literal["msgpack", "msgspec"]], + "Type alias for available binary encoding formats", +] + + +class MessageEncoding(Generic[ObjT, MsgT]): + """ + High-performance message encoding and decoding for multiprocessing communication. + + Supports configurable object serialization and binary encoding with specialized + handling for Pydantic models. Provides a two-stage pipeline of serialization + (object to dict/str) followed by encoding (dict/str to binary) for optimal + performance and compatibility across different transport mechanisms used in + distributed scheduler operations. + + Example: + :: + from guidellm.utils.encoding import MessageEncoding + from pydantic import BaseModel + + class DataModel(BaseModel): + name: str + value: int + + # Configure with dict serialization and msgpack encoding + encoding = MessageEncoding(serialization="dict", encoding="msgpack") + encoding.register_pydantic(DataModel) + + # Encode and decode objects + data = DataModel(name="test", value=42) + encoded_msg = encoding.encode(data) + decoded_data = encoding.decode(encoded_msg) + + :cvar DEFAULT_ENCODING_PREFERENCE: Preferred encoding formats in priority order + """ + + DEFAULT_ENCODING_PREFERENCE: ClassVar[list[str]] = ["msgspec", "msgpack"] + + @classmethod + def encode_message( + cls, + obj: ObjT, + serializer: Serializer | None, + encoder: Encoder | None, + ) -> MsgT: + """ + Encode object using specified serializer and encoder. + + :param obj: Object to encode + :param serializer: Serializer for object conversion, None for no serialization + :param encoder: Encoder for binary conversion, None for no encoding + :return: Encoded message ready for transport + """ + serialized = serializer.serialize(obj) if serializer else obj + + return encoder.encode(serialized) if encoder else serialized + + @classmethod + def decode_message( + cls, + message: MsgT, + serializer: Serializer | None, + encoder: Encoder | None, + ) -> ObjT: + """ + Decode message using specified serializer and encoder. + Must match the encoding configuration originally used. + + :param message: Encoded message to decode + :param serializer: Serializer for object reconstruction, None for no + serialization + :param encoder: Encoder for binary decoding, None for no encoding + :return: Reconstructed object + """ + serialized = encoder.decode(message) if encoder else message + + return serializer.deserialize(serialized) if serializer else serialized + + def __init__( + self, + serialization: SerializationTypesAlias = None, + encoding: EncodingTypesAlias | list[EncodingTypesAlias] = None, + pydantic_models: list[type[BaseModel]] | None = None, + ) -> None: + """ + Initialize MessageEncoding with serialization and encoding strategies. + + :param serialization: Serialization strategy (None, "dict", or "sequence") + :param encoding: Encoding strategy (None, "msgpack", "msgspec", or + preference list) + """ + self.serializer = Serializer(serialization, pydantic_models) + self.encoder = Encoder(encoding) + + def register_pydantic(self, model: type[BaseModel]) -> None: + """ + Register Pydantic model for specialized serialization handling. + + :param model: Pydantic model class to register for type preservation + """ + self.serializer.register_pydantic(model) + + def encode(self, obj: ObjT) -> MsgT: + """ + Encode object using instance configuration. + + :param obj: Object to encode using configured serialization and encoding + :return: Encoded message ready for transport + """ + return self.encode_message( + obj=obj, + serializer=self.serializer, + encoder=self.encoder, + ) + + def decode(self, message: MsgT) -> ObjT: + """ + Decode message using instance configuration. + + :param message: Encoded message to decode using configured strategies + :return: Reconstructed object + """ + return self.decode_message( + message=message, + serializer=self.serializer, + encoder=self.encoder, + ) + + +class Encoder: + """ + Binary encoding and decoding using MessagePack or msgspec formats. + + Handles binary serialization of Python objects using configurable encoding + strategies with automatic fallback when dependencies are unavailable. Supports + both standalone instances and pooled encoder/decoder pairs for performance + optimization in high-throughput scenarios. + """ + + def __init__( + self, encoding: EncodingTypesAlias | list[EncodingTypesAlias] = None + ) -> None: + """ + Initialize encoder with specified encoding strategy. + + :param encoding: Encoding format preference (None, "msgpack", "msgspec", or + preference list) + """ + self.encoding, self.encoder, self.decoder = self._resolve_encoding(encoding) + + def encode(self, obj: Any) -> bytes | Any: + """ + Encode object to binary format using configured encoding strategy. + + :param obj: Object to encode (must be serializable by chosen format) + :return: Encoded bytes or original object if no encoding configured + :raises ImportError: If required encoding library is not available + """ + if self.encoding == "msgpack": + if not HAS_MSGPACK: + raise ImportError("msgpack is not available") + + return self.encoder.pack(obj) if self.encoder else msgpack.packb(obj) + + if self.encoding == "msgspec": + if not HAS_MSGSPEC: + raise ImportError("msgspec is not available") + + return ( + self.encoder.encode(obj) + if self.encoder + else MsgspecEncoder().encode(obj) + ) + + return obj + + def decode(self, data: bytes | Any) -> Any: + """ + Decode binary data using configured encoding strategy. + + :param data: Binary data to decode or object if no encoding configured + :return: Decoded Python object + :raises ImportError: If required encoding library is not available + """ + if self.encoding == "msgpack": + if not HAS_MSGPACK: + raise ImportError("msgpack is not available") + + if self.decoder is not None: + self.decoder.feed(data) + return self.decoder.unpack() + + return msgpack.unpackb(data, raw=False) + + if self.encoding == "msgspec": + if not HAS_MSGSPEC: + raise ImportError("msgspec is not available") + + if self.decoder is not None: + return self.decoder.decode(data) + + return MsgspecDecoder().decode(data) + + return data + + def _resolve_encoding( + self, encoding: EncodingTypesAlias | list[EncodingTypesAlias] | None + ) -> tuple[EncodingTypesAlias, Any, Any]: + def _get_available_encoder_decoder( + encoding: EncodingTypesAlias, + ) -> tuple[Any, Any]: + if encoding == "msgpack" and HAS_MSGPACK: + return Packer(), Unpacker(raw=False) + if encoding == "msgspec" and HAS_MSGSPEC: + return MsgspecEncoder(), MsgspecDecoder() + return None, None + + if not isinstance(encoding, list): + if encoding is None: + return None, None, None + + encoder, decoder = _get_available_encoder_decoder(encoding) + if encoder is None or decoder is None: + raise ImportError(f"Encoding '{encoding}' is not available.") + + return encoding, encoder, decoder + + for test_encoding in encoding: + encoder, decoder = _get_available_encoder_decoder(test_encoding) + if encoder is not None and decoder is not None: + return test_encoding, encoder, decoder + + return None, None, None + + +class Serializer: + """ + Object serialization with specialized Pydantic model support. + + Converts Python objects to serializable formats (dict/sequence) with type + preservation for Pydantic models. Maintains object integrity through + encoding/decoding cycles by storing class metadata and enabling proper + reconstruction of complex objects. Supports both dictionary-based and + sequence-based serialization strategies for different use cases. + """ + + def __init__( + self, + serialization: SerializationTypesAlias = None, + pydantic_models: list[type[BaseModel]] | None = None, + ): + """ + Initialize serializer with strategy and Pydantic registry. + + :param serialization: Default serialization strategy for this instance + """ + self.serialization = serialization + self.pydantic_registry: dict[tuple[str, str], type[BaseModel]] = {} + if pydantic_models: + for model in pydantic_models: + self.register_pydantic(model) + + def register_pydantic(self, model: type[BaseModel]) -> None: + """ + Register Pydantic model for specialized serialization handling. + + :param model: Pydantic model class to register for type preservation + """ + key = (model.__module__, model.__name__) + self.pydantic_registry[key] = model + + def load_pydantic(self, type_name: str, module_name: str) -> type[BaseModel]: + """ + Load Pydantic class by name with registry fallback to dynamic import. + + :param type_name: Class name to load + :param module_name: Module containing the class + :return: Loaded Pydantic model class + """ + key = (module_name, type_name) + + if key in self.pydantic_registry: + return self.pydantic_registry[key] + + # Dynamic import fallback; need to update to better handle generics + module = __import__(module_name, fromlist=[type_name]) + pydantic_class = getattr(module, type_name) + self.pydantic_registry[key] = pydantic_class + + return pydantic_class + + def serialize(self, obj: Any) -> Any: + """ + Serialize object using specified or configured strategy. + + :param obj: Object to serialize + :return: Serialized representation (dict, str, or original object) + """ + if self.serialization == "dict": + return self.to_dict(obj) + elif self.serialization == "sequence": + return self.to_sequence(obj) + + return obj + + def deserialize(self, msg: Any) -> Any: + """ + Deserialize object using specified or configured strategy. + + :param msg: Serialized message to deserialize + :return: Reconstructed object + """ + if self.serialization == "dict": + return self.from_dict(msg) + elif self.serialization == "sequence": + return self.from_sequence(msg) + + return msg + + def to_dict(self, obj: Any) -> Any: + """ + Convert object to dictionary with Pydantic model type preservation. + + :param obj: Object to convert (BaseModel, collections, or primitive) + :return: Dictionary representation with type metadata for Pydantic models + """ + if isinstance(obj, BaseModel): + return self.to_dict_pydantic(obj) + + if isinstance(obj, (list, tuple)) and any( + isinstance(item, BaseModel) for item in obj + ): + return [ + self.to_dict_pydantic(item) if isinstance(item, BaseModel) else item + for item in obj + ] + + if isinstance(obj, dict) and any( + isinstance(value, BaseModel) for value in obj.values() + ): + return { + key: self.to_dict_pydantic(value) + if isinstance(value, BaseModel) + else value + for key, value in obj.items() + } + + return obj + + def from_dict(self, data: Any) -> Any: + """ + Reconstruct object from dictionary with Pydantic model type restoration. + + :param data: Dictionary representation possibly containing type metadata + :return: Reconstructed object with proper types restored + """ + if isinstance(data, (list, tuple)): + return [ + self.from_dict_pydantic(item) + if isinstance(item, dict) and "*PYD*" in item + else item + for item in data + ] + elif isinstance(data, dict) and data: + if "*PYD*" in data: + return self.from_dict_pydantic(data) + + return { + key: self.from_dict_pydantic(value) + if isinstance(value, dict) and "*PYD*" in value + else value + for key, value in data.items() + } + + return data + + def to_dict_pydantic(self, item: Any) -> Any: + """ + Convert item to dictionary with Pydantic type metadata. + + :param item: Item to convert (may or may not be a Pydantic model) + :return: Dictionary with type preservation metadata + """ + return { + "*PYD*": True, + "typ": item.__class__.__name__, + "mod": item.__class__.__module__, + "dat": item.model_dump(mode="python"), + } + + def from_dict_pydantic(self, item: dict[str, Any]) -> Any: + """ + Reconstruct object from dictionary with Pydantic type metadata. + + :param item: Dictionary containing type metadata and data + :return: Reconstructed Pydantic model or original data + """ + type_name = item["typ"] + module_name = item["mod"] + model_class = self.load_pydantic(type_name, module_name) + + return model_class.model_validate(item["dat"]) + + def to_sequence(self, obj: Any) -> str | Any: + """ + Convert object to sequence format with type-aware serialization. + + Handles Pydantic models, collections, and mappings with proper type + preservation through structured sequence encoding. + + :param obj: Object to serialize to sequence format + :return: Serialized sequence string or bytes + """ + if isinstance(obj, BaseModel): + payload_type = "pydantic" + payload = self.to_sequence_pydantic(obj) + elif isinstance(obj, (list, tuple)) and any( + isinstance(item, BaseModel) for item in obj + ): + payload_type = "collection_sequence" + payload = None + + for item in obj: + is_pydantic = isinstance(item, BaseModel) + payload = self.pack_next_sequence( + type_="pydantic" if is_pydantic else "python", + payload=( + self.to_sequence_pydantic(item) + if is_pydantic + else self.to_sequence_python(item) + ), + current=payload, + ) + elif isinstance(obj, Mapping) and any( + isinstance(value, BaseModel) for value in obj.values() + ): + payload_type = "collection_mapping" + keys = ",".join(str(key) for key in obj) + payload = keys.encode() + b"|" if HAS_ORJSON else keys + "|" + for item in obj.values(): + is_pydantic = isinstance(item, BaseModel) + payload = self.pack_next_sequence( + type_="pydantic" if is_pydantic else "python", + payload=( + self.to_sequence_pydantic(item) + if is_pydantic + else self.to_sequence_python(item) + ), + current=payload, + ) + else: + payload_type = "python" + payload = self.to_sequence_python(obj) + + return self.pack_next_sequence(payload_type, payload, None) + + def from_sequence(self, data: str | Any) -> Any: # noqa: C901, PLR0912 + """ + Reconstruct object from sequence format with type restoration. + + Handles deserialization of objects encoded with to_sequence, properly + restoring Pydantic models and collection structures. + + :param data: Serialized sequence data to reconstruct + :return: Reconstructed object with proper types + :raises ValueError: If sequence format is invalid or contains multiple + packed sequences + """ + type_, payload, remaining = self.unpack_next_sequence(data) + if remaining is not None: + raise ValueError("Data contains multiple packed sequences; expected one.") + + if type_ == "pydantic": + return self.from_sequence_pydantic(payload) + + if type_ == "python": + return self.from_sequence_python(payload) + + if type_ in {"collection_sequence", "collection_tuple"}: + items = [] + while payload: + type_, item_payload, payload = self.unpack_next_sequence(payload) + if type_ == "pydantic": + items.append(self.from_sequence_pydantic(item_payload)) + elif type_ == "python": + items.append(self.from_sequence_python(item_payload)) + else: + raise ValueError("Invalid type in collection sequence") + return items + + if type_ != "collection_mapping": + raise ValueError(f"Invalid type for mapping sequence: {type_}") + + if isinstance(payload, bytes): + keys_end = payload.index(b"|") + keys = payload[:keys_end].decode().split(",") + payload = payload[keys_end + 1 :] + else: + keys_end = payload.index("|") + keys = payload[:keys_end].split(",") + payload = payload[keys_end + 1 :] + + items = {} + index = 0 + while payload: + type_, item_payload, payload = self.unpack_next_sequence(payload) + if type_ == "pydantic": + items[keys[index]] = self.from_sequence_pydantic(item_payload) + elif type_ == "python": + items[keys[index]] = self.from_sequence_python(item_payload) + else: + raise ValueError("Invalid type in mapping sequence") + index += 1 + return items + + def to_sequence_pydantic(self, obj: BaseModel) -> str | bytes: + """ + Serialize Pydantic model to sequence format with class metadata. + + :param obj: Pydantic model instance to serialize + :return: Sequence string or bytes containing class info and JSON data + """ + class_name: str = obj.__class__.__name__ + class_module: str = obj.__class__.__module__ + json_data = obj.__pydantic_serializer__.to_json(obj) + + return ( + (class_name.encode() + b"|" + class_module.encode() + b"|" + json_data) + if HAS_ORJSON + else ( + class_name + "|" + class_module + "|" + json_data.decode() + if isinstance(json_data, bytes) + else json_data + ) + ) + + def from_sequence_pydantic(self, data: str | bytes) -> BaseModel: + """ + Reconstruct Pydantic model from sequence format. + + :param data: Sequence data containing class metadata and JSON + :return: Reconstructed Pydantic model instance + """ + if isinstance(data, bytes): + class_name_end = data.index(b"|") + class_name = data[:class_name_end].decode() + module_name_end = data.index(b"|", class_name_end + 1) + module_name = data[class_name_end + 1 : module_name_end].decode() + json_data = data[module_name_end + 1 :] + else: + class_name_end = data.index("|") + class_name = data[:class_name_end] + module_name_end = data.index("|", class_name_end + 1) + module_name = data[class_name_end + 1 : module_name_end] + json_data = data[module_name_end + 1 :] + + model_class = self.load_pydantic(class_name, module_name) + + return model_class.model_validate_json(json_data) + + def to_sequence_python(self, obj: Any) -> str | bytes: + """ + Serialize Python object to JSON format. + + :param obj: Python object to serialize + :return: JSON string or bytes representation + """ + return orjson.dumps(obj) if HAS_ORJSON else json.dumps(obj) + + def from_sequence_python(self, data: str | bytes) -> Any: + """ + Deserialize Python object from JSON format. + + :param data: JSON string or bytes to deserialize + :return: Reconstructed Python object + :raises ImportError: If orjson is required but not available + """ + if isinstance(data, bytes): + if not HAS_ORJSON: + raise ImportError("orjson is not available, cannot deserialize bytes") + return orjson.loads(data) + + return json.loads(data) + + def pack_next_sequence( # noqa: C901, PLR0912 + self, + type_: Literal[ + "pydantic", + "python", + "collection_tuple", + "collection_sequence", + "collection_mapping", + ], + payload: str | bytes, + current: str | bytes | None, + ) -> str | bytes: + """ + Pack payload into sequence format with type and length metadata. + + :param type_: Type identifier for the payload + :param payload: Data to pack into sequence + :param current: Current sequence data to append to (unused but maintained + for signature compatibility) + :return: Packed sequence with type, length, and payload + :raises ValueError: If payload type doesn't match current type or unknown + type specified + """ + if current is not None and type(payload) is not type(current): + raise ValueError("Payload and current must be of the same type") + + payload_len = len(payload) + + if isinstance(payload, bytes): + payload_len = payload_len.to_bytes( + length=(payload_len.bit_length() + 7) // 8 if payload_len > 0 else 1, + byteorder="big", + ) + if type_ == "pydantic": + payload_type = b"P" + elif type_ == "python": + payload_type = b"p" + elif type_ == "collection_tuple": + payload_type = b"T" + elif type_ == "collection_sequence": + payload_type = b"S" + elif type_ == "collection_mapping": + payload_type = b"M" + else: + raise ValueError(f"Unknown type for packing: {type_}") + delimiter = b"|" + else: + payload_len = str(payload_len) + if type_ == "pydantic": + payload_type = "P" + elif type_ == "python": + payload_type = "p" + elif type_ == "collection_tuple": + payload_type = "T" + elif type_ == "collection_sequence": + payload_type = "S" + elif type_ == "collection_mapping": + payload_type = "M" + else: + raise ValueError(f"Unknown type for packing: {type_}") + delimiter = "|" + + next_sequence = payload_type + delimiter + payload_len + delimiter + payload + + return current + next_sequence if current else next_sequence + + def unpack_next_sequence( # noqa: C901, PLR0912 + self, data: str | bytes + ) -> tuple[ + Literal[ + "pydantic", + "python", + "collection_tuple", + "collection_sequence", + "collection_mapping", + ], + str | bytes, + str | bytes | None, + ]: + """ + Unpack sequence format to extract type, payload, and remaining data. + + :param data: Packed sequence data to unpack + :return: Tuple of (type, payload, remaining_data) + :raises ValueError: If sequence format is invalid or unknown type character + """ + if isinstance(data, bytes): + if len(data) < len(b"T|N") or data[1:2] != b"|": + raise ValueError("Invalid packed data format") + + type_char = data[0:1] + if type_char == b"P": + type_ = "pydantic" + elif type_char == b"p": + type_ = "python" + elif type_char == b"T": + type_ = "collection_tuple" + elif type_char == b"S": + type_ = "collection_sequence" + elif type_char == b"M": + type_ = "collection_mapping" + else: + raise ValueError("Unknown type character in packed data") + + len_end = data.index(b"|", 2) + payload_len = int.from_bytes(data[2:len_end], "big") + payload = data[len_end + 1 : len_end + 1 + payload_len] + remaining = ( + data[len_end + 1 + payload_len :] + if len_end + 1 + payload_len < len(data) + else None + ) + + return type_, payload, remaining + + if len(data) < len("T|N") or data[1] != "|": + raise ValueError("Invalid packed data format") + + type_char = data[0] + if type_char == "P": + type_ = "pydantic" + elif type_char == "p": + type_ = "python" + elif type_char == "S": + type_ = "collection_sequence" + elif type_char == "M": + type_ = "collection_mapping" + else: + raise ValueError("Unknown type character in packed data") + + len_end = data.index("|", 2) + payload_len = int(data[2:len_end]) + payload = data[len_end + 1 : len_end + 1 + payload_len] + remaining = ( + data[len_end + 1 + payload_len :] + if len_end + 1 + payload_len < len(data) + else None + ) + + return type_, payload, remaining diff --git a/src/guidellm/utils/functions.py b/src/guidellm/utils/functions.py new file mode 100644 index 00000000..6343cbf2 --- /dev/null +++ b/src/guidellm/utils/functions.py @@ -0,0 +1,133 @@ +""" +Utility functions for safe operations and value handling. + +Provides defensive programming utilities for common operations that may encounter +None values, invalid inputs, or edge cases. Includes safe arithmetic operations, +attribute access, and timestamp formatting. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +__all__ = [ + "all_defined", + "safe_add", + "safe_divide", + "safe_format_timestamp", + "safe_getattr", + "safe_multiply", +] + + +def safe_getattr(obj: Any | None, attr: str, default: Any = None) -> Any: + """ + Safely get an attribute from an object with None handling. + + :param obj: Object to get the attribute from, or None + :param attr: Name of the attribute to retrieve + :param default: Value to return if object is None or attribute doesn't exist + :return: Attribute value or default if not found or object is None + """ + if obj is None: + return default + + return getattr(obj, attr, default) + + +def all_defined(*values: Any | None) -> bool: + """ + Check if all provided values are defined (not None). + + :param values: Variable number of values to check for None + :return: True if all values are not None, False otherwise + """ + return all(value is not None for value in values) + + +def safe_divide( + numerator: int | float | None, + denominator: int | float | None, + num_default: float = 0.0, + den_default: float = 1.0, +) -> float: + """ + Safely divide two numbers with None handling and zero protection. + + :param numerator: Number to divide, or None to use num_default + :param denominator: Number to divide by, or None to use den_default + :param num_default: Default value for numerator if None + :param den_default: Default value for denominator if None + :return: Division result with protection against division by zero + """ + numerator = numerator if numerator is not None else num_default + denominator = denominator if denominator is not None else den_default + + return numerator / (denominator or 1e-10) + + +def safe_multiply(*values: int | float | None, default: float = 1.0) -> float: + """ + Safely multiply multiple numbers with None handling. + + :param values: Variable number of values to multiply, None values treated as 1.0 + :param default: Starting value for multiplication + :return: Product of all non-None values multiplied by default + """ + result = default + for val in values: + result *= val if val is not None else 1.0 + return result + + +def safe_add( + *values: int | float | None, signs: list[int] | None = None, default: float = 0.0 +) -> float: + """ + Safely add multiple numbers with None handling and optional signs. + + :param values: Variable number of values to add, None values use default + :param signs: Optional list of 1 (add) or -1 (subtract) for each value. + If None, all values are added. Must match length of values. + :param default: Value to substitute for None values + :return: Result of adding all values safely (default used when value is None) + """ + if not values: + return default + + values = list(values) + + if signs is None: + signs = [1] * len(values) + + if len(signs) != len(values): + raise ValueError("Length of signs must match length of values") + + result = values[0] if values[0] is not None else default + + for ind in range(1, len(values)): + val = values[ind] if values[ind] is not None else default + result += signs[ind] * val + + return result + + +def safe_format_timestamp( + timestamp: float | None, format_: str = "%H:%M:%S", default: str = "N/A" +) -> str: + """ + Safely format a timestamp with error handling and validation. + + :param timestamp: Unix timestamp to format, or None + :param format_: Strftime format string for timestamp formatting + :param default: Value to return if timestamp is invalid or None + :return: Formatted timestamp string or default value + """ + if timestamp is None or timestamp < 0 or timestamp > 2**31: + return default + + try: + return datetime.fromtimestamp(timestamp).strftime(format_) + except (ValueError, OverflowError, OSError): + return default diff --git a/src/guidellm/utils/messaging.py b/src/guidellm/utils/messaging.py new file mode 100644 index 00000000..c56ec29a --- /dev/null +++ b/src/guidellm/utils/messaging.py @@ -0,0 +1,1029 @@ +""" +Inter-process messaging abstractions for distributed scheduler coordination. + +Provides high-level interfaces for asynchronous message passing between worker +processes using various transport mechanisms including queues and pipes. Supports +configurable encoding, serialization, error handling, and flow control with +buffering and stop event coordination for distributed scheduler operations. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import multiprocessing +import queue +import threading +import time +from abc import ABC, abstractmethod +from collections.abc import Iterable +from multiprocessing.connection import Connection +from multiprocessing.context import BaseContext +from multiprocessing.managers import SyncManager +from multiprocessing.synchronize import Event as ProcessingEvent +from threading import Event as ThreadingEvent +from typing import Any, Callable, Generic, Protocol, TypeVar + +import culsans +from pydantic import BaseModel + +from guidellm.utils.encoding import ( + EncodingTypesAlias, + MessageEncoding, + SerializationTypesAlias, +) + +__all__ = [ + "InterProcessMessaging", + "InterProcessMessagingManagerQueue", + "InterProcessMessagingPipe", + "InterProcessMessagingQueue", + "MessagingStopCallback", + "ReceiveMessageT", + "SendMessageT", +] + +SendMessageT = TypeVar("SendMessageT", bound=Any) +"""Generic type variable for messages sent through the messaging system""" +ReceiveMessageT = TypeVar("ReceiveMessageT", bound=Any) +"""Generic type variable for messages received through the messaging system""" + + +class MessagingStopCallback(Protocol): + """Protocol for evaluating stop conditions in messaging operations.""" + + def __call__( + self, messaging: InterProcessMessaging, pending: bool, queue_empty: int + ) -> bool: + """ + Evaluate whether messaging operations should stop. + + :param messaging: The messaging instance to evaluate + :param pending: Whether there are pending operations + :param queue_empty: The number of times in a row the queue has been empty + :return: True if operations should stop, False otherwise + """ + ... + + +class InterProcessMessaging(Generic[SendMessageT, ReceiveMessageT], ABC): + """ + Abstract base for inter-process messaging in distributed scheduler coordination. + + Provides unified interface for asynchronous message passing between scheduler + components using configurable transport mechanisms, encoding schemes, and + flow control policies. Manages buffering, serialization, error handling, + and coordinated shutdown across worker processes for distributed operations. + + Example: + :: + from guidellm.utils.messaging import InterProcessMessagingQueue + + messaging = InterProcessMessagingQueue( + serialization="pickle", + max_pending_size=100 + ) + + await messaging.start() + await messaging.put(request_data) + response = await messaging.get(timeout=5.0) + await messaging.stop() + """ + + STOP_REQUIRED_QUEUE_EMPTY: int = 3 + + def __init__( + self, + mp_context: BaseContext | None = None, + serialization: SerializationTypesAlias = "dict", + encoding: EncodingTypesAlias | list[EncodingTypesAlias] = None, + max_pending_size: int | None = None, + max_buffer_send_size: int | None = None, + max_done_size: int | None = None, + max_buffer_receive_size: int | None = None, + poll_interval: float = 0.1, + worker_index: int | None = None, + ): + """ + Initialize inter-process messaging coordinator. + + :param serialization: Message serialization method for transport encoding + :param encoding: Optional encoding scheme for serialized message data + :param max_pending_size: Maximum items in send queue before blocking + :param max_buffer_send_size: Maximum items in buffer send queue + :param max_done_size: Maximum items in done queue before blocking + :param max_buffer_receive_size: Maximum items in buffer receive queue + :param poll_interval: Time interval for checking queue status and events + :param worker_index: Index identifying this worker in the process group + """ + self.worker_index: int | None = worker_index + self.mp_context = mp_context or multiprocessing.get_context() + self.serialization = serialization + self.encoding = encoding + self.max_pending_size = max_pending_size + self.max_buffer_send_size = max_buffer_send_size + self.max_done_size = max_done_size + self.max_buffer_receive_size = max_buffer_receive_size + self.poll_interval = poll_interval + + self.send_stopped_event: ThreadingEvent | ProcessingEvent = None + self.receive_stopped_event: ThreadingEvent | ProcessingEvent = None + self.shutdown_event: ThreadingEvent = None + self.buffer_send_queue: culsans.Queue[SendMessageT] = None + self.buffer_receive_queue: culsans.Queue[ReceiveMessageT] = None + self.send_task: asyncio.Task = None + self.receive_task: asyncio.Task = None + self.running = False + + @abstractmethod + def create_worker_copy( + self, worker_index: int, **kwargs + ) -> InterProcessMessaging[ReceiveMessageT, SendMessageT]: + """ + Create worker-specific copy for distributed process coordination. + + :param worker_index: Index of the worker process for message routing + :return: Configured messaging instance for the specified worker + """ + ... + + @abstractmethod + def create_send_messages_threads( + self, + send_items: Iterable[Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ) -> list[tuple[Callable, tuple[Any, ...]]]: + """ + Create send message processing threads for transport implementation. + + :param send_items: Optional collection of items to send during processing + :param message_encoding: Message encoding configuration for serialization + :param check_stop: Callable for evaluating stop conditions during processing + :return: List of thread callables with their arguments for execution + """ + ... + + @abstractmethod + def create_receive_messages_threads( + self, + receive_callback: Callable[[Any], Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ) -> list[tuple[Callable, tuple[Any, ...]]]: + """ + Create receive message processing threads for transport implementation. + + :param receive_callback: Optional callback for processing received messages + :param message_encoding: Message encoding configuration for deserialization + :param check_stop: Callable for evaluating stop conditions during processing + :return: List of thread callables with their arguments for execution + """ + ... + + async def start( + self, + send_items: Iterable[Any] | None = None, + receive_callback: Callable[[Any], Any] | None = None, + send_stop_criteria: ( + list[ThreadingEvent | ProcessingEvent | MessagingStopCallback] | None + ) = None, + send_stopped_event: ThreadingEvent | ProcessingEvent | None = None, + receive_stop_criteria: ( + list[ThreadingEvent | ProcessingEvent | MessagingStopCallback] | None + ) = None, + receive_stopped_event: ThreadingEvent | ProcessingEvent | None = None, + pydantic_models: list[type[BaseModel]] | None = None, + ): + """ + Start asynchronous message processing tasks with buffering. + + :param send_items: Optional collection of items to send during processing + :param receive_callback: Optional callback for processing received messages + :param send_stop_criteria: Events and callables that trigger send task shutdown + :param send_stopped_event: Event set when send task has fully stopped + :param receive_stop_criteria: Events and callables that trigger receive shutdown + :param receive_stopped_event: Event set when receive task has fully stopped + :param pydantic_models: Optional list of Pydantic models for serialization + """ + self.running = True + self.send_stopped_event = send_stopped_event or ThreadingEvent() + self.receive_stopped_event = receive_stopped_event or ThreadingEvent() + self.shutdown_event = ThreadingEvent() + self.buffer_send_queue = culsans.Queue[SendMessageT]( + maxsize=self.max_buffer_send_size or 0 + ) + self.buffer_receive_queue = culsans.Queue[ReceiveMessageT]( + maxsize=self.max_buffer_receive_size or 0 + ) + self.tasks_lock = threading.Lock() + + message_encoding = MessageEncoding( + serialization=self.serialization, + encoding=self.encoding, + pydantic_models=pydantic_models, + ) + send_stop_criteria = send_stop_criteria or [] + receive_stop_events = receive_stop_criteria or [] + + self.send_task = asyncio.create_task( + self.send_messages_coroutine( + send_items=send_items, + message_encoding=message_encoding, + send_stop_criteria=send_stop_criteria, + ) + ) + self.receive_task = asyncio.create_task( + self.receive_messages_coroutine( + receive_callback=receive_callback, + message_encoding=message_encoding, + receive_stop_criteria=receive_stop_events, + ) + ) + + async def stop(self): + """ + Stop message processing tasks and clean up resources. + """ + self.shutdown_event.set() + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather( + self.send_task, self.receive_task, return_exceptions=True + ) + self.send_task = None + self.receive_task = None + if self.worker_index is None: + self.buffer_send_queue.clear() + await self.buffer_send_queue.aclose() + self.buffer_receive_queue.clear() + await self.buffer_receive_queue.aclose() + self.buffer_send_queue = None + self.buffer_receive_queue = None + self.send_stopped_event = None + self.receive_stopped_event = None + self.shutdown_event = None + self.running = False + + async def send_messages_coroutine( + self, + send_items: Iterable[Any] | None, + message_encoding: MessageEncoding, + send_stop_criteria: ( + list[ThreadingEvent | ProcessingEvent | MessagingStopCallback] | None + ), + ): + """ + Execute send message processing with encoding and stop condition handling. + + :param send_items: Optional collection of items to send during processing + :param message_encoding: Message encoding configuration for serialization + :param send_stop_criteria: Events and callables that trigger send task shutdown + """ + canceled_event = ThreadingEvent() + + try: + await asyncio.gather( + *[ + asyncio.to_thread(thread, *args) + for (thread, args) in self.create_send_messages_threads( + send_items=send_items, + message_encoding=message_encoding, + check_stop=self._create_check_stop_callable( + send_stop_criteria, canceled_event + ), + ) + ] + ) + except asyncio.CancelledError: + canceled_event.set() + raise + finally: + self.send_stopped_event.set() + + async def receive_messages_coroutine( + self, + receive_callback: Callable[[Any], Any] | None, + message_encoding: MessageEncoding, + receive_stop_criteria: ( + list[ThreadingEvent | ProcessingEvent | MessagingStopCallback] | None + ), + ): + """ + Execute receive message processing with decoding and callback handling. + + :param receive_callback: Optional callback for processing received messages + :param message_encoding: Message encoding configuration for deserialization + :param receive_stop_criteria: Events and callables that trigger receive shutdown + """ + canceled_event = ThreadingEvent() + + try: + await asyncio.gather( + *[ + asyncio.to_thread(thread, *args) + for thread, args in self.create_receive_messages_threads( + receive_callback=receive_callback, + message_encoding=message_encoding, + check_stop=self._create_check_stop_callable( + receive_stop_criteria, canceled_event + ), + ) + ] + ) + except asyncio.CancelledError: + canceled_event.set() + raise + finally: + self.receive_stopped_event.set() + + async def get(self, timeout: float | None = None) -> ReceiveMessageT: + """ + Retrieve message from receive buffer with optional timeout. + + :param timeout: Maximum time to wait for a message + :return: Decoded message from the receive buffer + """ + return await asyncio.wait_for( + self.buffer_receive_queue.async_get(), timeout=timeout + ) + + def get_sync(self, timeout: float | None = None) -> ReceiveMessageT: + """ + Retrieve message from receive buffer synchronously with optional timeout. + + :param timeout: Maximum time to wait for a message, if <=0 uses get_nowait + :return: Decoded message from the receive buffer + """ + if timeout is not None and timeout <= 0: + return self.buffer_receive_queue.get_nowait() + else: + return self.buffer_receive_queue.sync_get(timeout=timeout) + + async def put(self, item: SendMessageT, timeout: float | None = None): + """ + Add message to send buffer with optional timeout. + + :param item: Message item to add to the send buffer + :param timeout: Maximum time to wait for buffer space + """ + await asyncio.wait_for(self.buffer_send_queue.async_put(item), timeout=timeout) + + def put_sync(self, item: SendMessageT, timeout: float | None = None): + """ + Add message to send buffer synchronously with optional timeout. + + :param item: Message item to add to the send buffer + :param timeout: Maximum time to wait for buffer space, if <=0 uses put_nowait + """ + if timeout is not None and timeout <= 0: + self.buffer_send_queue.put_nowait(item) + else: + self.buffer_send_queue.sync_put(item, timeout=timeout) + + def _create_check_stop_callable( + self, + stop_criteria: ( + list[ThreadingEvent | ProcessingEvent | MessagingStopCallback] | None + ), + canceled_event: ThreadingEvent, + ): + stop_events = tuple( + item + for item in stop_criteria or [] + if isinstance(item, (ThreadingEvent, ProcessingEvent)) + ) + stop_callbacks = tuple(item for item in stop_criteria or [] if callable(item)) + + def check_stop(pending: bool, queue_empty: int) -> bool: + if canceled_event.is_set(): + return True + + if stop_callbacks and any( + cb(self, pending, queue_empty) for cb in stop_callbacks + ): + return True + + return ( + not pending + and queue_empty >= self.STOP_REQUIRED_QUEUE_EMPTY + and ( + self.shutdown_event.is_set() + or any(event.is_set() for event in stop_events) + ) + ) + + return check_stop + + +class InterProcessMessagingQueue(InterProcessMessaging[SendMessageT, ReceiveMessageT]): + """ + Queue-based inter-process messaging for distributed scheduler coordination. + + Provides message passing using multiprocessing.Queue objects for communication + between scheduler workers and main process. Handles message encoding, buffering, + flow control, and coordinated shutdown with configurable queue behavior and + error handling policies for distributed operations. + + Example: + :: + from guidellm.utils.messaging import InterProcessMessagingQueue + + messaging = InterProcessMessagingQueue( + serialization="pickle", + max_pending_size=100 + ) + + # Create worker copy for distributed processing + worker_messaging = messaging.create_worker_copy(worker_index=0) + """ + + def __init__( + self, + mp_context: BaseContext | None = None, + serialization: SerializationTypesAlias = "dict", + encoding: EncodingTypesAlias = None, + max_pending_size: int | None = None, + max_buffer_send_size: int | None = None, + max_done_size: int | None = None, + max_buffer_receive_size: int | None = None, + poll_interval: float = 0.1, + worker_index: int | None = None, + pending_queue: multiprocessing.Queue | None = None, + done_queue: multiprocessing.Queue | None = None, + ): + """ + Initialize queue-based messaging for inter-process communication. + + :param serialization: Message serialization method for transport encoding + :param encoding: Optional encoding scheme for serialized message data + :param max_pending_size: Maximum items in send queue before blocking + :param max_buffer_send_size: Maximum items in buffer send queue + :param max_done_size: Maximum items in receive queue before blocking + :param max_buffer_receive_size: Maximum items in buffer receive queue + :param poll_interval: Time interval for checking queue status and events + :param worker_index: Index identifying this worker in the process group + :param pending_queue: Multiprocessing queue for sending messages + :param done_queue: Multiprocessing queue for receiving completed messages + :param context: Multiprocessing context for creating queues + """ + super().__init__( + mp_context=mp_context, + serialization=serialization, + encoding=encoding, + max_pending_size=max_pending_size, + max_buffer_send_size=max_buffer_send_size, + max_done_size=max_done_size, + max_buffer_receive_size=max_buffer_receive_size, + poll_interval=poll_interval, + worker_index=worker_index, + ) + self.pending_queue = pending_queue or self.mp_context.Queue( + maxsize=max_pending_size or 0 + ) + self.done_queue = done_queue or self.mp_context.Queue( + maxsize=max_done_size or 0 + ) + + def create_worker_copy( + self, worker_index: int, **kwargs + ) -> InterProcessMessagingQueue[ReceiveMessageT, SendMessageT]: + """ + Create worker-specific copy for distributed queue-based coordination. + + :param worker_index: Index of the worker process for message routing + :return: Configured queue messaging instance for the specified worker + """ + copy_args = { + "mp_context": self.mp_context, + "serialization": self.serialization, + "encoding": self.encoding, + "max_pending_size": self.max_pending_size, + "max_buffer_send_size": self.max_buffer_send_size, + "max_done_size": self.max_done_size, + "max_buffer_receive_size": self.max_buffer_receive_size, + "poll_interval": self.poll_interval, + "worker_index": worker_index, + "pending_queue": self.pending_queue, + "done_queue": self.done_queue, + } + copy_args.update(kwargs) + + return InterProcessMessagingQueue[ReceiveMessageT, SendMessageT](**copy_args) + + async def stop(self): + """ + Stop the messaging system and wait for all tasks to complete. + """ + await super().stop() + if self.worker_index is None: + # only main process should close the queues + with contextlib.suppress(queue.Empty): + while True: + self.pending_queue.get_nowait() + self.pending_queue.close() + + with contextlib.suppress(queue.Empty): + while True: + self.done_queue.get_nowait() + self.done_queue.close() + + self.pending_queue = None + self.done_queue = None + + def create_send_messages_threads( + self, + send_items: Iterable[Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ) -> list[tuple[Callable, tuple[Any, ...]]]: + """ + Create send message processing threads for queue-based transport. + + :param send_items: Optional collection of items to send during processing + :param message_encoding: Message encoding configuration for serialization + :param check_stop: Callable for evaluating stop conditions during processing + :return: List of thread callables with their arguments for execution + """ + return [ + ( + self._send_messages_task_thread, + (send_items, message_encoding, check_stop), + ) + ] + + def create_receive_messages_threads( + self, + receive_callback: Callable[[Any], Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ) -> list[tuple[Callable, tuple[Any, ...]]]: + """ + Create receive message processing threads for queue-based transport. + + :param receive_callback: Optional callback for processing received messages + :param message_encoding: Message encoding configuration for deserialization + :param check_stop: Callable for evaluating stop conditions during processing + :return: List of thread callables with their arguments for execution + """ + return [ + ( + self._receive_messages_task_thread, + (receive_callback, message_encoding, check_stop), + ) + ] + + def _send_messages_task_thread( # noqa: C901, PLR0912 + self, + send_items: Iterable[Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ): + send_items_iter = iter(send_items) if send_items is not None else None + pending_item = None + queue_empty = 0 + + while not check_stop(pending_item is not None, queue_empty): + if pending_item is None: + try: + if send_items_iter is not None: + item = next(send_items_iter) + else: + item = self.buffer_send_queue.sync_get( + timeout=self.poll_interval + ) + pending_item = message_encoding.encode(item) + queue_empty = 0 + except (culsans.QueueEmpty, queue.Empty, StopIteration): + queue_empty += 1 + + if pending_item is not None: + try: + if self.worker_index is None: + # Main publisher + self.pending_queue.put(pending_item, timeout=self.poll_interval) + else: + # Worker + self.done_queue.put(pending_item, timeout=self.poll_interval) + if send_items_iter is None: + self.buffer_send_queue.task_done() + pending_item = None + except (culsans.QueueFull, queue.Full): + pass + + def _receive_messages_task_thread( # noqa: C901 + self, + receive_callback: Callable[[Any], Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ): + pending_item = None + received_item = None + queue_empty = 0 + + while not check_stop(pending_item is not None, queue_empty): + if pending_item is None: + try: + if self.worker_index is None: + # Main publisher + item = self.done_queue.get(timeout=self.poll_interval) + else: + # Worker + item = self.pending_queue.get(timeout=self.poll_interval) + pending_item = message_encoding.decode(item) + queue_empty = 0 + except (culsans.QueueEmpty, queue.Empty): + queue_empty += 1 + + if pending_item is not None or received_item is not None: + try: + if received_item is None: + received_item = ( + pending_item + if not receive_callback + else receive_callback(pending_item) + ) + + self.buffer_receive_queue.sync_put(received_item) + pending_item = None + received_item = None + except (culsans.QueueFull, queue.Full): + pass + + +class InterProcessMessagingManagerQueue( + InterProcessMessagingQueue[SendMessageT, ReceiveMessageT] +): + """ + Manager-based queue messaging for inter-process scheduler coordination. + + Extends queue-based messaging with multiprocessing.Manager support for + shared state coordination across worker processes. Provides managed queues + for reliable message passing in distributed scheduler environments with + enhanced process synchronization and resource management capabilities. + + Example: + :: + import multiprocessing + from guidellm.utils.messaging import InterProcessMessagingManagerQueue + + manager = multiprocessing.Manager() + messaging = InterProcessMessagingManagerQueue( + manager=manager, + serialization="pickle" + ) + """ + + def __init__( + self, + manager: SyncManager, + mp_context: BaseContext | None = None, + serialization: SerializationTypesAlias = "dict", + encoding: EncodingTypesAlias = None, + max_pending_size: int | None = None, + max_buffer_send_size: int | None = None, + max_done_size: int | None = None, + max_buffer_receive_size: int | None = None, + poll_interval: float = 0.1, + worker_index: int | None = None, + pending_queue: multiprocessing.Queue | None = None, + done_queue: multiprocessing.Queue | None = None, + ): + """ + Initialize manager-based queue messaging for inter-process communication. + + :param manager: Multiprocessing manager for shared queue creation + :param serialization: Message serialization method for transport encoding + :param encoding: Optional encoding scheme for serialized message data + :param max_pending_size: Maximum items in send queue before blocking + :param max_buffer_send_size: Maximum items in buffer send queue + :param max_done_size: Maximum items in receive queue before blocking + :param max_buffer_receive_size: Maximum items in buffer receive queue + :param poll_interval: Time interval for checking queue status and events + :param worker_index: Index identifying this worker in the process group + :param pending_queue: Managed multiprocessing queue for sending messages + :param done_queue: Managed multiprocessing queue for receiving completed + messages + """ + super().__init__( + mp_context=mp_context, + serialization=serialization, + encoding=encoding, + max_pending_size=max_pending_size, + max_buffer_send_size=max_buffer_send_size, + max_done_size=max_done_size, + max_buffer_receive_size=max_buffer_receive_size, + poll_interval=poll_interval, + worker_index=worker_index, + pending_queue=pending_queue or manager.Queue(maxsize=max_pending_size or 0), # type: ignore [assignment] + done_queue=done_queue or manager.Queue(maxsize=max_done_size or 0), # type: ignore [assignment] + ) + + def create_worker_copy( + self, worker_index: int, **kwargs + ) -> InterProcessMessagingManagerQueue[ReceiveMessageT, SendMessageT]: + """ + Create worker-specific copy for managed queue-based coordination. + + :param worker_index: Index of the worker process for message routing + :return: Configured manager queue messaging instance for the specified worker + """ + copy_args = { + "manager": None, + "mp_context": self.mp_context, + "serialization": self.serialization, + "encoding": self.encoding, + "max_pending_size": self.max_pending_size, + "max_buffer_send_size": self.max_buffer_send_size, + "max_done_size": self.max_done_size, + "max_buffer_receive_size": self.max_buffer_receive_size, + "poll_interval": self.poll_interval, + "worker_index": worker_index, + "pending_queue": self.pending_queue, + "done_queue": self.done_queue, + } + copy_args.update(kwargs) + + return InterProcessMessagingManagerQueue(**copy_args) + + async def stop(self): + """ + Stop the messaging system and wait for all tasks to complete. + """ + await InterProcessMessaging.stop(self) + self.pending_queue = None + self.done_queue = None + + +class InterProcessMessagingPipe(InterProcessMessaging[SendMessageT, ReceiveMessageT]): + """ + Pipe-based inter-process messaging for distributed scheduler coordination. + + Provides message passing using multiprocessing.Pipe objects for direct + communication between scheduler workers and main process. Offers lower + latency than queue-based messaging with duplex communication channels + for high-performance distributed operations. + + Example: + :: + from guidellm.utils.messaging import InterProcessMessagingPipe + + messaging = InterProcessMessagingPipe( + num_workers=4, + serialization="pickle", + poll_interval=0.05 + ) + + # Create worker copy for specific worker process + worker_messaging = messaging.create_worker_copy(worker_index=0) + """ + + def __init__( + self, + num_workers: int, + mp_context: BaseContext | None = None, + serialization: SerializationTypesAlias = "dict", + encoding: EncodingTypesAlias = None, + max_pending_size: int | None = None, + max_buffer_send_size: int | None = None, + max_done_size: int | None = None, + max_buffer_receive_size: int | None = None, + poll_interval: float = 0.1, + worker_index: int | None = None, + pipe: tuple[Connection, Connection] | None = None, + ): + """ + Initialize pipe-based messaging for inter-process communication. + + :param num_workers: Number of worker processes requiring pipe connections + :param serialization: Message serialization method for transport encoding + :param encoding: Optional encoding scheme for serialized message data + :param max_pending_size: Maximum items in send queue before blocking + :param max_buffer_send_size: Maximum items in buffer send queue + :param max_done_size: Maximum items in receive queue before blocking + :param max_buffer_receive_size: Maximum items in buffer receive queue + :param poll_interval: Time interval for checking queue status and events + :param worker_index: Index identifying this worker in the process group + :param pipe: Existing pipe connection for worker-specific instances + """ + super().__init__( + mp_context=mp_context, + serialization=serialization, + encoding=encoding, + max_pending_size=max_pending_size, + max_buffer_send_size=max_buffer_send_size, + max_done_size=max_done_size, + max_buffer_receive_size=max_buffer_receive_size, + poll_interval=poll_interval, + worker_index=worker_index, + ) + self.num_workers = num_workers + + if pipe is None: + self.pipes: list[tuple[Connection, Connection]] = [ + self.mp_context.Pipe(duplex=True) for _ in range(num_workers) + ] + else: + self.pipes: list[tuple[Connection, Connection]] = [pipe] + + def create_worker_copy( + self, worker_index: int, **kwargs + ) -> InterProcessMessagingPipe[ReceiveMessageT, SendMessageT]: + """ + Create worker-specific copy for pipe-based coordination. + + :param worker_index: Index of the worker process for pipe routing + :return: Configured pipe messaging instance for the specified worker + """ + copy_args = { + "num_workers": self.num_workers, + "mp_context": self.mp_context, + "serialization": self.serialization, + "encoding": self.encoding, + "max_pending_size": self.max_pending_size, + "max_buffer_send_size": self.max_buffer_send_size, + "max_done_size": self.max_done_size, + "max_buffer_receive_size": self.max_buffer_receive_size, + "poll_interval": self.poll_interval, + "worker_index": worker_index, + "pipe": self.pipes[worker_index], + } + copy_args.update(kwargs) + + return InterProcessMessagingPipe(**copy_args) + + async def stop(self): + """ + Stop the messaging system and wait for all tasks to complete. + """ + await super().stop() + if self.worker_index is None: + # Only main process should close the pipes + for main_con, worker_con in self.pipes: + main_con.close() + worker_con.close() + + def create_send_messages_threads( + self, + send_items: Iterable[Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ) -> list[tuple[Callable, tuple[Any, ...]]]: + """ + Create send message processing threads for pipe-based transport. + + :param send_items: Optional collection of items to send during processing + :param message_encoding: Message encoding configuration for serialization + :param check_stop: Callable for evaluating stop conditions during processing + :return: List of thread callables with their arguments for execution + """ + if self.worker_index is None: + # Create a separate task for each worker's pipe + return [ + ( + self._send_messages_task_thread, + (self.pipes[index], send_items, message_encoding, check_stop), + ) + for index in range(self.num_workers) + ] + else: + return [ + ( + self._send_messages_task_thread, + (self.pipes[0], send_items, message_encoding, check_stop), + ) + ] + + def create_receive_messages_threads( + self, + receive_callback: Callable[[Any], Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ) -> list[tuple[Callable, tuple[Any, ...]]]: + """ + Create receive message processing threads for pipe-based transport. + + :param receive_callback: Optional callback for processing received messages + :param message_encoding: Message encoding configuration for deserialization + :param check_stop: Callable for evaluating stop conditions during processing + :return: List of thread callables with their arguments for execution + """ + if self.worker_index is None: + # Create a separate task for each worker's pipe + return [ + ( + self._receive_messages_task_thread, + (self.pipes[index], receive_callback, message_encoding, check_stop), + ) + for index in range(self.num_workers) + ] + else: + return [ + ( + self._receive_messages_task_thread, + (self.pipes[0], receive_callback, message_encoding, check_stop), + ) + ] + + def _send_messages_task_thread( # noqa: C901, PLR0912 + self, + pipe: tuple[Connection, Connection], + send_items: Iterable[Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ): + local_stop = ThreadingEvent() + send_connection: Connection = pipe[0] if self.worker_index is None else pipe[1] + send_items_iter = iter(send_items) if send_items is not None else None + pending_item = None + queue_empty = 0 + pipe_item = None + pipe_lock = threading.Lock() + + def _background_pipe_recv(): + nonlocal pipe_item + + while not local_stop.is_set(): + try: + with pipe_lock: + pending = pipe_item + pipe_item = None + + if pending is not None: + send_connection.send(pending) + except (EOFError, ConnectionResetError): + break + + if send_items_iter is None: + threading.Thread(target=_background_pipe_recv, daemon=True).start() + + try: + while not check_stop(pending_item is not None, queue_empty): + if pending_item is None: + try: + if send_items_iter is not None: + item = next(send_items_iter) + else: + item = self.buffer_send_queue.sync_get( + timeout=self.poll_interval + ) + pending_item = message_encoding.encode(item) + queue_empty = 0 + except (culsans.QueueEmpty, queue.Empty, StopIteration): + queue_empty += 1 + + if pending_item is not None: + try: + with pipe_lock: + if pipe_item is not None: + time.sleep(self.poll_interval / 100) + raise queue.Full + else: + pipe_item = pending_item + if send_items_iter is None: + self.buffer_send_queue.task_done() + pending_item = None + except (culsans.QueueFull, queue.Full): + pass + finally: + local_stop.set() + + def _receive_messages_task_thread( # noqa: C901 + self, + pipe: tuple[Connection, Connection], + receive_callback: Callable[[Any], Any] | None, + message_encoding: MessageEncoding, + check_stop: Callable[[bool, bool], bool], + ): + receive_connection: Connection = ( + pipe[0] if self.worker_index is not None else pipe[1] + ) + pending_item = None + received_item = None + queue_empty = 0 + + while not check_stop(pending_item is not None, queue_empty): + if pending_item is None: + try: + if receive_connection.poll(self.poll_interval): + item = receive_connection.recv() + pending_item = message_encoding.decode(item) + else: + raise queue.Empty + queue_empty = 0 + except (culsans.QueueEmpty, queue.Empty): + queue_empty += 1 + + if pending_item is not None or received_item is not None: + try: + if received_item is None: + received_item = ( + pending_item + if not receive_callback + else receive_callback(pending_item) + ) + + self.buffer_receive_queue.sync_put(received_item) + pending_item = None + received_item = None + except (culsans.QueueFull, queue.Full): + pass diff --git a/src/guidellm/utils/mixins.py b/src/guidellm/utils/mixins.py new file mode 100644 index 00000000..b001ff2d --- /dev/null +++ b/src/guidellm/utils/mixins.py @@ -0,0 +1,115 @@ +""" +Mixin classes for common metadata extraction and object introspection. + +Provides reusable mixins for extracting structured metadata from objects, +enabling consistent information exposure across different class hierarchies. +""" + +from __future__ import annotations + +from typing import Any + +__all__ = ["InfoMixin"] + + +PYTHON_PRIMITIVES = (str, int, float, bool, list, tuple, dict) +"""Type alias for serialized object representations""" + + +class InfoMixin: + """ + Mixin class providing standardized metadata extraction for introspection. + + Enables consistent object metadata extraction patterns across different + class hierarchies for debugging, serialization, and runtime analysis. + Provides both instance and class-level methods for extracting structured + information from arbitrary objects with fallback handling for objects + without built-in info capabilities. + + Example: + :: + from guidellm.utils.mixins import InfoMixin + + class ConfiguredClass(InfoMixin): + def __init__(self, setting: str): + self.setting = setting + + obj = ConfiguredClass("value") + # Returns {'str': 'ConfiguredClass(...)', 'type': 'ConfiguredClass', ...} + print(obj.info) + """ + + @classmethod + def extract_from_obj(cls, obj: Any) -> dict[str, Any]: + """ + Extract structured metadata from any object. + + Attempts to use the object's own `info` method or property if available, + otherwise constructs metadata from object attributes and type information. + Provides consistent metadata format across different object types. + + :param obj: Object to extract metadata from + :return: Dictionary containing object metadata including type, class, + module, and public attributes + """ + if hasattr(obj, "info"): + return obj.info() if callable(obj.info) else obj.info + + return { + "str": str(obj), + "type": type(obj).__name__, + "class": obj.__class__.__name__ if hasattr(obj, "__class__") else None, + "module": obj.__class__.__module__ if hasattr(obj, "__class__") else None, + "attributes": ( + { + key: val if isinstance(val, PYTHON_PRIMITIVES) else repr(val) + for key, val in obj.__dict__.items() + if not key.startswith("_") + } + if hasattr(obj, "__dict__") + else {} + ), + } + + @classmethod + def create_info_dict(cls, obj: Any) -> dict[str, Any]: + """ + Create a structured info dictionary for the given object. + + Builds standardized metadata dictionary containing object identification, + type information, and accessible attributes. Used internally by other + info extraction methods and available for direct metadata construction. + + :param obj: Object to extract info from + :return: Dictionary containing structured metadata about the object + """ + return { + "str": str(obj), + "type": type(obj).__name__, + "class": obj.__class__.__name__ if hasattr(obj, "__class__") else None, + "module": obj.__class__.__module__ if hasattr(obj, "__class__") else None, + "attributes": ( + { + key: val + if isinstance(val, (str, int, float, bool, list, dict)) + else repr(val) + for key, val in obj.__dict__.items() + if not key.startswith("_") + } + if hasattr(obj, "__dict__") + else {} + ), + } + + @property + def info(self) -> dict[str, Any]: + """ + Return structured metadata about this instance. + + Provides consistent access to object metadata for debugging, serialization, + and introspection. Uses the create_info_dict method to generate standardized + metadata format including class information and public attributes. + + :return: Dictionary containing class name, module, and public attributes + """ + return self.create_info_dict(self) diff --git a/src/guidellm/utils/pydantic_utils.py b/src/guidellm/utils/pydantic_utils.py new file mode 100644 index 00000000..27c2e1cf --- /dev/null +++ b/src/guidellm/utils/pydantic_utils.py @@ -0,0 +1,401 @@ +""" +Pydantic utilities for polymorphic model serialization and registry integration. + +Provides integration between Pydantic and the registry system, enabling +polymorphic serialization and deserialization of Pydantic models using +a discriminator field and dynamic class registry. Includes base model classes +with standardized configurations and generic status breakdown models for +structured result organization. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Generic, TypeVar + +from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema +from typing_extensions import get_args, get_origin + +from guidellm.utils.registry import RegistryMixin + +__all__ = [ + "PydanticClassRegistryMixin", + "ReloadableBaseModel", + "StandardBaseDict", + "StandardBaseModel", + "StatusBreakdown", +] + + +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) +RegisterClassT = TypeVar("RegisterClassT") +SuccessfulT = TypeVar("SuccessfulT") +ErroredT = TypeVar("ErroredT") +IncompleteT = TypeVar("IncompleteT") +TotalT = TypeVar("TotalT") + + +class ReloadableBaseModel(BaseModel): + """ + Base Pydantic model with schema reloading capabilities. + + Provides dynamic schema rebuilding functionality for models that need to + update their validation schemas at runtime, particularly useful when + working with registry-based polymorphic models where new types are + registered after initial class definition. + """ + + model_config = ConfigDict( + extra="ignore", + use_enum_values=True, + from_attributes=True, + arbitrary_types_allowed=True, + ) + + @classmethod + def reload_schema(cls, parents: bool = True) -> None: + """ + Reload the class schema with updated registry information. + + Forces a complete rebuild of the Pydantic model schema to incorporate + any changes made to associated registries or validation rules. + + :param parents: Whether to also rebuild schemas for any pydantic parent + types that reference this model. + """ + cls.model_rebuild(force=True) + + if parents: + cls.reload_parent_schemas() + + @classmethod + def reload_parent_schemas(cls): + """ + Recursively reload schemas for all parent Pydantic models. + + Traverses the inheritance hierarchy to find all parent classes that + are Pydantic models and triggers schema rebuilding on each to ensure + that any changes in child models are reflected in parent schemas. + """ + potential_parents: set[type[BaseModel]] = {BaseModel} + stack: list[type[BaseModel]] = [BaseModel] + + while stack: + current = stack.pop() + for subclass in current.__subclasses__(): + if ( + issubclass(subclass, BaseModel) + and subclass is not cls + and subclass not in potential_parents + ): + potential_parents.add(subclass) + stack.append(subclass) + + for check in cls.__mro__: + if isinstance(check, type) and issubclass(check, BaseModel): + cls._reload_schemas_depending_on(check, potential_parents) + + @classmethod + def _reload_schemas_depending_on(cls, target: type[BaseModel], types: set[type]): + changed = True + while changed: + changed = False + for candidate in types: + if ( + isinstance(candidate, type) + and issubclass(candidate, BaseModel) + and any( + cls._uses_type(target, field_info.annotation) + for field_info in candidate.model_fields.values() + if field_info.annotation is not None + ) + ): + try: + before = candidate.model_json_schema() + except Exception: # noqa: BLE001 + before = None + candidate.model_rebuild(force=True) + if before is not None: + after = candidate.model_json_schema() + changed |= before != after + + @classmethod + def _uses_type(cls, target: type, candidate: type) -> bool: + if target is candidate: + return True + + origin = get_origin(candidate) + + if origin is None: + return isinstance(candidate, type) and issubclass(candidate, target) + + if isinstance(origin, type) and ( + target is origin or issubclass(origin, target) + ): + return True + + for arg in get_args(candidate) or []: + if isinstance(arg, type) and cls._uses_type(target, arg): + return True + + return False + + +class StandardBaseModel(BaseModel): + """ + Base Pydantic model with standardized configuration for GuideLLM. + + Provides consistent validation behavior and configuration settings across + all Pydantic models in the application, including field validation, + attribute conversion, and default value handling. + + Example: + :: + class MyModel(StandardBaseModel): + name: str + value: int = 42 + + # Access default values + default_value = MyModel.get_default("value") # Returns 42 + """ + + model_config = ConfigDict( + extra="ignore", + use_enum_values=True, + from_attributes=True, + ) + + @classmethod + def get_default(cls: type[BaseModel], field: str) -> Any: + """ + Get default value for a model field. + + :param field: Name of the field to get the default value for + :return: Default value of the specified field + :raises KeyError: If the field does not exist in the model + """ + return cls.model_fields[field].default + + +class StandardBaseDict(StandardBaseModel): + """ + Base Pydantic model allowing arbitrary additional fields. + + Extends StandardBaseModel to accept extra fields beyond those explicitly + defined in the model schema. Useful for flexible data structures that + need to accommodate varying or unknown field sets while maintaining + type safety for known fields. + """ + + model_config = ConfigDict( + extra="allow", + use_enum_values=True, + from_attributes=True, + arbitrary_types_allowed=True, + ) + + +class StatusBreakdown(BaseModel, Generic[SuccessfulT, ErroredT, IncompleteT, TotalT]): + """ + Generic model for organizing results by processing status. + + Provides structured categorization of results into successful, errored, + incomplete, and total status groups. Supports flexible typing for each + status category to accommodate different result types while maintaining + consistent organization patterns across the application. + + Example: + :: + from guidellm.utils import StatusBreakdown + + # Define a breakdown for request counts + breakdown = StatusBreakdown[int, int, int, int]( + successful=150, + errored=5, + incomplete=10, + total=165 + ) + """ + + successful: SuccessfulT = Field( + description="Results or metrics for requests with successful completion status", + default=None, # type: ignore[assignment] + ) + errored: ErroredT = Field( + description="Results or metrics for requests with error completion status", + default=None, # type: ignore[assignment] + ) + incomplete: IncompleteT = Field( + description="Results or metrics for requests with incomplete processing status", + default=None, # type: ignore[assignment] + ) + total: TotalT = Field( + description="Aggregated results or metrics combining all status categories", + default=None, # type: ignore[assignment] + ) + + +class PydanticClassRegistryMixin( + ReloadableBaseModel, RegistryMixin[type[BaseModelT]], ABC, Generic[BaseModelT] +): + """ + Polymorphic Pydantic model mixin enabling registry-based dynamic instantiation. + + Integrates Pydantic validation with the registry system to enable polymorphic + serialization and deserialization based on a discriminator field. Automatically + instantiates the correct subclass during validation based on registry mappings, + providing a foundation for extensible plugin-style architectures. + + Example: + :: + from speculators.utils import PydanticClassRegistryMixin + + class BaseConfig(PydanticClassRegistryMixin["BaseConfig"]): + schema_discriminator: ClassVar[str] = "config_type" + config_type: str = Field(description="Configuration type identifier") + + @classmethod + def __pydantic_schema_base_type__(cls) -> type["BaseConfig"]: + return BaseConfig + + @BaseConfig.register("database") + class DatabaseConfig(BaseConfig): + config_type: str = "database" + connection_string: str = Field(description="Database connection string") + + # Dynamic instantiation based on discriminator + config = BaseConfig.model_validate({ + "config_type": "database", + "connection_string": "postgresql://localhost:5432/db" + }) + + :cvar schema_discriminator: Field name used for polymorphic type discrimination + """ + + schema_discriminator: ClassVar[str] = "model_type" + + @classmethod + def register_decorator( + cls, clazz: RegisterClassT, name: str | list[str] | None = None + ) -> RegisterClassT: + """ + Register a Pydantic model class with type validation and schema reload. + + Validates that the class is a proper Pydantic BaseModel subclass before + registering it in the class registry. Automatically triggers schema + reload to incorporate the new type into polymorphic validation. + + :param clazz: Pydantic model class to register in the polymorphic hierarchy + :param name: Registry identifier for the class. Uses class name if None + :return: The registered class unchanged for decorator chaining + :raises TypeError: If clazz is not a Pydantic BaseModel subclass + """ + if not issubclass(clazz, BaseModel): + raise TypeError( + f"Cannot register {clazz.__name__} as it is not a subclass of " + "Pydantic BaseModel" + ) + + super().register_decorator(clazz, name=name) + cls.reload_schema() + + return clazz + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """ + Generate polymorphic validation schema for dynamic type instantiation. + + Creates a tagged union schema that enables Pydantic to automatically + instantiate the correct subclass based on the discriminator field value. + Falls back to base schema generation when no registry is available. + + :param source_type: Type being processed for schema generation + :param handler: Pydantic core schema generation handler + :return: Tagged union schema for polymorphic validation or base schema + """ + if source_type == cls.__pydantic_schema_base_type__(): + if not cls.registry: + return cls.__pydantic_generate_base_schema__(handler) + + choices = { + name: handler(model_class) for name, model_class in cls.registry.items() + } + + return core_schema.tagged_union_schema( + choices=choices, + discriminator=cls.schema_discriminator, + ) + + return handler(cls) + + @classmethod + @abstractmethod + def __pydantic_schema_base_type__(cls) -> type[BaseModelT]: + """ + Define the base type for polymorphic validation hierarchy. + + Must be implemented by subclasses to specify which type serves as the + root of the polymorphic hierarchy for schema generation and validation. + + :return: Base class type for the polymorphic model hierarchy + """ + ... + + @classmethod + def __pydantic_generate_base_schema__( + cls, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """ + Generate fallback schema for polymorphic models without registry. + + Provides a base schema that accepts any valid input when no registry + is available for polymorphic validation. Used as fallback during + schema generation when the registry has not been populated. + + :param handler: Pydantic core schema generation handler + :return: Base CoreSchema that accepts any valid input + """ + return core_schema.any_schema() + + @classmethod + def auto_populate_registry(cls) -> bool: + """ + Initialize registry with auto-discovery and reload validation schema. + + Triggers automatic population of the class registry through the parent + RegistryMixin functionality and ensures the Pydantic validation schema + is updated to include all discovered types for polymorphic validation. + + :return: True if registry was populated, False if already populated + :raises ValueError: If called when registry_auto_discovery is disabled + """ + populated = super().auto_populate_registry() + cls.reload_schema() + + return populated + + @classmethod + def registered_classes(cls) -> tuple[type[BaseModelT], ...]: + """ + Get all registered pydantic classes from the registry. + + Automatically triggers auto-discovery if registry_auto_discovery is enabled + to ensure all available implementations are included. + + :return: Tuple of all registered classes including auto-discovered ones + :raises ValueError: If called before any objects have been registered + """ + if cls.registry_auto_discovery: + cls.auto_populate_registry() + + if cls.registry is None: + raise ValueError( + "ClassRegistryMixin.registered_classes() must be called after " + "registering classes with ClassRegistryMixin.register()." + ) + + return tuple(cls.registry.values()) diff --git a/src/guidellm/utils/registry.py b/src/guidellm/utils/registry.py new file mode 100644 index 00000000..b9e3faf5 --- /dev/null +++ b/src/guidellm/utils/registry.py @@ -0,0 +1,214 @@ +""" +Registry system for dynamic object registration and discovery. + +Provides a flexible object registration system with optional auto-discovery +capabilities through decorators and module imports. Enables dynamic discovery +and instantiation of implementations based on configuration parameters, supporting +both manual registration and automatic package-based discovery for extensible +plugin architectures. +""" + +from __future__ import annotations + +from typing import Callable, ClassVar, Generic, TypeVar, cast + +from guidellm.utils.auto_importer import AutoImporterMixin + +__all__ = ["RegisterT", "RegistryMixin", "RegistryObjT"] + + +RegistryObjT = TypeVar("RegistryObjT") +"""Generic type variable for objects managed by the registry system.""" +RegisterT = TypeVar("RegisterT") +"""Generic type variable for the args and return values within the registry.""" + + +class RegistryMixin(Generic[RegistryObjT], AutoImporterMixin): + """ + Generic mixin for creating object registries with optional auto-discovery. + + Enables classes to maintain separate registries of objects that can be dynamically + discovered and instantiated through decorators and module imports. Supports both + manual registration via decorators and automatic discovery through package scanning + for extensible plugin architectures. + + Example: + :: + class BaseAlgorithm(RegistryMixin): + pass + + @BaseAlgorithm.register() + class ConcreteAlgorithm(BaseAlgorithm): + pass + + @BaseAlgorithm.register("custom_name") + class AnotherAlgorithm(BaseAlgorithm): + pass + + # Get all registered implementations + algorithms = BaseAlgorithm.registered_objects() + + Example with auto-discovery: + :: + class TokenProposal(RegistryMixin): + registry_auto_discovery = True + auto_package = "mypackage.proposals" + + # Automatically imports and registers decorated objects + proposals = TokenProposal.registered_objects() + + :cvar registry: Dictionary mapping names to registered objects + :cvar registry_auto_discovery: Enable automatic package-based discovery + :cvar registry_populated: Track whether auto-discovery has completed + """ + + registry: ClassVar[dict[str, RegistryObjT] | None] = None + registry_auto_discovery: ClassVar[bool] = False + registry_populated: ClassVar[bool] = False + + @classmethod + def register( + cls, name: str | list[str] | None = None + ) -> Callable[[RegisterT], RegisterT]: + """ + Decorator for registering objects with the registry. + + :param name: Optional name(s) to register the object under. + If None, uses the object's __name__ attribute + :return: Decorator function that registers the decorated object + :raises ValueError: If name is not a string, list of strings, or None + """ + + def _decorator(obj: RegisterT) -> RegisterT: + cls.register_decorator(obj, name=name) + return obj + + return _decorator + + @classmethod + def register_decorator( + cls, obj: RegisterT, name: str | list[str] | None = None + ) -> RegisterT: + """ + Register an object directly with the registry. + + :param obj: The object to register + :param name: Optional name(s) to register the object under. + If None, uses the object's __name__ attribute + :return: The registered object + :raises ValueError: If the object is already registered or name is invalid + """ + + if name is None: + name = obj.__name__ + elif not isinstance(name, (str, list)): + raise ValueError( + "RegistryMixin.register_decorator name must be a string or " + f"an iterable of strings. Got {name}." + ) + + if cls.registry is None: + cls.registry = {} + + names = [name] if isinstance(name, str) else list(name) + + for register_name in names: + if not isinstance(register_name, str): + raise ValueError( + "RegistryMixin.register_decorator name must be a string or " + f"a list of strings. Got {register_name}." + ) + + if register_name in cls.registry: + raise ValueError( + f"RegistryMixin.register_decorator cannot register an object " + f"{obj} with the name {register_name} because it is already " + "registered." + ) + + cls.registry[register_name] = cast("RegistryObjT", obj) + + return obj + + @classmethod + def auto_populate_registry(cls) -> bool: + """ + Import and register all modules from the auto_package. + + Automatically called by registered_objects when registry_auto_discovery is True + to ensure all available implementations are discovered. + + :return: True if registry was populated, False if already populated + :raises ValueError: If called when registry_auto_discovery is False + """ + if not cls.registry_auto_discovery: + raise ValueError( + "RegistryMixin.auto_populate_registry() cannot be called " + "because registry_auto_discovery is set to False. " + "Set registry_auto_discovery to True to enable auto-discovery." + ) + + if cls.registry_populated: + return False + + cls.auto_import_package_modules() + cls.registry_populated = True + + return True + + @classmethod + def registered_objects(cls) -> tuple[RegistryObjT, ...]: + """ + Get all registered objects from the registry. + + Automatically triggers auto-discovery if registry_auto_discovery is enabled + to ensure all available implementations are included. + + :return: Tuple of all registered objects including auto-discovered ones + :raises ValueError: If called before any objects have been registered + """ + if cls.registry_auto_discovery: + cls.auto_populate_registry() + + if cls.registry is None: + raise ValueError( + "RegistryMixin.registered_objects() must be called after " + "registering objects with RegistryMixin.register()." + ) + + return tuple(cls.registry.values()) + + @classmethod + def is_registered(cls, name: str) -> bool: + """ + Check if an object is registered under the given name. + It matches first by exact name, then by str.lower(). + + :param name: The name to check for registration. + :return: True if the object is registered, False otherwise. + """ + if cls.registry is None: + return False + + return name in cls.registry or name.lower() in [ + key.lower() for key in cls.registry + ] + + @classmethod + def get_registered_object(cls, name: str) -> RegistryObjT | None: + """ + Get a registered object by its name. It matches first by exact name, + then by str.lower(). + + :param name: The name of the registered object. + :return: The registered object if found, None otherwise. + """ + if cls.registry is None: + return None + + if name in cls.registry: + return cls.registry[name] + + lower_key_map = {key.lower(): key for key in cls.registry} + + return cls.registry.get(lower_key_map.get(name.lower())) diff --git a/src/guidellm/utils/singleton.py b/src/guidellm/utils/singleton.py new file mode 100644 index 00000000..3ec10f79 --- /dev/null +++ b/src/guidellm/utils/singleton.py @@ -0,0 +1,130 @@ +""" +Singleton pattern implementations for ensuring single instance classes. + +Provides singleton mixins for creating classes that maintain a single instance +throughout the application lifecycle, with support for both basic and thread-safe +implementations. These mixins integrate with the scheduler and other system components +to ensure consistent state management and prevent duplicate resource allocation. +""" + +from __future__ import annotations + +import threading + +__all__ = ["SingletonMixin", "ThreadSafeSingletonMixin"] + + +class SingletonMixin: + """ + Basic singleton mixin ensuring single instance per class. + + Implements the singleton pattern using class variables to control instance + creation. Subclasses must call super().__init__() for proper initialization + state management. Suitable for single-threaded environments or when external + synchronization is provided. + + Example: + :: + class ConfigManager(SingletonMixin): + def __init__(self, config_path: str): + super().__init__() + if not self.initialized: + self.config = load_config(config_path) + + manager1 = ConfigManager("config.json") + manager2 = ConfigManager("config.json") + assert manager1 is manager2 + """ + + def __new__(cls, *args, **kwargs): # noqa: ARG004 + """ + Create or return the singleton instance. + + :param args: Positional arguments passed to the constructor + :param kwargs: Keyword arguments passed to the constructor + :return: The singleton instance of the class + """ + # Use class-specific attribute name to avoid inheritance issues + attr_name = f"_singleton_instance_{cls.__name__}" + + if not hasattr(cls, attr_name) or getattr(cls, attr_name) is None: + instance = super().__new__(cls) + setattr(cls, attr_name, instance) + instance._singleton_initialized = False + return getattr(cls, attr_name) + + def __init__(self): + """Initialize the singleton instance exactly once.""" + if hasattr(self, "_singleton_initialized") and self._singleton_initialized: + return + self._singleton_initialized = True + + @property + def initialized(self): + """Return True if the singleton has been initialized.""" + return getattr(self, "_singleton_initialized", False) + + +class ThreadSafeSingletonMixin(SingletonMixin): + """ + Thread-safe singleton mixin with locking mechanisms. + + Extends SingletonMixin with thread safety using locks to prevent race + conditions during instance creation in multi-threaded environments. Essential + for scheduler components and other shared resources accessed concurrently. + + Example: + :: + class SchedulerResource(ThreadSafeSingletonMixin): + def __init__(self): + super().__init__() + if not self.initialized: + self.resource_pool = initialize_resources() + """ + + def __new__(cls, *args, **kwargs): # noqa: ARG004 + """ + Create or return the singleton instance with thread safety. + + :param args: Positional arguments passed to the constructor + :param kwargs: Keyword arguments passed to the constructor + :return: The singleton instance of the class + """ + # Use class-specific lock and instance names to avoid inheritance issues + lock_attr_name = f"_singleton_lock_{cls.__name__}" + instance_attr_name = f"_singleton_instance_{cls.__name__}" + + with getattr(cls, lock_attr_name): + instance_exists = ( + hasattr(cls, instance_attr_name) + and getattr(cls, instance_attr_name) is not None + ) + if not instance_exists: + instance = super(SingletonMixin, cls).__new__(cls) + setattr(cls, instance_attr_name, instance) + instance._singleton_initialized = False + instance._init_lock = threading.Lock() + return getattr(cls, instance_attr_name) + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + lock_attr_name = f"_singleton_lock_{cls.__name__}" + setattr(cls, lock_attr_name, threading.Lock()) + + def __init__(self): + """Initialize the singleton instance with thread-safe initialization.""" + with self._init_lock: + if hasattr(self, "_singleton_initialized") and self._singleton_initialized: + return + self._singleton_initialized = True + + @property + def thread_lock(self): + """Return the thread lock for this singleton instance.""" + return getattr(self, "_init_lock", None) + + @classmethod + def get_singleton_lock(cls): + """Get the class-specific singleton creation lock.""" + lock_attr_name = f"_singleton_lock_{cls.__name__}" + return getattr(cls, lock_attr_name, None) diff --git a/src/guidellm/utils/statistics.py b/src/guidellm/utils/statistics.py new file mode 100644 index 00000000..c820de9d --- /dev/null +++ b/src/guidellm/utils/statistics.py @@ -0,0 +1,990 @@ +""" +Statistical analysis utilities for distribution calculations and running metrics. + +Provides comprehensive statistical computation tools for analyzing numerical +distributions, percentiles, and streaming data. Includes specialized support for +request timing analysis, concurrency measurement, and rate calculations. Integrates +with Pydantic for serializable statistical models and supports both weighted and +unweighted distributions with cumulative distribution function (CDF) generation. +""" + +from __future__ import annotations + +import math +import time as timer +from collections import defaultdict +from typing import Any, Literal + +import numpy as np +from pydantic import Field, computed_field + +from guidellm.utils.pydantic_utils import StandardBaseModel, StatusBreakdown + +__all__ = [ + "DistributionSummary", + "Percentiles", + "RunningStats", + "StatusDistributionSummary", + "TimeRunningStats", +] + + +class Percentiles(StandardBaseModel): + """ + Standard percentiles model for statistical distribution analysis. + + Provides complete percentile coverage from 0.1th to 99.9th percentiles for + statistical distribution characterization. Used as a component within + DistributionSummary to provide detailed distribution shape analysis. + """ + + p001: float = Field( + description="The 0.1th percentile of the distribution.", + ) + p01: float = Field( + description="The 1st percentile of the distribution.", + ) + p05: float = Field( + description="The 5th percentile of the distribution.", + ) + p10: float = Field( + description="The 10th percentile of the distribution.", + ) + p25: float = Field( + description="The 25th percentile of the distribution.", + ) + p50: float = Field( + description="The 50th percentile of the distribution.", + ) + p75: float = Field( + description="The 75th percentile of the distribution.", + ) + p90: float = Field( + description="The 90th percentile of the distribution.", + ) + p95: float = Field( + description="The 95th percentile of the distribution.", + ) + p99: float = Field( + description="The 99th percentile of the distribution.", + ) + p999: float = Field( + description="The 99.9th percentile of the distribution.", + ) + + +class DistributionSummary(StandardBaseModel): + """ + Comprehensive statistical summary for numerical value distributions. + + Calculates and stores complete statistical metrics including central tendency, + dispersion, extremes, and percentiles for any numerical distribution. Supports + both weighted and unweighted data with optional cumulative distribution function + generation. Primary statistical analysis tool for request timing, performance + metrics, and benchmark result characterization. + + Example: + :: + # Create from simple values + summary = DistributionSummary.from_values([1.0, 2.0, 3.0, 4.0, 5.0]) + print(f"Mean: {summary.mean}, P95: {summary.percentiles.p95}") + + # Create from request timings for concurrency analysis + requests = [(0.0, 1.0), (0.5, 2.0), (1.0, 2.5)] + concurrency = DistributionSummary.from_request_times( + requests, "concurrency" + ) + """ + + mean: float = Field( + description="The mean/average of the distribution.", + ) + median: float = Field( + description="The median of the distribution.", + ) + mode: float = Field( + description="The mode of the distribution.", + ) + variance: float = Field( + description="The variance of the distribution.", + ) + std_dev: float = Field( + description="The standard deviation of the distribution.", + ) + min: float = Field( + description="The minimum value of the distribution.", + ) + max: float = Field( + description="The maximum value of the distribution.", + ) + count: int = Field( + description="The number of values in the distribution.", + ) + total_sum: float = Field( + description="The total sum of the values in the distribution.", + ) + percentiles: Percentiles = Field( + description="The percentiles of the distribution.", + ) + cumulative_distribution_function: list[tuple[float, float]] | None = Field( + description="The cumulative distribution function (CDF) of the distribution.", + default=None, + ) + + @staticmethod + def from_distribution_function( + distribution: list[tuple[float, float]], + include_cdf: bool = False, + ) -> DistributionSummary: + """ + Create statistical summary from weighted distribution or probability function. + + Converts weighted numerical values or probability distribution function (PDF) + into comprehensive statistical summary. Normalizes weights to probabilities + and calculates all statistical metrics including percentiles. + + :param distribution: List of (value, weight) or (value, probability) tuples + representing the distribution + :param include_cdf: Whether to include cumulative distribution function + in the output + :return: DistributionSummary instance with calculated statistical metrics + """ + values, weights = zip(*distribution) if distribution else ([], []) + values = np.array(values) # type: ignore[assignment] + weights = np.array(weights) # type: ignore[assignment] + + # create the PDF + probabilities = weights / np.sum(weights) # type: ignore[operator] + pdf = np.column_stack((values, probabilities)) + pdf = pdf[np.argsort(pdf[:, 0])] + values = pdf[:, 0] # type: ignore[assignment] + probabilities = pdf[:, 1] + + # calculate the CDF + cumulative_probabilities = np.cumsum(probabilities) + cdf = np.column_stack((values, cumulative_probabilities)) + + # calculate statistics + mean = np.sum(values * probabilities).item() # type: ignore[attr-defined] + median = cdf[np.argmax(cdf[:, 1] >= 0.5), 0].item() if len(cdf) > 0 else 0 # noqa: PLR2004 + mode = values[np.argmax(probabilities)].item() if len(values) > 0 else 0 # type: ignore[call-overload] + variance = np.sum((values - mean) ** 2 * probabilities).item() # type: ignore[attr-defined] + std_dev = math.sqrt(variance) + minimum = values[0].item() if len(values) > 0 else 0 + maximum = values[-1].item() if len(values) > 0 else 0 + count = len(values) + total_sum = np.sum(values).item() # type: ignore[attr-defined] + + return DistributionSummary( + mean=mean, + median=median, + mode=mode, + variance=variance, + std_dev=std_dev, + min=minimum, + max=maximum, + count=count, + total_sum=total_sum, + percentiles=( + Percentiles( + p001=cdf[np.argmax(cdf[:, 1] >= 0.001), 0].item(), # noqa: PLR2004 + p01=cdf[np.argmax(cdf[:, 1] >= 0.01), 0].item(), # noqa: PLR2004 + p05=cdf[np.argmax(cdf[:, 1] >= 0.05), 0].item(), # noqa: PLR2004 + p10=cdf[np.argmax(cdf[:, 1] >= 0.1), 0].item(), # noqa: PLR2004 + p25=cdf[np.argmax(cdf[:, 1] >= 0.25), 0].item(), # noqa: PLR2004 + p50=cdf[np.argmax(cdf[:, 1] >= 0.50), 0].item(), # noqa: PLR2004 + p75=cdf[np.argmax(cdf[:, 1] >= 0.75), 0].item(), # noqa: PLR2004 + p90=cdf[np.argmax(cdf[:, 1] >= 0.9), 0].item(), # noqa: PLR2004 + p95=cdf[np.argmax(cdf[:, 1] >= 0.95), 0].item(), # noqa: PLR2004 + p99=cdf[np.argmax(cdf[:, 1] >= 0.99), 0].item(), # noqa: PLR2004 + p999=cdf[np.argmax(cdf[:, 1] >= 0.999), 0].item(), # noqa: PLR2004 + ) + if len(cdf) > 0 + else Percentiles( + p001=0, + p01=0, + p05=0, + p10=0, + p25=0, + p50=0, + p75=0, + p90=0, + p95=0, + p99=0, + p999=0, + ) + ), + cumulative_distribution_function=cdf.tolist() if include_cdf else None, + ) + + @staticmethod + def from_values( + values: list[float], + weights: list[float] | None = None, + include_cdf: bool = False, + ) -> DistributionSummary: + """ + Create statistical summary from numerical values with optional weights. + + Wrapper around from_distribution_function for simple value lists. If weights + are not provided, all values are equally weighted. Enables statistical + analysis of any numerical dataset. + + :param values: Numerical values representing the distribution + :param weights: Optional weights for each value. If not provided, all values + are equally weighted + :param include_cdf: Whether to include cumulative distribution function in + the output DistributionSummary + :return: DistributionSummary instance with calculated statistical metrics + :raises ValueError: If values and weights lists have different lengths + """ + if weights is None: + weights = [1.0] * len(values) + + if len(values) != len(weights): + raise ValueError( + "The length of values and weights must be the same.", + ) + + return DistributionSummary.from_distribution_function( + distribution=list(zip(values, weights)), + include_cdf=include_cdf, + ) + + @staticmethod + def from_request_times( + requests: list[tuple[float, float]], + distribution_type: Literal["concurrency", "rate"], + include_cdf: bool = False, + epsilon: float = 1e-6, + ) -> DistributionSummary: + """ + Create statistical summary from request timing data. + + Analyzes request start/end times to calculate concurrency or rate + distributions. Converts timing events into statistical metrics for + performance analysis and load characterization. + + :param requests: List of (start_time, end_time) tuples for each request + :param distribution_type: Type of analysis - "concurrency" for simultaneous + requests or "rate" for completion rates + :param include_cdf: Whether to include cumulative distribution function + :param epsilon: Threshold for merging close timing events + :return: DistributionSummary with timing-based statistical metrics + :raises ValueError: If distribution_type is not "concurrency" or "rate" + """ + if distribution_type == "concurrency": + # convert to delta changes based on when requests were running + time_deltas: dict[float, int] = defaultdict(int) + for start, end in requests: + time_deltas[start] += 1 + time_deltas[end] -= 1 + + # convert to the events over time measuring concurrency changes + events = [] + active = 0 + + for time, delta in sorted(time_deltas.items()): + active += delta + events.append((time, active)) + elif distribution_type == "rate": + # convert to events for when requests finished + global_start = min(start for start, _ in requests) if requests else 0 + events = [(global_start, 1)] + [(end, 1) for _, end in requests] + else: + raise ValueError( + f"Invalid distribution_type '{distribution_type}'. " + "Must be 'concurrency' or 'rate'." + ) + + # combine any events that are very close together + flattened_events: list[tuple[float, float]] = [] + for time, val in sorted(events): + last_time, last_val = ( + flattened_events[-1] if flattened_events else (None, None) + ) + + if ( + last_time is not None + and last_val is not None + and abs(last_time - time) <= epsilon + ): + flattened_events[-1] = (last_time, last_val + val) + else: + flattened_events.append((time, val)) + + # convert to value distribution function + distribution: dict[float, float] = defaultdict(float) + + for ind in range(len(flattened_events) - 1): + start_time, value = flattened_events[ind] + end_time, _ = flattened_events[ind + 1] + duration = end_time - start_time + + if distribution_type == "concurrency": + # weight the concurrency value by the duration + distribution[value] += duration + elif distribution_type == "rate": + # weight the rate value by the duration + rate = value / duration + distribution[rate] += duration + + distribution_list: list[tuple[float, float]] = sorted(distribution.items()) + + return DistributionSummary.from_distribution_function( + distribution=distribution_list, + include_cdf=include_cdf, + ) + + @staticmethod + def from_iterable_request_times( + requests: list[tuple[float, float]], + first_iter_times: list[float], + iter_counts: list[int], + first_iter_counts: list[int] | None = None, + include_cdf: bool = False, + epsilon: float = 1e-6, + ) -> DistributionSummary: + """ + Create statistical summary from iterative request timing data. + + Analyzes autoregressive or streaming requests with multiple iterations + between start and end times. Calculates rate distributions based on + iteration timing patterns for LLM token generation analysis. + + :param requests: List of (start_time, end_time) tuples for each request + :param first_iter_times: Times when first iteration was received for + each request + :param iter_counts: Total iteration counts for each request from first + iteration to end + :param first_iter_counts: Iteration counts for first iteration (defaults + to 1 for each request) + :param include_cdf: Whether to include cumulative distribution function + :param epsilon: Threshold for merging close timing events + :return: DistributionSummary with iteration rate statistical metrics + :raises ValueError: If input lists have mismatched lengths + """ + + if first_iter_counts is None: + first_iter_counts = [1] * len(requests) + + if ( + len(requests) != len(first_iter_times) + or len(requests) != len(iter_counts) + or len(requests) != len(first_iter_counts) + ): + raise ValueError( + "requests, first_iter_times, iter_counts, and first_iter_counts must" + "be the same length." + f"Given {len(requests)}, {len(first_iter_times)}, {len(iter_counts)}, " + f"{len(first_iter_counts)}", + ) + + # first break up the requests into individual iterable events + events = defaultdict(int) + global_start = min(start for start, _ in requests) if requests else 0 + global_end = max(end for _, end in requests) if requests else 0 + events[global_start] = 0 + events[global_end] = 0 + + for (_, end), first_iter, first_iter_count, total_count in zip( + requests, first_iter_times, first_iter_counts, iter_counts + ): + events[first_iter] += first_iter_count + + if total_count > 1: + iter_latency = (end - first_iter) / (total_count - 1) + for ind in range(1, total_count): + events[first_iter + ind * iter_latency] += 1 + + # combine any events that are very close together + flattened_events: list[tuple[float, int]] = [] + + for time, count in sorted(events.items()): + last_time, last_count = ( + flattened_events[-1] if flattened_events else (None, None) + ) + + if ( + last_time is not None + and last_count is not None + and abs(last_time - time) <= epsilon + ): + flattened_events[-1] = (last_time, last_count + count) + else: + flattened_events.append((time, count)) + + # convert to value distribution function + distribution: dict[float, float] = defaultdict(float) + + for ind in range(len(flattened_events) - 1): + start_time, count = flattened_events[ind] + end_time, _ = flattened_events[ind + 1] + duration = end_time - start_time + rate = count / duration + distribution[rate] += duration + + distribution_list = sorted(distribution.items()) + + return DistributionSummary.from_distribution_function( + distribution=distribution_list, + include_cdf=include_cdf, + ) + + +class StatusDistributionSummary( + StatusBreakdown[ + DistributionSummary, + DistributionSummary, + DistributionSummary, + DistributionSummary, + ] +): + """ + Status-grouped statistical summary for request processing analysis. + + Provides comprehensive statistical analysis grouped by request status (total, + successful, incomplete, errored). Enables performance analysis across different + request outcomes for benchmarking and monitoring applications. Each status + category maintains complete DistributionSummary metrics. + + Example: + :: + status_summary = StatusDistributionSummary.from_values( + value_types=["successful", "error", "successful"], + values=[1.5, 10.0, 2.1] + ) + print(f"Success mean: {status_summary.successful.mean}") + print(f"Error rate: {status_summary.errored.count}") + """ + + @staticmethod + def from_values( + value_types: list[Literal["successful", "incomplete", "error"]], + values: list[float], + weights: list[float] | None = None, + include_cdf: bool = False, + ) -> StatusDistributionSummary: + """ + Create status-grouped statistical summary from values and status types. + + Groups numerical values by request status and calculates complete + statistical summaries for each category. Enables performance analysis + across different request outcomes. + + :param value_types: Status type for each value ("successful", "incomplete", + or "error") + :param values: Numerical values representing the distribution + :param weights: Optional weights for each value (defaults to equal weighting) + :param include_cdf: Whether to include cumulative distribution functions + :return: StatusDistributionSummary with statistics grouped by status + :raises ValueError: If input lists have mismatched lengths or invalid + status types + """ + if any( + type_ not in {"successful", "incomplete", "error"} for type_ in value_types + ): + raise ValueError( + "value_types must be one of 'successful', 'incomplete', or 'error'. " + f"Got {value_types} instead.", + ) + + if weights is None: + weights = [1.0] * len(values) + + if len(value_types) != len(values) or len(value_types) != len(weights): + raise ValueError( + "The length of value_types, values, and weights must be the same.", + ) + + _, successful_values, successful_weights = ( + zip(*successful) + if ( + successful := list( + filter( + lambda val: val[0] == "successful", + zip(value_types, values, weights), + ) + ) + ) + else ([], [], []) + ) + _, incomplete_values, incomplete_weights = ( + zip(*incomplete) + if ( + incomplete := list( + filter( + lambda val: val[0] == "incomplete", + zip(value_types, values, weights), + ) + ) + ) + else ([], [], []) + ) + _, errored_values, errored_weights = ( + zip(*errored) + if ( + errored := list( + filter( + lambda val: val[0] == "error", + zip(value_types, values, weights), + ) + ) + ) + else ([], [], []) + ) + + return StatusDistributionSummary( + total=DistributionSummary.from_values( + values, + weights, + include_cdf=include_cdf, + ), + successful=DistributionSummary.from_values( + successful_values, # type: ignore[arg-type] + successful_weights, # type: ignore[arg-type] + include_cdf=include_cdf, + ), + incomplete=DistributionSummary.from_values( + incomplete_values, # type: ignore[arg-type] + incomplete_weights, # type: ignore[arg-type] + include_cdf=include_cdf, + ), + errored=DistributionSummary.from_values( + errored_values, # type: ignore[arg-type] + errored_weights, # type: ignore[arg-type] + include_cdf=include_cdf, + ), + ) + + @staticmethod + def from_request_times( + request_types: list[Literal["successful", "incomplete", "error"]], + requests: list[tuple[float, float]], + distribution_type: Literal["concurrency", "rate"], + include_cdf: bool = False, + epsilon: float = 1e-6, + ) -> StatusDistributionSummary: + """ + Create status-grouped statistical summary from request timing data. + + Analyzes request timings grouped by status to calculate concurrency or + rate distributions for each outcome category. Enables comparative + performance analysis across successful, incomplete, and errored requests. + + :param request_types: Status type for each request ("successful", + "incomplete", or "error") + :param requests: List of (start_time, end_time) tuples for each request + :param distribution_type: Analysis type - "concurrency" or "rate" + :param include_cdf: Whether to include cumulative distribution functions + :param epsilon: Threshold for merging close timing events + :return: StatusDistributionSummary with timing statistics by status + :raises ValueError: If input lists have mismatched lengths or invalid types + """ + if distribution_type not in {"concurrency", "rate"}: + raise ValueError( + f"Invalid distribution_type '{distribution_type}'. " + "Must be 'concurrency' or 'rate'." + ) + + if any( + type_ not in {"successful", "incomplete", "error"} + for type_ in request_types + ): + raise ValueError( + "request_types must be one of 'successful', 'incomplete', or 'error'. " + f"Got {request_types} instead.", + ) + + if len(request_types) != len(requests): + raise ValueError( + "The length of request_types and requests must be the same. " + f"Got {len(request_types)} and {len(requests)} instead.", + ) + + _, successful_requests = ( + zip(*successful) + if ( + successful := list( + filter( + lambda val: val[0] == "successful", + zip(request_types, requests), + ) + ) + ) + else ([], []) + ) + _, incomplete_requests = ( + zip(*incomplete) + if ( + incomplete := list( + filter( + lambda val: val[0] == "incomplete", + zip(request_types, requests), + ) + ) + ) + else ([], []) + ) + _, errored_requests = ( + zip(*errored) + if ( + errored := list( + filter( + lambda val: val[0] == "error", + zip(request_types, requests), + ) + ) + ) + else ([], []) + ) + + return StatusDistributionSummary( + total=DistributionSummary.from_request_times( + requests, + distribution_type=distribution_type, + include_cdf=include_cdf, + epsilon=epsilon, + ), + successful=DistributionSummary.from_request_times( + successful_requests, # type: ignore[arg-type] + distribution_type=distribution_type, + include_cdf=include_cdf, + epsilon=epsilon, + ), + incomplete=DistributionSummary.from_request_times( + incomplete_requests, # type: ignore[arg-type] + distribution_type=distribution_type, + include_cdf=include_cdf, + epsilon=epsilon, + ), + errored=DistributionSummary.from_request_times( + errored_requests, # type: ignore[arg-type] + distribution_type=distribution_type, + include_cdf=include_cdf, + epsilon=epsilon, + ), + ) + + @staticmethod + def from_iterable_request_times( + request_types: list[Literal["successful", "incomplete", "error"]], + requests: list[tuple[float, float]], + first_iter_times: list[float], + iter_counts: list[int] | None = None, + first_iter_counts: list[int] | None = None, + include_cdf: bool = False, + epsilon: float = 1e-6, + ) -> StatusDistributionSummary: + """ + Create status-grouped statistical summary from iterative request timing data. + + Analyzes autoregressive request timings grouped by status to calculate + iteration rate distributions for each outcome category. Enables comparative + analysis of token generation or streaming response performance across + different request statuses. + + :param request_types: Status type for each request ("successful", + "incomplete", or "error") + :param requests: List of (start_time, end_time) tuples for each request + :param first_iter_times: Times when first iteration was received for + each request + :param iter_counts: Total iteration counts for each request (defaults to 1) + :param first_iter_counts: Iteration counts for first iteration (defaults + to 1) + :param include_cdf: Whether to include cumulative distribution functions + :param epsilon: Threshold for merging close timing events + :return: StatusDistributionSummary with iteration statistics by status + :raises ValueError: If input lists have mismatched lengths or invalid types + """ + if any( + type_ not in {"successful", "incomplete", "error"} + for type_ in request_types + ): + raise ValueError( + "request_types must be one of 'successful', 'incomplete', or 'error'. " + f"Got {request_types} instead.", + ) + + if iter_counts is None: + iter_counts = [1] * len(requests) + + if first_iter_counts is None: + first_iter_counts = [1] * len(requests) + + if ( + len(request_types) != len(requests) + or len(requests) != len(first_iter_times) + or len(requests) != len(iter_counts) + or len(requests) != len(first_iter_counts) + ): + raise ValueError( + "request_types, requests, first_iter_times, iter_counts, and " + "first_iter_counts must be the same length." + f"Given {len(request_types)}, {len(requests)}, " + f"{len(first_iter_times)}, {len(iter_counts)}, " + f"{len(first_iter_counts)}", + ) + + ( + _, + successful_requests, + successful_first_iter_times, + successful_iter_counts, + successful_first_iter_counts, + ) = ( + zip(*successful) + if ( + successful := list( + filter( + lambda val: val[0] == "successful", + zip( + request_types, + requests, + first_iter_times, + iter_counts, + first_iter_counts, + ), + ) + ) + ) + else ([], [], [], [], []) + ) + ( + _, + incomplete_requests, + incomplete_first_iter_times, + incomplete_iter_counts, + incomplete_first_iter_counts, + ) = ( + zip(*incomplete) + if ( + incomplete := list( + filter( + lambda val: val[0] == "incomplete", + zip( + request_types, + requests, + first_iter_times, + iter_counts, + first_iter_counts, + ), + ) + ) + ) + else ([], [], [], [], []) + ) + ( + _, + errored_requests, + errored_first_iter_times, + errored_iter_counts, + errored_first_iter_counts, + ) = ( + zip(*errored) + if ( + errored := list( + filter( + lambda val: val[0] == "error", + zip( + request_types, + requests, + first_iter_times, + iter_counts, + first_iter_counts, + ), + ) + ) + ) + else ([], [], [], [], []) + ) + + return StatusDistributionSummary( + total=DistributionSummary.from_iterable_request_times( + requests, + first_iter_times, + iter_counts, + first_iter_counts, + include_cdf=include_cdf, + epsilon=epsilon, + ), + successful=DistributionSummary.from_iterable_request_times( + successful_requests, # type: ignore[arg-type] + successful_first_iter_times, # type: ignore[arg-type] + successful_iter_counts, # type: ignore[arg-type] + successful_first_iter_counts, # type: ignore[arg-type] + include_cdf=include_cdf, + epsilon=epsilon, + ), + incomplete=DistributionSummary.from_iterable_request_times( + incomplete_requests, # type: ignore[arg-type] + incomplete_first_iter_times, # type: ignore[arg-type] + incomplete_iter_counts, # type: ignore[arg-type] + incomplete_first_iter_counts, # type: ignore[arg-type] + include_cdf=include_cdf, + epsilon=epsilon, + ), + errored=DistributionSummary.from_iterable_request_times( + errored_requests, # type: ignore[arg-type] + errored_first_iter_times, # type: ignore[arg-type] + errored_iter_counts, # type: ignore[arg-type] + errored_first_iter_counts, # type: ignore[arg-type] + include_cdf=include_cdf, + epsilon=epsilon, + ), + ) + + +class RunningStats(StandardBaseModel): + """ + Real-time statistics tracking for streaming numerical data. + + Maintains mean, rate, and cumulative statistics for continuous data streams + without storing individual values. Optimized for memory efficiency in + long-running monitoring applications. Supports arithmetic operators for + convenient value addition and provides computed properties for derived metrics. + + Example: + :: + stats = RunningStats() + stats += 10.5 # Add value using operator + stats.update(20.0, count=3) # Add value with custom count + print(f"Mean: {stats.mean}, Rate: {stats.rate}") + """ + + start_time: float = Field( + default_factory=timer.time, + description=( + "The time the running statistics object was created. " + "This is used to calculate the rate of the statistics." + ), + ) + count: int = Field( + default=0, + description="The number of values added to the running statistics.", + ) + total: float = Field( + default=0.0, + description="The total sum of the values added to the running statistics.", + ) + last: float = Field( + default=0.0, + description="The last value added to the running statistics.", + ) + + @computed_field # type: ignore[misc] + @property + def mean(self) -> float: + """ + :return: The mean of the running statistics (total / count). + If count is 0, return 0.0. + """ + if self.count == 0: + return 0.0 + return self.total / self.count + + @computed_field # type: ignore[misc] + @property + def rate(self) -> float: + """ + :return: The rate of the running statistics + (total / (time.time() - start_time)). + If count is 0, return 0.0. + """ + if self.count == 0: + return 0.0 + return self.total / (timer.time() - self.start_time) + + def __add__(self, value: Any) -> float: + """ + Add value using + operator and return current mean. + + :param value: Numerical value to add to the running statistics + :return: Updated mean after adding the value + :raises ValueError: If value is not numeric (int or float) + """ + if not isinstance(value, (int, float)): + raise ValueError( + f"Value must be an int or float, got {type(value)} instead.", + ) + + self.update(value) + + return self.mean + + def __iadd__(self, value: Any) -> RunningStats: + """ + Add value using += operator and return updated instance. + + :param value: Numerical value to add to the running statistics + :return: Self reference for method chaining + :raises ValueError: If value is not numeric (int or float) + """ + if not isinstance(value, (int, float)): + raise ValueError( + f"Value must be an int or float, got {type(value)} instead.", + ) + + self.update(value) + + return self + + def update(self, value: float, count: int = 1) -> None: + """ + Update running statistics with new value and count. + + :param value: Numerical value to add to the running statistics + :param count: Number of occurrences to count for this value (defaults to 1) + """ + self.count += count + self.total += value + self.last = value + + +class TimeRunningStats(RunningStats): + """ + Specialized running statistics for time-based measurements. + + Extends RunningStats with time-specific computed properties for millisecond + conversions. Designed for tracking latency, duration, and timing metrics in + performance monitoring applications. + + Example: + :: + time_stats = TimeRunningStats() + time_stats += 0.125 # Add 125ms in seconds + print(f"Mean: {time_stats.mean_ms}ms, Total: {time_stats.total_ms}ms") + """ + + @computed_field # type: ignore[misc] + @property + def total_ms(self) -> float: + """ + :return: The total time multiplied by 1000.0 to convert to milliseconds. + """ + return self.total * 1000.0 + + @computed_field # type: ignore[misc] + @property + def last_ms(self) -> float: + """ + :return: The last time multiplied by 1000.0 to convert to milliseconds. + """ + return self.last * 1000.0 + + @computed_field # type: ignore[misc] + @property + def mean_ms(self) -> float: + """ + :return: The mean time multiplied by 1000.0 to convert to milliseconds. + """ + return self.mean * 1000.0 + + @computed_field # type: ignore[misc] + @property + def rate_ms(self) -> float: + """ + :return: The rate of the running statistics multiplied by 1000.0 + to convert to milliseconds. + """ + return self.rate * 1000.0 diff --git a/src/guidellm/utils/synchronous.py b/src/guidellm/utils/synchronous.py new file mode 100644 index 00000000..3bec0247 --- /dev/null +++ b/src/guidellm/utils/synchronous.py @@ -0,0 +1,161 @@ +""" +Async utilities for waiting on synchronization objects. + +This module provides async-compatible wrappers for threading and multiprocessing +synchronization primitives (Events and Barriers). These utilities enable async code +to wait for synchronization objects without blocking the event loop, essential for +coordinating between async and sync code or between processes in the guidellm system. +""" + +from __future__ import annotations + +import asyncio +import contextlib +from datetime import time +from multiprocessing.synchronize import Barrier as ProcessingBarrier +from multiprocessing.synchronize import Event as ProcessingEvent +from threading import Barrier as ThreadingBarrier +from threading import Event as ThreadingEvent +from typing import Annotated, Union + +from typing_extensions import TypeAlias + +__all__ = [ + "SyncObjectTypesAlias", + "wait_for_sync_barrier", + "wait_for_sync_event", + "wait_for_sync_objects", +] + + +SyncObjectTypesAlias: TypeAlias = Annotated[ + Union[ThreadingEvent, ProcessingEvent, ThreadingBarrier, ProcessingBarrier], + "Type alias for threading and multiprocessing synchronization object types", +] + + +async def wait_for_sync_event( + event: ThreadingEvent | ProcessingEvent, + poll_interval: float, +) -> None: + """ + Asynchronously wait for a threading or multiprocessing Event to be set. + + This function polls the event at regular intervals without blocking the async + event loop, allowing other async tasks to continue executing while waiting. + + :param event: The Event object to wait for (threading or multiprocessing) + :param poll_interval: Time in seconds between polling checks + :raises asyncio.CancelledError: If the async task is cancelled + """ + stop = ThreadingEvent() + + def _watch(): + try: + while not stop.is_set(): + if event.wait(timeout=poll_interval): + return + except Exception as err: # noqa: BLE001 + if stop.is_set(): + return # Ignore error if we should have stopped + raise err + + try: + await asyncio.to_thread(_watch) + except asyncio.CancelledError: + stop.set() + raise + + +async def wait_for_sync_barrier( + barrier: ThreadingBarrier | ProcessingBarrier, + poll_interval: float, +) -> None: + """ + Asynchronously wait for a threading or multiprocessing Barrier to be reached. + + This function polls the barrier at regular intervals without blocking the async + event loop, allowing other async tasks to continue executing while waiting. + + :param barrier: The Barrier object to wait for (threading or multiprocessing) + :param poll_interval: Time in seconds between polling checks + :raises asyncio.CancelledError: If the async task is cancelled + """ + stop = ThreadingEvent() + barrier_broken = ThreadingEvent() + + def _wait_indefinite(): + try: + # wait forever, count on barrier broken event to exit + barrier.wait() + barrier_broken.set() + except Exception as err: + if stop.is_set(): + return # Ignore error if we should have stopped + raise err + + def _watch(): + while not barrier_broken.is_set(): + if stop.is_set(): + with contextlib.suppress(Exception): + if not barrier.broken: + barrier.abort() + break + time.sleep(poll_interval) + + try: + await asyncio.gather( + asyncio.to_thread(_wait_indefinite), + asyncio.to_thread(_watch), + ) + except asyncio.CancelledError: + stop.set() + raise + + +async def wait_for_sync_objects( + objects: SyncObjectTypesAlias + | list[SyncObjectTypesAlias] + | dict[str, SyncObjectTypesAlias], + poll_interval: float = 0.1, +) -> int | str: + """ + Asynchronously wait for the first synchronization object to complete. + + This function waits for the first Event to be set or Barrier to be reached + from a collection of synchronization objects. It returns immediately when + any object completes and cancels waiting on the remaining objects. + + :param objects: Single sync object, list of objects, or dict mapping names + to objects + :param poll_interval: Time in seconds between polling checks for each object + :return: Index (for list/single) or key name (for dict) of the first + completed object + :raises asyncio.CancelledError: If the async task is cancelled + """ + if isinstance(objects, dict): + keys = list(objects.keys()) + objects = list(objects.values()) + elif isinstance(objects, list): + keys = list(range(len(objects))) + else: + keys = [0] + objects = [objects] + + tasks = [ + asyncio.create_task( + wait_for_sync_barrier(obj, poll_interval) + if isinstance(obj, (ThreadingBarrier, ProcessingBarrier)) + else wait_for_sync_event(obj, poll_interval) + ) + for obj in objects + ] + + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + # Cancel the remaining pending tasks + for pend in pending: + pend.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + return keys[tasks.index(list(done)[0])] diff --git a/src/guidellm/utils/text.py b/src/guidellm/utils/text.py index 3b9a2e26..519b46c3 100644 --- a/src/guidellm/utils/text.py +++ b/src/guidellm/utils/text.py @@ -1,9 +1,21 @@ +""" +Text processing utilities for content manipulation and formatting operations. + +Provides comprehensive text processing capabilities including cleaning, filtering, +splitting, loading from various sources, and formatting utilities. Supports loading +text from URLs, compressed files, package resources, and local files with automatic +encoding detection. Includes specialized formatting for display values and text +wrapping operations for consistent presentation across the system. +""" + +from __future__ import annotations + import gzip import re import textwrap from importlib.resources import as_file, files # type: ignore[attr-defined] from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import ftfy import httpx @@ -11,36 +23,86 @@ from guidellm import data as package_data from guidellm.settings import settings +from guidellm.utils.console import Colors __all__ = [ + "MAX_PATH_LENGTH", "EndlessTextCreator", - "camelize_str", "clean_text", "filter_text", + "format_value_display", "is_puncutation", "load_text", "split_text", "split_text_list_by_length", ] -MAX_PATH_LENGTH = 4096 +MAX_PATH_LENGTH: int = 4096 + + +def format_value_display( + value: float, + label: str, + units: str = "", + total_characters: int | None = None, + digits_places: int | None = None, + decimal_places: int | None = None, +) -> str: + """ + Format a numeric value with units and label for consistent display output. + + Creates standardized display strings for metrics and measurements with + configurable precision, width, and color formatting. Supports both + fixed-width and variable-width output for tabular displays. + + :param value: Numeric value to format and display + :param label: Descriptive label for the value + :param units: Units string to append after the value + :param total_characters: Total width for right-aligned output formatting + :param digits_places: Total number of digits for numeric formatting + :param decimal_places: Number of decimal places for numeric precision + :return: Formatted string with value, units, and colored label + """ + if decimal_places is None and digits_places is None: + formatted_number = f"{value}:.0f" + elif digits_places is None: + formatted_number = f"{value:.{decimal_places}f}" + elif decimal_places is None: + formatted_number = f"{value:>{digits_places}f}" + else: + formatted_number = f"{value:>{digits_places}.{decimal_places}f}" + + result = f"{formatted_number}{units} [{Colors.info}]{label}[/{Colors.info}]" + + if total_characters is not None: + total_characters += len(Colors.info) * 2 + 5 + + if len(result) < total_characters: + result = result.rjust(total_characters) + + return result def split_text_list_by_length( text_list: list[Any], - max_characters: Union[int, list[int]], + max_characters: int | list[int], pad_horizontal: bool = True, pad_vertical: bool = True, ) -> list[list[str]]: """ - Split a list of strings into a list of strings, - each with a maximum length of max_characters - - :param text_list: the list of strings to split - :param max_characters: the maximum length of each string - :param pad_horizontal: whether to pad the strings horizontally, defaults to True - :param pad_vertical: whether to pad the strings vertically, defaults to True - :return: a list of strings + Split text strings into wrapped lines with specified maximum character limits. + + Processes each string in the input list by wrapping text to fit within character + limits, with optional padding for consistent formatting in tabular displays. + Supports different character limits per string and uniform padding across results. + + :param text_list: List of strings to process and wrap + :param max_characters: Maximum characters per line, either single value or + per-string limits + :param pad_horizontal: Right-align lines within their character limits + :param pad_vertical: Pad shorter results to match the longest wrapped result + :return: List of wrapped line lists, one per input string + :raises ValueError: If max_characters list length doesn't match text_list length """ if not isinstance(max_characters, list): max_characters = [max_characters] * len(text_list) @@ -76,16 +138,21 @@ def split_text_list_by_length( def filter_text( text: str, - filter_start: Optional[Union[str, int]] = None, - filter_end: Optional[Union[str, int]] = None, + filter_start: str | int | None = None, + filter_end: str | int | None = None, ) -> str: """ - Filter text by start and end strings or indices + Extract text substring using start and end markers or indices. + + Filters text content by locating string markers or using numeric indices + to extract specific portions. Supports flexible filtering for content + extraction and preprocessing operations. - :param text: the text to filter - :param filter_start: the start string or index to filter from - :param filter_end: the end string or index to filter to - :return: the filtered text + :param text: Source text to filter and extract from + :param filter_start: Starting marker string or index position + :param filter_end: Ending marker string or index position + :return: Filtered text substring between specified boundaries + :raises ValueError: If filter indices are invalid or markers not found """ filter_start_index = -1 filter_end_index = -1 @@ -113,10 +180,29 @@ def filter_text( def clean_text(text: str) -> str: + """ + Normalize text by fixing encoding issues and standardizing whitespace. + + Applies Unicode normalization and whitespace standardization for consistent + text processing. Removes excessive whitespace and fixes common encoding problems. + + :param text: Raw text string to clean and normalize + :return: Cleaned text with normalized encoding and whitespace + """ return re.sub(r"\s+", " ", ftfy.fix_text(text)).strip() def split_text(text: str, split_punctuation: bool = False) -> list[str]: + """ + Split text into tokens with optional punctuation separation. + + Tokenizes text into words and optionally separates punctuation marks + for detailed text analysis and processing operations. + + :param text: Text string to tokenize and split + :param split_punctuation: Separate punctuation marks as individual tokens + :return: List of text tokens + """ text = clean_text(text) if split_punctuation: @@ -125,16 +211,20 @@ def split_text(text: str, split_punctuation: bool = False) -> list[str]: return text.split() -def load_text(data: Union[str, Path], encoding: Optional[str] = None) -> str: +def load_text(data: str | Path, encoding: str | None = None) -> str: """ - Load an HTML file from a path or URL - - :param data: the path or URL to load the HTML file from - :type data: Union[str, Path] - :param encoding: the encoding to use when reading the file - :type encoding: str - :return: the HTML content - :rtype: str + Load text content from various sources including URLs, files, and package data. + + Supports loading from HTTP/FTP URLs, local files, compressed archives, package + resources, and raw text strings. Automatically detects source type and applies + appropriate loading strategy with encoding support. + + :param data: Source location or raw text - URL, file path, package resource + identifier, or text content + :param encoding: Character encoding for file reading operations + :return: Loaded text content as string + :raises FileNotFoundError: If local file path does not exist + :raises httpx.HTTPStatusError: If URL request fails """ logger.debug("Loading text: {}", data) @@ -180,35 +270,62 @@ def load_text(data: Union[str, Path], encoding: Optional[str] = None) -> str: def is_puncutation(text: str) -> bool: """ - Check if the text is a punctuation + Check if a single character is a punctuation mark. + + Identifies punctuation characters by excluding alphanumeric characters + and whitespace from single-character strings. - :param text: the text to check - :type text: str - :return: True if the text is a punctuation, False otherwise - :rtype: bool + :param text: Single character string to test + :return: True if the character is punctuation, False otherwise """ return len(text) == 1 and not text.isalnum() and not text.isspace() -def camelize_str(snake_case_string: str) -> str: - return (words := snake_case_string.split("_"))[0].lower() + "".join( - word.capitalize() for word in words[1:] - ) +class EndlessTextCreator: + """ + Infinite text generator for load testing and content creation operations. + Provides deterministic text generation by cycling through preprocessed word + tokens from source content. Supports filtering and punctuation handling for + realistic text patterns in benchmarking scenarios. + + Example: + :: + creator = EndlessTextCreator("path/to/source.txt") + generated = creator.create_text(start=0, length=100) + more_text = creator.create_text(start=50, length=200) + """ -class EndlessTextCreator: def __init__( self, - data: Union[str, Path], - filter_start: Optional[Union[str, int]] = None, - filter_end: Optional[Union[str, int]] = None, + data: str | Path, + filter_start: str | int | None = None, + filter_end: str | int | None = None, ): + """ + Initialize text creator with source content and optional filtering. + + :param data: Source text location or content - file path, URL, or raw text + :param filter_start: Starting marker or index for content filtering + :param filter_end: Ending marker or index for content filtering + """ self.data = data self.text = load_text(data) self.filtered_text = filter_text(self.text, filter_start, filter_end) self.words = split_text(self.filtered_text, split_punctuation=True) def create_text(self, start: int, length: int) -> str: + """ + Generate text by cycling through word tokens from the specified position. + + Creates deterministic text sequences by selecting consecutive tokens from + the preprocessed word list, wrapping around when reaching the end. + Maintains proper spacing and punctuation formatting. + + :param start: Starting position in the token sequence + :param length: Number of tokens to include in generated text + :return: Generated text string with proper spacing and punctuation + """ text = "" for counter in range(length): diff --git a/tests/unit/utils/dict.py b/tests/unit/utils/dict.py deleted file mode 100644 index 09d93df6..00000000 --- a/tests/unit/utils/dict.py +++ /dev/null @@ -1,71 +0,0 @@ -import pytest - -from guidellm.utils.dict import recursive_key_update - - -def update_str(string): - return string + "_updated" - - -@pytest.mark.smoke -def test_recursive_key_update_updates_keys(): - my_dict = { - "my_key": { - "my_nested_key": {"my_double_nested_key": "someValue"}, - "my_other_nested_key": "someValue", - }, - "my_other_key": "value", - } - my_updated_dict = { - "my_key_updated": { - "my_nested_key_updated": {"my_double_nested_key_updated": "someValue"}, - "my_other_nested_key_updated": "someValue", - }, - "my_other_key_updated": "value", - } - recursive_key_update(my_dict, update_str) - assert my_dict == my_updated_dict - - -def truncate_str_to_ten(string): - return string[:10] - - -@pytest.mark.smoke -def test_recursive_key_update_leaves_unchanged_keys(): - my_dict = { - "my_key": { - "my_nested_key": {"my_double_nested_key": "someValue"}, - "my_other_nested_key": "someValue", - }, - "my_other_key": "value", - } - my_updated_dict = { - "my_key": { - "my_nested_": {"my_double_": "someValue"}, - "my_other_n": "someValue", - }, - "my_other_k": "value", - } - recursive_key_update(my_dict, truncate_str_to_ten) - assert my_dict == my_updated_dict - - -@pytest.mark.smoke -def test_recursive_key_update_updates_dicts_in_list(): - my_dict = { - "my_key": [ - {"my_list_item_key_1": "someValue"}, - {"my_list_item_key_2": "someValue"}, - {"my_list_item_key_3": "someValue"}, - ] - } - my_updated_dict = { - "my_key_updated": [ - {"my_list_item_key_1_updated": "someValue"}, - {"my_list_item_key_2_updated": "someValue"}, - {"my_list_item_key_3_updated": "someValue"}, - ] - } - recursive_key_update(my_dict, update_str) - assert my_dict == my_updated_dict diff --git a/tests/unit/utils/test_auto_importer.py b/tests/unit/utils/test_auto_importer.py new file mode 100644 index 00000000..cc71bce3 --- /dev/null +++ b/tests/unit/utils/test_auto_importer.py @@ -0,0 +1,269 @@ +""" +Unit tests for the auto_importer module. +""" + +from __future__ import annotations + +from unittest import mock + +import pytest + +from guidellm.utils import AutoImporterMixin + + +class TestAutoImporterMixin: + """Test suite for AutoImporterMixin functionality.""" + + @pytest.fixture( + params=[ + { + "auto_package": "test.package", + "auto_ignore_modules": None, + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module1", "test.package.module2"], + }, + { + "auto_package": ("test.package1", "test.package2"), + "auto_ignore_modules": None, + "modules": [ + ("test.package1.moduleA", False), + ("test.package2.moduleB", False), + ], + "expected_imports": ["test.package1.moduleA", "test.package2.moduleB"], + }, + { + "auto_package": "test.package", + "auto_ignore_modules": ("test.package.module1",), + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module2"], + }, + ], + ids=["single_package", "multiple_packages", "ignored_modules"], + ) + def valid_instances(self, request): + """Fixture providing test data for AutoImporterMixin subclasses.""" + config = request.param + + class TestClass(AutoImporterMixin): + auto_package = config["auto_package"] + auto_ignore_modules = config["auto_ignore_modules"] + + return TestClass, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test AutoImporterMixin class signatures and attributes.""" + assert hasattr(AutoImporterMixin, "auto_package") + assert hasattr(AutoImporterMixin, "auto_ignore_modules") + assert hasattr(AutoImporterMixin, "auto_imported_modules") + assert hasattr(AutoImporterMixin, "auto_import_package_modules") + assert callable(AutoImporterMixin.auto_import_package_modules) + + # Test default class variables + assert AutoImporterMixin.auto_package is None + assert AutoImporterMixin.auto_ignore_modules is None + assert AutoImporterMixin.auto_imported_modules is None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test AutoImporterMixin subclass initialization.""" + test_class, config = valid_instances + assert issubclass(test_class, AutoImporterMixin) + assert test_class.auto_package == config["auto_package"] + assert test_class.auto_ignore_modules == config["auto_ignore_modules"] + assert test_class.auto_imported_modules is None + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test AutoImporterMixin with missing auto_package.""" + + class TestClass(AutoImporterMixin): + pass + + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestClass.auto_import_package_modules() + + @pytest.mark.smoke + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_auto_import_package_modules(self, mock_walk, mock_import, valid_instances): + """Test auto_import_package_modules core functionality.""" + test_class, config = valid_instances + + # Setup mocks based on config + packages = {} + modules = {} + + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + pkg_path = pkg.replace(".", "/") + packages[pkg] = MockHelper.create_mock_package(pkg, pkg_path) + else: + pkg = config["auto_package"] + packages[pkg] = MockHelper.create_mock_package(pkg, pkg.replace(".", "/")) + + for module_name, is_pkg in config["modules"]: + if not is_pkg: + modules[module_name] = MockHelper.create_mock_module(module_name) + + mock_import.side_effect = lambda name: {**packages, **modules}.get( + name, mock.MagicMock() + ) + + def walk_side_effect(path, prefix): + return [ + (None, module_name, is_pkg) + for module_name, is_pkg in config["modules"] + if module_name.startswith(prefix) + ] + + mock_walk.side_effect = walk_side_effect + + # Execute + test_class.auto_import_package_modules() + + # Verify + assert test_class.auto_imported_modules == config["expected_imports"] + + # Verify package imports + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + mock_import.assert_any_call(pkg) + else: + mock_import.assert_any_call(config["auto_package"]) + + # Verify expected module imports + for expected_module in config["expected_imports"]: + mock_import.assert_any_call(expected_module) + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_auto_import_package_modules_invalid(self, mock_walk, mock_import): + """Test auto_import_package_modules with invalid configurations.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Test import error handling + mock_import.side_effect = ImportError("Module not found") + + with pytest.raises(ImportError): + TestClass.auto_import_package_modules() + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_skip_packages(self, mock_walk, mock_import): + """Test that packages (is_pkg=True) are skipped.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_module = MockHelper.create_mock_module("test.package.module") + + mock_import.side_effect = lambda name: { + "test.package": mock_package, + "test.package.module": mock_module, + }[name] + + mock_walk.return_value = [ + (None, "test.package.subpackage", True), + (None, "test.package.module", False), + ] + + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.module"] + mock_import.assert_any_call("test.package.module") + # subpackage should not be imported + with pytest.raises(AssertionError): + mock_import.assert_any_call("test.package.subpackage") + + @pytest.mark.sanity + @mock.patch("sys.modules", {"test.package.existing": mock.MagicMock()}) + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_skip_already_imported_modules(self, mock_walk, mock_import): + """Test that modules already in sys.modules are tracked but not re-imported.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_import.side_effect = lambda name: { + "test.package": mock_package, + }.get(name, mock.MagicMock()) + + mock_walk.return_value = [ + (None, "test.package.existing", False), + ] + + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.existing"] + mock_import.assert_called_once_with("test.package") + with pytest.raises(AssertionError): + mock_import.assert_any_call("test.package.existing") + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_prevent_duplicate_module_imports(self, mock_walk, mock_import): + """Test that modules already in auto_imported_modules are not re-imported.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_module = MockHelper.create_mock_module("test.package.module") + + mock_import.side_effect = lambda name: { + "test.package": mock_package, + "test.package.module": mock_module, + }[name] + + mock_walk.return_value = [ + (None, "test.package.module", False), + (None, "test.package.module", False), + ] + + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.module"] + assert mock_import.call_count == 2 # Package + module (not duplicate) + + +class MockHelper: + """Helper class to create consistent mock objects for testing.""" + + @staticmethod + def create_mock_package(name: str, path: str): + """Create a mock package with required attributes.""" + package = mock.MagicMock() + package.__name__ = name + package.__path__ = [path] + return package + + @staticmethod + def create_mock_module(name: str): + """Create a mock module with required attributes.""" + module = mock.MagicMock() + module.__name__ = name + return module diff --git a/tests/unit/utils/test_encoding.py b/tests/unit/utils/test_encoding.py new file mode 100644 index 00000000..da1f63ee --- /dev/null +++ b/tests/unit/utils/test_encoding.py @@ -0,0 +1,556 @@ +from __future__ import annotations + +import uuid +from typing import Any, Generic, TypeVar + +import pytest +from pydantic import BaseModel, Field + +from guidellm.backend.objects import ( + GenerationRequest, + GenerationResponse, +) +from guidellm.scheduler.objects import RequestSchedulerTimings, ScheduledRequestInfo +from guidellm.utils.encoding import Encoder, MessageEncoding, Serializer + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing.""" + + name: str = Field(description="Name field for testing") + value: int = Field(description="Value field for testing") + + +class SampleModelSubclass(SampleModel): + """Subclass of SampleModel for testing.""" + + extra_field: str + + +SampleModelT = TypeVar("SampleModelT", bound=SampleModel) + + +class ComplexModel(BaseModel, Generic[SampleModelT]): + """Complex Pydantic model for testing.""" + + items: list[str] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + nested: SampleModelT | None = Field(default=None) + + +class GenricModelWrapper(Generic[SampleModelT]): + """Simulates a layered generic type.""" + + def method(self, **kwargs) -> ComplexModel[SampleModelT]: + return ComplexModel[SampleModelT](**kwargs) + + +class TestMessageEncoding: + """Test suite for MessageEncoding class.""" + + @pytest.fixture( + params=[ + {"serialization": None, "encoding": None}, + {"serialization": "dict", "encoding": None}, + {"serialization": "sequence", "encoding": None}, + {"serialization": None, "encoding": "msgpack"}, + {"serialization": "dict", "encoding": "msgpack"}, + {"serialization": "sequence", "encoding": "msgpack"}, + {"serialization": None, "encoding": "msgspec"}, + {"serialization": "dict", "encoding": "msgspec"}, + {"serialization": "sequence", "encoding": "msgspec"}, + {"serialization": None, "encoding": ["msgspec", "msgpack"]}, + {"serialization": "dict", "encoding": ["msgspec", "msgpack"]}, + ], + ids=[ + "no_serialization_no_encoding", + "dict_serialization_no_encoding", + "str_serialization_no_encoding", + "no_serialization_msgpack", + "dict_serialization_msgpack", + "str_serialization_msgpack", + "no_serialization_msgspec", + "dict_serialization_msgspec", + "str_serialization_msgspec", + "no_serialization_encoding_list", + "dict_serialization_encoding_list", + ], + ) + def valid_instances(self, request): + """Fixture providing test data for MessageEncoding.""" + constructor_args = request.param + try: + instance = MessageEncoding(**constructor_args) + return instance, constructor_args + except ImportError: + pytest.skip("Required encoding library not available") + + @pytest.mark.smoke + def test_class_signatures(self): + """Test MessageEncoding inheritance and type relationships.""" + assert issubclass(MessageEncoding, Generic) + assert hasattr(MessageEncoding, "DEFAULT_ENCODING_PREFERENCE") + assert isinstance(MessageEncoding.DEFAULT_ENCODING_PREFERENCE, list) + assert MessageEncoding.DEFAULT_ENCODING_PREFERENCE == ["msgspec", "msgpack"] + + # Check classmethods + assert hasattr(MessageEncoding, "encode_message") + assert callable(MessageEncoding.encode_message) + assert hasattr(MessageEncoding, "decode_message") + assert callable(MessageEncoding.decode_message) + + # Check instance methods + assert hasattr(MessageEncoding, "__init__") + assert hasattr(MessageEncoding, "register_pydantic") + assert hasattr(MessageEncoding, "encode") + assert hasattr(MessageEncoding, "decode") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test MessageEncoding initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, MessageEncoding) + assert hasattr(instance, "serializer") + assert isinstance(instance.serializer, Serializer) + assert instance.serializer.serialization == constructor_args["serialization"] + assert hasattr(instance, "encoder") + assert isinstance(instance.encoder, Encoder) + + expected_encoding = constructor_args["encoding"] + if isinstance(expected_encoding, list): + assert instance.encoder.encoding in expected_encoding + else: + assert instance.encoder.encoding == expected_encoding + + @pytest.mark.smoke + @pytest.mark.parametrize( + "obj", + [ + None, + 0, + 0.0, + "0.1.2.3", + [0, 0.0, "0.1.2.3", None], + (0, 0.0, "0.1.2.3", None), + {"key1": 0, "key2": 0.0, "key3": "0.1.2.3", "key4": None}, + ], + ) + def test_encode_decode_python(self, valid_instances, obj: Any): + """Test MessageEncoding encode/decode with comprehensive data types.""" + instance, constructor_args = valid_instances + + message = instance.encode(obj) + decoded = instance.decode(message) + + if isinstance(obj, tuple): + assert list(decoded) == list(obj) + else: + assert decoded == obj + + @pytest.mark.smoke + @pytest.mark.parametrize( + "obj", + [ + SampleModel(name="sample", value=123), + ComplexModel( + items=["item1", "item2"], + metadata={"key": "value"}, + nested=SampleModel(name="sample", value=123), + ), + ( + SampleModel(name="sample", value=123), + None, + ComplexModel( + items=["item1", "item2"], + metadata={"key": "value"}, + nested=SampleModel(name="sample", value=123), + ), + ), + { + "key1": SampleModel(name="sample", value=123), + "key2": None, + "key3": ComplexModel( + items=["item1", "item2"], + metadata={"key": "value"}, + nested=SampleModel(name="sample", value=123), + ), + }, + ], + ) + def test_encode_decode_pydantic(self, valid_instances, obj: Any): + """Test MessageEncoding encode/decode with Pydantic models.""" + instance, constructor_args = valid_instances + + if ( + constructor_args["serialization"] is None + and constructor_args["encoding"] is not None + ): + # msgpack/msgspec don't support Pydantic models natively + pytest.skip("Skipping unsupported Pydantic serialization/encoding combo") + + # Register Pydantic models for proper serialization + instance.register_pydantic(SampleModel) + instance.register_pydantic(ComplexModel) + + message = instance.encode(obj) + decoded = instance.decode(message) + + if isinstance(obj, tuple): + assert list(decoded) == list(obj) + else: + assert decoded == obj + + @pytest.mark.smoke + @pytest.mark.parametrize( + "obj", + [ + ( + None, + GenerationRequest(content="test content"), + ScheduledRequestInfo( + scheduler_timings=RequestSchedulerTimings( + targeted_start=1.0, + queued=0.1, + dequeued=0.2, + scheduled_at=0.3, + resolve_start=1.1, + resolve_end=1.5, + finalized=1.6, + ) + ), + ), + ( + GenerationResponse( + request_id=str(uuid.uuid4()), + request_args={}, + value="test response", + request_prompt_tokens=2, + request_output_tokens=3, + response_prompt_tokens=4, + response_output_tokens=6, + ), + GenerationRequest(content="test content"), + ScheduledRequestInfo( + scheduler_timings=RequestSchedulerTimings( + targeted_start=1.0, + queued=0.1, + dequeued=0.2, + scheduled_at=0.3, + resolve_start=1.1, + resolve_end=1.5, + finalized=1.6, + ) + ), + ), + ], + ) + def test_encode_decode_generative(self, valid_instances, obj: Any): + """Test MessageEncoding encode/decode with generative models.""" + instance, constructor_args = valid_instances + + if ( + constructor_args["serialization"] is None + and constructor_args["encoding"] is not None + ): + # msgpack/msgspec don't support Pydantic models natively + pytest.skip("Skipping unsupported Pydantic serialization/encoding combo") + + instance.register_pydantic(GenerationRequest) + instance.register_pydantic(GenerationResponse) + instance.register_pydantic(ScheduledRequestInfo) + + message = instance.encode(obj) + decoded = instance.decode(message) + + assert list(decoded) == list(obj) + + @pytest.mark.smoke + @pytest.mark.parametrize( + "serialization", + [ + None, + "dict", + "sequence", + ], + ) + @pytest.mark.parametrize( + "encoding", + [None, "msgpack", "msgspec"], + ) + @pytest.mark.parametrize( + "obj", + [ + "0.1.2.3", + [0, 0.0, "0.1.2.3", None, SampleModel(name="sample", value=123)], + { + "key1": 0, + "key2": 0.0, + "key3": "0.1.2.3", + "key4": None, + "key5": ComplexModel( + items=["item1", "item2"], + metadata={"key": "value"}, + nested=SampleModel(name="sample", value=123), + ), + }, + ], + ) + def test_encode_decode_message(self, serialization, encoding, obj): + """Test MessageEncoding.encode_message and decode_message class methods.""" + if encoding is not None and serialization is None and obj != "0.1.2.3": + pytest.skip("Skipping unsupported serialization/encoding combo") + + try: + serializer = Serializer(serialization) if serialization else None + encoder = Encoder(encoding) if encoding else None + + message = MessageEncoding.encode_message(obj, serializer, encoder) + decoded = MessageEncoding.decode_message(message, serializer, encoder) + + if isinstance(obj, tuple): + assert list(decoded) == list(obj) + else: + assert decoded == obj + except ImportError: + pytest.skip("Required encoding library not available") + + @pytest.mark.smoke + def test_register_pydantic(self): + """Test MessageEncoding.register_pydantic functionality.""" + instance = MessageEncoding(serialization="dict", encoding=None) + assert len(instance.serializer.pydantic_registry) == 0 + instance.register_pydantic(SampleModel) + assert len(instance.serializer.pydantic_registry) == 1 + assert ( + instance.serializer.pydantic_registry.values().__iter__().__next__() + is SampleModel + ) + + @pytest.mark.sanity + def test_invalid_initialization_values(self): + """Test invalid initialization (unsupported encoding).""" + inst = MessageEncoding(serialization="dict", encoding=["invalid_encoding"]) # type: ignore[arg-type] + assert inst.encoder.encoding is None + with pytest.raises(ImportError): + MessageEncoding(serialization="dict", encoding="invalid") # type: ignore[arg-type] + + +class TestEncoder: + """Test suite for Encoder class.""" + + @pytest.fixture( + params=[ + None, + "msgpack", + "msgspec", + ["msgspec", "msgpack"], + ["msgpack", "msgspec"], + ], + ids=[ + "none", + "msgpack", + "msgspec", + "list_pref_msgspec_first", + "list_pref_msgpack_first", + ], + ) + def valid_instances(self, request): + args = request.param + try: + inst = Encoder(args) + except ImportError: + pytest.skip("Encoding backend missing") + return inst, args + + @pytest.mark.smoke + def test_class_signatures(self): + assert hasattr(Encoder, "encode") + assert hasattr(Encoder, "decode") + assert hasattr(Encoder, "_resolve_encoding") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + inst, args = valid_instances + assert isinstance(inst, Encoder) + if isinstance(args, list): + assert inst.encoding in args or inst.encoding is None + else: + assert inst.encoding == args + + @pytest.mark.sanity + def test_invalid_initialization_values(self): + with pytest.raises(ImportError): + Encoder("invalid") # type: ignore[arg-type] + + @pytest.mark.smoke + @pytest.mark.parametrize("obj", [None, 0, 1.2, "text", [1, 2], {"a": 1}]) + def test_encode_decode(self, valid_instances, obj): + inst, _ = valid_instances + msg = inst.encode(obj) + out = inst.decode(msg) + assert out == obj + + +class TestSerializer: + """Test suite for Serializer class.""" + + @pytest.fixture(params=[None, "dict", "sequence"], ids=["none", "dict", "sequence"]) + def valid_instances(self, request): + inst = Serializer(request.param) + return inst, request.param + + @pytest.mark.smoke + def test_class_signatures(self): + assert hasattr(Serializer, "serialize") + assert hasattr(Serializer, "deserialize") + assert hasattr(Serializer, "register_pydantic") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + inst, mode = valid_instances + assert isinstance(inst, Serializer) + assert inst.serialization == mode + + @pytest.mark.smoke + def test_register_pydantic(self, valid_instances): + inst, _ = valid_instances + assert len(inst.pydantic_registry) == 0 + inst.register_pydantic(SampleModel) + assert len(inst.pydantic_registry) == 1 + + @pytest.mark.smoke + @pytest.mark.parametrize( + "obj", + [ + 1, + "str_val", + [1, 2, 3], + SampleModel(name="x", value=1), + {"k": SampleModel(name="y", value=2)}, + ], + ) + def test_serialize_deserialize(self, valid_instances, obj): + inst, mode = valid_instances + inst.register_pydantic(SampleModel) + msg = inst.serialize(obj) + out = inst.deserialize(msg) + if isinstance(obj, list): + assert list(out) == obj + else: + assert out == obj + + @pytest.mark.regression + def test_sequence_mapping_roundtrip(self): + inst = Serializer("sequence") + inst.register_pydantic(SampleModel) + data = { + "a": SampleModel(name="a", value=1), + "b": SampleModel(name="b", value=2), + } + msg = inst.serialize(data) + out = inst.deserialize(msg) + assert out == data + + @pytest.mark.sanity + def test_to_from_dict_variations(self): + inst = Serializer("dict") + inst.register_pydantic(SampleModel) + model = SampleModel(name="n", value=3) + lst = [model, 5] + mp = {"k1": model, "k2": 9} + assert inst.from_dict(inst.to_dict(model)) == model + assert inst.from_dict(inst.to_dict(lst)) == lst + assert inst.from_dict(inst.to_dict(mp)) == mp + + @pytest.mark.sanity + @pytest.mark.parametrize( + "collection", + [ + [SampleModel(name="x", value=1), 2, 3], + (SampleModel(name="y", value=2), None), + ], + ) + def test_to_from_sequence_collections(self, collection): + inst = Serializer("sequence") + inst.register_pydantic(SampleModel) + seq = inst.to_sequence(collection) + out = inst.from_sequence(seq) + assert len(out) == len(collection) + assert all(a == b for a, b in zip(out, list(collection))) + + @pytest.mark.sanity + def test_to_from_sequence_mapping(self): + inst = Serializer("sequence") + inst.register_pydantic(SampleModel) + data = {"k": SampleModel(name="z", value=7), "j": 1} + seq = inst.to_sequence(data) + out = inst.from_sequence(seq) + assert out == data + + @pytest.mark.sanity + def test_sequence_multiple_root_raises(self): + inst = Serializer("sequence") + part1 = inst.pack_next_sequence("python", inst.to_sequence_python(1), None) + part2 = inst.pack_next_sequence("python", inst.to_sequence_python(2), None) + with pytest.raises(ValueError): + inst.from_sequence(part1 + part2) # type: ignore[operator] + + @pytest.mark.sanity + def test_pack_next_sequence_type_mismatch(self): + inst = Serializer("sequence") + first_payload = inst.to_sequence_python(1) + first = inst.pack_next_sequence("python", first_payload, None) + bad_payload: Any = ( + first_payload.decode() if isinstance(first_payload, bytes) else b"1" + ) + with pytest.raises(ValueError): + inst.pack_next_sequence("python", bad_payload, first) + + @pytest.mark.sanity + def test_unpack_invalid(self): + inst = Serializer("sequence") + with pytest.raises(ValueError): + inst.unpack_next_sequence("X|3|abc") + with pytest.raises(ValueError): + inst.unpack_next_sequence("p?bad") + + @pytest.mark.sanity + def test_dynamic_import_load_pydantic(self, monkeypatch): + inst = Serializer("dict") + inst.pydantic_registry.clear() + sample = SampleModel(name="dyn", value=5) + dumped = inst.to_dict(sample) + inst.pydantic_registry.clear() + restored = inst.from_dict(dumped) + assert restored == sample + + @pytest.mark.sanity + def test_generic_model(self): + inst = Serializer("dict") + inst.register_pydantic(ComplexModel[SampleModelSubclass]) + nested = ComplexModel[SampleModelSubclass]( + items=["i1", "i2"], + metadata={"m": 1}, + nested=SampleModelSubclass(name="nested", value=10, extra_field="extra"), + ) + dumped = inst.to_dict(nested) + restored = inst.from_dict(dumped) + assert restored == nested + + @pytest.mark.sanity + @pytest.mark.xfail( + reason="A generic object returned by a generic method loses its type args" + ) + def test_generic_emitted_type(self): + generic_instance = GenricModelWrapper[SampleModelSubclass]() + + inst = Serializer("dict") + inst.register_pydantic(ComplexModel[SampleModelSubclass]) + nested = generic_instance.method( + items=["i1", "i2"], + metadata={"m": 1}, + nested=SampleModelSubclass(name="nested", value=10, extra_field="extra"), + ) + dumped = inst.to_dict(nested) + restored = inst.from_dict(dumped) + assert restored == nested diff --git a/tests/unit/utils/test_functions.py b/tests/unit/utils/test_functions.py new file mode 100644 index 00000000..3b353759 --- /dev/null +++ b/tests/unit/utils/test_functions.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from datetime import datetime + +import pytest + +from guidellm.utils.functions import ( + all_defined, + safe_add, + safe_divide, + safe_format_timestamp, + safe_getattr, + safe_multiply, +) + + +class TestAllDefined: + """Test suite for all_defined function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("values", "expected"), + [ + ((1, 2, 3), True), + (("test", "hello"), True), + ((0, False, ""), True), + ((1, None, 3), False), + ((None,), False), + ((None, None), False), + ((), True), + ], + ) + def test_invocation(self, values, expected): + """Test all_defined with valid inputs.""" + result = all_defined(*values) + assert result == expected + + @pytest.mark.sanity + def test_mixed_types(self): + """Test all_defined with mixed data types.""" + result = all_defined(1, "test", [], {}, 0.0, False) + assert result is True + + result = all_defined(1, "test", None, {}) + assert result is False + + +class TestSafeGetattr: + """Test suite for safe_getattr function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("obj", "attr", "default", "expected"), + [ + (None, "any_attr", "default_val", "default_val"), + (None, "any_attr", None, None), + ("test_string", "nonexistent", "default_val", "default_val"), + ], + ) + def test_invocation(self, obj, attr, default, expected): + """Test safe_getattr with valid inputs.""" + result = safe_getattr(obj, attr, default) + assert result == expected + + @pytest.mark.smoke + def test_with_object(self): + """Test safe_getattr with actual object attributes.""" + + class TestObj: + test_attr = "test_value" + + obj = TestObj() + result = safe_getattr(obj, "test_attr", "default") + assert result == "test_value" + + result = safe_getattr(obj, "missing_attr", "default") + assert result == "default" + + # Test with method attribute + result = safe_getattr("test_string", "upper", None) + assert callable(result) + assert result() == "TEST_STRING" + + +class TestSafeDivide: + """Test suite for safe_divide function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("numerator", "denominator", "num_default", "den_default", "expected"), + [ + (10, 2, 0.0, 1.0, 5.0), + (None, 2, 6.0, 1.0, 3.0), + (10, None, 0.0, 5.0, 2.0), + (None, None, 8.0, 4.0, 2.0), + (10, 0, 0.0, 1.0, 10 / 1e-10), + ], + ) + def test_invocation( + self, numerator, denominator, num_default, den_default, expected + ): + """Test safe_divide with valid inputs.""" + result = safe_divide(numerator, denominator, num_default, den_default) + assert result == pytest.approx(expected, rel=1e-6) + + @pytest.mark.sanity + def test_zero_division_protection(self): + """Test safe_divide protection against zero division.""" + result = safe_divide(10, 0) + assert result == 10 / 1e-10 + + result = safe_divide(5, None, den_default=0) + assert result == 5 / 1e-10 + + +class TestSafeMultiply: + """Test suite for safe_multiply function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("values", "default", "expected"), + [ + ((2, 3, 4), 1.0, 24.0), + ((2, None, 4), 1.0, 8.0), + ((None, None), 5.0, 5.0), + ((), 3.0, 3.0), + ((2, 3, None, 5), 2.0, 60.0), + ], + ) + def test_invocation(self, values, default, expected): + """Test safe_multiply with valid inputs.""" + result = safe_multiply(*values, default=default) + assert result == expected + + @pytest.mark.sanity + def test_with_zero(self): + """Test safe_multiply with zero values.""" + result = safe_multiply(2, 0, 3, default=1.0) + assert result == 0.0 + + result = safe_multiply(None, 0, None, default=5.0) + assert result == 0.0 + + +class TestSafeAdd: + """Test suite for safe_add function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("values", "signs", "default", "expected"), + [ + ((1, 2, 3), None, 0.0, 6.0), + ((1, None, 3), None, 5.0, 9.0), + ((10, 5), [1, -1], 0.0, 5.0), + ((None, None), [1, -1], 2.0, 0.0), + ((), None, 3.0, 3.0), + ((1, 2, 3), [1, 1, -1], 0.0, 0.0), + ], + ) + def test_invocation(self, values, signs, default, expected): + """Test safe_add with valid inputs.""" + result = safe_add(*values, signs=signs, default=default) + assert result == expected + + @pytest.mark.sanity + def test_invalid_signs_length(self): + """Test safe_add with invalid signs length.""" + with pytest.raises( + ValueError, match="Length of signs must match length of values" + ): + safe_add(1, 2, 3, signs=[1, -1]) + + @pytest.mark.sanity + def test_single_value(self): + """Test safe_add with single value.""" + result = safe_add(5, default=1.0) + assert result == 5.0 + + result = safe_add(None, default=3.0) + assert result == 3.0 + + +class TestSafeFormatTimestamp: + """Test suite for safe_format_timestamp function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("timestamp", "format_", "default", "expected"), + [ + (1609459200.0, "%Y-%m-%d", "N/A", "2020-12-31"), + (1609459200.0, "%H:%M:%S", "N/A", "19:00:00"), + (None, "%H:%M:%S", "N/A", "N/A"), + (-1, "%H:%M:%S", "N/A", "N/A"), + (2**32, "%H:%M:%S", "N/A", "N/A"), + ], + ) + def test_invocation(self, timestamp, format_, default, expected): + """Test safe_format_timestamp with valid inputs.""" + result = safe_format_timestamp(timestamp, format_, default) + assert result == expected + + @pytest.mark.sanity + def test_edge_cases(self): + """Test safe_format_timestamp with edge case timestamps.""" + result = safe_format_timestamp(0.0, "%Y", "N/A") + assert result == "1969" + + result = safe_format_timestamp(1.0, "%Y", "N/A") + assert result == "1969" + + result = safe_format_timestamp(2**31 - 1, "%Y", "N/A") + expected_year = datetime.fromtimestamp(2**31 - 1).strftime("%Y") + assert result == expected_year + + @pytest.mark.sanity + def test_invalid_timestamp_ranges(self): + """Test safe_format_timestamp with invalid timestamp ranges.""" + result = safe_format_timestamp(2**31 + 1, "%Y", "ERROR") + assert result == "ERROR" + + result = safe_format_timestamp(-1000, "%Y", "ERROR") + assert result == "ERROR" diff --git a/tests/unit/utils/test_messaging.py b/tests/unit/utils/test_messaging.py new file mode 100644 index 00000000..d6627e88 --- /dev/null +++ b/tests/unit/utils/test_messaging.py @@ -0,0 +1,974 @@ +from __future__ import annotations + +import asyncio +import multiprocessing +import threading +from functools import wraps +from typing import Any, TypeVar + +import culsans +import pytest +from pydantic import BaseModel + +from guidellm.backend import ( + GenerationRequest, + GenerationResponse, +) +from guidellm.scheduler import ScheduledRequestInfo +from guidellm.utils import ( + InterProcessMessaging, + InterProcessMessagingManagerQueue, + InterProcessMessagingPipe, + InterProcessMessagingQueue, +) +from guidellm.utils.messaging import ReceiveMessageT, SendMessageT + + +def async_timeout(delay: float): + """Decorator to add timeout to async test functions.""" + + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +class MockMessage(BaseModel): + content: str + num: int + + +class MockProcessTarget: + """Mock process target for testing.""" + + def __init__( + self, + messaging: InterProcessMessaging, + num_messages: int, + worker_index: int = 0, + ): + self.messaging = messaging + self.num_messages = num_messages + self.worker_index = worker_index + + def run(self): + loop = asyncio.new_event_loop() + + try: + asyncio.set_event_loop(loop) + asyncio.run(asyncio.wait_for(self._async_runner(), timeout=10.0)) + except RuntimeError: + pass + finally: + loop.close() + + async def _async_runner(self): + await self.messaging.start( + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo, + ], + ) + + try: + for _ in range(self.num_messages): + obj = await self.messaging.get(timeout=2.0) + await self.messaging.put(obj, timeout=2.0) + finally: + await self.messaging.stop() + + +@pytest.fixture( + params=[ + {"ctx_name": "fork"}, + {"ctx_name": "spawn"}, + ], + ids=["fork_ctx", "spawn_ctx"], +) +def multiprocessing_contexts(request): + context = multiprocessing.get_context(request.param["ctx_name"]) + manager = context.Manager() + try: + yield manager, context + finally: + manager.shutdown() + + +def test_send_message_type(): + """Test that SendMessageT is filled out correctly as a TypeVar.""" + assert isinstance(SendMessageT, type(TypeVar("test"))) + assert SendMessageT.__name__ == "SendMessageT" + assert SendMessageT.__bound__ is Any + assert SendMessageT.__constraints__ == () + + +def test_receive_message_type(): + """Test that ReceiveMessageT is filled out correctly as a TypeVar.""" + assert isinstance(ReceiveMessageT, type(TypeVar("test"))) + assert ReceiveMessageT.__name__ == "ReceiveMessageT" + assert ReceiveMessageT.__bound__ is Any + assert ReceiveMessageT.__constraints__ == () + + +class TestInterProcessMessaging: + """Test suite for InterProcessMessaging abstract base class.""" + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InterProcessMessaging abstract class signatures.""" + assert hasattr(InterProcessMessaging, "__init__") + assert hasattr(InterProcessMessaging, "create_worker_copy") + assert hasattr(InterProcessMessaging, "create_send_messages_threads") + assert hasattr(InterProcessMessaging, "create_receive_messages_threads") + assert hasattr(InterProcessMessaging, "start") + assert hasattr(InterProcessMessaging, "stop") + assert hasattr(InterProcessMessaging, "get") + assert hasattr(InterProcessMessaging, "put") + + # Check abstract methods + assert getattr( + InterProcessMessaging.create_worker_copy, "__isabstractmethod__", False + ) + assert getattr( + InterProcessMessaging.create_send_messages_threads, + "__isabstractmethod__", + False, + ) + assert getattr( + InterProcessMessaging.create_receive_messages_threads, + "__isabstractmethod__", + False, + ) + + @pytest.mark.smoke + def test_cannot_instantiate_directly(self): + """Test InterProcessMessaging cannot be instantiated directly.""" + with pytest.raises(TypeError): + InterProcessMessaging() + + +class TestInterProcessMessagingQueue: + """Test suite for InterProcessMessagingQueue.""" + + @pytest.fixture( + params=[ + { + "serialization": "dict", + "encoding": None, + "max_pending_size": None, + "max_done_size": None, + "worker_index": None, + }, + { + "serialization": "sequence", + "encoding": None, + "max_pending_size": 10, + "max_buffer_send_size": 2, + "max_done_size": 5, + "max_buffer_receive_size": 3, + "worker_index": None, + }, + { + "serialization": None, + "encoding": None, + "max_pending_size": None, + "max_done_size": None, + "worker_index": None, + }, + ], + ) + def valid_instances(self, multiprocessing_contexts, request): + """Fixture providing test data for InterProcessMessagingQueue.""" + constructor_args = request.param + manager, context = multiprocessing_contexts + instance = InterProcessMessagingQueue( + **constructor_args, poll_interval=0.01, mp_context=context + ) + + return instance, constructor_args, manager, context + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InterProcessMessagingQueue inheritance and signatures.""" + assert issubclass(InterProcessMessagingQueue, InterProcessMessaging) + assert hasattr(InterProcessMessagingQueue, "__init__") + assert hasattr(InterProcessMessagingQueue, "create_worker_copy") + assert hasattr(InterProcessMessagingQueue, "create_send_messages_threads") + assert hasattr(InterProcessMessagingQueue, "create_receive_messages_threads") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InterProcessMessagingQueue initialization.""" + instance, constructor_args, _, _ = valid_instances + + assert isinstance(instance, InterProcessMessagingQueue) + assert instance.worker_index == constructor_args["worker_index"] + assert instance.max_pending_size == constructor_args["max_pending_size"] + assert instance.max_done_size == constructor_args["max_done_size"] + assert hasattr(instance, "pending_queue") + assert hasattr(instance, "done_queue") + assert instance.running is False + + @pytest.mark.smoke + def test_create_worker_copy(self, valid_instances): + """Test InterProcessMessagingQueue.create_worker_copy.""" + instance, _, _, _ = valid_instances + worker_index = 42 + + worker_copy = instance.create_worker_copy(worker_index) + + assert isinstance(worker_copy, InterProcessMessagingQueue) + assert worker_copy.worker_index == worker_index + assert worker_copy.pending_queue is instance.pending_queue + assert worker_copy.done_queue is instance.done_queue + assert worker_copy.max_pending_size == instance.max_pending_size + assert worker_copy.max_done_size == instance.max_done_size + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "stop_events_lambda", + [ + list, + lambda: [threading.Event()], + lambda: [multiprocessing.Event()], + lambda: [threading.Event(), multiprocessing.Event()], + ], + ) + @async_timeout(5.0) + async def test_start_stop_lifecycle(self, valid_instances, stop_events_lambda): + """Test InterProcessMessagingQueue start/stop lifecycle.""" + instance, _, _, _ = valid_instances + stop_events = stop_events_lambda() + + # Initially not running + assert instance.running is False + assert instance.send_stopped_event is None + assert instance.receive_stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + # Start should work + await instance.start( + send_stop_criteria=stop_events, receive_stop_criteria=stop_events + ) + assert instance.running is True + assert instance.send_stopped_event is not None + assert isinstance(instance.send_stopped_event, threading.Event) + assert instance.receive_stopped_event is not None + assert isinstance(instance.receive_stopped_event, threading.Event) + assert instance.shutdown_event is not None + assert isinstance(instance.shutdown_event, threading.Event) + assert instance.buffer_send_queue is not None + assert isinstance(instance.buffer_send_queue, culsans.Queue) + assert instance.buffer_receive_queue is not None + assert isinstance(instance.buffer_receive_queue, culsans.Queue) + assert instance.send_task is not None + assert isinstance(instance.send_task, asyncio.Task) + assert instance.receive_task is not None + assert isinstance(instance.receive_task, asyncio.Task) + + # Stop should work + if stop_events: + for event in stop_events: + event.set() + + await asyncio.sleep(0.1) + assert instance.send_stopped_event.is_set() + assert instance.receive_stopped_event.is_set() + assert instance.send_task.done() + assert instance.receive_task.done() + + await instance.stop() + assert instance.running is False + assert instance.send_stopped_event is None + assert instance.receive_stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + 123451, + "asdfghjkl", + [None, 123, 45.67, "string", {"key": "value"}, [1, 2, 3]], + {"key": "value", "another_key": 123.456, "yet_another_key": [1, 2, 3]}, + MockMessage(content="hello", num=42), + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ( + GenerationResponse(request_id="id", request_args={}), + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get(self, valid_instances, test_obj): + instance, constructor_args, manager, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + # Handle case where ScheduledRequestInfo is not pickleable + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + process_target = MockProcessTarget( + instance.create_worker_copy(0), num_messages=5 + ) + process = context.Process(target=process_target.run) + process.start() + + # Local startup and wait + await instance.start( + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo, + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5): + await instance.put(test_obj, timeout=2.0) + + for _ in range(5): + val = await instance.get(timeout=2.0) + if not isinstance(test_obj, tuple): + assert val == test_obj + else: + assert list(val) == list(test_obj) + finally: + # Clean up + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ( + GenerationResponse(request_id="id", request_args={}), + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get_iter(self, valid_instances, test_obj): + instance, constructor_args, manager, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + # Handle case where ScheduledRequestInfo is not pickleable + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + process_target = MockProcessTarget( + instance.create_worker_copy(0), num_messages=5 + ) + process = context.Process(target=process_target.run) + process.start() + + def _received_callback(msg): + if not isinstance(test_obj, tuple): + assert msg == test_obj + else: + assert list(msg) == list(test_obj) + return "changed_obj" + + # Local startup and wait + await instance.start( + send_items=[test_obj for _ in range(5)], + receive_callback=_received_callback, + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo, + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5): + val = await instance.get(timeout=2.0) + assert val == "changed_obj" + finally: + # Clean up + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() + + +class TestInterProcessMessagingManagerQueue: + """Test suite for InterProcessMessagingManagerQueue.""" + + @pytest.fixture( + params=[ + { + "serialization": "dict", + "encoding": None, + "max_pending_size": None, + "max_done_size": None, + "worker_index": None, + }, + { + "serialization": "sequence", + "encoding": None, + "max_pending_size": 10, + "max_buffer_send_size": 2, + "max_done_size": 5, + "max_buffer_receive_size": 3, + "worker_index": None, + }, + { + "serialization": None, + "encoding": None, + "max_pending_size": None, + "max_done_size": None, + "worker_index": None, + }, + ], + ) + def valid_instances(self, multiprocessing_contexts, request): + """Fixture providing test data for InterProcessMessagingManagerQueue.""" + constructor_args = request.param + manager, context = multiprocessing_contexts + instance = InterProcessMessagingManagerQueue( + **constructor_args, manager=manager, poll_interval=0.01 + ) + return instance, constructor_args, manager, context + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InterProcessMessagingManagerQueue inheritance and signatures.""" + assert issubclass(InterProcessMessagingManagerQueue, InterProcessMessaging) + assert issubclass(InterProcessMessagingManagerQueue, InterProcessMessagingQueue) + assert hasattr(InterProcessMessagingManagerQueue, "__init__") + assert hasattr(InterProcessMessagingManagerQueue, "create_worker_copy") + assert hasattr(InterProcessMessagingManagerQueue, "_send_messages_task_thread") + assert hasattr( + InterProcessMessagingManagerQueue, "_receive_messages_task_thread" + ) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InterProcessMessagingManagerQueue initialization.""" + instance, constructor_args, _, _ = valid_instances + + assert isinstance(instance, InterProcessMessagingManagerQueue) + assert instance.worker_index == constructor_args["worker_index"] + assert instance.max_pending_size == constructor_args["max_pending_size"] + assert instance.max_done_size == constructor_args["max_done_size"] + assert hasattr(instance, "pending_queue") + assert hasattr(instance, "done_queue") + assert instance.running is False + + @pytest.mark.smoke + def test_create_worker_copy(self, valid_instances): + """Test InterProcessMessagingQueue.create_worker_copy.""" + instance, _, _, _ = valid_instances + worker_index = 42 + + worker_copy = instance.create_worker_copy(worker_index) + + assert isinstance(worker_copy, InterProcessMessagingManagerQueue) + assert worker_copy.worker_index == worker_index + assert worker_copy.pending_queue is instance.pending_queue + assert worker_copy.done_queue is instance.done_queue + assert worker_copy.max_pending_size == instance.max_pending_size + assert worker_copy.max_done_size == instance.max_done_size + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "stop_events_lambda", + [ + list, + lambda: [threading.Event()], + lambda: [multiprocessing.Event()], + lambda: [threading.Event(), multiprocessing.Event()], + ], + ) + @async_timeout(5.0) + async def test_start_stop_lifecycle(self, valid_instances, stop_events_lambda): + """Test InterProcessMessagingQueue start/stop lifecycle.""" + instance, _, _, _ = valid_instances + stop_events = stop_events_lambda() + + # Initially not running + assert instance.running is False + assert instance.send_stopped_event is None + assert instance.receive_stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + # Start should work + await instance.start( + send_stop_criteria=stop_events, receive_stop_criteria=stop_events + ) + assert instance.running is True + assert instance.send_stopped_event is not None + assert isinstance(instance.send_stopped_event, threading.Event) + assert instance.receive_stopped_event is not None + assert isinstance(instance.receive_stopped_event, threading.Event) + assert instance.shutdown_event is not None + assert isinstance(instance.shutdown_event, threading.Event) + assert instance.buffer_send_queue is not None + assert isinstance(instance.buffer_send_queue, culsans.Queue) + assert instance.buffer_receive_queue is not None + assert isinstance(instance.buffer_receive_queue, culsans.Queue) + assert instance.send_task is not None + assert isinstance(instance.send_task, asyncio.Task) + assert instance.receive_task is not None + assert isinstance(instance.receive_task, asyncio.Task) + + # Stop should work + if stop_events: + for event in stop_events: + event.set() + + await asyncio.sleep(0.1) + assert instance.send_stopped_event.is_set() + assert instance.receive_stopped_event.is_set() + assert instance.send_task.done() + assert instance.receive_task.done() + + await instance.stop() + assert instance.running is False + assert instance.send_stopped_event is None + assert instance.receive_stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + 123451, + "asdfghjkl", + [None, 123, 45.67, "string", {"key": "value"}, [1, 2, 3]], + {"key": "value", "another_key": 123.456, "yet_another_key": [1, 2, 3]}, + MockMessage(content="hello", num=42), + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get(self, valid_instances, test_obj): + instance, constructor_args, _, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + # Handle case where ScheduledRequestInfo is not pickleable + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + process_target = MockProcessTarget( + instance.create_worker_copy(0), num_messages=5 + ) + process = context.Process(target=process_target.run) + process.start() + + # Local startup and wait + await instance.start( + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo, + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5): + await instance.put(test_obj, timeout=2.0) + + for _ in range(5): + val = await instance.get(timeout=2.0) + if not isinstance(test_obj, tuple): + assert val == test_obj + else: + assert list(val) == list(test_obj) + finally: + # Clean up + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ( + GenerationResponse(request_id="id", request_args={}), + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get_iter(self, valid_instances, test_obj): + instance, constructor_args, _, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + # Handle case where ScheduledRequestInfo is not pickleable + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + process_target = MockProcessTarget( + instance.create_worker_copy(0), num_messages=5 + ) + process = context.Process(target=process_target.run) + process.start() + + def _received_callback(msg): + if not isinstance(test_obj, tuple): + assert msg == test_obj + else: + assert list(msg) == list(test_obj) + return "changed_obj" + + # Local startup and wait + await instance.start( + send_items=[test_obj for _ in range(5)], + receive_callback=_received_callback, + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo, + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5): + val = await instance.get(timeout=2.0) + assert val == "changed_obj" + finally: + # Clean up + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() + + +class TestInterProcessMessagingPipe: + """Test suite for InterProcessMessagingPipe.""" + + @pytest.fixture( + params=[ + { + "num_workers": 2, + "serialization": "dict", + "encoding": None, + "max_pending_size": None, + "max_done_size": None, + "worker_index": None, + }, + { + "num_workers": 1, + "serialization": "sequence", + "encoding": None, + "max_pending_size": 10, + "max_buffer_send_size": 2, + "max_done_size": 5, + "max_buffer_receive_size": 3, + "worker_index": None, + }, + { + "num_workers": 1, + "serialization": None, + "encoding": None, + "max_pending_size": None, + "max_done_size": None, + "worker_index": None, + }, + ], + ) + def valid_instances(self, multiprocessing_contexts, request): + """Fixture providing test data for InterProcessMessagingPipe.""" + constructor_args = request.param + manager, context = multiprocessing_contexts + instance = InterProcessMessagingPipe(**constructor_args, poll_interval=0.01) + return instance, constructor_args, manager, context + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InterProcessMessagingPipe inheritance and signatures.""" + assert issubclass(InterProcessMessagingPipe, InterProcessMessaging) + assert hasattr(InterProcessMessagingPipe, "__init__") + assert hasattr(InterProcessMessagingPipe, "create_worker_copy") + assert hasattr(InterProcessMessagingPipe, "_send_messages_task_thread") + assert hasattr(InterProcessMessagingPipe, "_receive_messages_task_thread") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InterProcessMessagingPipe initialization.""" + instance, constructor_args, _, _ = valid_instances + + assert isinstance(instance, InterProcessMessagingPipe) + assert instance.worker_index == constructor_args["worker_index"] + assert instance.max_pending_size == constructor_args["max_pending_size"] + assert instance.max_done_size == constructor_args["max_done_size"] + assert instance.num_workers == constructor_args["num_workers"] + assert hasattr(instance, "pipes") + assert len(instance.pipes) == constructor_args["num_workers"] + assert len(instance.pipes) == constructor_args["num_workers"] + assert instance.running is False + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("kwargs", "expected_error"), + [ + ({"invalid_param": "value"}, TypeError), + ({"num_workers": 1, "unknown_arg": "test"}, TypeError), + ], + ) + def test_invalid_initialization_values(self, kwargs, expected_error): + """Test InterProcessMessagingPipe with invalid field values.""" + with pytest.raises(expected_error): + InterProcessMessagingPipe(**kwargs) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test InterProcessMessagingPipe initialization without required field.""" + with pytest.raises(TypeError): + InterProcessMessagingPipe() + + @pytest.mark.smoke + def test_create_worker_copy(self, valid_instances): + """Test InterProcessMessagingPipe.create_worker_copy.""" + instance, _, _, _ = valid_instances + worker_index = 0 + + worker_copy = instance.create_worker_copy(worker_index) + + assert isinstance(worker_copy, InterProcessMessagingPipe) + assert worker_copy.worker_index == worker_index + assert worker_copy.pipes[0] is instance.pipes[worker_index] + assert worker_copy.max_pending_size == instance.max_pending_size + assert worker_copy.max_done_size == instance.max_done_size + assert worker_copy.num_workers == instance.num_workers + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(5.0) + async def test_start_stop_lifecycle(self, valid_instances): + """Test InterProcessMessagingPipe start/stop lifecycle.""" + instance, _, _, _ = valid_instances + stop_events = [] + + # Initially not running + assert instance.running is False + assert instance.send_stopped_event is None + assert instance.receive_stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + # Start should work + await instance.start( + send_stop_criteria=stop_events, receive_stop_criteria=stop_events + ) + assert instance.running is True + assert instance.send_stopped_event is not None + assert isinstance(instance.send_stopped_event, threading.Event) + assert instance.receive_stopped_event is not None + assert isinstance(instance.receive_stopped_event, threading.Event) + assert instance.shutdown_event is not None + assert isinstance(instance.shutdown_event, threading.Event) + assert instance.buffer_send_queue is not None + assert isinstance(instance.buffer_send_queue, culsans.Queue) + assert instance.buffer_receive_queue is not None + assert isinstance(instance.buffer_receive_queue, culsans.Queue) + assert instance.send_task is not None + assert isinstance(instance.send_task, asyncio.Task) + assert instance.receive_task is not None + assert isinstance(instance.receive_task, asyncio.Task) + + # Stop should work + await instance.stop() + assert instance.running is False + assert instance.send_stopped_event is None + assert instance.receive_stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + 123451, + "asdfghjkl", + [None, 123, 45.67, "string", {"key": "value"}, [1, 2, 3]], + {"key": "value", "another_key": 123.456, "yet_another_key": [1, 2, 3]}, + MockMessage(content="hello", num=42), + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ( + GenerationResponse(request_id="id", request_args={}), + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo(), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get(self, valid_instances, test_obj): + instance, constructor_args, manager, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + processes = [] + for index in range(constructor_args["num_workers"]): + process_target = MockProcessTarget( + instance.create_worker_copy(index), num_messages=5 + ) + process = context.Process(target=process_target.run) + processes.append(process) + process.start() + + # Local startup and wait + await instance.start( + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo, + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5 * constructor_args["num_workers"]): + await instance.put(test_obj, timeout=2.0) + + for _ in range(5 * constructor_args["num_workers"]): + val = await instance.get(timeout=2.0) + if not isinstance(test_obj, tuple): + assert val == test_obj + else: + assert list(val) == list(test_obj) + finally: + # Clean up + for process in processes: + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() diff --git a/tests/unit/utils/test_mixins.py b/tests/unit/utils/test_mixins.py new file mode 100644 index 00000000..cd8990de --- /dev/null +++ b/tests/unit/utils/test_mixins.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +import pytest + +from guidellm.utils.mixins import InfoMixin + + +class TestInfoMixin: + """Test suite for InfoMixin.""" + + @pytest.fixture( + params=[ + {"attr_one": "test_value", "attr_two": 42}, + {"attr_one": "hello_world", "attr_two": 100, "attr_three": [1, 2, 3]}, + ], + ids=["basic_attributes", "extended_attributes"], + ) + def valid_instances(self, request): + """Fixture providing test data for InfoMixin.""" + constructor_args = request.param + + class TestClass(InfoMixin): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + instance = TestClass(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InfoMixin class signatures and methods.""" + assert hasattr(InfoMixin, "extract_from_obj") + assert callable(InfoMixin.extract_from_obj) + assert hasattr(InfoMixin, "create_info_dict") + assert callable(InfoMixin.create_info_dict) + assert hasattr(InfoMixin, "info") + assert isinstance(InfoMixin.info, property) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InfoMixin initialization through inheritance.""" + instance, constructor_args = valid_instances + assert isinstance(instance, InfoMixin) + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.smoke + def test_info_property(self, valid_instances): + """Test InfoMixin.info property.""" + instance, constructor_args = valid_instances + result = instance.info + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert "class" in result + assert "module" in result + assert "attributes" in result + assert result["type"] == "TestClass" + assert result["class"] == "TestClass" + assert isinstance(result["attributes"], dict) + for key, value in constructor_args.items(): + assert key in result["attributes"] + assert result["attributes"][key] == value + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("obj_data", "expected_attributes"), + [ + ({"name": "test", "value": 42}, {"name": "test", "value": 42}), + ({"data": [1, 2, 3], "flag": True}, {"data": [1, 2, 3], "flag": True}), + ({"nested": {"key": "value"}}, {"nested": {"key": "value"}}), + ], + ) + def test_create_info_dict(self, obj_data, expected_attributes): + """Test InfoMixin.create_info_dict class method.""" + + class SimpleObject: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + obj = SimpleObject(**obj_data) + result = InfoMixin.create_info_dict(obj) + + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert "class" in result + assert "module" in result + assert "attributes" in result + assert result["type"] == "SimpleObject" + assert result["class"] == "SimpleObject" + assert result["attributes"] == expected_attributes + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("obj_data", "expected_attributes"), + [ + ({"name": "test", "value": 42}, {"name": "test", "value": 42}), + ({"data": [1, 2, 3], "flag": True}, {"data": [1, 2, 3], "flag": True}), + ], + ) + def test_extract_from_obj_without_info(self, obj_data, expected_attributes): + """Test InfoMixin.extract_from_obj with objects without info method.""" + + class SimpleObject: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + obj = SimpleObject(**obj_data) + result = InfoMixin.extract_from_obj(obj) + + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert "class" in result + assert "module" in result + assert "attributes" in result + assert result["type"] == "SimpleObject" + assert result["class"] == "SimpleObject" + assert result["attributes"] == expected_attributes + + @pytest.mark.smoke + def test_extract_from_obj_with_info_method(self): + """Test InfoMixin.extract_from_obj with objects that have info method.""" + + class ObjectWithInfoMethod: + def info(self): + return {"custom": "info_method", "type": "custom_type"} + + obj = ObjectWithInfoMethod() + result = InfoMixin.extract_from_obj(obj) + + assert result == {"custom": "info_method", "type": "custom_type"} + + @pytest.mark.smoke + def test_extract_from_obj_with_info_property(self): + """Test InfoMixin.extract_from_obj with objects that have info property.""" + + class ObjectWithInfoProperty: + @property + def info(self): + return {"custom": "info_property", "type": "custom_type"} + + obj = ObjectWithInfoProperty() + result = InfoMixin.extract_from_obj(obj) + + assert result == {"custom": "info_property", "type": "custom_type"} + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("obj_type", "obj_value"), + [ + (str, "test_string"), + (int, 42), + (float, 3.14), + (list, [1, 2, 3]), + (dict, {"key": "value"}), + ], + ) + def test_extract_from_obj_builtin_types(self, obj_type, obj_value): + """Test InfoMixin.extract_from_obj with built-in types.""" + result = InfoMixin.extract_from_obj(obj_value) + + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert result["type"] == obj_type.__name__ + assert result["str"] == str(obj_value) + + @pytest.mark.sanity + def test_extract_from_obj_without_dict(self): + """Test InfoMixin.extract_from_obj with objects without __dict__.""" + obj = 42 + result = InfoMixin.extract_from_obj(obj) + + assert isinstance(result, dict) + assert "attributes" in result + assert result["attributes"] == {} + assert result["type"] == "int" + assert result["str"] == "42" + + @pytest.mark.sanity + def test_extract_from_obj_with_private_attributes(self): + """Test InfoMixin.extract_from_obj filters private attributes.""" + + class ObjectWithPrivate: + def __init__(self): + self.public_attr = "public" + self._private_attr = "private" + self.__very_private = "very_private" + + obj = ObjectWithPrivate() + result = InfoMixin.extract_from_obj(obj) + + assert "public_attr" in result["attributes"] + assert result["attributes"]["public_attr"] == "public" + assert "_private_attr" not in result["attributes"] + assert "__very_private" not in result["attributes"] + + @pytest.mark.sanity + def test_extract_from_obj_complex_attributes(self): + """Test InfoMixin.extract_from_obj with complex attribute types.""" + + class ComplexObject: + def __init__(self): + self.simple_str = "test" + self.simple_int = 42 + self.simple_list = [1, 2, 3] + self.simple_dict = {"key": "value"} + self.complex_object = object() + + obj = ComplexObject() + result = InfoMixin.extract_from_obj(obj) + + attributes = result["attributes"] + assert attributes["simple_str"] == "test" + assert attributes["simple_int"] == 42 + assert attributes["simple_list"] == [1, 2, 3] + assert attributes["simple_dict"] == {"key": "value"} + assert isinstance(attributes["complex_object"], str) + + @pytest.mark.regression + def test_create_info_dict_consistency(self, valid_instances): + """Test InfoMixin.create_info_dict produces consistent results.""" + instance, _ = valid_instances + + result1 = InfoMixin.create_info_dict(instance) + result2 = InfoMixin.create_info_dict(instance) + + assert result1 == result2 + assert result1 is not result2 + + @pytest.mark.regression + def test_info_property_uses_create_info_dict(self, valid_instances): + """Test InfoMixin.info property uses create_info_dict method.""" + instance, _ = valid_instances + + info_result = instance.info + create_result = InfoMixin.create_info_dict(instance) + + assert info_result == create_result diff --git a/tests/unit/utils/test_pydantic_utils.py b/tests/unit/utils/test_pydantic_utils.py new file mode 100644 index 00000000..726b5ddf --- /dev/null +++ b/tests/unit/utils/test_pydantic_utils.py @@ -0,0 +1,1002 @@ +""" +Unit tests for the pydantic_utils module. +""" + +from __future__ import annotations + +from typing import ClassVar, TypeVar +from unittest import mock + +import pytest +from pydantic import BaseModel, Field, ValidationError + +from guidellm.utils import ( + PydanticClassRegistryMixin, + ReloadableBaseModel, + StandardBaseDict, + StandardBaseModel, + StatusBreakdown, +) +from guidellm.utils.pydantic_utils import ( + BaseModelT, + ErroredT, + IncompleteT, + RegisterClassT, + SuccessfulT, + TotalT, +) + + +@pytest.mark.smoke +def test_base_model_t(): + """Test that BaseModelT is configured correctly as a TypeVar.""" + assert isinstance(BaseModelT, type(TypeVar("test"))) + assert BaseModelT.__name__ == "BaseModelT" + assert BaseModelT.__bound__ is BaseModel + assert BaseModelT.__constraints__ == () + + +@pytest.mark.smoke +def test_register_class_t(): + """Test that RegisterClassT is configured correctly as a TypeVar.""" + assert isinstance(RegisterClassT, type(TypeVar("test"))) + assert RegisterClassT.__name__ == "RegisterClassT" + assert RegisterClassT.__bound__ is None + assert RegisterClassT.__constraints__ == () + + +@pytest.mark.smoke +def test_successful_t(): + """Test that SuccessfulT is configured correctly as a TypeVar.""" + assert isinstance(SuccessfulT, type(TypeVar("test"))) + assert SuccessfulT.__name__ == "SuccessfulT" + assert SuccessfulT.__bound__ is None + assert SuccessfulT.__constraints__ == () + + +@pytest.mark.smoke +def test_errored_t(): + """Test that ErroredT is configured correctly as a TypeVar.""" + assert isinstance(ErroredT, type(TypeVar("test"))) + assert ErroredT.__name__ == "ErroredT" + assert ErroredT.__bound__ is None + assert ErroredT.__constraints__ == () + + +@pytest.mark.smoke +def test_incomplete_t(): + """Test that IncompleteT is configured correctly as a TypeVar.""" + assert isinstance(IncompleteT, type(TypeVar("test"))) + assert IncompleteT.__name__ == "IncompleteT" + assert IncompleteT.__bound__ is None + assert IncompleteT.__constraints__ == () + + +@pytest.mark.smoke +def test_total_t(): + """Test that TotalT is configured correctly as a TypeVar.""" + assert isinstance(TotalT, type(TypeVar("test"))) + assert TotalT.__name__ == "TotalT" + assert TotalT.__bound__ is None + assert TotalT.__constraints__ == () + + +class TestReloadableBaseModel: + """Test suite for ReloadableBaseModel.""" + + @pytest.fixture( + params=[ + {"name": "test_value"}, + {"name": "hello_world"}, + {"name": "another_test"}, + ], + ids=["basic_string", "multi_word", "underscore"], + ) + def valid_instances(self, request) -> tuple[ReloadableBaseModel, dict[str, str]]: + """Fixture providing test data for ReloadableBaseModel.""" + + class TestModel(ReloadableBaseModel): + name: str + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ReloadableBaseModel inheritance and class variables.""" + assert issubclass(ReloadableBaseModel, BaseModel) + assert hasattr(ReloadableBaseModel, "model_config") + assert hasattr(ReloadableBaseModel, "reload_schema") + + # Check model configuration + config = ReloadableBaseModel.model_config + assert config["extra"] == "ignore" + assert config["use_enum_values"] is True + assert config["from_attributes"] is True + assert config["arbitrary_types_allowed"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test ReloadableBaseModel initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ReloadableBaseModel) + assert instance.name == constructor_args["name"] # type: ignore[attr-defined] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("name", None), + ("name", 123), + ("name", []), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test ReloadableBaseModel with invalid field values.""" + + class TestModel(ReloadableBaseModel): + name: str + + data = {field: value} + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test ReloadableBaseModel initialization without required field.""" + + class TestModel(ReloadableBaseModel): + name: str + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_reload_schema(self): + """Test ReloadableBaseModel.reload_schema method.""" + + class TestModel(ReloadableBaseModel): + name: str + + # Mock the model_rebuild method to simulate schema reload + with mock.patch.object(TestModel, "model_rebuild") as mock_rebuild: + TestModel.reload_schema() + mock_rebuild.assert_called_once_with(force=True) + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test ReloadableBaseModel serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["name"] == constructor_args["name"] + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.name == constructor_args["name"] + + +class TestStandardBaseModel: + """Test suite for StandardBaseModel.""" + + @pytest.fixture( + params=[ + {"field_str": "test_value", "field_int": 42}, + {"field_str": "hello_world", "field_int": 100}, + {"field_str": "another_test", "field_int": 0}, + ], + ids=["basic_values", "positive_values", "zero_value"], + ) + def valid_instances( + self, request + ) -> tuple[StandardBaseModel, dict[str, int | str]]: + """Fixture providing test data for StandardBaseModel.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StandardBaseModel inheritance and class variables.""" + assert issubclass(StandardBaseModel, BaseModel) + assert hasattr(StandardBaseModel, "model_config") + assert hasattr(StandardBaseModel, "get_default") + + # Check model configuration + config = StandardBaseModel.model_config + assert config["extra"] == "ignore" + assert config["use_enum_values"] is True + assert config["from_attributes"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StandardBaseModel initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StandardBaseModel) + assert instance.field_str == constructor_args["field_str"] # type: ignore[attr-defined] + assert instance.field_int == constructor_args["field_int"] # type: ignore[attr-defined] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("field_str", None), + ("field_str", 123), + ("field_int", "not_int"), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test StandardBaseModel with invalid field values.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + data = {field: value} + if field == "field_str": + data["field_int"] = 42 + else: + data["field_str"] = "test" + + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test StandardBaseModel initialization without required field.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_get_default(self): + """Test StandardBaseModel.get_default method.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=42, description="Test integer field") + + default_value = TestModel.get_default("field_int") + assert default_value == 42 + + @pytest.mark.sanity + def test_get_default_invalid(self): + """Test StandardBaseModel.get_default with invalid field.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + + with pytest.raises(KeyError): + TestModel.get_default("nonexistent_field") + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StandardBaseModel serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["field_str"] == constructor_args["field_str"] + assert data_dict["field_int"] == constructor_args["field_int"] + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.field_str == constructor_args["field_str"] + assert recreated.field_int == constructor_args["field_int"] + + +class TestStandardBaseDict: + """Test suite for StandardBaseDict.""" + + @pytest.fixture( + params=[ + {"field_str": "test_value", "extra_field": "extra_value"}, + {"field_str": "hello_world", "another_extra": 123}, + {"field_str": "another_test", "complex_extra": {"nested": "value"}}, + ], + ids=["string_extra", "int_extra", "dict_extra"], + ) + def valid_instances( + self, request + ) -> tuple[StandardBaseDict, dict[str, str | int | dict[str, str]]]: + """Fixture providing test data for StandardBaseDict.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StandardBaseDict inheritance and class variables.""" + assert issubclass(StandardBaseDict, StandardBaseModel) + assert hasattr(StandardBaseDict, "model_config") + + # Check model configuration + config = StandardBaseDict.model_config + assert config["extra"] == "allow" + assert config["use_enum_values"] is True + assert config["from_attributes"] is True + assert config["arbitrary_types_allowed"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StandardBaseDict initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StandardBaseDict) + assert instance.field_str == constructor_args["field_str"] # type: ignore[attr-defined] + + # Check extra fields are preserved + for key, value in constructor_args.items(): + if key != "field_str": + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("field_str", None), + ("field_str", 123), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test StandardBaseDict with invalid field values.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + data = {field: value} + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test StandardBaseDict initialization without required field.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StandardBaseDict serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["field_str"] == constructor_args["field_str"] + + # Check extra fields are in the serialized data + for key, value in constructor_args.items(): + if key != "field_str": + assert key in data_dict + assert data_dict[key] == value + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.field_str == constructor_args["field_str"] + + # Check extra fields are preserved after deserialization + for key, value in constructor_args.items(): + if key != "field_str": + assert hasattr(recreated, key) + assert getattr(recreated, key) == value + + +class TestStatusBreakdown: + """Test suite for StatusBreakdown.""" + + @pytest.fixture( + params=[ + {"successful": 100, "errored": 5, "incomplete": 10, "total": 115}, + { + "successful": "success_data", + "errored": "error_data", + "incomplete": "incomplete_data", + "total": "total_data", + }, + { + "successful": [1, 2, 3], + "errored": [4, 5], + "incomplete": [6], + "total": [1, 2, 3, 4, 5, 6], + }, + ], + ids=["int_values", "string_values", "list_values"], + ) + def valid_instances(self, request) -> tuple[StatusBreakdown, dict]: + """Fixture providing test data for StatusBreakdown.""" + constructor_args = request.param + instance = StatusBreakdown(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StatusBreakdown inheritance and type relationships.""" + assert issubclass(StatusBreakdown, BaseModel) + # Check if Generic is in the MRO (method resolution order) + assert any(cls.__name__ == "Generic" for cls in StatusBreakdown.__mro__) + assert "successful" in StatusBreakdown.model_fields + assert "errored" in StatusBreakdown.model_fields + assert "incomplete" in StatusBreakdown.model_fields + assert "total" in StatusBreakdown.model_fields + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StatusBreakdown initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StatusBreakdown) + assert instance.successful == constructor_args["successful"] + assert instance.errored == constructor_args["errored"] + assert instance.incomplete == constructor_args["incomplete"] + assert instance.total == constructor_args["total"] + + @pytest.mark.smoke + def test_initialization_defaults(self): + """Test StatusBreakdown initialization with default values.""" + instance: StatusBreakdown = StatusBreakdown() + assert instance.successful is None + assert instance.errored is None + assert instance.incomplete is None + assert instance.total is None + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StatusBreakdown serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["successful"] == constructor_args["successful"] + assert data_dict["errored"] == constructor_args["errored"] + assert data_dict["incomplete"] == constructor_args["incomplete"] + assert data_dict["total"] == constructor_args["total"] + + recreated: StatusBreakdown = StatusBreakdown.model_validate(data_dict) + assert isinstance(recreated, StatusBreakdown) + assert recreated.successful == constructor_args["successful"] + assert recreated.errored == constructor_args["errored"] + assert recreated.incomplete == constructor_args["incomplete"] + assert recreated.total == constructor_args["total"] + + +class TestPydanticClassRegistryMixin: + """Test suite for PydanticClassRegistryMixin.""" + + @pytest.fixture( + params=[ + {"test_type": "test_sub", "value": "test_value"}, + {"test_type": "test_sub", "value": "hello_world"}, + ], + ids=["basic_value", "multi_word"], + ) + def valid_instances( + self, request + ) -> tuple[PydanticClassRegistryMixin, dict, type, type]: + """Fixture providing test data for PydanticClassRegistryMixin.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + TestBaseModel.reload_schema() + + constructor_args = request.param + instance = TestSubModel(value=constructor_args["value"]) + return instance, constructor_args, TestBaseModel, TestSubModel + + @pytest.mark.smoke + def test_class_signatures(self): + """Test PydanticClassRegistryMixin inheritance and class variables.""" + assert issubclass(PydanticClassRegistryMixin, ReloadableBaseModel) + assert hasattr(PydanticClassRegistryMixin, "schema_discriminator") + assert PydanticClassRegistryMixin.schema_discriminator == "model_type" + assert hasattr(PydanticClassRegistryMixin, "register_decorator") + assert hasattr(PydanticClassRegistryMixin, "__get_pydantic_core_schema__") + assert hasattr(PydanticClassRegistryMixin, "__pydantic_generate_base_schema__") + assert hasattr(PydanticClassRegistryMixin, "auto_populate_registry") + assert hasattr(PydanticClassRegistryMixin, "registered_classes") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test PydanticClassRegistryMixin initialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + assert isinstance(instance, sub_class) + assert isinstance(instance, base_class) + assert instance.test_type == constructor_args["test_type"] + assert instance.value == constructor_args["value"] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("test_type", None), + ("test_type", 123), + ("value", None), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test PydanticClassRegistryMixin with invalid field values.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + data = {field: value} + if field == "test_type": + data["value"] = "test" + else: + data["test_type"] = "test_sub" + + with pytest.raises(ValidationError): + TestSubModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test PydanticClassRegistryMixin initialization without required field.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + with pytest.raises(ValidationError): + TestSubModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_register_decorator(self): + """Test PydanticClassRegistryMixin.register_decorator method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register() + class TestSubModel(TestBaseModel): + test_type: str = "TestSubModel" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "TestSubModel" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["TestSubModel"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_with_name(self): + """Test PydanticClassRegistryMixin.register_decorator with custom name.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("custom_name") + class TestSubModel(TestBaseModel): + test_type: str = "custom_name" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "custom_name" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["custom_name"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_invalid_type(self): + """Test PydanticClassRegistryMixin.register_decorator with invalid type.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + class RegularClass: + pass + + with pytest.raises(TypeError) as exc_info: + TestBaseModel.register_decorator(RegularClass) # type: ignore[arg-type] + + assert "not a subclass of Pydantic BaseModel" in str(exc_info.value) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test PydanticClassRegistryMixin.auto_populate_registry method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = True + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + with ( + mock.patch.object(TestBaseModel, "reload_schema") as mock_reload, + mock.patch( + "guidellm.utils.registry.RegistryMixin.auto_populate_registry", + return_value=True, + ), + ): + result = TestBaseModel.auto_populate_registry() + assert result is True + mock_reload.assert_called_once() + + @pytest.mark.smoke + def test_registered_classes(self): + """Test PydanticClassRegistryMixin.registered_classes method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = False + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub_a") + class TestSubModelA(TestBaseModel): + test_type: str = "test_sub_a" + value_a: str + + @TestBaseModel.register("test_sub_b") + class TestSubModelB(TestBaseModel): + test_type: str = "test_sub_b" + value_b: int + + # Test normal case with registered classes + registered = TestBaseModel.registered_classes() + assert isinstance(registered, tuple) + assert len(registered) == 2 + assert TestSubModelA in registered + assert TestSubModelB in registered + + @pytest.mark.sanity + def test_registered_classes_with_auto_discovery(self): + """Test PydanticClassRegistryMixin.registered_classes with auto discovery.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = True + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + with mock.patch.object( + TestBaseModel, "auto_populate_registry" + ) as mock_auto_populate: + # Mock the registry to simulate registered classes + TestBaseModel.registry = {"test_class": type("TestClass", (), {})} + mock_auto_populate.return_value = False + + registered = TestBaseModel.registered_classes() + mock_auto_populate.assert_called_once() + assert isinstance(registered, tuple) + assert len(registered) == 1 + + @pytest.mark.sanity + def test_registered_classes_no_registry(self): + """Test PydanticClassRegistryMixin.registered_classes with no registry.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + # Ensure registry is None + TestBaseModel.registry = None + + with pytest.raises(ValueError) as exc_info: + TestBaseModel.registered_classes() + + assert "must be called after registering classes" in str(exc_info.value) + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test PydanticClassRegistryMixin serialization and deserialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + + # Test serialization with model_dump + dump_data = instance.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["test_type"] == constructor_args["test_type"] + assert dump_data["value"] == constructor_args["value"] + + # Test deserialization via subclass + recreated = sub_class.model_validate(dump_data) + assert isinstance(recreated, sub_class) + assert recreated.test_type == constructor_args["test_type"] + assert recreated.value == constructor_args["value"] + + # Test polymorphic deserialization via base class + recreated_base = base_class.model_validate(dump_data) # type: ignore[assignment] + assert isinstance(recreated_base, sub_class) + assert recreated_base.test_type == constructor_args["test_type"] + assert recreated_base.value == constructor_args["value"] + + @pytest.mark.regression + def test_polymorphic_container_marshalling(self): + """Test PydanticClassRegistryMixin in container models.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @classmethod + def __pydantic_generate_base_schema__(cls, handler): + return handler(cls) + + @TestBaseModel.register("sub_a") + class TestSubModelA(TestBaseModel): + test_type: str = "sub_a" + value_a: str + + @TestBaseModel.register("sub_b") + class TestSubModelB(TestBaseModel): + test_type: str = "sub_b" + value_b: int + + class ContainerModel(BaseModel): + name: str + model: TestBaseModel + models: list[TestBaseModel] + + sub_a = TestSubModelA(value_a="test") + sub_b = TestSubModelB(value_b=123) + + container = ContainerModel(name="container", model=sub_a, models=[sub_a, sub_b]) + + # Verify container construction + assert isinstance(container.model, TestSubModelA) + assert container.model.test_type == "sub_a" + assert container.model.value_a == "test" + assert len(container.models) == 2 + assert isinstance(container.models[0], TestSubModelA) + assert isinstance(container.models[1], TestSubModelB) + + # Test serialization + dump_data = container.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["name"] == "container" + assert dump_data["model"]["test_type"] == "sub_a" + assert dump_data["model"]["value_a"] == "test" + assert len(dump_data["models"]) == 2 + assert dump_data["models"][0]["test_type"] == "sub_a" + assert dump_data["models"][1]["test_type"] == "sub_b" + + # Test deserialization + recreated = ContainerModel.model_validate(dump_data) + assert isinstance(recreated, ContainerModel) + assert recreated.name == "container" + assert isinstance(recreated.model, TestSubModelA) + assert len(recreated.models) == 2 + assert isinstance(recreated.models[0], TestSubModelA) + assert isinstance(recreated.models[1], TestSubModelB) + + @pytest.mark.smoke + def test_register_preserves_pydantic_metadata(self): # noqa: C901 + """Test that registered Pydantic classes retain docs, types, and methods.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "model_type" + model_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + + return TestBaseModel + + @TestBaseModel.register("documented_model") + class DocumentedModel(TestBaseModel): + """This is a documented Pydantic model with methods and type hints.""" + + model_type: str = "documented_model" + value: int = Field(description="An integer value for the model") + + def get_value(self) -> int: + """Get the stored value. + + :return: The stored integer value + """ + return self.value + + def set_value(self, new_value: int) -> None: + """Set a new value. + + :param new_value: The new integer value to set + """ + self.value = new_value + + @classmethod + def from_string(cls, value_str: str) -> DocumentedModel: + """Create instance from string. + + :param value_str: String representation of value + :return: New DocumentedModel instance + """ + return cls(value=int(value_str)) + + @staticmethod + def validate_value(value: int) -> bool: + """Validate that a value is positive. + + :param value: Value to validate + :return: True if positive, False otherwise + """ + return value > 0 + + def model_post_init(self, __context) -> None: + """Post-initialization processing. + + :param __context: Validation context + """ + if self.value < 0: + raise ValueError("Value must be non-negative") + + # Check that the class was registered + assert TestBaseModel.is_registered("documented_model") + registered_class = TestBaseModel.get_registered_object("documented_model") + assert registered_class is DocumentedModel + + # Check that the class retains its documentation + assert registered_class.__doc__ is not None + assert "documented Pydantic model with methods" in registered_class.__doc__ + + # Check that methods retain their documentation + assert registered_class.get_value.__doc__ is not None + assert "Get the stored value" in registered_class.get_value.__doc__ + assert registered_class.set_value.__doc__ is not None + assert "Set a new value" in registered_class.set_value.__doc__ + assert registered_class.from_string.__doc__ is not None + assert "Create instance from string" in registered_class.from_string.__doc__ + assert registered_class.validate_value.__doc__ is not None + assert ( + "Validate that a value is positive" + in registered_class.validate_value.__doc__ + ) + assert registered_class.model_post_init.__doc__ is not None + assert ( + "Post-initialization processing" in registered_class.model_post_init.__doc__ + ) + + # Check that methods are callable and work correctly + instance = DocumentedModel(value=42) + assert isinstance(instance, DocumentedModel) + assert instance.get_value() == 42 + instance.set_value(100) + assert instance.get_value() == 100 + assert instance.model_type == "documented_model" + + # Check class methods work + instance2 = DocumentedModel.from_string("123") + assert instance2.get_value() == 123 + assert instance2.model_type == "documented_model" + + # Check static methods work + assert DocumentedModel.validate_value(10) is True + assert DocumentedModel.validate_value(-5) is False + + # Check that Pydantic functionality is preserved + data_dict = instance.model_dump() + assert data_dict["value"] == 100 + assert data_dict["model_type"] == "documented_model" + + recreated = DocumentedModel.model_validate(data_dict) + assert isinstance(recreated, DocumentedModel) + assert recreated.value == 100 + assert recreated.model_type == "documented_model" + + # Test field validation + with pytest.raises(ValidationError): + DocumentedModel(value="not_an_int") + + # Test post_init validation + with pytest.raises(ValueError, match="Value must be non-negative"): + DocumentedModel(value=-10) + + # Check that Pydantic field metadata is preserved + value_field = DocumentedModel.model_fields["value"] + assert value_field.description == "An integer value for the model" + + # Check that type annotations are preserved (if accessible) + import inspect + + if hasattr(inspect, "get_annotations"): + # Python 3.10+ + try: + annotations = inspect.get_annotations(DocumentedModel.get_value) + return_ann = annotations.get("return") + assert return_ann is int or return_ann == "int" + except (AttributeError, NameError): + # Fallback for older Python or missing annotations + pass + + # Check that the class name is preserved + assert DocumentedModel.__name__ == "DocumentedModel" + assert DocumentedModel.__qualname__.endswith("DocumentedModel") + + # Verify that the class is still properly integrated with the registry system + all_registered = TestBaseModel.registered_classes() + assert DocumentedModel in all_registered + + # Test that the registered class is the same as the original + assert registered_class is DocumentedModel diff --git a/tests/unit/utils/test_registry.py b/tests/unit/utils/test_registry.py new file mode 100644 index 00000000..eed126d3 --- /dev/null +++ b/tests/unit/utils/test_registry.py @@ -0,0 +1,593 @@ +""" +Unit tests for the registry module. +""" + +from __future__ import annotations + +import inspect +from typing import TypeVar +from unittest import mock + +import pytest + +from guidellm.utils import RegistryMixin +from guidellm.utils.registry import RegisterT, RegistryObjT + + +def test_registry_obj_type(): + """Test that RegistryObjT is configured correctly as a TypeVar.""" + assert isinstance(RegistryObjT, type(TypeVar("test"))) + assert RegistryObjT.__name__ == "RegistryObjT" + assert RegistryObjT.__bound__ is None + assert RegistryObjT.__constraints__ == () + + +def test_registered_type(): + """Test that RegisterT is configured correctly as a TypeVar.""" + assert isinstance(RegisterT, type(TypeVar("test"))) + assert RegisterT.__name__ == "RegisterT" + assert RegisterT.__bound__ is None + assert RegisterT.__constraints__ == () + + +class TestRegistryMixin: + """Test suite for RegistryMixin class.""" + + @pytest.fixture( + params=[ + {"registry_auto_discovery": False, "auto_package": None}, + {"registry_auto_discovery": True, "auto_package": "test.package"}, + ], + ids=["manual_registry", "auto_discovery"], + ) + def valid_instances(self, request): + """Fixture providing test data for RegistryMixin subclasses.""" + config = request.param + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = config["registry_auto_discovery"] + if config["auto_package"]: + auto_package = config["auto_package"] + + return TestRegistryClass, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test RegistryMixin inheritance and exposed methods.""" + assert hasattr(RegistryMixin, "registry") + assert hasattr(RegistryMixin, "registry_auto_discovery") + assert hasattr(RegistryMixin, "registry_populated") + assert hasattr(RegistryMixin, "register") + assert hasattr(RegistryMixin, "register_decorator") + assert hasattr(RegistryMixin, "auto_populate_registry") + assert hasattr(RegistryMixin, "registered_objects") + assert hasattr(RegistryMixin, "is_registered") + assert hasattr(RegistryMixin, "get_registered_object") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test RegistryMixin initialization.""" + registry_class, config = valid_instances + + assert registry_class.registry is None + assert ( + registry_class.registry_auto_discovery == config["registry_auto_discovery"] + ) + assert registry_class.registry_populated is False + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test RegistryMixin with missing auto_package when auto_discovery enabled.""" + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = True + + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestRegistryClass.auto_import_package_modules() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, "TestClass"), + ], + ) + def test_register(self, valid_instances, name, expected_key): + """Test register method with various name configurations.""" + registry_class, _ = valid_instances + + @registry_class.register(name) + class TestClass: + pass + + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_invalid(self, valid_instances, invalid_name): + """Test register method with invalid name types.""" + registry_class, _ = valid_instances + + # The register method returns a decorator, so we need to apply it to test + # validation + decorator = registry_class.register(invalid_name) + + class TestClass: + pass + + with pytest.raises( + ValueError, match="name must be a string or an iterable of strings" + ): + decorator(TestClass) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, "TestClass"), + ], + ) + def test_register_decorator(self, valid_instances, name, expected_key): + """Test register_decorator method with various name configurations.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + registry_class.register_decorator(TestClass, name=name) + + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_decorator_invalid(self, valid_instances, invalid_name): + """Test register_decorator with invalid name types.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + with pytest.raises( + ValueError, match="name must be a string or an iterable of strings" + ): + registry_class.register_decorator(TestClass, name=invalid_name) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test auto_populate_registry method with valid configuration.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test.package" + + with mock.patch.object( + TestAutoRegistry, "auto_import_package_modules" + ) as mock_import: + result = TestAutoRegistry.auto_populate_registry() + assert result is True + mock_import.assert_called_once() + assert TestAutoRegistry.registry_populated is True + + # Second call should return False + result = TestAutoRegistry.auto_populate_registry() + assert result is False + mock_import.assert_called_once() + + @pytest.mark.sanity + def test_auto_populate_registry_invalid(self): + """Test auto_populate_registry when auto-discovery is disabled.""" + + class TestDisabledRegistry(RegistryMixin): + registry_auto_discovery = False + + with pytest.raises(ValueError, match="registry_auto_discovery is set to False"): + TestDisabledRegistry.auto_populate_registry() + + @pytest.mark.smoke + def test_registered_objects(self, valid_instances): + """Test registered_objects method with manual registration.""" + registry_class, config = valid_instances + + @registry_class.register("class1") + class TestClass1: + pass + + @registry_class.register("class2") + class TestClass2: + pass + + if config["registry_auto_discovery"]: + with mock.patch.object(registry_class, "auto_import_package_modules"): + objects = registry_class.registered_objects() + else: + objects = registry_class.registered_objects() + + assert isinstance(objects, tuple) + assert len(objects) == 2 + assert TestClass1 in objects + assert TestClass2 in objects + + @pytest.mark.sanity + def test_registered_objects_invalid(self): + """Test registered_objects when no objects are registered.""" + + class TestRegistryClass(RegistryMixin): + pass + + with pytest.raises( + ValueError, match="must be called after registering objects" + ): + TestRegistryClass.registered_objects() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "check_name", "expected"), + [ + ("test_name", "test_name", True), + ("TestName", "testname", True), + ("UPPERCASE", "uppercase", True), + ("test_name", "nonexistent", False), + ], + ) + def test_is_registered(self, valid_instances, register_name, check_name, expected): + """Test is_registered with various name combinations.""" + registry_class, _ = valid_instances + + @registry_class.register(register_name) + class TestClass: + pass + + result = registry_class.is_registered(check_name) + assert result == expected + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "lookup_name"), + [ + ("test_name", "test_name"), + ("TestName", "testname"), + ("UPPERCASE", "uppercase"), + ], + ) + def test_get_registered_object(self, valid_instances, register_name, lookup_name): + """Test get_registered_object with valid names.""" + registry_class, _ = valid_instances + + @registry_class.register(register_name) + class TestClass: + pass + + result = registry_class.get_registered_object(lookup_name) + assert result is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "lookup_name", + ["nonexistent", "wrong_name", "DIFFERENT_CASE"], + ) + def test_get_registered_object_invalid(self, valid_instances, lookup_name): + """Test get_registered_object with invalid names.""" + registry_class, _ = valid_instances + + @registry_class.register("valid_name") + class TestClass: + pass + + result = registry_class.get_registered_object(lookup_name) + assert result is None + + @pytest.mark.regression + def test_multiple_registries_isolation(self): + """Test that different registry classes maintain separate registries.""" + + class Registry1(RegistryMixin): + pass + + class Registry2(RegistryMixin): + pass + + @Registry1.register() + class TestClass1: + pass + + @Registry2.register() + class TestClass2: + pass + + assert Registry1.registry is not None + assert Registry2.registry is not None + assert Registry1.registry != Registry2.registry + assert "TestClass1" in Registry1.registry + assert "TestClass2" in Registry2.registry + assert "TestClass1" not in Registry2.registry + assert "TestClass2" not in Registry1.registry + + @pytest.mark.smoke + def test_auto_discovery_initialization(self): + """Test initialization of auto-discovery enabled registry.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + assert TestAutoRegistry.registry is None + assert TestAutoRegistry.registry_populated is False + assert TestAutoRegistry.auto_package == "test_package.modules" + assert TestAutoRegistry.registry_auto_discovery is True + + @pytest.mark.smoke + def test_auto_discovery_registered_objects(self): + """Test automatic population during registered_objects call.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with mock.patch.object( + TestAutoRegistry, "auto_populate_registry" + ) as mock_populate: + TestAutoRegistry.registry = {"class1": "obj1", "class2": "obj2"} + objects = TestAutoRegistry.registered_objects() + mock_populate.assert_called_once() + assert objects == ("obj1", "obj2") + + @pytest.mark.sanity + def test_register_duplicate_registration(self, valid_instances): + """Test register method with duplicate names.""" + registry_class, _ = valid_instances + + @registry_class.register("duplicate_name") + class TestClass1: + pass + + with pytest.raises(ValueError, match="already registered"): + + @registry_class.register("duplicate_name") + class TestClass2: + pass + + @pytest.mark.sanity + def test_register_decorator_duplicate_registration(self, valid_instances): + """Test register_decorator with duplicate names.""" + registry_class, _ = valid_instances + + class TestClass1: + pass + + class TestClass2: + pass + + registry_class.register_decorator(TestClass1, name="duplicate_name") + with pytest.raises(ValueError, match="already registered"): + registry_class.register_decorator(TestClass2, name="duplicate_name") + + @pytest.mark.sanity + def test_register_decorator_invalid_list_element(self, valid_instances): + """Test register_decorator with invalid elements in name list.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + with pytest.raises( + ValueError, match="name must be a string or a list of strings" + ): + registry_class.register_decorator(TestClass, name=["valid", 123]) + + @pytest.mark.sanity + def test_register_decorator_invalid_object(self, valid_instances): + """Test register_decorator with object lacking __name__ attribute.""" + registry_class, _ = valid_instances + + with pytest.raises(AttributeError): + registry_class.register_decorator("not_a_class") + + @pytest.mark.sanity + def test_register_decorator_empty_string_name(self, valid_instances): + """Test register_decorator with empty string name.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + registry_class.register_decorator(TestClass, name="") + assert "" in registry_class.registry + assert registry_class.registry[""] is TestClass + + @pytest.mark.sanity + def test_register_decorator_none_in_list(self, valid_instances): + """Test register_decorator with None in name list.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + with pytest.raises( + ValueError, match="name must be a string or a list of strings" + ): + registry_class.register_decorator(TestClass, name=["valid", None]) + + @pytest.mark.smoke + def test_is_registered_empty_registry(self, valid_instances): + """Test is_registered with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.is_registered("any_name") + assert result is False + + @pytest.mark.smoke + def test_get_registered_object_empty_registry(self, valid_instances): + """Test get_registered_object with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.get_registered_object("any_name") + assert result is None + + @pytest.mark.regression + def test_auto_registry_integration(self): + """Test complete auto-discovery workflow with mocked imports.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with ( + mock.patch("pkgutil.walk_packages") as mock_walk, + mock.patch("importlib.import_module") as mock_import, + ): + mock_package = mock.MagicMock() + mock_package.__path__ = ["test_package/modules"] + mock_package.__name__ = "test_package.modules" + + def import_module(name: str): + if name == "test_package.modules": + return mock_package + elif name == "test_package.modules.module1": + module = mock.MagicMock() + module.__name__ = "test_package.modules.module1" + + class Module1Class: + pass + + TestAutoRegistry.register_decorator(Module1Class, "Module1Class") + return module + else: + raise ImportError(f"No module named {name}") + + def walk_packages(package_path, package_name): + if package_name == "test_package.modules.": + return [(None, "test_package.modules.module1", False)] + else: + raise ValueError(f"Unknown package: {package_name}") + + mock_walk.side_effect = walk_packages + mock_import.side_effect = import_module + + objects = TestAutoRegistry.registered_objects() + assert len(objects) == 1 + assert TestAutoRegistry.registry_populated is True + assert TestAutoRegistry.registry is not None + assert "Module1Class" in TestAutoRegistry.registry + + @pytest.mark.smoke + def test_register_preserves_class_metadata(self): + """Test that registered classes retain docs, types, and methods.""" + + class TestRegistry(RegistryMixin): + pass + + @TestRegistry.register("documented_class") + class DocumentedClass: + """This is a documented class with methods and type hints.""" + + def __init__(self, value: int) -> None: + """Initialize with a value. + + :param value: An integer value + """ + self.value = value + + def get_value(self) -> int: + """Get the stored value. + + :return: The stored integer value + """ + return self.value + + def set_value(self, new_value: int) -> None: + """Set a new value. + + :param new_value: The new integer value to set + """ + self.value = new_value + + @classmethod + def from_string(cls, value_str: str) -> DocumentedClass: + """Create instance from string. + + :param value_str: String representation of value + :return: New DocumentedClass instance + """ + return cls(int(value_str)) + + @staticmethod + def validate_value(value: int) -> bool: + """Validate that a value is positive. + + :param value: Value to validate + :return: True if positive, False otherwise + """ + return value > 0 + + # Check that the class was registered + assert TestRegistry.is_registered("documented_class") + registered_class = TestRegistry.get_registered_object("documented_class") + assert registered_class is DocumentedClass + + # Check that the class retains its documentation + assert registered_class.__doc__ is not None + assert "documented class with methods" in registered_class.__doc__ + assert registered_class.__init__.__doc__ is not None + assert "Initialize with a value" in registered_class.__init__.__doc__ + assert registered_class.get_value.__doc__ is not None + assert "Get the stored value" in registered_class.get_value.__doc__ + assert registered_class.set_value.__doc__ is not None + assert "Set a new value" in registered_class.set_value.__doc__ + assert registered_class.from_string.__doc__ is not None + assert "Create instance from string" in registered_class.from_string.__doc__ + assert registered_class.validate_value.__doc__ is not None + assert ( + "Validate that a value is positive" + in registered_class.validate_value.__doc__ + ) + + # Check that methods are callable and work correctly + instance = registered_class(42) + assert instance.get_value() == 42 + instance.set_value(100) + assert instance.get_value() == 100 + instance2 = registered_class.from_string("123") + assert instance2.get_value() == 123 + assert registered_class.validate_value(10) is True + assert registered_class.validate_value(-5) is False + + # Check that type annotations are preserved (if accessible) + if hasattr(inspect, "get_annotations"): + # Python 3.10+ + try: + annotations = inspect.get_annotations(registered_class.__init__) + assert "value" in annotations + assert annotations["value"] is int + return_ann = annotations.get("return") + assert return_ann is None or return_ann is type(None) + except (AttributeError, NameError): + # Fallback for older Python or missing annotations + pass + + # Check that the class name is preserved + assert registered_class.__name__ == "DocumentedClass" + assert registered_class.__qualname__.endswith("DocumentedClass") diff --git a/tests/unit/utils/test_singleton.py b/tests/unit/utils/test_singleton.py new file mode 100644 index 00000000..ee01ead1 --- /dev/null +++ b/tests/unit/utils/test_singleton.py @@ -0,0 +1,371 @@ +from __future__ import annotations + +import threading +import time + +import pytest + +from guidellm.utils.singleton import SingletonMixin, ThreadSafeSingletonMixin + + +class TestSingletonMixin: + """Test suite for SingletonMixin class.""" + + @pytest.fixture( + params=[ + {"init_value": "test_value"}, + {"init_value": "another_value"}, + ], + ids=["basic_singleton", "different_value"], + ) + def valid_instances(self, request): + """Provide parameterized test configurations for singleton testing.""" + config = request.param + + class TestSingleton(SingletonMixin): + def __init__(self): + # Check if we need to initialize before calling super().__init__() + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = config["init_value"] + + return TestSingleton, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test SingletonMixin inheritance and exposed attributes.""" + assert hasattr(SingletonMixin, "__new__") + assert hasattr(SingletonMixin, "__init__") + assert hasattr(SingletonMixin, "initialized") + assert isinstance(SingletonMixin.initialized, property) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test SingletonMixin initialization.""" + singleton_class, config = valid_instances + + # Create first instance + instance1 = singleton_class() + + assert isinstance(instance1, singleton_class) + assert instance1.initialized is True + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + + # Check that the class has the singleton instance stored + instance_attr = f"_singleton_instance_{singleton_class.__name__}" + assert hasattr(singleton_class, instance_attr) + assert getattr(singleton_class, instance_attr) is instance1 + + @pytest.mark.smoke + def test_singleton_behavior(self, valid_instances): + """Test that multiple instantiations return the same instance.""" + singleton_class, config = valid_instances + + # Create multiple instances + instance1 = singleton_class() + instance2 = singleton_class() + instance3 = singleton_class() + + # All should be the same instance + assert instance1 is instance2 + assert instance2 is instance3 + assert instance1 is instance3 + + # Value should remain from first initialization + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + assert instance2.value == config["init_value"] + assert instance3.value == config["init_value"] + + @pytest.mark.sanity + def test_initialization_called_once(self, valid_instances): + """Test that __init__ is only called once despite multiple instantiations.""" + singleton_class, config = valid_instances + + class TestSingletonWithCounter(SingletonMixin): + init_count = 0 + + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + TestSingletonWithCounter.init_count += 1 + self.value = config["init_value"] + + # Create multiple instances + instance1 = TestSingletonWithCounter() + instance2 = TestSingletonWithCounter() + + assert TestSingletonWithCounter.init_count == 1 + assert instance1 is instance2 + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + + @pytest.mark.regression + def test_multiple_singleton_classes_isolation(self): + """Test that different singleton classes maintain separate instances.""" + + class Singleton1(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value1" + + class Singleton2(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value2" + + instance1a = Singleton1() + instance2a = Singleton2() + instance1b = Singleton1() + instance2b = Singleton2() + + # Each class has its own singleton instance + assert instance1a is instance1b + assert instance2a is instance2b + assert instance1a is not instance2a + + # Each maintains its own value + assert hasattr(instance1a, "value") + assert hasattr(instance2a, "value") + assert instance1a.value == "value1" + assert instance2a.value == "value2" + + @pytest.mark.regression + def test_inheritance_singleton_sharing(self): + """Test that inherited singleton classes share the same singleton_instance.""" + + class BaseSingleton(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "base_value" + + class ChildSingleton(BaseSingleton): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.extra = "extra_value" + + # Child classes now have separate singleton instances + base_instance = BaseSingleton() + child_instance = ChildSingleton() + + # They should be different instances now (fixed inheritance behavior) + assert base_instance is not child_instance + assert hasattr(base_instance, "value") + assert base_instance.value == "base_value" + assert hasattr(child_instance, "value") + assert child_instance.value == "base_value" + assert hasattr(child_instance, "extra") + assert child_instance.extra == "extra_value" + + @pytest.mark.sanity + def test_without_super_init_call(self): + """Test singleton behavior when subclass doesn't call super().__init__().""" + + class BadSingleton(SingletonMixin): + def __init__(self): + # Not calling super().__init__() + self.value = "bad_value" + + instance1 = BadSingleton() + instance2 = BadSingleton() + + assert instance1 is instance2 + assert hasattr(instance1, "initialized") + assert instance1.initialized is False + + +class TestThreadSafeSingletonMixin: + """Test suite for ThreadSafeSingletonMixin class.""" + + @pytest.fixture( + params=[ + {"init_value": "thread_safe_value"}, + {"init_value": "concurrent_value"}, + ], + ids=["basic_thread_safe", "concurrent_test"], + ) + def valid_instances(self, request): + """Fixture providing test data for ThreadSafeSingletonMixin subclasses.""" + config = request.param + + class TestThreadSafeSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = config["init_value"] + + return TestThreadSafeSingleton, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ThreadSafeSingletonMixin inheritance and exposed attributes.""" + assert issubclass(ThreadSafeSingletonMixin, SingletonMixin) + assert hasattr(ThreadSafeSingletonMixin, "get_singleton_lock") + assert hasattr(ThreadSafeSingletonMixin, "__new__") + assert hasattr(ThreadSafeSingletonMixin, "__init__") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test ThreadSafeSingletonMixin initialization.""" + singleton_class, config = valid_instances + + instance = singleton_class() + + assert isinstance(instance, singleton_class) + assert instance.initialized is True + assert hasattr(instance, "value") + assert instance.value == config["init_value"] + assert hasattr(instance, "thread_lock") + lock_type = type(threading.Lock()) + assert isinstance(instance.thread_lock, lock_type) + + @pytest.mark.smoke + def test_singleton_behavior(self, valid_instances): + """Test multiple instantiations return same instance with thread safety.""" + singleton_class, config = valid_instances + + instance1 = singleton_class() + instance2 = singleton_class() + + assert instance1 is instance2 + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + assert hasattr(instance1, "thread_lock") + + @pytest.mark.regression + def test_thread_safety_concurrent_creation(self, valid_instances): + """Test thread safety during concurrent instance creation.""" + singleton_class, config = valid_instances + + instances = [] + exceptions = [] + creation_count = 0 + lock = threading.Lock() + + class ThreadSafeTestSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + nonlocal creation_count + with lock: + creation_count += 1 + + time.sleep(0.01) + self.value = config["init_value"] + + def create_instance(): + try: + instance = ThreadSafeTestSingleton() + instances.append(instance) + except (TypeError, ValueError, AttributeError) as exc: + exceptions.append(exc) + + threads = [] + for _ in range(10): + thread = threading.Thread(target=create_instance) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + assert len(exceptions) == 0, f"Exceptions occurred: {exceptions}" + + assert len(instances) == 10 + for instance in instances: + assert instance is instances[0] + + assert creation_count == 1 + assert all(instance.value == config["init_value"] for instance in instances) + + @pytest.mark.sanity + def test_thread_lock_creation(self, valid_instances): + """Test that thread_lock is created during initialization.""" + singleton_class, config = valid_instances + + instance1 = singleton_class() + instance2 = singleton_class() + + assert hasattr(instance1, "thread_lock") + lock_type = type(threading.Lock()) + assert isinstance(instance1.thread_lock, lock_type) + assert instance1.thread_lock is instance2.thread_lock + + @pytest.mark.regression + def test_multiple_thread_safe_classes_isolation(self): + """Test thread-safe singleton classes behavior with separate locks.""" + + class ThreadSafeSingleton1(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value1" + + class ThreadSafeSingleton2(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value2" + + instance1 = ThreadSafeSingleton1() + instance2 = ThreadSafeSingleton2() + + lock1 = ThreadSafeSingleton1.get_singleton_lock() + lock2 = ThreadSafeSingleton2.get_singleton_lock() + + assert lock1 is not None + assert lock2 is not None + assert lock1 is not lock2 + + assert instance1 is not instance2 + assert hasattr(instance1, "value") + assert hasattr(instance2, "value") + assert instance1.value == "value1" + assert instance2.value == "value2" + + @pytest.mark.sanity + def test_inheritance_with_thread_safety(self): + """Test inheritance behavior with thread-safe singletons.""" + + class BaseThreadSafeSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "base_value" + + class ChildThreadSafeSingleton(BaseThreadSafeSingleton): + def __init__(self): + super().__init__() + + base_instance = BaseThreadSafeSingleton() + child_instance = ChildThreadSafeSingleton() + + base_lock = BaseThreadSafeSingleton.get_singleton_lock() + child_lock = ChildThreadSafeSingleton.get_singleton_lock() + + assert base_lock is not None + assert child_lock is not None + assert base_lock is not child_lock + + assert base_instance is not child_instance + assert hasattr(base_instance, "value") + assert base_instance.value == "base_value" + assert hasattr(base_instance, "thread_lock") diff --git a/tests/unit/utils/test_synchronous.py b/tests/unit/utils/test_synchronous.py new file mode 100644 index 00000000..1a9ea2c9 --- /dev/null +++ b/tests/unit/utils/test_synchronous.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import asyncio +import multiprocessing +import threading +from functools import wraps +from multiprocessing.synchronize import Barrier as ProcessingBarrier +from multiprocessing.synchronize import Event as ProcessingEvent +from typing import Union + +import pytest + +from guidellm.utils.synchronous import ( + SyncObjectTypesAlias, + wait_for_sync_barrier, + wait_for_sync_event, + wait_for_sync_objects, +) + + +def async_timeout(delay: float): + """Decorator to add timeout to async functions.""" + + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +def test_sync_object_types_alias(): + """Test that SyncObjectTypesAlias is defined correctly as a type alias.""" + assert hasattr(SyncObjectTypesAlias, "__origin__") + if hasattr(SyncObjectTypesAlias, "__args__"): + actual_type = SyncObjectTypesAlias.__args__[0] + assert hasattr(actual_type, "__origin__") + assert actual_type.__origin__ is Union + union_args = actual_type.__args__ + assert threading.Event in union_args + assert ProcessingEvent in union_args + assert threading.Barrier in union_args + assert ProcessingBarrier in union_args + + +class TestWaitForSyncEvent: + """Test suite for wait_for_sync_event function.""" + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "event_type", + [threading.Event, multiprocessing.Event], + ids=["threading", "multiprocessing"], + ) + @async_timeout(2.0) + async def test_invocation(self, event_type): + """Test wait_for_sync_event with valid events that get set.""" + event: threading.Event | ProcessingEvent = event_type() + + async def set_event(): + await asyncio.sleep(0.01) + event.set() + + asyncio.create_task(set_event()) + await wait_for_sync_event(event, poll_interval=0.001) + assert event.is_set() + + @pytest.mark.sanity + @pytest.mark.asyncio + @pytest.mark.parametrize( + "event_type", + [threading.Event, multiprocessing.Event], + ids=["threading", "multiprocessing"], + ) + @async_timeout(2.0) + async def test_cancellation_stops_waiting(self, event_type): + """Test that cancelling the task stops waiting for the event.""" + event: threading.Event | ProcessingEvent = event_type() + + async def waiter(): + await wait_for_sync_event(event, poll_interval=0.001) + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.02) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + +class TestWaitForSyncBarrier: + """Test suite for wait_for_sync_barrier function.""" + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "barrier_type", + [threading.Barrier, multiprocessing.Barrier], + ids=["threading", "multiprocessing"], + ) + @async_timeout(5.0) + async def test_invocation(self, barrier_type): + """Test wait_for_sync_barrier with barrier that gets reached.""" + barrier: threading.Barrier | ProcessingBarrier = barrier_type(2) + + async def reach_barrier(): + await asyncio.sleep(0.01) + await asyncio.to_thread(barrier.wait) + + task = asyncio.create_task(reach_barrier()) + await wait_for_sync_barrier(barrier, poll_interval=0.01) + await task + + @pytest.mark.sanity + @pytest.mark.asyncio + @pytest.mark.parametrize( + "barrier_type", + [threading.Barrier, multiprocessing.Barrier], + ids=["threading", "multiprocessing"], + ) + @async_timeout(2.0) + async def test_cancellation_stops_waiting(self, barrier_type): + """Test that cancelling the task stops waiting for the barrier.""" + barrier: threading.Barrier | ProcessingBarrier = barrier_type(2) + + async def waiter(): + await wait_for_sync_barrier(barrier, 0.01) + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.1) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + +class TestWaitForSyncObjects: + """Test suite for wait_for_sync_objects function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("objects_types", "expected_result"), + [ + (threading.Event, 0), + (multiprocessing.Event, 0), + (threading.Barrier, 0), + (multiprocessing.Barrier, 0), + ([threading.Event, multiprocessing.Barrier], 1), + ([multiprocessing.Event, threading.Barrier], 0), + ( + [ + threading.Event, + multiprocessing.Event, + threading.Barrier, + multiprocessing.Barrier, + ], + 2, + ), + ( + { + "multiprocessing.Event": multiprocessing.Event, + "threading.Barrier": threading.Barrier, + }, + "threading.Barrier", + ), + ( + { + "threading.Event": threading.Event, + "multiprocessing.Barrier": multiprocessing.Barrier, + }, + "threading.Event", + ), + ( + { + "multiprocessing.Event": multiprocessing.Event, + "threading.Event": threading.Event, + "multiprocessing.Barrier": multiprocessing.Barrier, + "threading.Barrier": threading.Barrier, + }, + "threading.Event", + ), + ], + ids=[ + "threading_event", + "multiprocessing_event", + "threading_barrier", + "multiprocessing_barrier", + "mixed_list_event_barrier_1", + "mixed_list_event_barrier_2", + "mixed_list_all", + "mixed_dict_event_barrier_1", + "mixed_dict_event_barrier_2", + "mixed_dict_all", + ], + ) + @pytest.mark.asyncio + @async_timeout(2.0) + async def test_invocation(self, objects_types, expected_result): + """Test wait_for_sync_objects with various object configurations.""" + if isinstance(objects_types, list): + objects = [ + obj() + if obj not in (threading.Barrier, multiprocessing.Barrier) + else obj(2) + for obj in objects_types + ] + elif isinstance(objects_types, dict): + objects = { + key: ( + obj() + if obj not in (threading.Barrier, multiprocessing.Barrier) + else obj(2) + ) + for key, obj in objects_types.items() + } + else: + objects = [ + objects_types() + if objects_types not in (threading.Barrier, multiprocessing.Barrier) + else objects_types(2) + ] + + async def set_target(): + await asyncio.sleep(0.01) + obj = objects[expected_result] + if isinstance(obj, (threading.Event, ProcessingEvent)): + obj.set() + else: + await asyncio.to_thread(obj.wait) + + task = asyncio.create_task(set_target()) + result = await wait_for_sync_objects(objects, poll_interval=0.001) + await task + + assert result == expected_result diff --git a/tests/unit/utils/test_text.py b/tests/unit/utils/test_text.py new file mode 100644 index 00000000..50f18ce3 --- /dev/null +++ b/tests/unit/utils/test_text.py @@ -0,0 +1,531 @@ +from __future__ import annotations + +import gzip +import tempfile +from pathlib import Path +from unittest.mock import Mock, patch + +import httpx +import pytest + +from guidellm.utils.text import ( + MAX_PATH_LENGTH, + EndlessTextCreator, + clean_text, + filter_text, + format_value_display, + is_puncutation, + load_text, + split_text, + split_text_list_by_length, +) + + +def test_max_path_length(): + """Test that MAX_PATH_LENGTH is correctly defined.""" + assert isinstance(MAX_PATH_LENGTH, int) + assert MAX_PATH_LENGTH == 4096 + + +class TestFormatValueDisplay: + """Test suite for format_value_display.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ( + "value", + "label", + "units", + "total_characters", + "digits_places", + "decimal_places", + "expected", + ), + [ + (42.0, "test", "", None, None, None, "42 [info]test[/info]"), + (42.5, "test", "ms", None, None, 1, "42.5ms [info]test[/info]"), + (42.123, "test", "", None, 5, 2, " 42.12 [info]test[/info]"), + ( + 42.0, + "test", + "ms", + 30, + None, + 0, + " 42ms [info]test[/info]", + ), + ], + ) + def test_invocation( + self, + value, + label, + units, + total_characters, + digits_places, + decimal_places, + expected, + ): + """Test format_value_display with various parameters.""" + result = format_value_display( + value=value, + label=label, + units=units, + total_characters=total_characters, + digits_places=digits_places, + decimal_places=decimal_places, + ) + assert label in result + assert units in result + value_check = ( + str(int(value)) + if decimal_places == 0 + else ( + f"{value:.{decimal_places}f}" + if decimal_places is not None + else str(value) + ) + ) + assert value_check in result or str(value) in result + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("value", "label"), + [ + (None, "test"), + (42.0, None), + ("not_number", "test"), + ], + ) + def test_invocation_with_none_values(self, value, label): + """Test format_value_display with None/invalid inputs still works.""" + result = format_value_display(value, label) + assert isinstance(result, str) + if label is not None: + assert str(label) in result + if value is not None: + assert str(value) in result + + +class TestSplitTextListByLength: + """Test suite for split_text_list_by_length.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ( + "text_list", + "max_characters", + "pad_horizontal", + "pad_vertical", + "expected_structure", + ), + [ + ( + ["hello world", "test"], + 5, + False, + False, + [["hello", "world"], ["test"]], + ), + ( + ["short", "longer text"], + [5, 10], + True, + True, + [[" short"], ["longer", "text"]], + ), + ( + ["a", "b", "c"], + 10, + True, + True, + [[" a"], [" b"], [" c"]], + ), + ], + ) + def test_invocation( + self, + text_list, + max_characters, + pad_horizontal, + pad_vertical, + expected_structure, + ): + """Test split_text_list_by_length with various parameters.""" + result = split_text_list_by_length( + text_list, max_characters, pad_horizontal, pad_vertical + ) + assert len(result) == len(text_list) + if pad_vertical: + max_lines = max(len(lines) for lines in result) + assert all(len(lines) == max_lines for lines in result) + + @pytest.mark.sanity + def test_invalid_max_characters_length(self): + """Test split_text_list_by_length with mismatched max_characters length.""" + error_msg = "max_characters must be a list of the same length" + with pytest.raises(ValueError, match=error_msg): + split_text_list_by_length(["a", "b"], [5, 10, 15]) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("text_list", "max_characters"), + [ + (None, 5), + (["test"], None), + (["test"], []), + ], + ) + def test_invalid_invocation(self, text_list, max_characters): + """Test split_text_list_by_length with invalid inputs.""" + with pytest.raises((TypeError, ValueError)): + split_text_list_by_length(text_list, max_characters) + + +class TestFilterText: + """Test suite for filter_text.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "filter_start", "filter_end", "expected"), + [ + ("hello world test", "world", None, "world test"), + ("hello world test", None, "world", "hello "), + ("hello world test", "hello", "test", "hello world "), + ("hello world test", 6, 11, "world test"), + ("hello world test", 0, 5, "hello"), + ("hello world test", None, None, "hello world test"), + ], + ) + def test_invocation(self, text, filter_start, filter_end, expected): + """Test filter_text with various start and end markers.""" + result = filter_text(text, filter_start, filter_end) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("text", "filter_start", "filter_end"), + [ + ("hello", "notfound", None), + ("hello", None, "notfound"), + ("hello", "invalid_type", None), + ("hello", None, "invalid_type"), + ], + ) + def test_invalid_invocation(self, text, filter_start, filter_end): + """Test filter_text with invalid markers.""" + with pytest.raises((ValueError, TypeError)): + filter_text(text, filter_start, filter_end) + + +class TestCleanText: + """Test suite for clean_text.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "expected"), + [ + ("hello world", "hello world"), + (" hello\n\nworld ", "hello world"), + ("hello\tworld\r\ntest", "hello world test"), + ("", ""), + (" ", ""), + ], + ) + def test_invocation(self, text, expected): + """Test clean_text with various whitespace scenarios.""" + result = clean_text(text) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + "text", + [ + None, + 123, + ], + ) + def test_invalid_invocation(self, text): + """Test clean_text with invalid inputs.""" + with pytest.raises((TypeError, AttributeError)): + clean_text(text) + + +class TestSplitText: + """Test suite for split_text.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "split_punctuation", "expected"), + [ + ("hello world", False, ["hello", "world"]), + ("hello, world!", True, ["hello", ",", "world", "!"]), + ("test.example", False, ["test.example"]), + ("test.example", True, ["test", ".", "example"]), + ("", False, []), + ], + ) + def test_invocation(self, text, split_punctuation, expected): + """Test split_text with various punctuation options.""" + result = split_text(text, split_punctuation) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + "text", + [ + None, + 123, + ], + ) + def test_invalid_invocation(self, text): + """Test split_text with invalid inputs.""" + with pytest.raises((TypeError, AttributeError)): + split_text(text) + + +class TestLoadText: + """Test suite for load_text.""" + + @pytest.mark.smoke + def test_empty_data(self): + """Test load_text with empty data.""" + result = load_text("") + assert result == "" + + @pytest.mark.smoke + def test_raw_text(self): + """Test load_text with raw text that's not a file.""" + long_text = "a" * (MAX_PATH_LENGTH + 1) + result = load_text(long_text) + assert result == long_text + + @pytest.mark.smoke + def test_local_file(self): + """Test load_text with local file.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as tmp: + test_content = "test file content" + tmp.write(test_content) + tmp.flush() + + result = load_text(tmp.name) + assert result == test_content + + Path(tmp.name).unlink() + + @pytest.mark.smoke + def test_gzipped_file(self): + """Test load_text with gzipped file.""" + with tempfile.NamedTemporaryFile(delete=False, suffix=".gz") as tmp: + test_content = "test gzipped content" + with gzip.open(tmp.name, "wt") as gzf: + gzf.write(test_content) + + result = load_text(tmp.name) + assert result == test_content + + Path(tmp.name).unlink() + + @pytest.mark.smoke + @patch("httpx.Client") + def test_url_loading(self, mock_client): + """Test load_text with HTTP URL.""" + mock_response = Mock() + mock_response.text = "url content" + mock_client.return_value.__enter__.return_value.get.return_value = mock_response + + result = load_text("http://example.com/test.txt") + assert result == "url content" + + @pytest.mark.smoke + @patch("guidellm.utils.text.files") + @patch("guidellm.utils.text.as_file") + def test_package_data_loading(self, mock_as_file, mock_files): + """Test load_text with package data.""" + mock_resource = Mock() + mock_files.return_value.joinpath.return_value = mock_resource + + mock_file = Mock() + mock_file.read.return_value = "package data content" + mock_as_file.return_value.__enter__.return_value = mock_file + + with patch("gzip.open") as mock_gzip: + mock_gzip.return_value.__enter__.return_value = mock_file + result = load_text("data:test.txt") + assert result == "package data content" + + @pytest.mark.sanity + def test_nonexistent_file(self): + """Test load_text with nonexistent file returns the path as raw text.""" + result = load_text("/nonexistent/path/file.txt") + assert result == "/nonexistent/path/file.txt" + + @pytest.mark.sanity + @patch("httpx.Client") + def test_url_error(self, mock_client): + """Test load_text with HTTP error.""" + mock_client.return_value.__enter__.return_value.get.side_effect = ( + httpx.HTTPStatusError("HTTP error", request=None, response=None) + ) + + with pytest.raises(httpx.HTTPStatusError): + load_text("http://example.com/error.txt") + + +class TestIsPuncutation: + """Test suite for is_puncutation.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "expected"), + [ + (".", True), + (",", True), + ("!", True), + ("?", True), + (";", True), + ("a", False), + ("1", False), + (" ", False), + ("ab", False), + ("", False), + ], + ) + def test_invocation(self, text, expected): + """Test is_puncutation with various characters.""" + result = is_puncutation(text) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + "text", + [ + None, + 123, + ], + ) + def test_invalid_invocation(self, text): + """Test is_puncutation with invalid inputs.""" + with pytest.raises((TypeError, AttributeError)): + is_puncutation(text) + + +class TestEndlessTextCreator: + """Test suite for EndlessTextCreator.""" + + @pytest.fixture( + params=[ + { + "data": "hello world test", + "filter_start": None, + "filter_end": None, + }, + { + "data": "hello world test", + "filter_start": "world", + "filter_end": None, + }, + {"data": "one two three four", "filter_start": 0, "filter_end": 9}, + ], + ids=["no_filter", "string_filter", "index_filter"], + ) + def valid_instances(self, request): + """Fixture providing test data for EndlessTextCreator.""" + constructor_args = request.param + instance = EndlessTextCreator(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test EndlessTextCreator signatures and methods.""" + assert hasattr(EndlessTextCreator, "__init__") + assert hasattr(EndlessTextCreator, "create_text") + instance = EndlessTextCreator("test") + assert hasattr(instance, "data") + assert hasattr(instance, "text") + assert hasattr(instance, "filtered_text") + assert hasattr(instance, "words") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test EndlessTextCreator initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, EndlessTextCreator) + assert instance.data == constructor_args["data"] + assert isinstance(instance.text, str) + assert isinstance(instance.filtered_text, str) + assert isinstance(instance.words, list) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("data", "filter_start", "filter_end"), + [ + ("test", "notfound", None), + ], + ) + def test_invalid_initialization_values(self, data, filter_start, filter_end): + """Test EndlessTextCreator with invalid initialization values.""" + with pytest.raises((TypeError, ValueError)): + EndlessTextCreator(data, filter_start, filter_end) + + @pytest.mark.smoke + def test_initialization_with_none(self): + """Test EndlessTextCreator handles None data gracefully.""" + instance = EndlessTextCreator(None) + assert isinstance(instance, EndlessTextCreator) + assert instance.data is None + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("start", "length", "expected_length"), + [ + (0, 5, 5), + (2, 3, 3), + (0, 0, 0), + ], + ) + def test_create_text(self, valid_instances, start, length, expected_length): + """Test EndlessTextCreator.create_text.""" + instance, constructor_args = valid_instances + result = instance.create_text(start, length) + assert isinstance(result, str) + if length > 0 and instance.words: + assert len(result) > 0 + + @pytest.mark.smoke + def test_create_text_cycling(self): + """Test EndlessTextCreator.create_text cycling behavior.""" + instance = EndlessTextCreator("one two three") + result1 = instance.create_text(0, 3) + result2 = instance.create_text(3, 3) + assert isinstance(result1, str) + assert isinstance(result2, str) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("start", "length"), + [ + ("invalid", 5), + (0, "invalid"), + ], + ) + def test_create_text_invalid(self, valid_instances, start, length): + """Test EndlessTextCreator.create_text with invalid inputs.""" + instance, constructor_args = valid_instances + with pytest.raises((TypeError, ValueError)): + instance.create_text(start, length) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("start", "length", "min_length"), + [ + (-1, 5, 0), + (0, -1, 0), + ], + ) + def test_create_text_edge_cases(self, valid_instances, start, length, min_length): + """Test EndlessTextCreator.create_text with edge cases.""" + instance, constructor_args = valid_instances + result = instance.create_text(start, length) + assert isinstance(result, str) + assert len(result) >= min_length diff --git a/tests/unit/utils/text.py b/tests/unit/utils/text.py deleted file mode 100644 index ae0fa52f..00000000 --- a/tests/unit/utils/text.py +++ /dev/null @@ -1,13 +0,0 @@ -import pytest - -from guidellm.utils.text import camelize_str - - -@pytest.mark.smoke -def test_camelize_str_camelizes_string(): - assert camelize_str("no_longer_snake_case") == "noLongerSnakeCase" - - -@pytest.mark.smoke -def test_camelize_str_leaves_non_snake_case_text_untouched(): - assert camelize_str("notsnakecase") == "notsnakecase" From d15cf17f07ee438646bdd2736245fef540918ede Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 19 Sep 2025 11:54:01 +0000 Subject: [PATCH 2/3] Remove old pydantic file that is now replaced Signed-off-by: Mark Kurtz --- src/guidellm/objects/__init__.py | 1 - src/guidellm/objects/pydantic.py | 89 -------------------------------- 2 files changed, 90 deletions(-) delete mode 100644 src/guidellm/objects/pydantic.py diff --git a/src/guidellm/objects/__init__.py b/src/guidellm/objects/__init__.py index 89e3c9b9..f97f1ef3 100644 --- a/src/guidellm/objects/__init__.py +++ b/src/guidellm/objects/__init__.py @@ -1,4 +1,3 @@ -from .pydantic import StandardBaseModel, StatusBreakdown from .statistics import ( DistributionSummary, Percentiles, diff --git a/src/guidellm/objects/pydantic.py b/src/guidellm/objects/pydantic.py deleted file mode 100644 index fcededcf..00000000 --- a/src/guidellm/objects/pydantic.py +++ /dev/null @@ -1,89 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Generic, Optional, TypeVar - -import yaml -from loguru import logger -from pydantic import BaseModel, ConfigDict, Field - -__all__ = ["StandardBaseModel", "StatusBreakdown"] - -T = TypeVar("T", bound="StandardBaseModel") - - -class StandardBaseModel(BaseModel): - """ - A base class for Pydantic models throughout GuideLLM enabling standard - configuration and logging. - """ - - model_config = ConfigDict( - extra="ignore", - use_enum_values=True, - validate_assignment=True, - from_attributes=True, - ) - - def __init__(self, /, **data: Any) -> None: - super().__init__(**data) - logger.debug( - "Initialized new instance of {} with data: {}", - self.__class__.__name__, - data, - ) - - @classmethod - def get_default(cls: type[T], field: str) -> Any: - """Get default values for model fields""" - return cls.model_fields[field].default - - @classmethod - def from_file(cls: type[T], filename: Path, overrides: Optional[dict] = None) -> T: - """ - Attempt to create a new instance of the model using - data loaded from json or yaml file. - """ - try: - with filename.open() as f: - if str(filename).endswith(".json"): - data = json.load(f) - else: # Assume everything else is yaml - data = yaml.safe_load(f) - except (json.JSONDecodeError, yaml.YAMLError) as e: - logger.error(f"Failed to parse {filename} as type {cls.__name__}") - raise ValueError(f"Error when parsing file: {filename}") from e - - data.update(overrides) - return cls.model_validate(data) - - -SuccessfulT = TypeVar("SuccessfulT") -ErroredT = TypeVar("ErroredT") -IncompleteT = TypeVar("IncompleteT") -TotalT = TypeVar("TotalT") - - -class StatusBreakdown(BaseModel, Generic[SuccessfulT, ErroredT, IncompleteT, TotalT]): - """ - A base class for Pydantic models that are separated by statuses including - successful, incomplete, and errored. It additionally enables the inclusion - of total, which is intended as the combination of all statuses. - Total may or may not be used depending on if it duplicates information. - """ - - successful: SuccessfulT = Field( - description="The results with a successful status.", - default=None, # type: ignore[assignment] - ) - errored: ErroredT = Field( - description="The results with an errored status.", - default=None, # type: ignore[assignment] - ) - incomplete: IncompleteT = Field( - description="The results with an incomplete status.", - default=None, # type: ignore[assignment] - ) - total: TotalT = Field( - description="The combination of all statuses.", - default=None, # type: ignore[assignment] - ) From 5b83c2d371713bd821d6f8cda4ec2bb76a8b400c Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 19 Sep 2025 12:05:28 +0000 Subject: [PATCH 3/3] fixes from copilot review Signed-off-by: Mark Kurtz --- src/guidellm/benchmark/progress.py | 2 +- src/guidellm/utils/__init__.py | 4 ++-- src/guidellm/utils/synchronous.py | 2 +- src/guidellm/utils/text.py | 8 ++++---- tests/unit/utils/test_text.py | 14 +++++++------- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/guidellm/benchmark/progress.py b/src/guidellm/benchmark/progress.py index d6f437e1..1232107b 100644 --- a/src/guidellm/benchmark/progress.py +++ b/src/guidellm/benchmark/progress.py @@ -253,7 +253,7 @@ def format_progress_display( decimal_places: Optional[int] = None, ) -> str: if decimal_places is None and digits_places is None: - formatted_number = f"{value}:.0f" + formatted_number = f"{value:.0f}" elif digits_places is None: formatted_number = f"{value:.{decimal_places}f}" elif decimal_places is None: diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index 83a276b2..20daeea4 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -58,7 +58,7 @@ clean_text, filter_text, format_value_display, - is_puncutation, + is_punctuation, load_text, split_text, split_text_list_by_length, @@ -109,7 +109,7 @@ "filter_text", "format_value_display", "get_literal_vals", - "is_puncutation", + "is_punctuation", "load_text", "safe_add", "safe_divide", diff --git a/src/guidellm/utils/synchronous.py b/src/guidellm/utils/synchronous.py index 3bec0247..14f3d908 100644 --- a/src/guidellm/utils/synchronous.py +++ b/src/guidellm/utils/synchronous.py @@ -11,7 +11,7 @@ import asyncio import contextlib -from datetime import time +import time from multiprocessing.synchronize import Barrier as ProcessingBarrier from multiprocessing.synchronize import Event as ProcessingEvent from threading import Barrier as ThreadingBarrier diff --git a/src/guidellm/utils/text.py b/src/guidellm/utils/text.py index 519b46c3..8385ec7b 100644 --- a/src/guidellm/utils/text.py +++ b/src/guidellm/utils/text.py @@ -31,7 +31,7 @@ "clean_text", "filter_text", "format_value_display", - "is_puncutation", + "is_punctuation", "load_text", "split_text", "split_text_list_by_length", @@ -64,7 +64,7 @@ def format_value_display( :return: Formatted string with value, units, and colored label """ if decimal_places is None and digits_places is None: - formatted_number = f"{value}:.0f" + formatted_number = f"{value:.0f}" elif digits_places is None: formatted_number = f"{value:.{decimal_places}f}" elif decimal_places is None: @@ -268,7 +268,7 @@ def load_text(data: str | Path, encoding: str | None = None) -> str: return data.read_text(encoding=encoding) -def is_puncutation(text: str) -> bool: +def is_punctuation(text: str) -> bool: """ Check if a single character is a punctuation mark. @@ -332,7 +332,7 @@ def create_text(self, start: int, length: int) -> str: index = (start + counter) % len(self.words) add_word = self.words[index] - if counter != 0 and not is_puncutation(add_word): + if counter != 0 and not is_punctuation(add_word): text += " " text += add_word diff --git a/tests/unit/utils/test_text.py b/tests/unit/utils/test_text.py index 50f18ce3..3774ca1f 100644 --- a/tests/unit/utils/test_text.py +++ b/tests/unit/utils/test_text.py @@ -14,7 +14,7 @@ clean_text, filter_text, format_value_display, - is_puncutation, + is_punctuation, load_text, split_text, split_text_list_by_length, @@ -372,8 +372,8 @@ def test_url_error(self, mock_client): load_text("http://example.com/error.txt") -class TestIsPuncutation: - """Test suite for is_puncutation.""" +class TestIsPunctuation: + """Test suite for is_punctuation.""" @pytest.mark.smoke @pytest.mark.parametrize( @@ -392,8 +392,8 @@ class TestIsPuncutation: ], ) def test_invocation(self, text, expected): - """Test is_puncutation with various characters.""" - result = is_puncutation(text) + """Test is_punctuation with various characters.""" + result = is_punctuation(text) assert result == expected @pytest.mark.sanity @@ -405,9 +405,9 @@ def test_invocation(self, text, expected): ], ) def test_invalid_invocation(self, text): - """Test is_puncutation with invalid inputs.""" + """Test is_punctuation with invalid inputs.""" with pytest.raises((TypeError, AttributeError)): - is_puncutation(text) + is_punctuation(text) class TestEndlessTextCreator: