diff --git a/pyproject.toml b/pyproject.toml index 17380312..783292dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ keywords = [ ] dependencies = [ "click>=8.0.0,<8.2.0", + "culsans~=0.9.0", "datasets", "ftfy>=6.0.0", "httpx[http2]<1.0.0", @@ -157,15 +158,16 @@ indent-style = "space" [tool.ruff.lint] ignore = [ - "PLR0913", - "TC001", - "COM812", - "ISC001", - "TC002", + "COM812", # ignore trailing comma errors due to older Python versions + "PD011", # ignore .values usage since ruff assumes it's a Pandas DataFrame + "PLR0913", # ignore too many arguments in function definitions "PLW1514", # allow Path.open without encoding "RET505", # allow `else` blocks "RET506", # allow `else` blocks - "PD011", # ignore .values usage since ruff assumes it's a Pandas DataFrame + "S311", # allow standard pseudo-random generators + "TC001", # ignore imports used only for type checking + "TC002", # ignore imports used only for type checking + "TC003", # ignore imports used only for type checking ] select = [ # Rules reference: https://docs.astral.sh/ruff/rules/ diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index 576fe64d..6c8561ac 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -1,6 +1,13 @@ from .auto_importer import AutoImporterMixin from .colors import Colors from .default_group import DefaultGroupHandler +from .encoding import ( + Encoder, + EncodingTypesAlias, + MessageEncoding, + SerializationTypesAlias, + Serializer, +) from .functions import ( all_defined, safe_add, @@ -16,6 +23,12 @@ from .hf_transformers import ( check_load_processor, ) +from .messaging import ( + InterProcessMessaging, + InterProcessMessagingManagerQueue, + InterProcessMessagingPipe, + InterProcessMessagingQueue, +) from .pydantic_utils import ( PydanticClassRegistryMixin, ReloadableBaseModel, @@ -49,13 +62,22 @@ "Colors", "DefaultGroupHandler", "DistributionSummary", + "Encoder", + "EncodingTypesAlias", "EndlessTextCreator", "IntegerRangeSampler", + "InterProcessMessaging", + "InterProcessMessagingManagerQueue", + "InterProcessMessagingPipe", + "InterProcessMessagingQueue", + "MessageEncoding", "Percentiles", "PydanticClassRegistryMixin", "RegistryMixin", "ReloadableBaseModel", "RunningStats", + "SerializationTypesAlias", + "Serializer", "SingletonMixin", "StandardBaseDict", "StandardBaseModel", diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py new file mode 100644 index 00000000..86743c84 --- /dev/null +++ b/src/guidellm/utils/encoding.py @@ -0,0 +1,797 @@ +""" +Message encoding utilities for multiprocess communication with Pydantic model support. + +Provides binary serialization and deserialization of Python objects using various +serialization formats and encoding packages to enable performance configurations +for distributed scheduler operations. Supports configurable two-stage processing +pipeline: object serialization (to dict/sequence) followed by binary encoding +(msgpack/msgspec) with specialized Pydantic model handling for type preservation. +""" + +from __future__ import annotations + +import json +from collections.abc import Mapping +from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar + +try: + import msgpack + from msgpack import Packer, Unpacker + + HAS_MSGPACK = True +except ImportError: + msgpack = Packer = Unpacker = None + HAS_MSGPACK = False + +try: + from msgspec.msgpack import Decoder as MsgspecDecoder + from msgspec.msgpack import Encoder as MsgspecEncoder + + HAS_MSGSPEC = True +except ImportError: + MsgspecDecoder = MsgspecEncoder = None + HAS_MSGSPEC = False + +try: + import orjson + + HAS_ORJSON = True +except ImportError: + orjson = None + HAS_ORJSON = False + +from pydantic import BaseModel +from typing_extensions import TypeAlias + +__all__ = [ + "Encoder", + "EncodingTypesAlias", + "MessageEncoding", + "MsgT", + "ObjT", + "SerializationTypesAlias", + "Serializer", +] + +ObjT = TypeVar("ObjT") +MsgT = TypeVar("MsgT") + +SerializationTypesAlias: TypeAlias = Annotated[ + Optional[Literal["dict", "sequence"]], + "Type alias for available serialization strategies", +] +EncodingTypesAlias: TypeAlias = Annotated[ + Optional[Literal["msgpack", "msgspec"]], + "Type alias for available binary encoding formats", +] + + +class MessageEncoding(Generic[ObjT, MsgT]): + """ + High-performance message encoding and decoding for multiprocessing communication. + + Supports configurable object serialization and binary encoding with specialized + handling for Pydantic models. Provides a two-stage pipeline of serialization + (object to dict/str) followed by encoding (dict/str to binary) for optimal + performance and compatibility across different transport mechanisms used in + distributed scheduler operations. + + For performance reasons, this encoding only supports Python primitives + (int, float, str), Pydantic models, and collections of mixed Pydantic and Python + models for a single level of nesting (list[Pydantic, int, float, str], tuple, dict). + Nested combinations of mixtures of Pydantic and Python models are not supported. + + Example: + :: + from guidellm.utils.encoding import MessageEncoding + from pydantic import BaseModel + + class DataModel(BaseModel): + name: str + value: int + + # Configure with dict serialization and msgpack encoding + encoding = MessageEncoding(serialization="dict", encoding="msgpack") + encoding.register_pydantic(DataModel) + + # Encode and decode objects + data = DataModel(name="test", value=42) + encoded_msg = encoding.encode(data) + decoded_data = encoding.decode(encoded_msg) + + :cvar DEFAULT_ENCODING_PREFERENCE: Preferred encoding formats in priority order + """ + + DEFAULT_ENCODING_PREFERENCE: ClassVar[list[str]] = ["msgspec", "msgpack"] + + @classmethod + def encode_message( + cls, + obj: ObjT, + serializer: Serializer | None, + encoder: Encoder | None, + ) -> MsgT: + """ + Encode object using specified serializer and encoder. + + :param obj: Object to encode + :param serializer: Serializer for object conversion, None for no serialization + :param encoder: Encoder for binary conversion, None for no encoding + :return: Encoded message ready for transport + """ + serialized = serializer.serialize(obj) if serializer else obj + + return encoder.encode(serialized) if encoder else serialized + + @classmethod + def decode_message( + cls, + message: MsgT, + serializer: Serializer | None, + encoder: Encoder | None, + ) -> ObjT: + """ + Decode message using specified serializer and encoder. + Must match the encoding configuration originally used. + + :param message: Encoded message to decode + :param serializer: Serializer for object reconstruction, None for no + serialization + :param encoder: Encoder for binary decoding, None for no encoding + :return: Reconstructed object + """ + serialized = encoder.decode(message) if encoder else message + + return serializer.deserialize(serialized) if serializer else serialized + + def __init__( + self, + serialization: SerializationTypesAlias = None, + encoding: EncodingTypesAlias | list[EncodingTypesAlias] = None, + pydantic_models: list[type[BaseModel]] | None = None, + ) -> None: + """ + Initialize MessageEncoding with serialization and encoding strategies. + + :param serialization: Serialization strategy (None, "dict", or "sequence") + :param encoding: Encoding strategy (None, "msgpack", "msgspec", or + preference list) + """ + self.serializer = Serializer(serialization, pydantic_models) + self.encoder = Encoder(encoding) + + def register_pydantic(self, model: type[BaseModel]) -> None: + """ + Register Pydantic model for specialized serialization handling. + + :param model: Pydantic model class to register for type preservation + """ + self.serializer.register_pydantic(model) + + def encode(self, obj: ObjT) -> MsgT: + """ + Encode object using instance configuration. + + :param obj: Object to encode using configured serialization and encoding + :return: Encoded message ready for transport + """ + return self.encode_message( + obj=obj, + serializer=self.serializer, + encoder=self.encoder, + ) + + def decode(self, message: MsgT) -> ObjT: + """ + Decode message using instance configuration. + + :param message: Encoded message to decode using configured strategies + :return: Reconstructed object + """ + return self.decode_message( + message=message, + serializer=self.serializer, + encoder=self.encoder, + ) + + +class Encoder: + """ + Binary encoding and decoding using MessagePack or msgspec formats. + + Handles binary serialization of Python objects using configurable encoding + strategies with automatic fallback when dependencies are unavailable. Supports + both standalone instances and pooled encoder/decoder pairs for performance + optimization in high-throughput scenarios. + """ + + def __init__( + self, encoding: EncodingTypesAlias | list[EncodingTypesAlias] = None + ) -> None: + """ + Initialize encoder with specified encoding strategy. + + :param encoding: Encoding format preference (None, "msgpack", "msgspec", or + preference list) + """ + self.encoding, self.encoder, self.decoder = self._resolve_encoding(encoding) + + def encode(self, obj: Any) -> bytes | Any: + """ + Encode object to binary format using configured encoding strategy. + + :param obj: Object to encode (must be serializable by chosen format) + :return: Encoded bytes or original object if no encoding configured + :raises ImportError: If required encoding library is not available + """ + if self.encoding == "msgpack": + if not HAS_MSGPACK: + raise ImportError("msgpack is not available") + + return self.encoder.pack(obj) if self.encoder else msgpack.packb(obj) + + if self.encoding == "msgspec": + if not HAS_MSGSPEC: + raise ImportError("msgspec is not available") + + return ( + self.encoder.encode(obj) + if self.encoder + else MsgspecEncoder().encode(obj) + ) + + return obj + + def decode(self, data: bytes | Any) -> Any: + """ + Decode binary data using configured encoding strategy. + + :param data: Binary data to decode or object if no encoding configured + :return: Decoded Python object + :raises ImportError: If required encoding library is not available + """ + if self.encoding == "msgpack": + if not HAS_MSGPACK: + raise ImportError("msgpack is not available") + + if self.decoder is not None: + self.decoder.feed(data) + return self.decoder.unpack() + + return msgpack.unpackb(data, raw=False) + + if self.encoding == "msgspec": + if not HAS_MSGSPEC: + raise ImportError("msgspec is not available") + + if self.decoder is not None: + return self.decoder.decode(data) + + return MsgspecDecoder().decode(data) + + return data + + def _resolve_encoding( + self, encoding: EncodingTypesAlias | list[EncodingTypesAlias] | None + ) -> tuple[EncodingTypesAlias, Any, Any]: + def _get_available_encoder_decoder( + encoding: EncodingTypesAlias, + ) -> tuple[Any, Any]: + if encoding == "msgpack" and HAS_MSGPACK: + return Packer(), Unpacker(raw=False) + if encoding == "msgspec" and HAS_MSGSPEC: + return MsgspecEncoder(), MsgspecDecoder() + return None, None + + if not isinstance(encoding, list): + if encoding is None: + return None, None, None + + encoder, decoder = _get_available_encoder_decoder(encoding) + if encoder is None or decoder is None: + raise ImportError(f"Encoding '{encoding}' is not available.") + + return encoding, encoder, decoder + + for test_encoding in encoding: + encoder, decoder = _get_available_encoder_decoder(test_encoding) + if encoder is not None and decoder is not None: + return test_encoding, encoder, decoder + + return None, None, None + + +class Serializer: + """ + Object serialization with specialized Pydantic model support. + + Converts Python objects to serializable formats (dict/sequence) with type + preservation for Pydantic models. Maintains object integrity through + encoding/decoding cycles by storing class metadata and enabling proper + reconstruction of complex objects. Supports both dictionary-based and + sequence-based serialization strategies for different use cases. + + For performance reasons, this serializer only supports Python primitives + (int, float, str), Pydantic models, and collections of mixed Pydantic and Python + models for a single level of nesting (list[Pydantic, int, float, str], tuple, dict). + Nested combinations of mixtures of Pydantic and Python models are not supported. + """ + + def __init__( + self, + serialization: SerializationTypesAlias = None, + pydantic_models: list[type[BaseModel]] | None = None, + ): + """ + Initialize serializer with strategy and Pydantic registry. + + :param serialization: Default serialization strategy for this instance + """ + self.serialization = serialization + self.pydantic_registry: dict[tuple[str, str], type[BaseModel]] = {} + if pydantic_models: + for model in pydantic_models: + self.register_pydantic(model) + + def register_pydantic(self, model: type[BaseModel]) -> None: + """ + Register Pydantic model for specialized serialization handling. + + :param model: Pydantic model class to register for type preservation + """ + key = (model.__module__, model.__name__) + self.pydantic_registry[key] = model + + def load_pydantic(self, type_name: str, module_name: str) -> type[BaseModel]: + """ + Load Pydantic class by name with registry fallback to dynamic import. + + :param type_name: Class name to load + :param module_name: Module containing the class + :return: Loaded Pydantic model class + """ + key = (module_name, type_name) + + if key in self.pydantic_registry: + return self.pydantic_registry[key] + + # Dynamic import fallback; need to update to better handle generics + module = __import__(module_name, fromlist=[type_name]) + pydantic_class = getattr(module, type_name) + self.pydantic_registry[key] = pydantic_class + + return pydantic_class + + def serialize(self, obj: Any) -> Any: + """ + Serialize object using specified or configured strategy. + + :param obj: Object to serialize + :return: Serialized representation (dict, str, or original object) + """ + if self.serialization == "dict": + return self.to_dict(obj) + elif self.serialization == "sequence": + return self.to_sequence(obj) + + return obj + + def deserialize(self, msg: Any) -> Any: + """ + Deserialize object using specified or configured strategy. + + :param msg: Serialized message to deserialize + :return: Reconstructed object + """ + if self.serialization == "dict": + return self.from_dict(msg) + elif self.serialization == "sequence": + return self.from_sequence(msg) + + return msg + + def to_dict(self, obj: Any) -> Any: + """ + Convert object to dictionary with Pydantic model type preservation. + + :param obj: Object to convert (BaseModel, collections, or primitive) + :return: Dictionary representation with type metadata for Pydantic models + """ + if isinstance(obj, BaseModel): + return self.to_dict_pydantic(obj) + + if isinstance(obj, (list, tuple)) and any( + isinstance(item, BaseModel) for item in obj + ): + return [ + self.to_dict_pydantic(item) if isinstance(item, BaseModel) else item + for item in obj + ] + + if isinstance(obj, dict) and any( + isinstance(value, BaseModel) for value in obj.values() + ): + return { + key: self.to_dict_pydantic(value) + if isinstance(value, BaseModel) + else value + for key, value in obj.items() + } + + return obj + + def from_dict(self, data: Any) -> Any: + """ + Reconstruct object from dictionary with Pydantic model type restoration. + + :param data: Dictionary representation possibly containing type metadata + :return: Reconstructed object with proper types restored + """ + if isinstance(data, (list, tuple)): + return [ + self.from_dict_pydantic(item) + if isinstance(item, dict) and "*PYD*" in item + else item + for item in data + ] + elif isinstance(data, dict) and data: + if "*PYD*" in data: + return self.from_dict_pydantic(data) + + return { + key: self.from_dict_pydantic(value) + if isinstance(value, dict) and "*PYD*" in value + else value + for key, value in data.items() + } + + return data + + def to_dict_pydantic(self, item: Any) -> Any: + """ + Convert item to dictionary with Pydantic type metadata. + + :param item: Item to convert (may or may not be a Pydantic model) + :return: Dictionary with type preservation metadata + """ + return { + "*PYD*": True, + "typ": item.__class__.__name__, + "mod": item.__class__.__module__, + "dat": item.model_dump(mode="python"), + } + + def from_dict_pydantic(self, item: dict[str, Any]) -> Any: + """ + Reconstruct object from dictionary with Pydantic type metadata. + + :param item: Dictionary containing type metadata and data + :return: Reconstructed Pydantic model or original data + """ + type_name = item["typ"] + module_name = item["mod"] + model_class = self.load_pydantic(type_name, module_name) + + return model_class.model_validate(item["dat"]) + + def to_sequence(self, obj: Any) -> str | Any: + """ + Convert object to sequence format with type-aware serialization. + + Handles Pydantic models, collections, and mappings with proper type + preservation through structured sequence encoding. + + :param obj: Object to serialize to sequence format + :return: Serialized sequence string or bytes + """ + if isinstance(obj, BaseModel): + payload_type = "pydantic" + payload = self.to_sequence_pydantic(obj) + elif isinstance(obj, (list, tuple)) and any( + isinstance(item, BaseModel) for item in obj + ): + payload_type = "collection_sequence" + payload = None + + for item in obj: + is_pydantic = isinstance(item, BaseModel) + payload = self.pack_next_sequence( + type_="pydantic" if is_pydantic else "python", + payload=( + self.to_sequence_pydantic(item) + if is_pydantic + else self.to_sequence_python(item) + ), + current=payload, + ) + elif isinstance(obj, Mapping) and any( + isinstance(value, BaseModel) for value in obj.values() + ): + payload_type = "collection_mapping" + keys = ",".join(str(key) for key in obj) + payload = keys.encode() + b"|" if HAS_ORJSON else keys + "|" + for item in obj.values(): + is_pydantic = isinstance(item, BaseModel) + payload = self.pack_next_sequence( + type_="pydantic" if is_pydantic else "python", + payload=( + self.to_sequence_pydantic(item) + if is_pydantic + else self.to_sequence_python(item) + ), + current=payload, + ) + else: + payload_type = "python" + payload = self.to_sequence_python(obj) + + return self.pack_next_sequence(payload_type, payload, None) + + def from_sequence(self, data: str | Any) -> Any: # noqa: C901, PLR0912 + """ + Reconstruct object from sequence format with type restoration. + + Handles deserialization of objects encoded with to_sequence, properly + restoring Pydantic models and collection structures. + + :param data: Serialized sequence data to reconstruct + :return: Reconstructed object with proper types + :raises ValueError: If sequence format is invalid or contains multiple + packed sequences + """ + type_, payload, remaining = self.unpack_next_sequence(data) + if remaining is not None: + raise ValueError("Data contains multiple packed sequences; expected one.") + + if type_ == "pydantic": + return self.from_sequence_pydantic(payload) + + if type_ == "python": + return self.from_sequence_python(payload) + + if type_ in {"collection_sequence", "collection_tuple"}: + items = [] + while payload: + type_, item_payload, payload = self.unpack_next_sequence(payload) + if type_ == "pydantic": + items.append(self.from_sequence_pydantic(item_payload)) + elif type_ == "python": + items.append(self.from_sequence_python(item_payload)) + else: + raise ValueError("Invalid type in collection sequence") + return items + + if type_ != "collection_mapping": + raise ValueError(f"Invalid type for mapping sequence: {type_}") + + if isinstance(payload, bytes): + keys_end = payload.index(b"|") + keys = payload[:keys_end].decode().split(",") + payload = payload[keys_end + 1 :] + else: + keys_end = payload.index("|") + keys = payload[:keys_end].split(",") + payload = payload[keys_end + 1 :] + + items = {} + index = 0 + while payload: + type_, item_payload, payload = self.unpack_next_sequence(payload) + if type_ == "pydantic": + items[keys[index]] = self.from_sequence_pydantic(item_payload) + elif type_ == "python": + items[keys[index]] = self.from_sequence_python(item_payload) + else: + raise ValueError("Invalid type in mapping sequence") + index += 1 + return items + + def to_sequence_pydantic(self, obj: BaseModel) -> str | bytes: + """ + Serialize Pydantic model to sequence format with class metadata. + + :param obj: Pydantic model instance to serialize + :return: Sequence string or bytes containing class info and JSON data + """ + class_name: str = obj.__class__.__name__ + class_module: str = obj.__class__.__module__ + json_data = obj.__pydantic_serializer__.to_json(obj) + + return ( + (class_name.encode() + b"|" + class_module.encode() + b"|" + json_data) + if HAS_ORJSON + else ( + class_name + "|" + class_module + "|" + json_data.decode() + if isinstance(json_data, bytes) + else json_data + ) + ) + + def from_sequence_pydantic(self, data: str | bytes) -> BaseModel: + """ + Reconstruct Pydantic model from sequence format. + + :param data: Sequence data containing class metadata and JSON + :return: Reconstructed Pydantic model instance + """ + if isinstance(data, bytes): + class_name_end = data.index(b"|") + class_name = data[:class_name_end].decode() + module_name_end = data.index(b"|", class_name_end + 1) + module_name = data[class_name_end + 1 : module_name_end].decode() + json_data = data[module_name_end + 1 :] + else: + class_name_end = data.index("|") + class_name = data[:class_name_end] + module_name_end = data.index("|", class_name_end + 1) + module_name = data[class_name_end + 1 : module_name_end] + json_data = data[module_name_end + 1 :] + + model_class = self.load_pydantic(class_name, module_name) + + return model_class.model_validate_json(json_data) + + def to_sequence_python(self, obj: Any) -> str | bytes: + """ + Serialize Python object to JSON format. + + :param obj: Python object to serialize + :return: JSON string or bytes representation + """ + return orjson.dumps(obj) if HAS_ORJSON else json.dumps(obj) + + def from_sequence_python(self, data: str | bytes) -> Any: + """ + Deserialize Python object from JSON format. + + :param data: JSON string or bytes to deserialize + :return: Reconstructed Python object + :raises ImportError: If orjson is required but not available + """ + if isinstance(data, bytes): + if not HAS_ORJSON: + raise ImportError("orjson is not available, cannot deserialize bytes") + return orjson.loads(data) + + return json.loads(data) + + def pack_next_sequence( # noqa: C901, PLR0912 + self, + type_: Literal[ + "pydantic", + "python", + "collection_tuple", + "collection_sequence", + "collection_mapping", + ], + payload: str | bytes, + current: str | bytes | None, + ) -> str | bytes: + """ + Pack payload into sequence format with type and length metadata. + + :param type_: Type identifier for the payload + :param payload: Data to pack into sequence + :param current: Current sequence data to append to (unused but maintained + for signature compatibility) + :return: Packed sequence with type, length, and payload + :raises ValueError: If payload type doesn't match current type or unknown + type specified + """ + if current is not None and type(payload) is not type(current): + raise ValueError("Payload and current must be of the same type") + + payload_len = len(payload) + + if isinstance(payload, bytes): + payload_len = payload_len.to_bytes( + length=(payload_len.bit_length() + 7) // 8 if payload_len > 0 else 1, + byteorder="big", + ) + if type_ == "pydantic": + payload_type = b"P" + elif type_ == "python": + payload_type = b"p" + elif type_ == "collection_tuple": + payload_type = b"T" + elif type_ == "collection_sequence": + payload_type = b"S" + elif type_ == "collection_mapping": + payload_type = b"M" + else: + raise ValueError(f"Unknown type for packing: {type_}") + delimiter = b"|" + else: + payload_len = str(payload_len) + if type_ == "pydantic": + payload_type = "P" + elif type_ == "python": + payload_type = "p" + elif type_ == "collection_tuple": + payload_type = "T" + elif type_ == "collection_sequence": + payload_type = "S" + elif type_ == "collection_mapping": + payload_type = "M" + else: + raise ValueError(f"Unknown type for packing: {type_}") + delimiter = "|" + + next_sequence = payload_type + delimiter + payload_len + delimiter + payload + + return current + next_sequence if current else next_sequence + + def unpack_next_sequence( # noqa: C901, PLR0912 + self, data: str | bytes + ) -> tuple[ + Literal[ + "pydantic", + "python", + "collection_tuple", + "collection_sequence", + "collection_mapping", + ], + str | bytes, + str | bytes | None, + ]: + """ + Unpack sequence format to extract type, payload, and remaining data. + + :param data: Packed sequence data to unpack + :return: Tuple of (type, payload, remaining_data) + :raises ValueError: If sequence format is invalid or unknown type character + """ + if isinstance(data, bytes): + if len(data) < len(b"T|N") or data[1:2] != b"|": + raise ValueError("Invalid packed data format") + + type_char = data[0:1] + if type_char == b"P": + type_ = "pydantic" + elif type_char == b"p": + type_ = "python" + elif type_char == b"T": + type_ = "collection_tuple" + elif type_char == b"S": + type_ = "collection_sequence" + elif type_char == b"M": + type_ = "collection_mapping" + else: + raise ValueError("Unknown type character in packed data") + + len_end = data.index(b"|", 2) + payload_len = int.from_bytes(data[2:len_end], "big") + payload = data[len_end + 1 : len_end + 1 + payload_len] + remaining = ( + data[len_end + 1 + payload_len :] + if len_end + 1 + payload_len < len(data) + else None + ) + + return type_, payload, remaining + + if len(data) < len("T|N") or data[1] != "|": + raise ValueError("Invalid packed data format") + + type_char = data[0] + if type_char == "P": + type_ = "pydantic" + elif type_char == "p": + type_ = "python" + elif type_char == "S": + type_ = "collection_sequence" + elif type_char == "M": + type_ = "collection_mapping" + else: + raise ValueError("Unknown type character in packed data") + + len_end = data.index("|", 2) + payload_len = int(data[2:len_end]) + payload = data[len_end + 1 : len_end + 1 + payload_len] + remaining = ( + data[len_end + 1 + payload_len :] + if len_end + 1 + payload_len < len(data) + else None + ) + + return type_, payload, remaining diff --git a/src/guidellm/utils/messaging.py b/src/guidellm/utils/messaging.py new file mode 100644 index 00000000..700f41e0 --- /dev/null +++ b/src/guidellm/utils/messaging.py @@ -0,0 +1,1056 @@ +""" +Inter-process messaging abstractions for distributed scheduler coordination. + +Provides high-level interfaces for asynchronous message passing between worker +processes using various transport mechanisms including queues and pipes. Supports +configurable encoding, serialization, error handling, and flow control with +buffering and stop event coordination for distributed scheduler operations. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import multiprocessing +import queue +import threading +import time +from abc import ABC, abstractmethod +from collections.abc import Iterable +from multiprocessing.connection import Connection +from multiprocessing.connection import Pipe as ProcessingPipe +from multiprocessing.context import BaseContext +from multiprocessing.synchronize import Event as ProcessingEvent +from threading import Event as ThreadingEvent +from typing import Any, Callable, Generic, Literal, TypeVar + +import culsans +from pydantic import BaseModel + +from guidellm.utils.encoding import ( + EncodingTypesAlias, + MessageEncoding, + SerializationTypesAlias, +) + +__all__ = [ + "InterProcessMessaging", + "InterProcessMessagingManagerQueue", + "InterProcessMessagingPipe", + "InterProcessMessagingQueue", + "ReceiveMessageT", + "SendMessageT", +] + +SendMessageT = TypeVar("SendMessageT", bound=Any) +"""Generic type variable for messages sent through the messaging system""" +ReceiveMessageT = TypeVar("ReceiveMessageT", bound=Any) +"""Generic type variable for messages received through the messaging system""" + + +class InterProcessMessaging(Generic[SendMessageT, ReceiveMessageT], ABC): + """ + Abstract base for inter-process messaging in distributed scheduler coordination. + + Provides unified interface for asynchronous message passing between scheduler + components using configurable transport mechanisms, encoding schemes, and + flow control policies. Manages buffering, serialization, error handling, + and coordinated shutdown across worker processes for distributed load testing. + + Example: + :: + from guidellm.utils.messaging import InterProcessMessagingQueue + + messaging = InterProcessMessagingQueue( + serialization="pickle", + on_stop_action="stop_after_empty" + ) + + await messaging.start() + await messaging.put(request_data) + response = await messaging.get(timeout=5.0) + await messaging.stop() + """ + + def __init__( + self, + serialization: SerializationTypesAlias = "dict", + encoding: EncodingTypesAlias = None, + max_send_size: int | None = None, + max_buffer_send_size: int | None = None, + max_receive_size: int | None = None, + max_buffer_receive_size: int | None = None, + on_stop_action: Literal[ + "ignore", "stop", "stop_after_empty", "error" + ] = "stop_after_empty", + on_empty_action: Literal["ignore", "stop", "error"] = "ignore", + on_full_action: Literal["ignore", "stop", "error"] = "ignore", + poll_interval: float = 0.1, + worker_index: int | None = None, + ): + """ + Initialize inter-process messaging coordinator. + + :param serialization: Message serialization method for transport encoding + :param encoding: Optional encoding scheme for serialized message data + :param max_send_size: Maximum items in send queue before blocking + :param max_buffer_send_size: Maximum items in buffer send queue + :param max_receive_size: Maximum items in receive queue before blocking + :param max_buffer_receive_size: Maximum items in buffer receive queue + :param on_stop_action: Behavior when stop events are triggered + :param on_empty_action: Behavior when message queues become empty + :param on_full_action: Behavior when message queues become full + :param poll_interval: Time interval for checking queue status and events + :param worker_index: Index identifying this worker in the process group + """ + self.worker_index: int | None = worker_index + self.serialization = serialization + self.encoding = encoding + self.max_send_size = max_send_size + self.max_buffer_send_size = max_buffer_send_size + self.max_receive_size = max_receive_size + self.max_buffer_receive_size = max_buffer_receive_size + self.on_stop_action = on_stop_action + self.on_empty_action = on_empty_action + self.on_full_action = on_full_action + self.poll_interval = poll_interval + + self.stopped_event: ThreadingEvent = None + self.shutdown_event: ThreadingEvent = None + self.buffer_send_queue: culsans.Queue[SendMessageT] = None + self.buffer_receive_queue: culsans.Queue[ReceiveMessageT] = None + self.send_task: asyncio.Task = None + self.receive_task: asyncio.Task = None + self.running = False + + @abstractmethod + def create_worker_copy( + self, worker_index: int + ) -> InterProcessMessaging[SendMessageT, ReceiveMessageT]: + """ + Create worker-specific copy for distributed process coordination. + + :param worker_index: Index of the worker process for message routing + :return: Configured messaging instance for the specified worker + """ + ... + + @abstractmethod + async def send_messages_task( + self, + message_encoding: MessageEncoding, + stop_events: list[ThreadingEvent | ProcessingEvent], + send_items: Iterable[Any] | None, + ): + """ + Execute asynchronous message sending task for process coordination. + + :param message_encoding: Encoding configuration for message serialization + :param stop_events: Events that trigger task termination + :param send_items: Optional collection of items to send to other processes + """ + ... + + @abstractmethod + async def receive_messages_task( + self, + message_encoding: MessageEncoding, + stop_events: list[ThreadingEvent | ProcessingEvent], + receive_callback: Callable[[Any], None] | None, + ): + """ + Execute asynchronous message receiving task for process coordination. + + :param message_encoding: Encoding configuration for message deserialization + :param stop_events: Events that trigger task termination + :param receive_callback: Optional callback to process received messages + """ + ... + + async def start( + self, + send_items: Iterable[Any] | None = None, + receive_callback: Callable[[Any], None] | None = None, + stop_events: list[ThreadingEvent | ProcessingEvent] | None = None, + send_stop_events: list[ThreadingEvent | ProcessingEvent] | None = None, + receive_stop_events: list[ThreadingEvent | ProcessingEvent] | None = None, + pydantic_models: list[type[BaseModel]] | None = None, + ): + """ + Start asynchronous message processing tasks with buffering. + + :param send_items: Optional collection of items to send during processing + :param receive_callback: Optional callback for processing received messages + :param stop_events: External events that trigger messaging shutdown + :param send_stop_events: Events that trigger send task shutdown + :param receive_stop_events: Events that trigger receive task shutdown + :param pydantic_models: Optional list of Pydantic models for serialization + """ + self.running = True + self.stopped_event = ThreadingEvent() + self.shutdown_event = ThreadingEvent() + self.buffer_send_queue = culsans.Queue[SendMessageT]() + self.buffer_receive_queue = culsans.Queue[ReceiveMessageT]() + + message_encoding = MessageEncoding( + serialization=self.serialization, + encoding=self.encoding, + pydantic_models=pydantic_models, + ) + if send_stop_events is None: + send_stop_events = [] + if receive_stop_events is None: + receive_stop_events = [] + if stop_events: + send_stop_events.extend(stop_events) + receive_stop_events.extend(stop_events) + + self.send_task = asyncio.create_task( + self.send_messages_task( + message_encoding=message_encoding, + stop_events=send_stop_events, + send_items=send_items, + ) + ) + self.receive_task = asyncio.create_task( + self.receive_messages_task( + message_encoding=message_encoding, + stop_events=receive_stop_events, + receive_callback=receive_callback, + ) + ) + + async def stop(self): + """ + Stop message processing tasks and clean up resources. + """ + self.shutdown_event.set() + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather( + self.send_task, self.receive_task, return_exceptions=True + ) + self.send_task = None + self.receive_task = None + await self.buffer_send_queue.aclose() + await self.buffer_receive_queue.aclose() + self.buffer_send_queue = None + self.buffer_receive_queue = None + self.stopped_event = None + self.shutdown_event = None + self.running = False + + async def get(self, timeout: float | None = None) -> ReceiveMessageT: + """ + Retrieve message from receive buffer with optional timeout. + + :param timeout: Maximum time to wait for a message + :return: Decoded message from the receive buffer + """ + return await asyncio.wait_for( + self.buffer_receive_queue.async_get(), timeout=timeout + ) + + def get_sync(self, timeout: float | None = None) -> ReceiveMessageT: + """ + Retrieve message from receive buffer synchronously with optional timeout. + + :param timeout: Maximum time to wait for a message, if <=0 uses get_nowait + :return: Decoded message from the receive buffer + """ + if timeout is not None and timeout <= 0: + return self.buffer_receive_queue.get_nowait() + else: + return self.buffer_receive_queue.sync_get(timeout=timeout) + + async def put(self, item: SendMessageT, timeout: float | None = None): + """ + Add message to send buffer with optional timeout. + + :param item: Message item to add to the send buffer + :param timeout: Maximum time to wait for buffer space + """ + await asyncio.wait_for(self.buffer_send_queue.async_put(item), timeout=timeout) + + def put_sync(self, item: SendMessageT, timeout: float | None = None): + """ + Add message to send buffer synchronously with optional timeout. + + :param item: Message item to add to the send buffer + :param timeout: Maximum time to wait for buffer space, if <=0 uses put_nowait + """ + if timeout is not None and timeout <= 0: + self.buffer_send_queue.put_nowait(item) + else: + self.buffer_send_queue.sync_put(item, timeout=timeout) + + def check_on_stop_action( + self, + pending: Any | None, + queue_empty: bool, + stop_events: list[ThreadingEvent | ProcessingEvent], + ) -> bool: + """ + Check if messaging should stop based on configured stop action. + + :param pending: Currently pending message being processed + :param queue_empty: Whether the message queue is currently empty + :param stop_events: Events that indicate stop condition + :return: True if messaging should stop, False otherwise + :raises RuntimeError: When stop action is 'error' and stop event is set + """ + shutdown_set = self.shutdown_event.is_set() + + if self.on_stop_action == "ignore": + return shutdown_set and pending is None + + stop_set = any(event.is_set() for event in stop_events) + + if self.on_stop_action == "error": + if stop_set: + raise RuntimeError("Stop event set (on_stop_action='error').") + return shutdown_set and pending is None + + return ( + ( + self.on_stop_action == "stop" + or (self.on_stop_action == "stop_after_empty" and queue_empty) + ) + and (shutdown_set or stop_set) + and pending is None + ) + + def check_on_queue_empty_action(self, pending: Any | None) -> bool: + """ + Check if messaging should stop based on empty queue action. + + :param pending: Currently pending message being processed + :return: True if messaging should stop, False otherwise + :raises RuntimeError: When empty action is 'error' and queue is empty + """ + if self.on_empty_action == "error": + raise RuntimeError("Queue empty (on_empty_action='error').") + + return self.on_empty_action == "stop" and pending is None + + def check_on_queue_full_action(self, pending: Any | None) -> bool: + """ + Check if messaging should stop based on full queue action. + + :param pending: Currently pending message being processed + :return: True if messaging should stop, False otherwise + :raises RuntimeError: When full action is 'error' and queue is full + """ + if self.on_full_action == "error": + raise RuntimeError("Queue full (on_full_action='error').") + + return self.on_full_action == "stop" and pending is None + + +class InterProcessMessagingQueue(InterProcessMessaging[SendMessageT, ReceiveMessageT]): + """ + Queue-based inter-process messaging for distributed scheduler coordination. + + Provides message passing using multiprocessing.Queue objects for communication + between scheduler workers and main process. Handles message encoding, buffering, + flow control, and coordinated shutdown with configurable queue behavior and + error handling policies for distributed load testing operations. + + Example: + :: + from guidellm.utils.messaging import InterProcessMessagingQueue + + messaging = InterProcessMessagingQueue( + serialization="pickle", + max_send_size=100, + on_stop_action="stop_after_empty" + ) + + # Create worker copy for distributed processing + worker_messaging = messaging.create_worker_copy(worker_index=0) + """ + + def __init__( + self, + serialization: SerializationTypesAlias = "dict", + encoding: EncodingTypesAlias = None, + max_send_size: int | None = None, + max_buffer_send_size: int | None = None, + max_receive_size: int | None = None, + max_buffer_receive_size: int | None = None, + on_stop_action: Literal[ + "ignore", "stop", "stop_after_empty", "error" + ] = "stop_after_empty", + on_empty_action: Literal["ignore", "stop", "error"] = "ignore", + on_full_action: Literal["ignore", "stop", "error"] = "ignore", + poll_interval: float = 0.1, + worker_index: int | None = None, + send_queue: multiprocessing.Queue | None = None, + done_queue: multiprocessing.Queue | None = None, + ): + """ + Initialize queue-based messaging for inter-process communication. + + :param serialization: Message serialization method for transport encoding + :param encoding: Optional encoding scheme for serialized message data + :param max_send_size: Maximum items in send queue before blocking + :param max_buffer_send_size: Maximum items in buffer send queue + :param max_receive_size: Maximum items in receive queue before blocking + :param max_buffer_receive_size: Maximum items in buffer receive queue + :param on_stop_action: Behavior when stop events are triggered + :param on_empty_action: Behavior when message queues become empty + :param on_full_action: Behavior when message queues become full + :param poll_interval: Time interval for checking queue status and events + :param worker_index: Index identifying this worker in the process group + :param send_queue: Multiprocessing queue for sending messages + :param done_queue: Multiprocessing queue for receiving completed messages + """ + super().__init__( + serialization=serialization, + encoding=encoding, + max_send_size=max_send_size, + max_buffer_send_size=max_buffer_send_size, + max_receive_size=max_receive_size, + max_buffer_receive_size=max_buffer_receive_size, + on_stop_action=on_stop_action, + on_empty_action=on_empty_action, + on_full_action=on_full_action, + poll_interval=poll_interval, + worker_index=worker_index, + ) + self.send_queue = send_queue or multiprocessing.Queue( + maxsize=max_send_size or 0 + ) + self.done_queue = done_queue or multiprocessing.Queue( + maxsize=max_receive_size or 0 + ) + + def create_worker_copy( + self, worker_index: int + ) -> InterProcessMessagingQueue[SendMessageT, ReceiveMessageT]: + """ + Create worker-specific copy for distributed queue-based coordination. + + :param worker_index: Index of the worker process for message routing + :return: Configured queue messaging instance for the specified worker + """ + return InterProcessMessagingQueue( + serialization=self.serialization, + encoding=self.encoding, + max_send_size=self.max_send_size, + max_buffer_send_size=self.max_buffer_send_size, + max_receive_size=self.max_receive_size, + max_buffer_receive_size=self.max_buffer_receive_size, + on_stop_action=self.on_stop_action, + on_empty_action=self.on_empty_action, + on_full_action=self.on_full_action, + poll_interval=self.poll_interval, + worker_index=worker_index, + send_queue=self.send_queue, + done_queue=self.done_queue, + ) + + async def stop(self): + """ + Stop the messaging system and wait for all tasks to complete. + """ + await super().stop() + if self.worker_index is None: + # only main process should close the queues + self.send_queue.close() + self.done_queue.close() + self.send_queue = None + self.done_queue = None + + async def send_messages_task( + self, + message_encoding: MessageEncoding, + stop_events: list[ThreadingEvent | ProcessingEvent], + send_items: Iterable[Any] | None, + ): + """ + Execute asynchronous queue-based message sending task. + + :param message_encoding: Encoding configuration for message serialization + :param stop_events: Events that trigger task termination + :param send_items: Optional collection of items to send via queues + """ + canceled_event = ThreadingEvent() + + try: + await asyncio.to_thread( + self._send_messages_task_thread, + message_encoding, + stop_events, + send_items, + canceled_event, + ) + except asyncio.CancelledError: + canceled_event.set() + raise + finally: + self.stopped_event.set() + + async def receive_messages_task( + self, + message_encoding: MessageEncoding, + stop_events: list[ThreadingEvent | ProcessingEvent], + receive_callback: Callable[[Any], None] | None, + ): + """ + Execute asynchronous queue-based message receiving task. + + :param message_encoding: Encoding configuration for message deserialization + :param stop_events: Events that trigger task termination + :param receive_callback: Optional callback to process received messages + """ + canceled_event = ThreadingEvent() + + try: + return await asyncio.to_thread( + self._receive_messages_task_thread, + message_encoding, + stop_events, + receive_callback, + canceled_event, + ) + except asyncio.CancelledError: + canceled_event.set() + raise + finally: + self.stopped_event.set() + + def _send_messages_task_thread( # noqa: C901, PLR0912 + self, + message_encoding: MessageEncoding, + stop_events: list[ThreadingEvent | ProcessingEvent], + send_items: Iterable[Any] | None, + canceled_event: ThreadingEvent, + ): + send_items_iter = iter(send_items) if send_items is not None else None + pending_item = None + queue_empty_reported = False + + while not canceled_event.is_set(): + if self.check_on_stop_action( + pending_item, queue_empty_reported, stop_events + ): + break + + queue_empty_reported = False + + if pending_item is None: + try: + if send_items_iter is not None: + item = next(send_items_iter) + else: + item = self.buffer_send_queue.sync_get( + timeout=self.poll_interval + ) + pending_item = message_encoding.encode(item) + except (culsans.QueueEmpty, queue.Empty, StopIteration): + queue_empty_reported = True + if self.check_on_queue_empty_action(pending_item): + break + + if pending_item is not None: + try: + if self.worker_index is None: + # Main publisher + self.send_queue.put(pending_item, timeout=self.poll_interval) + else: + # Worker + self.done_queue.put(pending_item, timeout=self.poll_interval) + if send_items_iter is None: + self.buffer_send_queue.task_done() + pending_item = None + except (culsans.QueueFull, queue.Full): + if self.check_on_queue_full_action(pending_item): + break + + def _receive_messages_task_thread( # noqa: C901 + self, + message_encoding: MessageEncoding, + stop_events: list[ThreadingEvent | ProcessingEvent], + receive_callback: Callable[[Any], None] | None, + canceled_event: ThreadingEvent, + ): + pending_item = None + received_item = None + queue_empty_reported = False + + while not canceled_event.is_set(): + if self.check_on_stop_action( + pending_item, queue_empty_reported, stop_events + ): + break + + if pending_item is None: + try: + if self.worker_index is None: + # Main publisher + item = self.done_queue.get(timeout=self.poll_interval) + else: + # Worker + item = self.send_queue.get(timeout=self.poll_interval) + pending_item = message_encoding.decode(item) + except (culsans.QueueEmpty, queue.Empty): + queue_empty_reported = True + if self.check_on_queue_empty_action(pending_item): + break + + if pending_item is not None or received_item is not None: + try: + if received_item is None: + received_item = ( + pending_item + if not receive_callback + else receive_callback(pending_item) + ) + + self.buffer_receive_queue.sync_put(received_item) + pending_item = None + received_item = None + except (culsans.QueueFull, queue.Full): + if self.check_on_queue_full_action(pending_item): + break + + +class InterProcessMessagingManagerQueue( + InterProcessMessagingQueue[SendMessageT, ReceiveMessageT] +): + """ + Manager-based queue messaging for inter-process scheduler coordination. + + Extends queue-based messaging with multiprocessing.Manager support for + shared state coordination across worker processes. Provides managed queues + for reliable message passing in distributed scheduler environments with + enhanced process synchronization and resource management capabilities. + + Example: + :: + import multiprocessing + from guidellm.utils.messaging import InterProcessMessagingManagerQueue + + manager = multiprocessing.Manager() + messaging = InterProcessMessagingManagerQueue( + manager=manager, + serialization="pickle" + ) + """ + + def __init__( + self, + manager: BaseContext, + serialization: SerializationTypesAlias = "dict", + encoding: EncodingTypesAlias = None, + max_send_size: int | None = None, + max_buffer_send_size: int | None = None, + max_receive_size: int | None = None, + max_buffer_receive_size: int | None = None, + on_stop_action: Literal[ + "ignore", "stop", "stop_after_empty", "error" + ] = "stop_after_empty", + on_empty_action: Literal["ignore", "stop", "error"] = "ignore", + on_full_action: Literal["ignore", "stop", "error"] = "ignore", + poll_interval: float = 0.1, + worker_index: int | None = None, + send_queue: multiprocessing.Queue | None = None, + done_queue: multiprocessing.Queue | None = None, + ): + """ + Initialize manager-based queue messaging for inter-process communication. + + :param manager: Multiprocessing manager for shared queue creation + :param serialization: Message serialization method for transport encoding + :param encoding: Optional encoding scheme for serialized message data + :param max_send_size: Maximum items in send queue before blocking + :param max_buffer_send_size: Maximum items in buffer send queue + :param max_receive_size: Maximum items in receive queue before blocking + :param max_buffer_receive_size: Maximum items in buffer receive queue + :param on_stop_action: Behavior when stop events are triggered + :param on_empty_action: Behavior when message queues become empty + :param on_full_action: Behavior when message queues become full + :param poll_interval: Time interval for checking queue status and events + :param worker_index: Index identifying this worker in the process group + :param send_queue: Managed multiprocessing queue for sending messages + :param done_queue: Managed multiprocessing queue for receiving completed + messages + """ + super().__init__( + serialization=serialization, + encoding=encoding, + max_send_size=max_send_size, + max_buffer_send_size=max_buffer_send_size, + max_receive_size=max_receive_size, + max_buffer_receive_size=max_buffer_receive_size, + on_stop_action=on_stop_action, + on_empty_action=on_empty_action, + on_full_action=on_full_action, + poll_interval=poll_interval, + worker_index=worker_index, + send_queue=send_queue or manager.Queue(maxsize=max_send_size or 0), + done_queue=done_queue or manager.Queue(maxsize=max_receive_size or 0), + ) + + def create_worker_copy( + self, worker_index: int + ) -> InterProcessMessagingManagerQueue[SendMessageT, ReceiveMessageT]: + """ + Create worker-specific copy for managed queue-based coordination. + + :param worker_index: Index of the worker process for message routing + :return: Configured manager queue messaging instance for the specified worker + """ + return InterProcessMessagingManagerQueue( + manager=None, + serialization=self.serialization, + encoding=self.encoding, + max_send_size=self.max_send_size, + max_buffer_send_size=self.max_buffer_send_size, + max_receive_size=self.max_receive_size, + max_buffer_receive_size=self.max_buffer_receive_size, + on_stop_action=self.on_stop_action, + on_empty_action=self.on_empty_action, + on_full_action=self.on_full_action, + poll_interval=self.poll_interval, + worker_index=worker_index, + send_queue=self.send_queue, + done_queue=self.done_queue, + ) + + async def stop(self): + """ + Stop the messaging system and wait for all tasks to complete. + """ + await InterProcessMessaging.stop(self) + self.send_queue = None + self.done_queue = None + + +class InterProcessMessagingPipe(InterProcessMessaging[SendMessageT, ReceiveMessageT]): + """ + Pipe-based inter-process messaging for distributed scheduler coordination. + + Provides message passing using multiprocessing.Pipe objects for direct + communication between scheduler workers and main process. Offers lower + latency than queue-based messaging with duplex communication channels + for high-performance distributed load testing operations. + + Example: + :: + from guidellm.utils.messaging import InterProcessMessagingPipe + + messaging = InterProcessMessagingPipe( + num_workers=4, + serialization="pickle", + poll_interval=0.05 + ) + + # Create worker copy for specific worker process + worker_messaging = messaging.create_worker_copy(worker_index=0) + """ + + def __init__( + self, + num_workers: int, + serialization: SerializationTypesAlias = "dict", + encoding: EncodingTypesAlias = None, + max_send_size: int | None = None, + max_buffer_send_size: int | None = None, + max_receive_size: int | None = None, + max_buffer_receive_size: int | None = None, + on_stop_action: Literal[ + "ignore", "stop", "stop_after_empty", "error" + ] = "stop_after_empty", + on_empty_action: Literal["ignore", "stop", "error"] = "ignore", + on_full_action: Literal["ignore", "stop", "error"] = "ignore", + poll_interval: float = 0.1, + worker_index: int | None = None, + pipe: tuple[Connection, Connection] | None = None, + ): + """ + Initialize pipe-based messaging for inter-process communication. + + :param num_workers: Number of worker processes requiring pipe connections + :param serialization: Message serialization method for transport encoding + :param encoding: Optional encoding scheme for serialized message data + :param max_send_size: Maximum items in send queue before blocking + :param max_buffer_send_size: Maximum items in buffer send queue + :param max_receive_size: Maximum items in receive queue before blocking + :param max_buffer_receive_size: Maximum items in buffer receive queue + :param on_stop_action: Behavior when stop events are triggered + :param on_empty_action: Behavior when message queues become empty + :param on_full_action: Behavior when message queues become full + :param poll_interval: Time interval for checking queue status and events + :param worker_index: Index identifying this worker in the process group + :param pipe: Existing pipe connection for worker-specific instances + """ + super().__init__( + serialization=serialization, + encoding=encoding, + max_send_size=max_send_size, + max_buffer_send_size=max_buffer_send_size, + max_receive_size=max_receive_size, + max_buffer_receive_size=max_buffer_receive_size, + on_stop_action=on_stop_action, + on_empty_action=on_empty_action, + on_full_action=on_full_action, + poll_interval=poll_interval, + worker_index=worker_index, + ) + self.num_workers = num_workers + + if pipe is None: + self.pipes: list[tuple[Connection, Connection]] = [ + ProcessingPipe(duplex=True) for _ in range(num_workers) + ] + else: + self.pipes: list[tuple[Connection, Connection]] = [pipe] + + def create_worker_copy( + self, worker_index: int + ) -> InterProcessMessagingPipe[SendMessageT, ReceiveMessageT]: + """ + Create worker-specific copy for pipe-based coordination. + + :param worker_index: Index of the worker process for pipe routing + :return: Configured pipe messaging instance for the specified worker + """ + return InterProcessMessagingPipe( + num_workers=self.num_workers, + serialization=self.serialization, + encoding=self.encoding, + max_send_size=self.max_send_size, + max_receive_size=self.max_receive_size, + on_stop_action=self.on_stop_action, + on_empty_action=self.on_empty_action, + on_full_action=self.on_full_action, + poll_interval=self.poll_interval, + worker_index=worker_index, + pipe=self.pipes[worker_index], + ) + + async def stop(self): + """ + Stop the messaging system and wait for all tasks to complete. + """ + await super().stop() + if self.worker_index is None: + # Only main process should close the pipes + for main_con, worker_con in self.pipes: + main_con.close() + worker_con.close() + + async def send_messages_task( + self, + message_encoding: MessageEncoding, + stop_events: list[ThreadingEvent | ProcessingEvent], + send_items: Iterable[Any] | None, + ): + """ + Execute asynchronous pipe-based message sending task. + + :param message_encoding: Encoding configuration for message serialization + :param stop_events: Events that trigger task termination + :param send_items: Optional collection of items to send via pipes + """ + canceled_event = ThreadingEvent() + + try: + if self.worker_index is None: + # Create a separate task for each worker's pipe + await asyncio.gather( + *[ + asyncio.to_thread( + self._send_messages_task_thread, + self.pipes[index], + message_encoding, + stop_events, + send_items, + canceled_event, + ) + for index in range(self.num_workers) + ] + ) + else: + await asyncio.to_thread( + self._send_messages_task_thread, + self.pipes[0], + message_encoding, + stop_events, + send_items, + canceled_event, + ) + except asyncio.CancelledError: + canceled_event.set() + raise + finally: + self.stopped_event.set() + + async def receive_messages_task( + self, + message_encoding: MessageEncoding, + stop_events: list[ThreadingEvent | ProcessingEvent], + receive_callback: Callable[[Any], None] | None, + ): + """ + Execute asynchronous pipe-based message receiving task. + + :param message_encoding: Encoding configuration for message deserialization + :param stop_events: Events that trigger task termination + :param receive_callback: Optional callback to process received messages + """ + canceled_event = ThreadingEvent() + + try: + if self.worker_index is None: + # Create a separate task for each worker's pipe + await asyncio.gather( + *[ + asyncio.to_thread( + self._receive_messages_task_thread, + self.pipes[index], + message_encoding, + stop_events, + receive_callback, + canceled_event, + ) + for index in range(self.num_workers) + ] + ) + else: + await asyncio.to_thread( + self._receive_messages_task_thread, + self.pipes[0], + message_encoding, + stop_events, + receive_callback, + canceled_event, + ) + except asyncio.CancelledError: + canceled_event.set() + raise + finally: + self.stopped_event.set() + + def _send_messages_task_thread( # noqa: C901, PLR0912 + self, + pipe: tuple[Connection, Connection], + message_encoding: MessageEncoding, + stop_events: list[ThreadingEvent | ProcessingEvent], + send_items: Iterable[Any] | None, + canceled_event: ThreadingEvent, + ): + send_connection: Connection = pipe[0] if self.worker_index is None else pipe[1] + send_items_iter = iter(send_items) if send_items is not None else None + pending_item = None + queue_empty_reported = False + pipe_item = None + pipe_lock = threading.Lock() + + def _background_pipe_recv(): + nonlocal pipe_item + + while ( + not canceled_event.is_set() + and self.stopped_event is not None + and not self.stopped_event.is_set() + ): + try: + with pipe_lock: + pending = pipe_item + pipe_item = None + + if pending is not None: + send_connection.send(pending) + except (EOFError, ConnectionResetError): + break + + if send_items_iter is None: + threading.Thread(target=_background_pipe_recv, daemon=True).start() + + while not canceled_event.is_set(): + if self.check_on_stop_action( + pending_item, queue_empty_reported, stop_events + ): + break + + queue_empty_reported = False + + if pending_item is None: + try: + if send_items_iter is not None: + item = next(send_items_iter) + else: + item = self.buffer_send_queue.sync_get( + timeout=self.poll_interval + ) + pending_item = message_encoding.encode(item) + except (culsans.QueueEmpty, queue.Empty, StopIteration): + queue_empty_reported = True + if self.check_on_queue_empty_action(pending_item): + break + + if pending_item is not None: + try: + with pipe_lock: + if pipe_item is not None: + time.sleep(self.poll_interval / 100) + raise queue.Full + else: + pipe_item = pending_item + if send_items_iter is None: + self.buffer_send_queue.task_done() + pending_item = None + except (culsans.QueueFull, queue.Full): + if self.check_on_queue_full_action(pending_item): + break + + def _receive_messages_task_thread( # noqa: C901 + self, + pipe: tuple[Connection, Connection], + message_encoding: MessageEncoding, + stop_events: list[ThreadingEvent | ProcessingEvent], + receive_callback: Callable[[Any], None] | None, + canceled_event: ThreadingEvent, + ): + receive_connection: Connection = ( + pipe[0] if self.worker_index is not None else pipe[1] + ) + pending_item = None + received_item = None + queue_empty_reported = False + + while not canceled_event.is_set(): + if self.check_on_stop_action( + pending_item, queue_empty_reported, stop_events + ): + break + + if pending_item is None: + try: + if receive_connection.poll(self.poll_interval): + item = receive_connection.recv() + pending_item = message_encoding.decode(item) + else: + raise queue.Empty + except (culsans.QueueEmpty, queue.Empty): + queue_empty_reported = True + if self.check_on_queue_empty_action(pending_item): + break + + if pending_item is not None or received_item is not None: + try: + if received_item is None: + received_item = ( + pending_item + if not receive_callback + else receive_callback(pending_item) + ) + + self.buffer_receive_queue.sync_put(received_item) + pending_item = None + received_item = None + except (culsans.QueueFull, queue.Full): + if self.check_on_queue_full_action(pending_item): + break diff --git a/tests/unit/utils/test_encoding.py b/tests/unit/utils/test_encoding.py new file mode 100644 index 00000000..763f390d --- /dev/null +++ b/tests/unit/utils/test_encoding.py @@ -0,0 +1,510 @@ +from __future__ import annotations + +import uuid +from typing import Any, Generic + +import pytest +from pydantic import BaseModel, Field + +from guidellm.backend.objects import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, +) +from guidellm.scheduler.objects import RequestSchedulerTimings, ScheduledRequestInfo +from guidellm.utils.encoding import Encoder, MessageEncoding, Serializer + + +class SampleModel(BaseModel): + """Sample Pydantic model for testing.""" + + name: str = Field(description="Name field for testing") + value: int = Field(description="Value field for testing") + + +class ComplexModel(BaseModel): + """Complex Pydantic model for testing.""" + + items: list[str] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + nested: SampleModel | None = Field(default=None) + + +class TestMessageEncoding: + """Test suite for MessageEncoding class.""" + + @pytest.fixture( + params=[ + {"serialization": None, "encoding": None}, + {"serialization": "dict", "encoding": None}, + {"serialization": "sequence", "encoding": None}, + {"serialization": None, "encoding": "msgpack"}, + {"serialization": "dict", "encoding": "msgpack"}, + {"serialization": "sequence", "encoding": "msgpack"}, + {"serialization": None, "encoding": "msgspec"}, + {"serialization": "dict", "encoding": "msgspec"}, + {"serialization": "sequence", "encoding": "msgspec"}, + {"serialization": None, "encoding": ["msgspec", "msgpack"]}, + {"serialization": "dict", "encoding": ["msgspec", "msgpack"]}, + ], + ids=[ + "no_serialization_no_encoding", + "dict_serialization_no_encoding", + "str_serialization_no_encoding", + "no_serialization_msgpack", + "dict_serialization_msgpack", + "str_serialization_msgpack", + "no_serialization_msgspec", + "dict_serialization_msgspec", + "str_serialization_msgspec", + "no_serialization_encoding_list", + "dict_serialization_encoding_list", + ], + ) + def valid_instances(self, request): + """Fixture providing test data for MessageEncoding.""" + constructor_args = request.param + try: + instance = MessageEncoding(**constructor_args) + return instance, constructor_args + except ImportError: + pytest.skip("Required encoding library not available") + + @pytest.mark.smoke + def test_class_signatures(self): + """Test MessageEncoding inheritance and type relationships.""" + assert issubclass(MessageEncoding, Generic) + assert hasattr(MessageEncoding, "DEFAULT_ENCODING_PREFERENCE") + assert isinstance(MessageEncoding.DEFAULT_ENCODING_PREFERENCE, list) + assert MessageEncoding.DEFAULT_ENCODING_PREFERENCE == ["msgspec", "msgpack"] + + # Check classmethods + assert hasattr(MessageEncoding, "encode_message") + assert callable(MessageEncoding.encode_message) + assert hasattr(MessageEncoding, "decode_message") + assert callable(MessageEncoding.decode_message) + + # Check instance methods + assert hasattr(MessageEncoding, "__init__") + assert hasattr(MessageEncoding, "register_pydantic") + assert hasattr(MessageEncoding, "encode") + assert hasattr(MessageEncoding, "decode") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test MessageEncoding initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, MessageEncoding) + assert hasattr(instance, "serializer") + assert isinstance(instance.serializer, Serializer) + assert instance.serializer.serialization == constructor_args["serialization"] + assert hasattr(instance, "encoder") + assert isinstance(instance.encoder, Encoder) + + expected_encoding = constructor_args["encoding"] + if isinstance(expected_encoding, list): + assert instance.encoder.encoding in expected_encoding + else: + assert instance.encoder.encoding == expected_encoding + + @pytest.mark.smoke + @pytest.mark.parametrize( + "obj", + [ + None, + 0, + 0.0, + "0.1.2.3", + [0, 0.0, "0.1.2.3", None], + (0, 0.0, "0.1.2.3", None), + {"key1": 0, "key2": 0.0, "key3": "0.1.2.3", "key4": None}, + ], + ) + def test_encode_decode_python(self, valid_instances, obj: Any): + """Test MessageEncoding encode/decode with comprehensive data types.""" + instance, constructor_args = valid_instances + + message = instance.encode(obj) + decoded = instance.decode(message) + + if isinstance(obj, tuple): + assert list(decoded) == list(obj) + else: + assert decoded == obj + + @pytest.mark.smoke + @pytest.mark.parametrize( + "obj", + [ + SampleModel(name="sample", value=123), + ComplexModel( + items=["item1", "item2"], + metadata={"key": "value"}, + nested=SampleModel(name="sample", value=123), + ), + ( + SampleModel(name="sample", value=123), + None, + ComplexModel( + items=["item1", "item2"], + metadata={"key": "value"}, + nested=SampleModel(name="sample", value=123), + ), + ), + { + "key1": SampleModel(name="sample", value=123), + "key2": None, + "key3": ComplexModel( + items=["item1", "item2"], + metadata={"key": "value"}, + nested=SampleModel(name="sample", value=123), + ), + }, + ], + ) + def test_encode_decode_pydantic(self, valid_instances, obj: Any): + """Test MessageEncoding encode/decode with Pydantic models.""" + instance, constructor_args = valid_instances + + if ( + constructor_args["serialization"] is None + and constructor_args["encoding"] is not None + ): + # msgpack/msgspec don't support Pydantic models natively + pytest.skip("Skipping unsupported Pydantic serialization/encoding combo") + + # Register Pydantic models for proper serialization + instance.register_pydantic(SampleModel) + instance.register_pydantic(ComplexModel) + + message = instance.encode(obj) + decoded = instance.decode(message) + + if isinstance(obj, tuple): + assert list(decoded) == list(obj) + else: + assert decoded == obj + + @pytest.mark.smoke + @pytest.mark.parametrize( + "obj", + [ + ( + None, + GenerationRequest(content="test content"), + ScheduledRequestInfo[GenerationRequestTimings]( + scheduler_timings=RequestSchedulerTimings( + targeted_start=1.0, + queued=0.1, + dequeued=0.2, + scheduled_at=0.3, + resolve_start=1.1, + resolve_end=1.5, + finalized=1.6, + ) + ), + ), + ( + GenerationResponse( + request_id=str(uuid.uuid4()), + request_args={}, + value="test response", + request_prompt_tokens=2, + request_output_tokens=3, + response_prompt_tokens=4, + response_output_tokens=6, + ), + GenerationRequest(content="test content"), + ScheduledRequestInfo[GenerationRequestTimings]( + scheduler_timings=RequestSchedulerTimings( + targeted_start=1.0, + queued=0.1, + dequeued=0.2, + scheduled_at=0.3, + resolve_start=1.1, + resolve_end=1.5, + finalized=1.6, + ) + ), + ), + ], + ) + def test_encode_decode_generative(self, valid_instances, obj: Any): + """Test MessageEncoding encode/decode with generative models.""" + instance, constructor_args = valid_instances + + if ( + constructor_args["serialization"] is None + and constructor_args["encoding"] is not None + ): + # msgpack/msgspec don't support Pydantic models natively + pytest.skip("Skipping unsupported Pydantic serialization/encoding combo") + + instance.register_pydantic(GenerationRequest) + instance.register_pydantic(GenerationResponse) + instance.register_pydantic(ScheduledRequestInfo[GenerationRequestTimings]) + + message = instance.encode(obj) + decoded = instance.decode(message) + + assert list(decoded) == list(obj) + + @pytest.mark.smoke + @pytest.mark.parametrize( + "serialization", + [ + None, + "dict", + "sequence", + ], + ) + @pytest.mark.parametrize( + "encoding", + [None, "msgpack", "msgspec"], + ) + @pytest.mark.parametrize( + "obj", + [ + "0.1.2.3", + [0, 0.0, "0.1.2.3", None, SampleModel(name="sample", value=123)], + { + "key1": 0, + "key2": 0.0, + "key3": "0.1.2.3", + "key4": None, + "key5": ComplexModel( + items=["item1", "item2"], + metadata={"key": "value"}, + nested=SampleModel(name="sample", value=123), + ), + }, + ], + ) + def test_encode_decode_message(self, serialization, encoding, obj): + """Test MessageEncoding.encode_message and decode_message class methods.""" + if encoding is not None and serialization is None and obj != "0.1.2.3": + pytest.skip("Skipping unsupported serialization/encoding combo") + + try: + serializer = Serializer(serialization) if serialization else None + encoder = Encoder(encoding) if encoding else None + + message = MessageEncoding.encode_message(obj, serializer, encoder) + decoded = MessageEncoding.decode_message(message, serializer, encoder) + + if isinstance(obj, tuple): + assert list(decoded) == list(obj) + else: + assert decoded == obj + except ImportError: + pytest.skip("Required encoding library not available") + + @pytest.mark.smoke + def test_register_pydantic(self): + """Test MessageEncoding.register_pydantic functionality.""" + instance = MessageEncoding(serialization="dict", encoding=None) + assert len(instance.serializer.pydantic_registry) == 0 + instance.register_pydantic(SampleModel) + assert len(instance.serializer.pydantic_registry) == 1 + assert ( + instance.serializer.pydantic_registry.values().__iter__().__next__() + is SampleModel + ) + + @pytest.mark.sanity + def test_invalid_initialization_values(self): + """Test invalid initialization (unsupported encoding).""" + inst = MessageEncoding(serialization="dict", encoding=["invalid_encoding"]) # type: ignore[arg-type] + assert inst.encoder.encoding is None + with pytest.raises(ImportError): + MessageEncoding(serialization="dict", encoding="invalid") # type: ignore[arg-type] + + +class TestEncoder: + """Test suite for Encoder class.""" + + @pytest.fixture( + params=[ + None, + "msgpack", + "msgspec", + ["msgspec", "msgpack"], + ["msgpack", "msgspec"], + ], + ids=[ + "none", + "msgpack", + "msgspec", + "list_pref_msgspec_first", + "list_pref_msgpack_first", + ], + ) + def valid_instances(self, request): + args = request.param + try: + inst = Encoder(args) + except ImportError: + pytest.skip("Encoding backend missing") + return inst, args + + @pytest.mark.smoke + def test_class_signatures(self): + assert hasattr(Encoder, "encode") + assert hasattr(Encoder, "decode") + assert hasattr(Encoder, "_resolve_encoding") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + inst, args = valid_instances + assert isinstance(inst, Encoder) + if isinstance(args, list): + assert inst.encoding in args or inst.encoding is None + else: + assert inst.encoding == args + + @pytest.mark.sanity + def test_invalid_initialization_values(self): + with pytest.raises(ImportError): + Encoder("invalid") # type: ignore[arg-type] + + @pytest.mark.smoke + @pytest.mark.parametrize("obj", [None, 0, 1.2, "text", [1, 2], {"a": 1}]) + def test_encode_decode(self, valid_instances, obj): + inst, _ = valid_instances + msg = inst.encode(obj) + out = inst.decode(msg) + assert out == obj + + +class TestSerializer: + """Test suite for Serializer class.""" + + @pytest.fixture(params=[None, "dict", "sequence"], ids=["none", "dict", "sequence"]) + def valid_instances(self, request): + inst = Serializer(request.param) + return inst, request.param + + @pytest.mark.smoke + def test_class_signatures(self): + assert hasattr(Serializer, "serialize") + assert hasattr(Serializer, "deserialize") + assert hasattr(Serializer, "register_pydantic") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + inst, mode = valid_instances + assert isinstance(inst, Serializer) + assert inst.serialization == mode + + @pytest.mark.smoke + def test_register_pydantic(self, valid_instances): + inst, _ = valid_instances + assert len(inst.pydantic_registry) == 0 + inst.register_pydantic(SampleModel) + assert len(inst.pydantic_registry) == 1 + + @pytest.mark.smoke + @pytest.mark.parametrize( + "obj", + [ + 1, + "str_val", + [1, 2, 3], + SampleModel(name="x", value=1), + {"k": SampleModel(name="y", value=2)}, + ], + ) + def test_serialize_deserialize(self, valid_instances, obj): + inst, mode = valid_instances + inst.register_pydantic(SampleModel) + msg = inst.serialize(obj) + out = inst.deserialize(msg) + if isinstance(obj, list): + assert list(out) == obj + else: + assert out == obj + + @pytest.mark.regression + def test_sequence_mapping_roundtrip(self): + inst = Serializer("sequence") + inst.register_pydantic(SampleModel) + data = { + "a": SampleModel(name="a", value=1), + "b": SampleModel(name="b", value=2), + } + msg = inst.serialize(data) + out = inst.deserialize(msg) + assert out == data + + @pytest.mark.sanity + def test_to_from_dict_variations(self): + inst = Serializer("dict") + inst.register_pydantic(SampleModel) + model = SampleModel(name="n", value=3) + lst = [model, 5] + mp = {"k1": model, "k2": 9} + assert inst.from_dict(inst.to_dict(model)) == model + assert inst.from_dict(inst.to_dict(lst)) == lst + assert inst.from_dict(inst.to_dict(mp)) == mp + + @pytest.mark.sanity + @pytest.mark.parametrize( + "collection", + [ + [SampleModel(name="x", value=1), 2, 3], + (SampleModel(name="y", value=2), None), + ], + ) + def test_to_from_sequence_collections(self, collection): + inst = Serializer("sequence") + inst.register_pydantic(SampleModel) + seq = inst.to_sequence(collection) + out = inst.from_sequence(seq) + assert len(out) == len(collection) + assert all(a == b for a, b in zip(out, list(collection))) + + @pytest.mark.sanity + def test_to_from_sequence_mapping(self): + inst = Serializer("sequence") + inst.register_pydantic(SampleModel) + data = {"k": SampleModel(name="z", value=7), "j": 1} + seq = inst.to_sequence(data) + out = inst.from_sequence(seq) + assert out == data + + @pytest.mark.sanity + def test_sequence_multiple_root_raises(self): + inst = Serializer("sequence") + part1 = inst.pack_next_sequence("python", inst.to_sequence_python(1), None) + part2 = inst.pack_next_sequence("python", inst.to_sequence_python(2), None) + with pytest.raises(ValueError): + inst.from_sequence(part1 + part2) # type: ignore[operator] + + @pytest.mark.sanity + def test_pack_next_sequence_type_mismatch(self): + inst = Serializer("sequence") + first_payload = inst.to_sequence_python(1) + first = inst.pack_next_sequence("python", first_payload, None) + bad_payload: Any = ( + first_payload.decode() if isinstance(first_payload, bytes) else b"1" + ) + with pytest.raises(ValueError): + inst.pack_next_sequence("python", bad_payload, first) + + @pytest.mark.sanity + def test_unpack_invalid(self): + inst = Serializer("sequence") + with pytest.raises(ValueError): + inst.unpack_next_sequence("X|3|abc") + with pytest.raises(ValueError): + inst.unpack_next_sequence("p?bad") + + @pytest.mark.sanity + def test_dynamic_import_load_pydantic(self, monkeypatch): + inst = Serializer("dict") + inst.pydantic_registry.clear() + sample = SampleModel(name="dyn", value=5) + dumped = inst.to_dict(sample) + inst.pydantic_registry.clear() + restored = inst.from_dict(dumped) + assert restored == sample diff --git a/tests/unit/utils/test_messaging.py b/tests/unit/utils/test_messaging.py new file mode 100644 index 00000000..fc6155f8 --- /dev/null +++ b/tests/unit/utils/test_messaging.py @@ -0,0 +1,1108 @@ +from __future__ import annotations + +import asyncio +import multiprocessing +import threading +from functools import wraps +from typing import Any, TypeVar + +import culsans +import pytest +from pydantic import BaseModel + +from guidellm.backend import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, +) +from guidellm.scheduler import ScheduledRequestInfo +from guidellm.utils import ( + InterProcessMessaging, + InterProcessMessagingManagerQueue, + InterProcessMessagingPipe, + InterProcessMessagingQueue, +) +from guidellm.utils.messaging import ReceiveMessageT, SendMessageT + + +def async_timeout(delay: float): + """Decorator to add timeout to async test functions.""" + + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +class MockMessage(BaseModel): + content: str + num: int + + +class MockProcessTarget: + """Mock process target for testing.""" + + def __init__( + self, + messaging: InterProcessMessaging, + num_messages: int, + worker_index: int = 0, + ): + self.messaging = messaging + self.num_messages = num_messages + self.worker_index = worker_index + + def run(self): + loop = asyncio.new_event_loop() + + try: + asyncio.set_event_loop(loop) + asyncio.run(asyncio.wait_for(self._async_runner(), timeout=10.0)) + except RuntimeError: + pass + finally: + loop.close() + + async def _async_runner(self): + await self.messaging.start( + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo[GenerationRequestTimings], + ], + ) + + try: + for _ in range(self.num_messages): + obj = await self.messaging.get(timeout=2.0) + await self.messaging.put(obj, timeout=2.0) + finally: + await self.messaging.stop() + + +@pytest.fixture( + params=[ + {"ctx_name": "fork"}, + {"ctx_name": "spawn"}, + ], + ids=["fork_ctx", "spawn_ctx"], +) +def multiprocessing_contexts(request): + context = multiprocessing.get_context(request.param["ctx_name"]) + manager = context.Manager() + try: + yield manager, context + finally: + manager.shutdown() + + +def test_send_message_type(): + """Test that SendMessageT is filled out correctly as a TypeVar.""" + assert isinstance(SendMessageT, type(TypeVar("test"))) + assert SendMessageT.__name__ == "SendMessageT" + assert SendMessageT.__bound__ is Any + assert SendMessageT.__constraints__ == () + + +def test_receive_message_type(): + """Test that ReceiveMessageT is filled out correctly as a TypeVar.""" + assert isinstance(ReceiveMessageT, type(TypeVar("test"))) + assert ReceiveMessageT.__name__ == "ReceiveMessageT" + assert ReceiveMessageT.__bound__ is Any + assert ReceiveMessageT.__constraints__ == () + + +class TestInterProcessMessaging: + """Test suite for InterProcessMessaging abstract base class.""" + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InterProcessMessaging abstract class signatures.""" + assert hasattr(InterProcessMessaging, "__init__") + assert hasattr(InterProcessMessaging, "create_worker_copy") + assert hasattr(InterProcessMessaging, "send_messages_task") + assert hasattr(InterProcessMessaging, "receive_messages_task") + assert hasattr(InterProcessMessaging, "start") + assert hasattr(InterProcessMessaging, "stop") + assert hasattr(InterProcessMessaging, "get") + assert hasattr(InterProcessMessaging, "put") + + # Check abstract methods + assert getattr( + InterProcessMessaging.create_worker_copy, "__isabstractmethod__", False + ) + assert getattr( + InterProcessMessaging.send_messages_task, "__isabstractmethod__", False + ) + assert getattr( + InterProcessMessaging.receive_messages_task, "__isabstractmethod__", False + ) + + @pytest.mark.smoke + def test_cannot_instantiate_directly(self): + """Test InterProcessMessaging cannot be instantiated directly.""" + with pytest.raises(TypeError): + InterProcessMessaging() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ( + "on_stop_action", + "pending", + "queue_empty", + "stop_event_set", + "shutdown_event_set", + "expected_result", + "expect_error", + ), + [ + ("ignore", None, False, False, False, False, False), + ("ignore", None, False, True, False, False, False), + ("ignore", None, False, False, True, True, False), + ("ignore", "pending", False, False, True, False, False), + ("stop", None, False, True, False, True, False), + ("stop", None, False, False, True, True, False), + ("stop", "pending", False, True, False, False, False), + ("stop_after_empty", None, True, True, False, True, False), + ("stop_after_empty", None, False, True, False, False, False), + ("stop_after_empty", None, True, False, True, True, False), + ("error", None, False, True, False, None, True), + ("error", None, False, False, True, True, False), + ], + ) + def test_check_on_stop_action( + self, + on_stop_action, + pending, + queue_empty, + stop_event_set, + shutdown_event_set, + expected_result, + expect_error, + ): + """Test InterProcessMessaging check_on_stop_action behavior.""" + # Create a concrete implementation for testing + messaging = InterProcessMessagingQueue(on_stop_action=on_stop_action) + + # Set up events + stop_event = threading.Event() + if stop_event_set: + stop_event.set() + + shutdown_event = threading.Event() + if shutdown_event_set: + shutdown_event.set() + + messaging.shutdown_event = shutdown_event + + # Test the method + if expect_error: + with pytest.raises(RuntimeError): + messaging.check_on_stop_action(pending, queue_empty, [stop_event]) + else: + result = messaging.check_on_stop_action(pending, queue_empty, [stop_event]) + assert result == expected_result + + @pytest.mark.smoke + @pytest.mark.parametrize( + ( + "on_empty_action", + "pending", + "stop_event_set", + "shutdown_event_set", + "expected_result", + "expect_error", + ), + [ + ("ignore", None, False, False, False, False), + ("ignore", None, True, False, False, False), + ("ignore", "pending", True, False, False, False), + ("stop", None, True, False, True, False), + ("stop", None, False, True, True, False), + ("stop", "pending", True, False, False, False), + ("error", None, False, False, None, True), + ], + ) + def test_check_on_queue_empty_action( + self, + on_empty_action, + pending, + stop_event_set, + shutdown_event_set, + expected_result, + expect_error, + ): + """Test InterProcessMessaging check_on_queue_empty_action behavior.""" + messaging = InterProcessMessagingQueue(on_empty_action=on_empty_action) + + # Set up events + stop_event = threading.Event() + if stop_event_set: + stop_event.set() + + shutdown_event = threading.Event() + if shutdown_event_set: + shutdown_event.set() + + messaging.shutdown_event = shutdown_event + + # Test the method + if expect_error: + with pytest.raises(RuntimeError): + messaging.check_on_queue_empty_action(pending) + else: + result = messaging.check_on_queue_empty_action(pending) + assert result == expected_result + + @pytest.mark.smoke + @pytest.mark.parametrize( + ( + "on_full_action", + "pending", + "stop_event_set", + "shutdown_event_set", + "expected_result", + "expect_error", + ), + [ + ("ignore", None, False, False, False, False), + ("ignore", None, True, False, False, False), + ("ignore", "pending", True, False, False, False), + ("stop", None, True, False, True, False), + ("stop", None, False, True, True, False), + ("stop", "pending", True, False, False, False), + ("error", None, False, False, None, True), + ], + ) + def test_check_on_queue_full_action( + self, + on_full_action, + pending, + stop_event_set, + shutdown_event_set, + expected_result, + expect_error, + ): + """Test InterProcessMessaging check_on_queue_full_action behavior.""" + messaging = InterProcessMessagingQueue(on_full_action=on_full_action) + + # Set up events + stop_event = threading.Event() + if stop_event_set: + stop_event.set() + + shutdown_event = threading.Event() + if shutdown_event_set: + shutdown_event.set() + + messaging.shutdown_event = shutdown_event + + # Test the method + if expect_error: + with pytest.raises(RuntimeError): + messaging.check_on_queue_full_action(pending) + else: + result = messaging.check_on_queue_full_action(pending) + assert result == expected_result + + +class TestInterProcessMessagingQueue: + """Test suite for InterProcessMessagingQueue.""" + + @pytest.fixture( + params=[ + { + "serialization": "dict", + "encoding": None, + "max_send_size": None, + "max_receive_size": None, + "worker_index": None, + }, + { + "serialization": "sequence", + "encoding": None, + "max_send_size": 10, + "max_buffer_send_size": 2, + "max_receive_size": 5, + "max_buffer_receive_size": 3, + "worker_index": None, + }, + { + "serialization": None, + "encoding": None, + "max_send_size": None, + "max_receive_size": None, + "worker_index": None, + }, + ], + ) + def valid_instances(self, multiprocessing_contexts, request): + """Fixture providing test data for InterProcessMessagingQueue.""" + constructor_args = request.param + instance = InterProcessMessagingQueue(**constructor_args, poll_interval=0.01) + manager, context = multiprocessing_contexts + + return instance, constructor_args, manager, context + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InterProcessMessagingQueue inheritance and signatures.""" + assert issubclass(InterProcessMessagingQueue, InterProcessMessaging) + assert hasattr(InterProcessMessagingQueue, "__init__") + assert hasattr(InterProcessMessagingQueue, "create_worker_copy") + assert hasattr(InterProcessMessagingQueue, "send_messages_task") + assert hasattr(InterProcessMessagingQueue, "receive_messages_task") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InterProcessMessagingQueue initialization.""" + instance, constructor_args, _, _ = valid_instances + + assert isinstance(instance, InterProcessMessagingQueue) + assert instance.worker_index == constructor_args["worker_index"] + assert instance.max_send_size == constructor_args["max_send_size"] + assert instance.max_receive_size == constructor_args["max_receive_size"] + assert hasattr(instance, "send_queue") + assert hasattr(instance, "done_queue") + assert instance.running is False + + @pytest.mark.smoke + def test_create_worker_copy(self, valid_instances): + """Test InterProcessMessagingQueue.create_worker_copy.""" + instance, _, _, _ = valid_instances + worker_index = 42 + + worker_copy = instance.create_worker_copy(worker_index) + + assert isinstance(worker_copy, InterProcessMessagingQueue) + assert worker_copy.worker_index == worker_index + assert worker_copy.send_queue is instance.send_queue + assert worker_copy.done_queue is instance.done_queue + assert worker_copy.max_send_size == instance.max_send_size + assert worker_copy.max_receive_size == instance.max_receive_size + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "stop_events_lambda", + [ + list, + lambda: [threading.Event()], + lambda: [multiprocessing.Event()], + lambda: [threading.Event(), multiprocessing.Event()], + ], + ) + @async_timeout(5.0) + async def test_start_stop_lifecycle(self, valid_instances, stop_events_lambda): + """Test InterProcessMessagingQueue start/stop lifecycle.""" + instance, _, _, _ = valid_instances + stop_events = stop_events_lambda() + + # Initially not running + assert instance.running is False + assert instance.stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + # Start should work + await instance.start(stop_events=stop_events) + assert instance.running is True + assert instance.stopped_event is not None + assert isinstance(instance.stopped_event, threading.Event) + assert instance.shutdown_event is not None + assert isinstance(instance.shutdown_event, threading.Event) + assert instance.buffer_send_queue is not None + assert isinstance(instance.buffer_send_queue, culsans.Queue) + assert instance.buffer_receive_queue is not None + assert isinstance(instance.buffer_receive_queue, culsans.Queue) + assert instance.send_task is not None + assert isinstance(instance.send_task, asyncio.Task) + assert instance.receive_task is not None + assert isinstance(instance.receive_task, asyncio.Task) + + # Stop should work + if stop_events: + for event in stop_events: + event.set() + + await asyncio.sleep(0.1) + assert instance.stopped_event.is_set() + assert instance.send_task.done() + assert instance.receive_task.done() + + await instance.stop() + assert instance.running is False + assert instance.stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + 123451, + "asdfghjkl", + [None, 123, 45.67, "string", {"key": "value"}, [1, 2, 3]], + {"key": "value", "another_key": 123.456, "yet_another_key": [1, 2, 3]}, + MockMessage(content="hello", num=42), + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo[GenerationRequestTimings](), + ), + ( + GenerationResponse(request_id="id", request_args={}), + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo[GenerationRequestTimings](), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get(self, valid_instances, test_obj): + instance, constructor_args, manager, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + # Handle case where ScheduledRequestInfo is not pickleable + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + process_target = MockProcessTarget( + instance.create_worker_copy(0), num_messages=5 + ) + process = context.Process(target=process_target.run) + process.start() + + # Local startup and wait + await instance.start( + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo[GenerationRequestTimings], + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5): + await instance.put(test_obj, timeout=2.0) + + for _ in range(5): + val = await instance.get(timeout=2.0) + if not isinstance(test_obj, tuple): + assert val == test_obj + else: + assert list(val) == list(test_obj) + finally: + # Clean up + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo[GenerationRequestTimings](), + ), + ( + GenerationResponse(request_id="id", request_args={}), + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo[GenerationRequestTimings](), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get_iter(self, valid_instances, test_obj): + instance, constructor_args, manager, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + # Handle case where ScheduledRequestInfo is not pickleable + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + process_target = MockProcessTarget( + instance.create_worker_copy(0), num_messages=5 + ) + process = context.Process(target=process_target.run) + process.start() + + def _received_callback(msg): + if not isinstance(test_obj, tuple): + assert msg == test_obj + else: + assert list(msg) == list(test_obj) + return "changed_obj" + + # Local startup and wait + await instance.start( + send_items=[test_obj for _ in range(5)], + receive_callback=_received_callback, + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo[GenerationRequestTimings], + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5): + val = await instance.get(timeout=2.0) + assert val == "changed_obj" + finally: + # Clean up + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() + + +class TestInterProcessMessagingManagerQueue: + """Test suite for InterProcessMessagingManagerQueue.""" + + @pytest.fixture( + params=[ + { + "serialization": "dict", + "encoding": None, + "max_send_size": None, + "max_receive_size": None, + "worker_index": None, + }, + { + "serialization": "sequence", + "encoding": None, + "max_send_size": 10, + "max_buffer_send_size": 2, + "max_receive_size": 5, + "max_buffer_receive_size": 3, + "worker_index": None, + }, + { + "serialization": None, + "encoding": None, + "max_send_size": None, + "max_receive_size": None, + "worker_index": None, + }, + ], + ) + def valid_instances(self, multiprocessing_contexts, request): + """Fixture providing test data for InterProcessMessagingManagerQueue.""" + constructor_args = request.param + manager, context = multiprocessing_contexts + instance = InterProcessMessagingManagerQueue( + **constructor_args, manager=manager, poll_interval=0.01 + ) + return instance, constructor_args, manager, context + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InterProcessMessagingManagerQueue inheritance and signatures.""" + assert issubclass(InterProcessMessagingManagerQueue, InterProcessMessaging) + assert issubclass(InterProcessMessagingManagerQueue, InterProcessMessagingQueue) + assert hasattr(InterProcessMessagingManagerQueue, "__init__") + assert hasattr(InterProcessMessagingManagerQueue, "create_worker_copy") + assert hasattr(InterProcessMessagingManagerQueue, "send_messages_task") + assert hasattr(InterProcessMessagingManagerQueue, "receive_messages_task") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InterProcessMessagingManagerQueue initialization.""" + instance, constructor_args, _, _ = valid_instances + + assert isinstance(instance, InterProcessMessagingManagerQueue) + assert instance.worker_index == constructor_args["worker_index"] + assert instance.max_send_size == constructor_args["max_send_size"] + assert instance.max_receive_size == constructor_args["max_receive_size"] + assert hasattr(instance, "send_queue") + assert hasattr(instance, "done_queue") + assert instance.running is False + + @pytest.mark.smoke + def test_create_worker_copy(self, valid_instances): + """Test InterProcessMessagingQueue.create_worker_copy.""" + instance, _, _, _ = valid_instances + worker_index = 42 + + worker_copy = instance.create_worker_copy(worker_index) + + assert isinstance(worker_copy, InterProcessMessagingManagerQueue) + assert worker_copy.worker_index == worker_index + assert worker_copy.send_queue is instance.send_queue + assert worker_copy.done_queue is instance.done_queue + assert worker_copy.max_send_size == instance.max_send_size + assert worker_copy.max_receive_size == instance.max_receive_size + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "stop_events_lambda", + [ + list, + lambda: [threading.Event()], + lambda: [multiprocessing.Event()], + lambda: [threading.Event(), multiprocessing.Event()], + ], + ) + @async_timeout(5.0) + async def test_start_stop_lifecycle(self, valid_instances, stop_events_lambda): + """Test InterProcessMessagingQueue start/stop lifecycle.""" + instance, _, _, _ = valid_instances + stop_events = stop_events_lambda() + + # Initially not running + assert instance.running is False + assert instance.stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + # Start should work + await instance.start(stop_events=stop_events) + assert instance.running is True + assert instance.stopped_event is not None + assert isinstance(instance.stopped_event, threading.Event) + assert instance.shutdown_event is not None + assert isinstance(instance.shutdown_event, threading.Event) + assert instance.buffer_send_queue is not None + assert isinstance(instance.buffer_send_queue, culsans.Queue) + assert instance.buffer_receive_queue is not None + assert isinstance(instance.buffer_receive_queue, culsans.Queue) + assert instance.send_task is not None + assert isinstance(instance.send_task, asyncio.Task) + assert instance.receive_task is not None + assert isinstance(instance.receive_task, asyncio.Task) + + # Stop should work + if stop_events: + for event in stop_events: + event.set() + + await asyncio.sleep(0.1) + assert instance.stopped_event.is_set() + assert instance.send_task.done() + assert instance.receive_task.done() + + await instance.stop() + assert instance.running is False + assert instance.stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + 123451, + "asdfghjkl", + [None, 123, 45.67, "string", {"key": "value"}, [1, 2, 3]], + {"key": "value", "another_key": 123.456, "yet_another_key": [1, 2, 3]}, + MockMessage(content="hello", num=42), + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo[GenerationRequestTimings](), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get(self, valid_instances, test_obj): + instance, constructor_args, _, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + # Handle case where ScheduledRequestInfo is not pickleable + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + process_target = MockProcessTarget( + instance.create_worker_copy(0), num_messages=5 + ) + process = context.Process(target=process_target.run) + process.start() + + # Local startup and wait + await instance.start( + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo[GenerationRequestTimings], + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5): + await instance.put(test_obj, timeout=2.0) + + for _ in range(5): + val = await instance.get(timeout=2.0) + if not isinstance(test_obj, tuple): + assert val == test_obj + else: + assert list(val) == list(test_obj) + finally: + # Clean up + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo[GenerationRequestTimings](), + ), + ( + GenerationResponse(request_id="id", request_args={}), + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo[GenerationRequestTimings](), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get_iter(self, valid_instances, test_obj): + instance, constructor_args, _, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + # Handle case where ScheduledRequestInfo is not pickleable + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + process_target = MockProcessTarget( + instance.create_worker_copy(0), num_messages=5 + ) + process = context.Process(target=process_target.run) + process.start() + + def _received_callback(msg): + if not isinstance(test_obj, tuple): + assert msg == test_obj + else: + assert list(msg) == list(test_obj) + return "changed_obj" + + # Local startup and wait + await instance.start( + send_items=[test_obj for _ in range(5)], + receive_callback=_received_callback, + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo[GenerationRequestTimings], + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5): + val = await instance.get(timeout=2.0) + assert val == "changed_obj" + finally: + # Clean up + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() + + +class TestInterProcessMessagingPipe: + """Test suite for InterProcessMessagingPipe.""" + + @pytest.fixture( + params=[ + { + "num_workers": 2, + "serialization": "dict", + "encoding": None, + "max_send_size": None, + "max_receive_size": None, + "worker_index": None, + }, + { + "num_workers": 1, + "serialization": "sequence", + "encoding": None, + "max_send_size": 10, + "max_buffer_send_size": 2, + "max_receive_size": 5, + "max_buffer_receive_size": 3, + "worker_index": None, + }, + { + "num_workers": 1, + "serialization": None, + "encoding": None, + "max_send_size": None, + "max_receive_size": None, + "worker_index": None, + }, + ], + ) + def valid_instances(self, multiprocessing_contexts, request): + """Fixture providing test data for InterProcessMessagingPipe.""" + constructor_args = request.param + manager, context = multiprocessing_contexts + instance = InterProcessMessagingPipe(**constructor_args, poll_interval=0.01) + return instance, constructor_args, manager, context + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InterProcessMessagingPipe inheritance and signatures.""" + assert issubclass(InterProcessMessagingPipe, InterProcessMessaging) + assert hasattr(InterProcessMessagingPipe, "__init__") + assert hasattr(InterProcessMessagingPipe, "create_worker_copy") + assert hasattr(InterProcessMessagingPipe, "send_messages_task") + assert hasattr(InterProcessMessagingPipe, "receive_messages_task") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InterProcessMessagingPipe initialization.""" + instance, constructor_args, _, _ = valid_instances + + assert isinstance(instance, InterProcessMessagingPipe) + assert instance.worker_index == constructor_args["worker_index"] + assert instance.max_send_size == constructor_args["max_send_size"] + assert instance.max_receive_size == constructor_args["max_receive_size"] + assert instance.num_workers == constructor_args["num_workers"] + assert hasattr(instance, "pipes") + assert len(instance.pipes) == constructor_args["num_workers"] + assert len(instance.pipes) == constructor_args["num_workers"] + assert instance.running is False + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("kwargs", "expected_error"), + [ + ({"invalid_param": "value"}, TypeError), + ({"num_workers": 1, "unknown_arg": "test"}, TypeError), + ], + ) + def test_invalid_initialization_values(self, kwargs, expected_error): + """Test InterProcessMessagingPipe with invalid field values.""" + with pytest.raises(expected_error): + InterProcessMessagingPipe(**kwargs) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test InterProcessMessagingPipe initialization without required field.""" + with pytest.raises(TypeError): + InterProcessMessagingPipe() + + @pytest.mark.smoke + def test_create_worker_copy(self, valid_instances): + """Test InterProcessMessagingPipe.create_worker_copy.""" + instance, _, _, _ = valid_instances + worker_index = 0 + + worker_copy = instance.create_worker_copy(worker_index) + + assert isinstance(worker_copy, InterProcessMessagingPipe) + assert worker_copy.worker_index == worker_index + assert worker_copy.pipes[0] is instance.pipes[worker_index] + assert worker_copy.max_send_size == instance.max_send_size + assert worker_copy.max_receive_size == instance.max_receive_size + assert worker_copy.num_workers == instance.num_workers + + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(5.0) + async def test_start_stop_lifecycle(self, valid_instances): + """Test InterProcessMessagingPipe start/stop lifecycle.""" + instance, _, _, _ = valid_instances + stop_events = [] + + # Initially not running + assert instance.running is False + assert instance.stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + # Start should work + await instance.start(stop_events=stop_events) + assert instance.running is True + assert instance.stopped_event is not None + assert isinstance(instance.stopped_event, threading.Event) + assert instance.shutdown_event is not None + assert isinstance(instance.shutdown_event, threading.Event) + assert instance.buffer_send_queue is not None + assert isinstance(instance.buffer_send_queue, culsans.Queue) + assert instance.buffer_receive_queue is not None + assert isinstance(instance.buffer_receive_queue, culsans.Queue) + assert instance.send_task is not None + assert isinstance(instance.send_task, asyncio.Task) + assert instance.receive_task is not None + assert isinstance(instance.receive_task, asyncio.Task) + + # Stop should work + await instance.stop() + assert instance.running is False + assert instance.stopped_event is None + assert instance.shutdown_event is None + assert instance.buffer_send_queue is None + assert instance.buffer_receive_queue is None + assert instance.send_task is None + assert instance.receive_task is None + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + "test_obj", + [ + 123451, + "asdfghjkl", + [None, 123, 45.67, "string", {"key": "value"}, [1, 2, 3]], + {"key": "value", "another_key": 123.456, "yet_another_key": [1, 2, 3]}, + MockMessage(content="hello", num=42), + ( + None, + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo[GenerationRequestTimings](), + ), + ( + GenerationResponse(request_id="id", request_args={}), + GenerationRequest(content="asdfkj;"), + ScheduledRequestInfo[GenerationRequestTimings](), + ), + ], + ) + @async_timeout(10.0) + async def test_lifecycle_put_get(self, valid_instances, test_obj): + instance, constructor_args, manager, context = valid_instances + + if ( + ( + isinstance(test_obj, ScheduledRequestInfo) + or ( + isinstance(test_obj, tuple) + and any(isinstance(item, ScheduledRequestInfo) for item in test_obj) + ) + ) + and constructor_args["serialization"] is None + and constructor_args["encoding"] is None + ): + pytest.skip("ScheduledRequestInfo is not pickleable") + + # Worker setup + processes = [] + for index in range(constructor_args["num_workers"]): + process_target = MockProcessTarget( + instance.create_worker_copy(index), num_messages=5 + ) + process = context.Process(target=process_target.run) + processes.append(process) + process.start() + + # Local startup and wait + await instance.start( + pydantic_models=[ + MockMessage, + GenerationRequest, + GenerationResponse, + ScheduledRequestInfo[GenerationRequestTimings], + ], + ) + await asyncio.sleep(0.1) + + try: + for _ in range(5 * constructor_args["num_workers"]): + await instance.put(test_obj, timeout=2.0) + + for _ in range(5 * constructor_args["num_workers"]): + val = await instance.get(timeout=2.0) + if not isinstance(test_obj, tuple): + assert val == test_obj + else: + assert list(val) == list(test_obj) + finally: + # Clean up + for process in processes: + process.join(timeout=2.0) + if process.is_alive(): + process.terminate() + process.join(timeout=1.0) + + await instance.stop() diff --git a/tests/unit/utils/test_statistics.py b/tests/unit/utils/test_statistics.py index fa8cccd0..855bfa5f 100644 --- a/tests/unit/utils/test_statistics.py +++ b/tests/unit/utils/test_statistics.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from guidellm.objects import ( +from guidellm.utils import ( DistributionSummary, Percentiles, RunningStats, diff --git a/tests/unit/utils/test_text.py b/tests/unit/utils/test_text.py index 2f363c46..3774ca1f 100644 --- a/tests/unit/utils/test_text.py +++ b/tests/unit/utils/test_text.py @@ -373,7 +373,7 @@ def test_url_error(self, mock_client): class TestIsPunctuation: - """Test suite for is_puncutation.""" + """Test suite for is_punctuation.""" @pytest.mark.smoke @pytest.mark.parametrize(