diff --git a/pyproject.toml b/pyproject.toml index 9fdc70ad..17380312 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ keywords = [ "inference", "language-models", "large-language-model", + "load-generation", "llm", "machine-learning", "model-benchmark", diff --git a/src/guidellm/backend/response.py b/src/guidellm/backend/response.py index ee2101d7..bfa738d8 100644 --- a/src/guidellm/backend/response.py +++ b/src/guidellm/backend/response.py @@ -3,7 +3,7 @@ from pydantic import computed_field from guidellm.config import settings -from guidellm.objects.pydantic import StandardBaseModel +from guidellm.utils import StandardBaseModel __all__ = [ "RequestArgs", diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py index af7f1a13..450b536a 100644 --- a/src/guidellm/benchmark/aggregator.py +++ b/src/guidellm/benchmark/aggregator.py @@ -22,12 +22,6 @@ GenerativeTextResponseStats, ) from guidellm.config import settings -from guidellm.objects import ( - RunningStats, - StandardBaseModel, - StatusBreakdown, - TimeRunningStats, -) from guidellm.request import ( GenerationRequest, GenerativeRequestLoaderDescription, @@ -40,7 +34,13 @@ SchedulerRequestResult, WorkerDescription, ) -from guidellm.utils import check_load_processor +from guidellm.utils import ( + RunningStats, + StandardBaseModel, + StatusBreakdown, + TimeRunningStats, + check_load_processor, +) __all__ = [ "AggregatorT", diff --git a/src/guidellm/benchmark/benchmark.py b/src/guidellm/benchmark/benchmark.py index 02eea02b..eadcf984 100644 --- a/src/guidellm/benchmark/benchmark.py +++ b/src/guidellm/benchmark/benchmark.py @@ -12,11 +12,6 @@ SynchronousProfile, ThroughputProfile, ) -from guidellm.objects import ( - StandardBaseModel, - StatusBreakdown, - StatusDistributionSummary, -) from guidellm.request import ( GenerativeRequestLoaderDescription, RequestLoaderDescription, @@ -32,6 +27,11 @@ ThroughputStrategy, WorkerDescription, ) +from guidellm.utils import ( + StandardBaseModel, + StatusBreakdown, + StatusDistributionSummary, +) __all__ = [ "Benchmark", diff --git a/src/guidellm/benchmark/benchmarker.py b/src/guidellm/benchmark/benchmarker.py index 11b6d245..876e6f43 100644 --- a/src/guidellm/benchmark/benchmarker.py +++ b/src/guidellm/benchmark/benchmarker.py @@ -22,7 +22,6 @@ ) from guidellm.benchmark.benchmark import BenchmarkArgs, GenerativeBenchmark from guidellm.benchmark.profile import Profile -from guidellm.objects import StandardBaseModel from guidellm.request import ( GenerationRequest, GenerativeRequestLoaderDescription, @@ -37,6 +36,7 @@ SchedulerRequestResult, SchedulingStrategy, ) +from guidellm.utils import StandardBaseModel __all__ = ["Benchmarker", "BenchmarkerResult", "GenerativeBenchmarker"] diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py index 8a113f72..225ed2b1 100644 --- a/src/guidellm/benchmark/output.py +++ b/src/guidellm/benchmark/output.py @@ -21,15 +21,16 @@ ThroughputProfile, ) from guidellm.config import settings -from guidellm.objects import ( +from guidellm.presentation import UIDataBuilder +from guidellm.presentation.injector import create_report +from guidellm.scheduler import strategy_display_str +from guidellm.utils import ( + Colors, DistributionSummary, StandardBaseModel, StatusDistributionSummary, + split_text_list_by_length, ) -from guidellm.presentation import UIDataBuilder -from guidellm.presentation.injector import create_report -from guidellm.scheduler import strategy_display_str -from guidellm.utils import Colors, split_text_list_by_length __all__ = [ "GenerativeBenchmarksConsole", diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py index 642cb7a8..d46f2b16 100644 --- a/src/guidellm/benchmark/profile.py +++ b/src/guidellm/benchmark/profile.py @@ -5,7 +5,6 @@ from pydantic import Field, computed_field from guidellm.config import settings -from guidellm.objects import StandardBaseModel from guidellm.scheduler import ( AsyncConstantStrategy, AsyncPoissonStrategy, @@ -15,6 +14,7 @@ SynchronousStrategy, ThroughputStrategy, ) +from guidellm.utils import StandardBaseModel __all__ = [ "AsyncProfile", diff --git a/src/guidellm/benchmark/scenario.py b/src/guidellm/benchmark/scenario.py index af43e426..57dfa98b 100644 --- a/src/guidellm/benchmark/scenario.py +++ b/src/guidellm/benchmark/scenario.py @@ -11,8 +11,8 @@ from guidellm.backend.backend import BackendType from guidellm.benchmark.profile import ProfileType -from guidellm.objects.pydantic import StandardBaseModel from guidellm.scheduler.strategy import StrategyType +from guidellm.utils import StandardBaseModel __ALL__ = ["Scenario", "GenerativeTextScenario", "get_builtin_scenarios"] diff --git a/src/guidellm/objects/__init__.py b/src/guidellm/objects/__init__.py deleted file mode 100644 index 89e3c9b9..00000000 --- a/src/guidellm/objects/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from .pydantic import StandardBaseModel, StatusBreakdown -from .statistics import ( - DistributionSummary, - Percentiles, - RunningStats, - StatusDistributionSummary, - TimeRunningStats, -) - -__all__ = [ - "DistributionSummary", - "Percentiles", - "RunningStats", - "StandardBaseModel", - "StatusBreakdown", - "StatusDistributionSummary", - "TimeRunningStats", -] diff --git a/src/guidellm/objects/pydantic.py b/src/guidellm/objects/pydantic.py deleted file mode 100644 index fcededcf..00000000 --- a/src/guidellm/objects/pydantic.py +++ /dev/null @@ -1,89 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Generic, Optional, TypeVar - -import yaml -from loguru import logger -from pydantic import BaseModel, ConfigDict, Field - -__all__ = ["StandardBaseModel", "StatusBreakdown"] - -T = TypeVar("T", bound="StandardBaseModel") - - -class StandardBaseModel(BaseModel): - """ - A base class for Pydantic models throughout GuideLLM enabling standard - configuration and logging. - """ - - model_config = ConfigDict( - extra="ignore", - use_enum_values=True, - validate_assignment=True, - from_attributes=True, - ) - - def __init__(self, /, **data: Any) -> None: - super().__init__(**data) - logger.debug( - "Initialized new instance of {} with data: {}", - self.__class__.__name__, - data, - ) - - @classmethod - def get_default(cls: type[T], field: str) -> Any: - """Get default values for model fields""" - return cls.model_fields[field].default - - @classmethod - def from_file(cls: type[T], filename: Path, overrides: Optional[dict] = None) -> T: - """ - Attempt to create a new instance of the model using - data loaded from json or yaml file. - """ - try: - with filename.open() as f: - if str(filename).endswith(".json"): - data = json.load(f) - else: # Assume everything else is yaml - data = yaml.safe_load(f) - except (json.JSONDecodeError, yaml.YAMLError) as e: - logger.error(f"Failed to parse {filename} as type {cls.__name__}") - raise ValueError(f"Error when parsing file: {filename}") from e - - data.update(overrides) - return cls.model_validate(data) - - -SuccessfulT = TypeVar("SuccessfulT") -ErroredT = TypeVar("ErroredT") -IncompleteT = TypeVar("IncompleteT") -TotalT = TypeVar("TotalT") - - -class StatusBreakdown(BaseModel, Generic[SuccessfulT, ErroredT, IncompleteT, TotalT]): - """ - A base class for Pydantic models that are separated by statuses including - successful, incomplete, and errored. It additionally enables the inclusion - of total, which is intended as the combination of all statuses. - Total may or may not be used depending on if it duplicates information. - """ - - successful: SuccessfulT = Field( - description="The results with a successful status.", - default=None, # type: ignore[assignment] - ) - errored: ErroredT = Field( - description="The results with an errored status.", - default=None, # type: ignore[assignment] - ) - incomplete: IncompleteT = Field( - description="The results with an incomplete status.", - default=None, # type: ignore[assignment] - ) - total: TotalT = Field( - description="The combination of all statuses.", - default=None, # type: ignore[assignment] - ) diff --git a/src/guidellm/presentation/data_models.py b/src/guidellm/presentation/data_models.py index ff5221e3..3164dc86 100644 --- a/src/guidellm/presentation/data_models.py +++ b/src/guidellm/presentation/data_models.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from guidellm.benchmark.benchmark import GenerativeBenchmark -from guidellm.objects.statistics import DistributionSummary +from guidellm.utils.statistics import DistributionSummary class Bucket(BaseModel): diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index 50ab3cca..2eff87d5 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -13,8 +13,8 @@ from guidellm.config import settings from guidellm.dataset import ColumnInputTypes, load_dataset -from guidellm.objects import StandardBaseModel from guidellm.request.request import GenerationRequest +from guidellm.utils import StandardBaseModel __all__ = [ "GenerativeRequestLoader", diff --git a/src/guidellm/request/request.py b/src/guidellm/request/request.py index 81c8cabd..bf4e59fb 100644 --- a/src/guidellm/request/request.py +++ b/src/guidellm/request/request.py @@ -3,7 +3,7 @@ from pydantic import Field -from guidellm.objects.pydantic import StandardBaseModel +from guidellm.utils import StandardBaseModel __all__ = ["GenerationRequest"] diff --git a/src/guidellm/scheduler/result.py b/src/guidellm/scheduler/result.py index 0f12687f..0cca530b 100644 --- a/src/guidellm/scheduler/result.py +++ b/src/guidellm/scheduler/result.py @@ -4,9 +4,9 @@ Optional, ) -from guidellm.objects import StandardBaseModel from guidellm.scheduler.strategy import SchedulingStrategy from guidellm.scheduler.types import RequestT, ResponseT +from guidellm.utils import StandardBaseModel __all__ = [ "SchedulerRequestInfo", diff --git a/src/guidellm/scheduler/strategy.py b/src/guidellm/scheduler/strategy.py index 200c799e..d4c065da 100644 --- a/src/guidellm/scheduler/strategy.py +++ b/src/guidellm/scheduler/strategy.py @@ -12,7 +12,7 @@ from pydantic import Field from guidellm.config import settings -from guidellm.objects import StandardBaseModel +from guidellm.utils import StandardBaseModel __all__ = [ "AsyncConstantStrategy", diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index a53b14c2..ab16e4db 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -24,10 +24,10 @@ ResponseSummary, StreamingTextResponse, ) -from guidellm.objects import StandardBaseModel from guidellm.request import GenerationRequest from guidellm.scheduler.result import SchedulerRequestInfo from guidellm.scheduler.types import RequestT, ResponseT +from guidellm.utils import StandardBaseModel __all__ = [ "GenerativeRequestsWorker", diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index fb9262c3..576fe64d 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -1,5 +1,14 @@ +from .auto_importer import AutoImporterMixin from .colors import Colors from .default_group import DefaultGroupHandler +from .functions import ( + all_defined, + safe_add, + safe_divide, + safe_format_timestamp, + safe_getattr, + safe_multiply, +) from .hf_datasets import ( SUPPORTED_TYPES, save_dataset_to_file, @@ -7,12 +16,28 @@ from .hf_transformers import ( check_load_processor, ) +from .pydantic_utils import ( + PydanticClassRegistryMixin, + ReloadableBaseModel, + StandardBaseDict, + StandardBaseModel, + StatusBreakdown, +) from .random import IntegerRangeSampler +from .registry import RegistryMixin +from .singleton import SingletonMixin, ThreadSafeSingletonMixin +from .statistics import ( + DistributionSummary, + Percentiles, + RunningStats, + StatusDistributionSummary, + TimeRunningStats, +) from .text import ( EndlessTextCreator, clean_text, filter_text, - is_puncutation, + is_punctuation, load_text, split_text, split_text_list_by_length, @@ -20,15 +45,35 @@ __all__ = [ "SUPPORTED_TYPES", + "AutoImporterMixin", "Colors", "DefaultGroupHandler", + "DistributionSummary", "EndlessTextCreator", "IntegerRangeSampler", + "Percentiles", + "PydanticClassRegistryMixin", + "RegistryMixin", + "ReloadableBaseModel", + "RunningStats", + "SingletonMixin", + "StandardBaseDict", + "StandardBaseModel", + "StatusBreakdown", + "StatusDistributionSummary", + "ThreadSafeSingletonMixin", + "TimeRunningStats", + "all_defined", "check_load_processor", "clean_text", "filter_text", - "is_puncutation", + "is_punctuation", "load_text", + "safe_add", + "safe_divide", + "safe_format_timestamp", + "safe_getattr", + "safe_multiply", "save_dataset_to_file", "split_text", "split_text_list_by_length", diff --git a/src/guidellm/utils/auto_importer.py b/src/guidellm/utils/auto_importer.py new file mode 100644 index 00000000..5b939014 --- /dev/null +++ b/src/guidellm/utils/auto_importer.py @@ -0,0 +1,98 @@ +""" +Automatic module importing utilities for dynamic class discovery. + +This module provides a mixin class for automatic module importing within a package, +enabling dynamic discovery of classes and implementations without explicit imports. +It is particularly useful for auto-registering classes in a registry pattern where +subclasses need to be discoverable at runtime. + +The AutoImporterMixin can be combined with registration mechanisms to create +extensible systems where new implementations are automatically discovered and +registered when they are placed in the correct package structure. +""" + +from __future__ import annotations + +import importlib +import pkgutil +import sys +from typing import ClassVar + +__all__ = ["AutoImporterMixin"] + + +class AutoImporterMixin: + """ + Mixin class for automatic module importing within packages. + + This mixin enables dynamic discovery of classes and implementations without + explicit imports by automatically importing all modules within specified + packages. It is designed for use with class registration mechanisms to enable + automatic discovery and registration of classes when they are placed in the + correct package structure. + + Example: + :: + from guidellm.utils import AutoImporterMixin + + class MyRegistry(AutoImporterMixin): + auto_package = "my_package.implementations" + + MyRegistry.auto_import_package_modules() + + :cvar auto_package: Package name or tuple of package names to import modules from + :cvar auto_ignore_modules: Module names to ignore during import + :cvar auto_imported_modules: List tracking which modules have been imported + """ + + auto_package: ClassVar[str | tuple[str, ...] | None] = None + auto_ignore_modules: ClassVar[tuple[str, ...] | None] = None + auto_imported_modules: ClassVar[list[str] | None] = None + + @classmethod + def auto_import_package_modules(cls) -> None: + """ + Automatically import all modules within the specified package(s). + + Scans the package(s) defined in the `auto_package` class variable and imports + all modules found, tracking them in `auto_imported_modules`. Skips packages + (directories) and any modules listed in `auto_ignore_modules`. + + :raises ValueError: If the `auto_package` class variable is not set + """ + if cls.auto_package is None: + raise ValueError( + "The class variable 'auto_package' must be set to the package name to " + "import modules from." + ) + + cls.auto_imported_modules = [] + packages = ( + cls.auto_package + if isinstance(cls.auto_package, tuple) + else (cls.auto_package,) + ) + + for package_name in packages: + package = importlib.import_module(package_name) + + for _, module_name, is_pkg in pkgutil.walk_packages( + package.__path__, package.__name__ + "." + ): + if ( + is_pkg + or ( + cls.auto_ignore_modules is not None + and module_name in cls.auto_ignore_modules + ) + or module_name in cls.auto_imported_modules + ): + # Skip packages and ignored modules + continue + + if module_name in sys.modules: + # Avoid circular imports + cls.auto_imported_modules.append(module_name) + else: + importlib.import_module(module_name) + cls.auto_imported_modules.append(module_name) diff --git a/src/guidellm/utils/functions.py b/src/guidellm/utils/functions.py new file mode 100644 index 00000000..b28aa21e --- /dev/null +++ b/src/guidellm/utils/functions.py @@ -0,0 +1,130 @@ +""" +Utility functions for safe operations and value handling. + +Provides defensive programming utilities for common operations that may encounter +None values, invalid inputs, or edge cases. Includes safe arithmetic operations, +attribute access, and timestamp formatting. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +__all__ = [ + "all_defined", + "safe_add", + "safe_divide", + "safe_format_timestamp", + "safe_getattr", + "safe_multiply", +] + + +def safe_getattr(obj: Any | None, attr: str, default: Any = None) -> Any: + """ + Safely get an attribute from an object with None handling. + + :param obj: Object to get the attribute from, or None + :param attr: Name of the attribute to retrieve + :param default: Value to return if object is None or attribute doesn't exist + :return: Attribute value or default if not found or object is None + """ + if obj is None: + return default + + return getattr(obj, attr, default) + + +def all_defined(*values: Any | None) -> bool: + """ + Check if all provided values are defined (not None). + + :param values: Variable number of values to check for None + :return: True if all values are not None, False otherwise + """ + return all(value is not None for value in values) + + +def safe_divide( + numerator: int | float | None, + denominator: int | float | None, + num_default: float = 0.0, + den_default: float = 1.0, +) -> float: + """ + Safely divide two numbers with None handling and zero protection. + + :param numerator: Number to divide, or None to use num_default + :param denominator: Number to divide by, or None to use den_default + :param num_default: Default value for numerator if None + :param den_default: Default value for denominator if None + :return: Division result with protection against division by zero + """ + numerator = numerator if numerator is not None else num_default + denominator = denominator if denominator is not None else den_default + + return numerator / (denominator or 1e-10) + + +def safe_multiply(*values: int | float | None, default: float = 1.0) -> float: + """ + Safely multiply multiple numbers with None handling. + + :param values: Variable number of values to multiply, None values treated as 1.0 + :param default: Starting value for multiplication + :return: Product of all non-None values multiplied by default + """ + result = default + for val in values: + result *= val if val is not None else 1.0 + return result + + +def safe_add( + *values: int | float | None, signs: list[int] | None = None, default: float = 0.0 +) -> float: + """ + Safely add multiple numbers with None handling and optional signs. + + :param values: Variable number of values to add, None values use default + :param signs: Optional list of 1 (add) or -1 (subtract) for each value. + If None, all values are added. Must match length of values. + :param default: Value to substitute for None values + :return: Result of adding all values safely (default used when value is None) + """ + if not values: + return default + + values = list(values) + + if signs is None: + signs = [1] * len(values) + + if len(signs) != len(values): + raise ValueError("Length of signs must match length of values") + + result = values[0] if values[0] is not None else default + + for ind in range(1, len(values)): + val = values[ind] if values[ind] is not None else default + result += signs[ind] * val + + return result + + +def safe_format_timestamp( + timestamp: float | None, format_: str = "%H:%M:%S", default: str = "N/A" +) -> str: + """ + Safely format a timestamp with error handling and validation. + + :param timestamp: Unix timestamp to format, or None + :param format_: Strftime format string for timestamp formatting + :param default: Value to return if timestamp is invalid or None + :return: Formatted timestamp string or default value + """ + try: + return datetime.fromtimestamp(timestamp).strftime(format_) + except (ValueError, TypeError, OverflowError, OSError): + return default diff --git a/src/guidellm/utils/mixins.py b/src/guidellm/utils/mixins.py new file mode 100644 index 00000000..b001ff2d --- /dev/null +++ b/src/guidellm/utils/mixins.py @@ -0,0 +1,115 @@ +""" +Mixin classes for common metadata extraction and object introspection. + +Provides reusable mixins for extracting structured metadata from objects, +enabling consistent information exposure across different class hierarchies. +""" + +from __future__ import annotations + +from typing import Any + +__all__ = ["InfoMixin"] + + +PYTHON_PRIMITIVES = (str, int, float, bool, list, tuple, dict) +"""Type alias for serialized object representations""" + + +class InfoMixin: + """ + Mixin class providing standardized metadata extraction for introspection. + + Enables consistent object metadata extraction patterns across different + class hierarchies for debugging, serialization, and runtime analysis. + Provides both instance and class-level methods for extracting structured + information from arbitrary objects with fallback handling for objects + without built-in info capabilities. + + Example: + :: + from guidellm.utils.mixins import InfoMixin + + class ConfiguredClass(InfoMixin): + def __init__(self, setting: str): + self.setting = setting + + obj = ConfiguredClass("value") + # Returns {'str': 'ConfiguredClass(...)', 'type': 'ConfiguredClass', ...} + print(obj.info) + """ + + @classmethod + def extract_from_obj(cls, obj: Any) -> dict[str, Any]: + """ + Extract structured metadata from any object. + + Attempts to use the object's own `info` method or property if available, + otherwise constructs metadata from object attributes and type information. + Provides consistent metadata format across different object types. + + :param obj: Object to extract metadata from + :return: Dictionary containing object metadata including type, class, + module, and public attributes + """ + if hasattr(obj, "info"): + return obj.info() if callable(obj.info) else obj.info + + return { + "str": str(obj), + "type": type(obj).__name__, + "class": obj.__class__.__name__ if hasattr(obj, "__class__") else None, + "module": obj.__class__.__module__ if hasattr(obj, "__class__") else None, + "attributes": ( + { + key: val if isinstance(val, PYTHON_PRIMITIVES) else repr(val) + for key, val in obj.__dict__.items() + if not key.startswith("_") + } + if hasattr(obj, "__dict__") + else {} + ), + } + + @classmethod + def create_info_dict(cls, obj: Any) -> dict[str, Any]: + """ + Create a structured info dictionary for the given object. + + Builds standardized metadata dictionary containing object identification, + type information, and accessible attributes. Used internally by other + info extraction methods and available for direct metadata construction. + + :param obj: Object to extract info from + :return: Dictionary containing structured metadata about the object + """ + return { + "str": str(obj), + "type": type(obj).__name__, + "class": obj.__class__.__name__ if hasattr(obj, "__class__") else None, + "module": obj.__class__.__module__ if hasattr(obj, "__class__") else None, + "attributes": ( + { + key: val + if isinstance(val, (str, int, float, bool, list, dict)) + else repr(val) + for key, val in obj.__dict__.items() + if not key.startswith("_") + } + if hasattr(obj, "__dict__") + else {} + ), + } + + @property + def info(self) -> dict[str, Any]: + """ + Return structured metadata about this instance. + + Provides consistent access to object metadata for debugging, serialization, + and introspection. Uses the create_info_dict method to generate standardized + metadata format including class information and public attributes. + + :return: Dictionary containing class name, module, and public attributes + """ + return self.create_info_dict(self) diff --git a/src/guidellm/utils/pydantic_utils.py b/src/guidellm/utils/pydantic_utils.py new file mode 100644 index 00000000..52bf6564 --- /dev/null +++ b/src/guidellm/utils/pydantic_utils.py @@ -0,0 +1,302 @@ +""" +Pydantic utilities for polymorphic model serialization and registry integration. + +Provides integration between Pydantic and the registry system, enabling +polymorphic serialization and deserialization of Pydantic models using +a discriminator field and dynamic class registry. Includes base model classes +with standardized configurations and generic status breakdown models for +structured result organization. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Generic, TypeVar + +from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema + +from guidellm.utils.registry import RegistryMixin + +__all__ = [ + "PydanticClassRegistryMixin", + "ReloadableBaseModel", + "StandardBaseDict", + "StandardBaseModel", + "StatusBreakdown", +] + + +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) +SuccessfulT = TypeVar("SuccessfulT") +ErroredT = TypeVar("ErroredT") +IncompleteT = TypeVar("IncompleteT") +TotalT = TypeVar("TotalT") + + +class ReloadableBaseModel(BaseModel): + """ + Base Pydantic model with schema reloading capabilities. + + Provides dynamic schema rebuilding functionality for models that need to + update their validation schemas at runtime, particularly useful when + working with registry-based polymorphic models where new types are + registered after initial class definition. + """ + + model_config = ConfigDict( + extra="ignore", + use_enum_values=True, + validate_assignment=True, + from_attributes=True, + arbitrary_types_allowed=True, + ) + + @classmethod + def reload_schema(cls) -> None: + """ + Reload the class schema with updated registry information. + + Forces a complete rebuild of the Pydantic model schema to incorporate + any changes made to associated registries or validation rules. + """ + cls.model_rebuild(force=True) + + +class StandardBaseModel(BaseModel): + """ + Base Pydantic model with standardized configuration for GuideLLM. + + Provides consistent validation behavior and configuration settings across + all Pydantic models in the application, including field validation, + attribute conversion, and default value handling. + + Example: + :: + class MyModel(StandardBaseModel): + name: str + value: int = 42 + + # Access default values + default_value = MyModel.get_default("value") # Returns 42 + """ + + model_config = ConfigDict( + extra="ignore", + use_enum_values=True, + validate_assignment=True, + from_attributes=True, + ) + + @classmethod + def get_default(cls: type[BaseModelT], field: str) -> Any: + """ + Get default value for a model field. + + :param field: Name of the field to get the default value for + :return: Default value of the specified field + :raises KeyError: If the field does not exist in the model + """ + return cls.model_fields[field].default + + +class StandardBaseDict(StandardBaseModel): + """ + Base Pydantic model allowing arbitrary additional fields. + + Extends StandardBaseModel to accept extra fields beyond those explicitly + defined in the model schema. Useful for flexible data structures that + need to accommodate varying or unknown field sets while maintaining + type safety for known fields. + """ + + model_config = ConfigDict( + extra="allow", + use_enum_values=True, + validate_assignment=True, + from_attributes=True, + arbitrary_types_allowed=True, + ) + + +class StatusBreakdown(BaseModel, Generic[SuccessfulT, ErroredT, IncompleteT, TotalT]): + """ + Generic model for organizing results by processing status. + + Provides structured categorization of results into successful, errored, + incomplete, and total status groups. Supports flexible typing for each + status category to accommodate different result types while maintaining + consistent organization patterns across the application. + + Example: + :: + from guidellm.utils.pydantic_utils import StatusBreakdown + + # Define a breakdown for request counts + breakdown = StatusBreakdown[int, int, int, int]( + successful=150, + errored=5, + incomplete=10, + total=165 + ) + """ + + successful: SuccessfulT = Field( + description="Results or metrics for requests with successful completion status", + default=None, # type: ignore[assignment] + ) + errored: ErroredT = Field( + description="Results or metrics for requests with error completion status", + default=None, # type: ignore[assignment] + ) + incomplete: IncompleteT = Field( + description="Results or metrics for requests with incomplete processing status", + default=None, # type: ignore[assignment] + ) + total: TotalT = Field( + description="Aggregated results or metrics combining all status categories", + default=None, # type: ignore[assignment] + ) + + +class PydanticClassRegistryMixin( + ReloadableBaseModel, RegistryMixin[type[BaseModelT]], ABC, Generic[BaseModelT] +): + """ + Polymorphic Pydantic model mixin enabling registry-based dynamic instantiation. + + Integrates Pydantic validation with the registry system to enable polymorphic + serialization and deserialization based on a discriminator field. Automatically + instantiates the correct subclass during validation based on registry mappings, + providing a foundation for extensible plugin-style architectures. + + Example: + :: + from guidellm.utils.pydantic_utils import PydanticClassRegistryMixin + + class BaseConfig(PydanticClassRegistryMixin["BaseConfig"]): + schema_discriminator: ClassVar[str] = "config_type" + config_type: str = Field(description="Configuration type identifier") + + @classmethod + def __pydantic_schema_base_type__(cls) -> type["BaseConfig"]: + return BaseConfig + + @BaseConfig.register("database") + class DatabaseConfig(BaseConfig): + config_type: str = "database" + connection_string: str = Field(description="Database connection string") + + # Dynamic instantiation based on discriminator + config = BaseConfig.model_validate({ + "config_type": "database", + "connection_string": "postgresql://localhost:5432/db" + }) + + :cvar schema_discriminator: Field name used for polymorphic type discrimination + """ + + schema_discriminator: ClassVar[str] = "model_type" + + @classmethod + def register_decorator( + cls, clazz: type[BaseModelT], name: str | list[str] | None = None + ) -> type[BaseModelT]: + """ + Register a Pydantic model class with type validation and schema reload. + + Validates that the class is a proper Pydantic BaseModel subclass before + registering it in the class registry. Automatically triggers schema + reload to incorporate the new type into polymorphic validation. + + :param clazz: Pydantic model class to register in the polymorphic hierarchy + :param name: Registry identifier for the class. Uses class name if None + :return: The registered class unchanged for decorator chaining + :raises TypeError: If clazz is not a Pydantic BaseModel subclass + """ + if not issubclass(clazz, BaseModel): + raise TypeError( + f"Cannot register {clazz.__name__} as it is not a subclass of " + "Pydantic BaseModel" + ) + + dec_clazz = super().register_decorator(clazz, name=name) + cls.reload_schema() + + return dec_clazz + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """ + Generate polymorphic validation schema for dynamic type instantiation. + + Creates a tagged union schema that enables Pydantic to automatically + instantiate the correct subclass based on the discriminator field value. + Falls back to base schema generation when no registry is available. + + :param source_type: Type being processed for schema generation + :param handler: Pydantic core schema generation handler + :return: Tagged union schema for polymorphic validation or base schema + """ + if source_type == cls.__pydantic_schema_base_type__(): + if not cls.registry: + return cls.__pydantic_generate_base_schema__(handler) + + choices = { + name: handler(model_class) for name, model_class in cls.registry.items() + } + + return core_schema.tagged_union_schema( + choices=choices, + discriminator=cls.schema_discriminator, + ) + + return handler(cls) + + @classmethod + @abstractmethod + def __pydantic_schema_base_type__(cls) -> type[BaseModelT]: + """ + Define the base type for polymorphic validation hierarchy. + + Must be implemented by subclasses to specify which type serves as the + root of the polymorphic hierarchy for schema generation and validation. + + :return: Base class type for the polymorphic model hierarchy + """ + ... + + @classmethod + def __pydantic_generate_base_schema__( + cls, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """ + Generate fallback schema for polymorphic models without registry. + + Provides a base schema that accepts any valid input when no registry + is available for polymorphic validation. Used as fallback during + schema generation when the registry has not been populated. + + :param handler: Pydantic core schema generation handler + :return: Base CoreSchema that accepts any valid input + """ + return core_schema.any_schema() + + @classmethod + def auto_populate_registry(cls) -> bool: + """ + Initialize registry with auto-discovery and reload validation schema. + + Triggers automatic population of the class registry through the parent + RegistryMixin functionality and ensures the Pydantic validation schema + is updated to include all discovered types for polymorphic validation. + + :return: True if registry was populated, False if already populated + :raises ValueError: If called when registry_auto_discovery is disabled + """ + populated = super().auto_populate_registry() + cls.reload_schema() + + return populated diff --git a/src/guidellm/utils/registry.py b/src/guidellm/utils/registry.py new file mode 100644 index 00000000..5d4bc055 --- /dev/null +++ b/src/guidellm/utils/registry.py @@ -0,0 +1,206 @@ +""" +Registry system for dynamic object registration and discovery. + +Provides a flexible object registration system with optional auto-discovery +capabilities through decorators and module imports. Enables dynamic discovery +and instantiation of implementations based on configuration parameters, supporting +both manual registration and automatic package-based discovery for extensible +plugin architectures. +""" + +from __future__ import annotations + +from typing import Any, Callable, ClassVar, Generic, TypeVar + +from guidellm.utils.auto_importer import AutoImporterMixin + +__all__ = ["RegistryMixin", "RegistryObjT"] + + +RegistryObjT = TypeVar("RegistryObjT", bound=Any) +""" +Generic type variable for objects managed by the registry system. +""" + + +class RegistryMixin(Generic[RegistryObjT], AutoImporterMixin): + """ + Generic mixin for creating object registries with optional auto-discovery. + + Enables classes to maintain separate registries of objects that can be + dynamically discovered and instantiated through decorators and module imports. + Supports both manual registration via decorators and automatic discovery + through package scanning for extensible plugin architectures. + + Example: + :: + class BaseAlgorithm(RegistryMixin): + pass + + @BaseAlgorithm.register() + class ConcreteAlgorithm(BaseAlgorithm): + pass + + @BaseAlgorithm.register("custom_name") + class AnotherAlgorithm(BaseAlgorithm): + pass + + # Get all registered implementations + algorithms = BaseAlgorithm.registered_objects() + + Example with auto-discovery: + :: + class TokenProposal(RegistryMixin): + registry_auto_discovery = True + auto_package = "mypackage.proposals" + + # Automatically imports and registers decorated objects + proposals = TokenProposal.registered_objects() + + :cvar registry: Dictionary mapping names to registered objects + :cvar registry_auto_discovery: Enable automatic package-based discovery + :cvar registry_populated: Track whether auto-discovery has completed + """ + + registry: ClassVar[dict[str, RegistryObjT] | None] = None + registry_auto_discovery: ClassVar[bool] = False + registry_populated: ClassVar[bool] = False + + @classmethod + def register( + cls, name: str | list[str] | None = None + ) -> Callable[[RegistryObjT], RegistryObjT]: + """ + Decorator that registers an object with the registry. + + :param name: Optional name(s) to register the object under. + If None, the object name is used as the registry key. + :return: A decorator function that registers the decorated object. + :raises ValueError: If name is provided but is not a string or list of strings. + """ + if name is not None and not isinstance(name, (str, list)): + raise ValueError( + "RegistryMixin.register() name must be a string, list of strings, " + f"or None. Got {name}." + ) + + return lambda obj: cls.register_decorator(obj, name=name) + + @classmethod + def register_decorator( + cls, obj: RegistryObjT, name: str | list[str] | None = None + ) -> RegistryObjT: + """ + Direct decorator that registers an object with the registry. + + :param obj: The object to register. + :param name: Optional name(s) to register the object under. + If None, the object name is used as the registry key. + :return: The registered object. + :raises ValueError: If the object is already registered or if name is invalid. + """ + + if not name: + name = obj.__name__ + elif not isinstance(name, (str, list)): + raise ValueError( + "RegistryMixin.register_decorator name must be a string or " + f"an iterable of strings. Got {name}." + ) + + if cls.registry is None: + cls.registry = {} + + names = [name] if isinstance(name, str) else list(name) + + for register_name in names: + if not isinstance(register_name, str): + raise ValueError( + "RegistryMixin.register_decorator name must be a string or " + f"a list of strings. Got {register_name}." + ) + + if register_name in cls.registry: + raise ValueError( + f"RegistryMixin.register_decorator cannot register an object " + f"{obj} with the name {register_name} because it is already " + "registered." + ) + + cls.registry[register_name.lower()] = obj + + return obj + + @classmethod + def auto_populate_registry(cls) -> bool: + """ + Import and register all modules from the specified auto_package. + + Automatically called by registered_objects when registry_auto_discovery is True + to ensure all available implementations are discovered before returning results. + + :return: True if the registry was populated, False if already populated. + :raises ValueError: If called when registry_auto_discovery is False. + """ + if not cls.registry_auto_discovery: + raise ValueError( + "RegistryMixin.auto_populate_registry() cannot be called " + "because registry_auto_discovery is set to False. " + "Set registry_auto_discovery to True to enable auto-discovery." + ) + + if cls.registry_populated: + return False + + cls.auto_import_package_modules() + cls.registry_populated = True + + return True + + @classmethod + def registered_objects(cls) -> tuple[RegistryObjT, ...]: + """ + Get all registered objects from the registry. + + Automatically triggers auto-discovery if registry_auto_discovery is enabled + to ensure all available implementations are included. + + :return: Tuple of all registered objects including auto-discovered ones. + :raises ValueError: If called before any objects have been registered. + """ + if cls.registry_auto_discovery: + cls.auto_populate_registry() + + if cls.registry is None: + raise ValueError( + "RegistryMixin.registered_objects() must be called after " + "registering objects with RegistryMixin.register()." + ) + + return tuple(cls.registry.values()) + + @classmethod + def is_registered(cls, name: str) -> bool: + """ + Check if an object is registered under the given name. + + :param name: The name to check for registration. + :return: True if the object is registered, False otherwise. + """ + if cls.registry is None: + return False + + return name.lower() in cls.registry + + @classmethod + def get_registered_object(cls, name: str) -> RegistryObjT | None: + """ + Get a registered object by its name. + + :param name: The name of the registered object. + :return: The registered object if found, None otherwise. + """ + if cls.registry is None: + return None + + return cls.registry.get(name.lower()) diff --git a/src/guidellm/utils/singleton.py b/src/guidellm/utils/singleton.py new file mode 100644 index 00000000..3ec10f79 --- /dev/null +++ b/src/guidellm/utils/singleton.py @@ -0,0 +1,130 @@ +""" +Singleton pattern implementations for ensuring single instance classes. + +Provides singleton mixins for creating classes that maintain a single instance +throughout the application lifecycle, with support for both basic and thread-safe +implementations. These mixins integrate with the scheduler and other system components +to ensure consistent state management and prevent duplicate resource allocation. +""" + +from __future__ import annotations + +import threading + +__all__ = ["SingletonMixin", "ThreadSafeSingletonMixin"] + + +class SingletonMixin: + """ + Basic singleton mixin ensuring single instance per class. + + Implements the singleton pattern using class variables to control instance + creation. Subclasses must call super().__init__() for proper initialization + state management. Suitable for single-threaded environments or when external + synchronization is provided. + + Example: + :: + class ConfigManager(SingletonMixin): + def __init__(self, config_path: str): + super().__init__() + if not self.initialized: + self.config = load_config(config_path) + + manager1 = ConfigManager("config.json") + manager2 = ConfigManager("config.json") + assert manager1 is manager2 + """ + + def __new__(cls, *args, **kwargs): # noqa: ARG004 + """ + Create or return the singleton instance. + + :param args: Positional arguments passed to the constructor + :param kwargs: Keyword arguments passed to the constructor + :return: The singleton instance of the class + """ + # Use class-specific attribute name to avoid inheritance issues + attr_name = f"_singleton_instance_{cls.__name__}" + + if not hasattr(cls, attr_name) or getattr(cls, attr_name) is None: + instance = super().__new__(cls) + setattr(cls, attr_name, instance) + instance._singleton_initialized = False + return getattr(cls, attr_name) + + def __init__(self): + """Initialize the singleton instance exactly once.""" + if hasattr(self, "_singleton_initialized") and self._singleton_initialized: + return + self._singleton_initialized = True + + @property + def initialized(self): + """Return True if the singleton has been initialized.""" + return getattr(self, "_singleton_initialized", False) + + +class ThreadSafeSingletonMixin(SingletonMixin): + """ + Thread-safe singleton mixin with locking mechanisms. + + Extends SingletonMixin with thread safety using locks to prevent race + conditions during instance creation in multi-threaded environments. Essential + for scheduler components and other shared resources accessed concurrently. + + Example: + :: + class SchedulerResource(ThreadSafeSingletonMixin): + def __init__(self): + super().__init__() + if not self.initialized: + self.resource_pool = initialize_resources() + """ + + def __new__(cls, *args, **kwargs): # noqa: ARG004 + """ + Create or return the singleton instance with thread safety. + + :param args: Positional arguments passed to the constructor + :param kwargs: Keyword arguments passed to the constructor + :return: The singleton instance of the class + """ + # Use class-specific lock and instance names to avoid inheritance issues + lock_attr_name = f"_singleton_lock_{cls.__name__}" + instance_attr_name = f"_singleton_instance_{cls.__name__}" + + with getattr(cls, lock_attr_name): + instance_exists = ( + hasattr(cls, instance_attr_name) + and getattr(cls, instance_attr_name) is not None + ) + if not instance_exists: + instance = super(SingletonMixin, cls).__new__(cls) + setattr(cls, instance_attr_name, instance) + instance._singleton_initialized = False + instance._init_lock = threading.Lock() + return getattr(cls, instance_attr_name) + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + lock_attr_name = f"_singleton_lock_{cls.__name__}" + setattr(cls, lock_attr_name, threading.Lock()) + + def __init__(self): + """Initialize the singleton instance with thread-safe initialization.""" + with self._init_lock: + if hasattr(self, "_singleton_initialized") and self._singleton_initialized: + return + self._singleton_initialized = True + + @property + def thread_lock(self): + """Return the thread lock for this singleton instance.""" + return getattr(self, "_init_lock", None) + + @classmethod + def get_singleton_lock(cls): + """Get the class-specific singleton creation lock.""" + lock_attr_name = f"_singleton_lock_{cls.__name__}" + return getattr(cls, lock_attr_name, None) diff --git a/src/guidellm/objects/statistics.py b/src/guidellm/utils/statistics.py similarity index 99% rename from src/guidellm/objects/statistics.py rename to src/guidellm/utils/statistics.py index 8ba504be..669aef6d 100644 --- a/src/guidellm/objects/statistics.py +++ b/src/guidellm/utils/statistics.py @@ -6,7 +6,7 @@ import numpy as np from pydantic import Field, computed_field -from guidellm.objects.pydantic import StandardBaseModel, StatusBreakdown +from guidellm.utils import StandardBaseModel, StatusBreakdown __all__ = [ "DistributionSummary", diff --git a/src/guidellm/utils/text.py b/src/guidellm/utils/text.py index cdefaa14..beebfe37 100644 --- a/src/guidellm/utils/text.py +++ b/src/guidellm/utils/text.py @@ -1,9 +1,21 @@ +""" +Text processing utilities for content manipulation and formatting operations. + +Provides comprehensive text processing capabilities including cleaning, filtering, +splitting, loading from various sources, and formatting utilities. Supports loading +text from URLs, compressed files, package resources, and local files with automatic +encoding detection. Includes specialized formatting for display values and text +wrapping operations for consistent presentation across the system. +""" + +from __future__ import annotations + import gzip import re import textwrap from importlib.resources import as_file, files # type: ignore[attr-defined] from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import ftfy import httpx @@ -11,35 +23,86 @@ from guidellm import data as package_data from guidellm.config import settings +from guidellm.utils.colors import Colors __all__ = [ + "MAX_PATH_LENGTH", "EndlessTextCreator", "clean_text", "filter_text", - "is_puncutation", + "format_value_display", + "is_punctuation", "load_text", "split_text", "split_text_list_by_length", ] -MAX_PATH_LENGTH = 4096 +MAX_PATH_LENGTH: int = 4096 + + +def format_value_display( + value: float, + label: str, + units: str = "", + total_characters: int | None = None, + digits_places: int | None = None, + decimal_places: int | None = None, +) -> str: + """ + Format a numeric value with units and label for consistent display output. + + Creates standardized display strings for metrics and measurements with + configurable precision, width, and color formatting. Supports both + fixed-width and variable-width output for tabular displays. + + :param value: Numeric value to format and display + :param label: Descriptive label for the value + :param units: Units string to append after the value + :param total_characters: Total width for right-aligned output formatting + :param digits_places: Total number of digits for numeric formatting + :param decimal_places: Number of decimal places for numeric precision + :return: Formatted string with value, units, and colored label + """ + if decimal_places is None and digits_places is None: + formatted_number = f"{value}:.0f" + elif digits_places is None: + formatted_number = f"{value:.{decimal_places}f}" + elif decimal_places is None: + formatted_number = f"{value:>{digits_places}f}" + else: + formatted_number = f"{value:>{digits_places}.{decimal_places}f}" + + result = f"{formatted_number}{units} [{Colors.info}]{label}[/{Colors.info}]" + + if total_characters is not None: + total_characters += len(Colors.info) * 2 + 5 + + if len(result) < total_characters: + result = result.rjust(total_characters) + + return result def split_text_list_by_length( text_list: list[Any], - max_characters: Union[int, list[int]], + max_characters: int | list[int], pad_horizontal: bool = True, pad_vertical: bool = True, ) -> list[list[str]]: """ - Split a list of strings into a list of strings, - each with a maximum length of max_characters - - :param text_list: the list of strings to split - :param max_characters: the maximum length of each string - :param pad_horizontal: whether to pad the strings horizontally, defaults to True - :param pad_vertical: whether to pad the strings vertically, defaults to True - :return: a list of strings + Split text strings into wrapped lines with specified maximum character limits. + + Processes each string in the input list by wrapping text to fit within character + limits, with optional padding for consistent formatting in tabular displays. + Supports different character limits per string and uniform padding across results. + + :param text_list: List of strings to process and wrap + :param max_characters: Maximum characters per line, either single value or + per-string limits + :param pad_horizontal: Right-align lines within their character limits + :param pad_vertical: Pad shorter results to match the longest wrapped result + :return: List of wrapped line lists, one per input string + :raises ValueError: If max_characters list length doesn't match text_list length """ if not isinstance(max_characters, list): max_characters = [max_characters] * len(text_list) @@ -75,16 +138,21 @@ def split_text_list_by_length( def filter_text( text: str, - filter_start: Optional[Union[str, int]] = None, - filter_end: Optional[Union[str, int]] = None, + filter_start: str | int | None = None, + filter_end: str | int | None = None, ) -> str: """ - Filter text by start and end strings or indices + Extract text substring using start and end markers or indices. + + Filters text content by locating string markers or using numeric indices + to extract specific portions. Supports flexible filtering for content + extraction and preprocessing operations. - :param text: the text to filter - :param filter_start: the start string or index to filter from - :param filter_end: the end string or index to filter to - :return: the filtered text + :param text: Source text to filter and extract from + :param filter_start: Starting marker string or index position + :param filter_end: Ending marker string or index position + :return: Filtered text substring between specified boundaries + :raises ValueError: If filter indices are invalid or markers not found """ filter_start_index = -1 filter_end_index = -1 @@ -112,10 +180,29 @@ def filter_text( def clean_text(text: str) -> str: + """ + Normalize text by fixing encoding issues and standardizing whitespace. + + Applies Unicode normalization and whitespace standardization for consistent + text processing. Removes excessive whitespace and fixes common encoding problems. + + :param text: Raw text string to clean and normalize + :return: Cleaned text with normalized encoding and whitespace + """ return re.sub(r"\s+", " ", ftfy.fix_text(text)).strip() def split_text(text: str, split_punctuation: bool = False) -> list[str]: + """ + Split text into tokens with optional punctuation separation. + + Tokenizes text into words and optionally separates punctuation marks + for detailed text analysis and processing operations. + + :param text: Text string to tokenize and split + :param split_punctuation: Separate punctuation marks as individual tokens + :return: List of text tokens + """ text = clean_text(text) if split_punctuation: @@ -124,16 +211,20 @@ def split_text(text: str, split_punctuation: bool = False) -> list[str]: return text.split() -def load_text(data: Union[str, Path], encoding: Optional[str] = None) -> str: +def load_text(data: str | Path, encoding: str | None = None) -> str: """ - Load an HTML file from a path or URL - - :param data: the path or URL to load the HTML file from - :type data: Union[str, Path] - :param encoding: the encoding to use when reading the file - :type encoding: str - :return: the HTML content - :rtype: str + Load text content from various sources including URLs, files, and package data. + + Supports loading from HTTP/FTP URLs, local files, compressed archives, package + resources, and raw text strings. Automatically detects source type and applies + appropriate loading strategy with encoding support. + + :param data: Source location or raw text - URL, file path, package resource + identifier, or text content + :param encoding: Character encoding for file reading operations + :return: Loaded text content as string + :raises FileNotFoundError: If local file path does not exist + :raises httpx.HTTPStatusError: If URL request fails """ logger.debug("Loading text: {}", data) @@ -177,38 +268,71 @@ def load_text(data: Union[str, Path], encoding: Optional[str] = None) -> str: return data.read_text(encoding=encoding) -def is_puncutation(text: str) -> bool: +def is_punctuation(text: str) -> bool: """ - Check if the text is a punctuation + Check if a single character is a punctuation mark. + + Identifies punctuation characters by excluding alphanumeric characters + and whitespace from single-character strings. - :param text: the text to check - :type text: str - :return: True if the text is a punctuation, False otherwise - :rtype: bool + :param text: Single character string to test + :return: True if the character is punctuation, False otherwise """ return len(text) == 1 and not text.isalnum() and not text.isspace() class EndlessTextCreator: + """ + Infinite text generator for load testing and content creation operations. + + Provides deterministic text generation by cycling through preprocessed word + tokens from source content. Supports filtering and punctuation handling for + realistic text patterns in benchmarking scenarios. + + Example: + :: + creator = EndlessTextCreator("path/to/source.txt") + generated = creator.create_text(start=0, length=100) + more_text = creator.create_text(start=50, length=200) + """ + def __init__( self, - data: Union[str, Path], - filter_start: Optional[Union[str, int]] = None, - filter_end: Optional[Union[str, int]] = None, + data: str | Path, + filter_start: str | int | None = None, + filter_end: str | int | None = None, ): + """ + Initialize text creator with source content and optional filtering. + + :param data: Source text location or content - file path, URL, or raw text + :param filter_start: Starting marker or index for content filtering + :param filter_end: Ending marker or index for content filtering + """ self.data = data self.text = load_text(data) self.filtered_text = filter_text(self.text, filter_start, filter_end) self.words = split_text(self.filtered_text, split_punctuation=True) def create_text(self, start: int, length: int) -> str: + """ + Generate text by cycling through word tokens from the specified position. + + Creates deterministic text sequences by selecting consecutive tokens from + the preprocessed word list, wrapping around when reaching the end. + Maintains proper spacing and punctuation formatting. + + :param start: Starting position in the token sequence + :param length: Number of tokens to include in generated text + :return: Generated text string with proper spacing and punctuation + """ text = "" for counter in range(length): index = (start + counter) % len(self.words) add_word = self.words[index] - if counter != 0 and not is_puncutation(add_word): + if counter != 0 and not is_punctuation(add_word): text += " " text += add_word diff --git a/tests/unit/mock_benchmark.py b/tests/unit/mock_benchmark.py index 81364fa1..29c092c8 100644 --- a/tests/unit/mock_benchmark.py +++ b/tests/unit/mock_benchmark.py @@ -6,13 +6,13 @@ GenerativeTextResponseStats, SynchronousProfile, ) -from guidellm.objects import StatusBreakdown from guidellm.request import GenerativeRequestLoaderDescription from guidellm.scheduler import ( GenerativeRequestsWorkerDescription, SchedulerRequestInfo, SynchronousStrategy, ) +from guidellm.utils import StatusBreakdown __all__ = ["mock_generative_benchmark"] diff --git a/tests/unit/objects/__init__.py b/tests/unit/objects/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/objects/test_pydantic.py b/tests/unit/objects/test_pydantic.py deleted file mode 100644 index cb7f438f..00000000 --- a/tests/unit/objects/test_pydantic.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest -from pydantic import computed_field - -from guidellm.objects.pydantic import StandardBaseModel - - -class ExampleModel(StandardBaseModel): - name: str - age: int - - @computed_field # type: ignore[misc] - @property - def computed(self) -> str: - return self.name + " " + str(self.age) - - -@pytest.mark.smoke -def test_standard_base_model_initialization(): - example = ExampleModel(name="John Doe", age=30) - assert example.name == "John Doe" - assert example.age == 30 - assert example.computed == "John Doe 30" - - -@pytest.mark.smoke -def test_standard_base_model_invalid_initialization(): - with pytest.raises(ValueError): - ExampleModel(name="John Doe", age="thirty") # type: ignore[arg-type] - - -@pytest.mark.smoke -def test_standard_base_model_marshalling(): - example = ExampleModel(name="John Doe", age=30) - serialized = example.model_dump() - assert serialized["name"] == "John Doe" - assert serialized["age"] == 30 - assert serialized["computed"] == "John Doe 30" - - serialized["computed"] = "Jane Doe 40" - deserialized = ExampleModel.model_validate(serialized) - assert deserialized.name == "John Doe" - assert deserialized.age == 30 - assert deserialized.computed == "John Doe 30" diff --git a/tests/unit/utils/test_auto_importer.py b/tests/unit/utils/test_auto_importer.py new file mode 100644 index 00000000..cc71bce3 --- /dev/null +++ b/tests/unit/utils/test_auto_importer.py @@ -0,0 +1,269 @@ +""" +Unit tests for the auto_importer module. +""" + +from __future__ import annotations + +from unittest import mock + +import pytest + +from guidellm.utils import AutoImporterMixin + + +class TestAutoImporterMixin: + """Test suite for AutoImporterMixin functionality.""" + + @pytest.fixture( + params=[ + { + "auto_package": "test.package", + "auto_ignore_modules": None, + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module1", "test.package.module2"], + }, + { + "auto_package": ("test.package1", "test.package2"), + "auto_ignore_modules": None, + "modules": [ + ("test.package1.moduleA", False), + ("test.package2.moduleB", False), + ], + "expected_imports": ["test.package1.moduleA", "test.package2.moduleB"], + }, + { + "auto_package": "test.package", + "auto_ignore_modules": ("test.package.module1",), + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module2"], + }, + ], + ids=["single_package", "multiple_packages", "ignored_modules"], + ) + def valid_instances(self, request): + """Fixture providing test data for AutoImporterMixin subclasses.""" + config = request.param + + class TestClass(AutoImporterMixin): + auto_package = config["auto_package"] + auto_ignore_modules = config["auto_ignore_modules"] + + return TestClass, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test AutoImporterMixin class signatures and attributes.""" + assert hasattr(AutoImporterMixin, "auto_package") + assert hasattr(AutoImporterMixin, "auto_ignore_modules") + assert hasattr(AutoImporterMixin, "auto_imported_modules") + assert hasattr(AutoImporterMixin, "auto_import_package_modules") + assert callable(AutoImporterMixin.auto_import_package_modules) + + # Test default class variables + assert AutoImporterMixin.auto_package is None + assert AutoImporterMixin.auto_ignore_modules is None + assert AutoImporterMixin.auto_imported_modules is None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test AutoImporterMixin subclass initialization.""" + test_class, config = valid_instances + assert issubclass(test_class, AutoImporterMixin) + assert test_class.auto_package == config["auto_package"] + assert test_class.auto_ignore_modules == config["auto_ignore_modules"] + assert test_class.auto_imported_modules is None + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test AutoImporterMixin with missing auto_package.""" + + class TestClass(AutoImporterMixin): + pass + + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestClass.auto_import_package_modules() + + @pytest.mark.smoke + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_auto_import_package_modules(self, mock_walk, mock_import, valid_instances): + """Test auto_import_package_modules core functionality.""" + test_class, config = valid_instances + + # Setup mocks based on config + packages = {} + modules = {} + + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + pkg_path = pkg.replace(".", "/") + packages[pkg] = MockHelper.create_mock_package(pkg, pkg_path) + else: + pkg = config["auto_package"] + packages[pkg] = MockHelper.create_mock_package(pkg, pkg.replace(".", "/")) + + for module_name, is_pkg in config["modules"]: + if not is_pkg: + modules[module_name] = MockHelper.create_mock_module(module_name) + + mock_import.side_effect = lambda name: {**packages, **modules}.get( + name, mock.MagicMock() + ) + + def walk_side_effect(path, prefix): + return [ + (None, module_name, is_pkg) + for module_name, is_pkg in config["modules"] + if module_name.startswith(prefix) + ] + + mock_walk.side_effect = walk_side_effect + + # Execute + test_class.auto_import_package_modules() + + # Verify + assert test_class.auto_imported_modules == config["expected_imports"] + + # Verify package imports + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + mock_import.assert_any_call(pkg) + else: + mock_import.assert_any_call(config["auto_package"]) + + # Verify expected module imports + for expected_module in config["expected_imports"]: + mock_import.assert_any_call(expected_module) + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_auto_import_package_modules_invalid(self, mock_walk, mock_import): + """Test auto_import_package_modules with invalid configurations.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Test import error handling + mock_import.side_effect = ImportError("Module not found") + + with pytest.raises(ImportError): + TestClass.auto_import_package_modules() + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_skip_packages(self, mock_walk, mock_import): + """Test that packages (is_pkg=True) are skipped.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_module = MockHelper.create_mock_module("test.package.module") + + mock_import.side_effect = lambda name: { + "test.package": mock_package, + "test.package.module": mock_module, + }[name] + + mock_walk.return_value = [ + (None, "test.package.subpackage", True), + (None, "test.package.module", False), + ] + + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.module"] + mock_import.assert_any_call("test.package.module") + # subpackage should not be imported + with pytest.raises(AssertionError): + mock_import.assert_any_call("test.package.subpackage") + + @pytest.mark.sanity + @mock.patch("sys.modules", {"test.package.existing": mock.MagicMock()}) + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_skip_already_imported_modules(self, mock_walk, mock_import): + """Test that modules already in sys.modules are tracked but not re-imported.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_import.side_effect = lambda name: { + "test.package": mock_package, + }.get(name, mock.MagicMock()) + + mock_walk.return_value = [ + (None, "test.package.existing", False), + ] + + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.existing"] + mock_import.assert_called_once_with("test.package") + with pytest.raises(AssertionError): + mock_import.assert_any_call("test.package.existing") + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_prevent_duplicate_module_imports(self, mock_walk, mock_import): + """Test that modules already in auto_imported_modules are not re-imported.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_module = MockHelper.create_mock_module("test.package.module") + + mock_import.side_effect = lambda name: { + "test.package": mock_package, + "test.package.module": mock_module, + }[name] + + mock_walk.return_value = [ + (None, "test.package.module", False), + (None, "test.package.module", False), + ] + + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.module"] + assert mock_import.call_count == 2 # Package + module (not duplicate) + + +class MockHelper: + """Helper class to create consistent mock objects for testing.""" + + @staticmethod + def create_mock_package(name: str, path: str): + """Create a mock package with required attributes.""" + package = mock.MagicMock() + package.__name__ = name + package.__path__ = [path] + return package + + @staticmethod + def create_mock_module(name: str): + """Create a mock module with required attributes.""" + module = mock.MagicMock() + module.__name__ = name + return module diff --git a/tests/unit/utils/test_functions.py b/tests/unit/utils/test_functions.py new file mode 100644 index 00000000..3b353759 --- /dev/null +++ b/tests/unit/utils/test_functions.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from datetime import datetime + +import pytest + +from guidellm.utils.functions import ( + all_defined, + safe_add, + safe_divide, + safe_format_timestamp, + safe_getattr, + safe_multiply, +) + + +class TestAllDefined: + """Test suite for all_defined function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("values", "expected"), + [ + ((1, 2, 3), True), + (("test", "hello"), True), + ((0, False, ""), True), + ((1, None, 3), False), + ((None,), False), + ((None, None), False), + ((), True), + ], + ) + def test_invocation(self, values, expected): + """Test all_defined with valid inputs.""" + result = all_defined(*values) + assert result == expected + + @pytest.mark.sanity + def test_mixed_types(self): + """Test all_defined with mixed data types.""" + result = all_defined(1, "test", [], {}, 0.0, False) + assert result is True + + result = all_defined(1, "test", None, {}) + assert result is False + + +class TestSafeGetattr: + """Test suite for safe_getattr function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("obj", "attr", "default", "expected"), + [ + (None, "any_attr", "default_val", "default_val"), + (None, "any_attr", None, None), + ("test_string", "nonexistent", "default_val", "default_val"), + ], + ) + def test_invocation(self, obj, attr, default, expected): + """Test safe_getattr with valid inputs.""" + result = safe_getattr(obj, attr, default) + assert result == expected + + @pytest.mark.smoke + def test_with_object(self): + """Test safe_getattr with actual object attributes.""" + + class TestObj: + test_attr = "test_value" + + obj = TestObj() + result = safe_getattr(obj, "test_attr", "default") + assert result == "test_value" + + result = safe_getattr(obj, "missing_attr", "default") + assert result == "default" + + # Test with method attribute + result = safe_getattr("test_string", "upper", None) + assert callable(result) + assert result() == "TEST_STRING" + + +class TestSafeDivide: + """Test suite for safe_divide function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("numerator", "denominator", "num_default", "den_default", "expected"), + [ + (10, 2, 0.0, 1.0, 5.0), + (None, 2, 6.0, 1.0, 3.0), + (10, None, 0.0, 5.0, 2.0), + (None, None, 8.0, 4.0, 2.0), + (10, 0, 0.0, 1.0, 10 / 1e-10), + ], + ) + def test_invocation( + self, numerator, denominator, num_default, den_default, expected + ): + """Test safe_divide with valid inputs.""" + result = safe_divide(numerator, denominator, num_default, den_default) + assert result == pytest.approx(expected, rel=1e-6) + + @pytest.mark.sanity + def test_zero_division_protection(self): + """Test safe_divide protection against zero division.""" + result = safe_divide(10, 0) + assert result == 10 / 1e-10 + + result = safe_divide(5, None, den_default=0) + assert result == 5 / 1e-10 + + +class TestSafeMultiply: + """Test suite for safe_multiply function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("values", "default", "expected"), + [ + ((2, 3, 4), 1.0, 24.0), + ((2, None, 4), 1.0, 8.0), + ((None, None), 5.0, 5.0), + ((), 3.0, 3.0), + ((2, 3, None, 5), 2.0, 60.0), + ], + ) + def test_invocation(self, values, default, expected): + """Test safe_multiply with valid inputs.""" + result = safe_multiply(*values, default=default) + assert result == expected + + @pytest.mark.sanity + def test_with_zero(self): + """Test safe_multiply with zero values.""" + result = safe_multiply(2, 0, 3, default=1.0) + assert result == 0.0 + + result = safe_multiply(None, 0, None, default=5.0) + assert result == 0.0 + + +class TestSafeAdd: + """Test suite for safe_add function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("values", "signs", "default", "expected"), + [ + ((1, 2, 3), None, 0.0, 6.0), + ((1, None, 3), None, 5.0, 9.0), + ((10, 5), [1, -1], 0.0, 5.0), + ((None, None), [1, -1], 2.0, 0.0), + ((), None, 3.0, 3.0), + ((1, 2, 3), [1, 1, -1], 0.0, 0.0), + ], + ) + def test_invocation(self, values, signs, default, expected): + """Test safe_add with valid inputs.""" + result = safe_add(*values, signs=signs, default=default) + assert result == expected + + @pytest.mark.sanity + def test_invalid_signs_length(self): + """Test safe_add with invalid signs length.""" + with pytest.raises( + ValueError, match="Length of signs must match length of values" + ): + safe_add(1, 2, 3, signs=[1, -1]) + + @pytest.mark.sanity + def test_single_value(self): + """Test safe_add with single value.""" + result = safe_add(5, default=1.0) + assert result == 5.0 + + result = safe_add(None, default=3.0) + assert result == 3.0 + + +class TestSafeFormatTimestamp: + """Test suite for safe_format_timestamp function.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("timestamp", "format_", "default", "expected"), + [ + (1609459200.0, "%Y-%m-%d", "N/A", "2020-12-31"), + (1609459200.0, "%H:%M:%S", "N/A", "19:00:00"), + (None, "%H:%M:%S", "N/A", "N/A"), + (-1, "%H:%M:%S", "N/A", "N/A"), + (2**32, "%H:%M:%S", "N/A", "N/A"), + ], + ) + def test_invocation(self, timestamp, format_, default, expected): + """Test safe_format_timestamp with valid inputs.""" + result = safe_format_timestamp(timestamp, format_, default) + assert result == expected + + @pytest.mark.sanity + def test_edge_cases(self): + """Test safe_format_timestamp with edge case timestamps.""" + result = safe_format_timestamp(0.0, "%Y", "N/A") + assert result == "1969" + + result = safe_format_timestamp(1.0, "%Y", "N/A") + assert result == "1969" + + result = safe_format_timestamp(2**31 - 1, "%Y", "N/A") + expected_year = datetime.fromtimestamp(2**31 - 1).strftime("%Y") + assert result == expected_year + + @pytest.mark.sanity + def test_invalid_timestamp_ranges(self): + """Test safe_format_timestamp with invalid timestamp ranges.""" + result = safe_format_timestamp(2**31 + 1, "%Y", "ERROR") + assert result == "ERROR" + + result = safe_format_timestamp(-1000, "%Y", "ERROR") + assert result == "ERROR" diff --git a/tests/unit/utils/test_mixins.py b/tests/unit/utils/test_mixins.py new file mode 100644 index 00000000..cd8990de --- /dev/null +++ b/tests/unit/utils/test_mixins.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +import pytest + +from guidellm.utils.mixins import InfoMixin + + +class TestInfoMixin: + """Test suite for InfoMixin.""" + + @pytest.fixture( + params=[ + {"attr_one": "test_value", "attr_two": 42}, + {"attr_one": "hello_world", "attr_two": 100, "attr_three": [1, 2, 3]}, + ], + ids=["basic_attributes", "extended_attributes"], + ) + def valid_instances(self, request): + """Fixture providing test data for InfoMixin.""" + constructor_args = request.param + + class TestClass(InfoMixin): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + instance = TestClass(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test InfoMixin class signatures and methods.""" + assert hasattr(InfoMixin, "extract_from_obj") + assert callable(InfoMixin.extract_from_obj) + assert hasattr(InfoMixin, "create_info_dict") + assert callable(InfoMixin.create_info_dict) + assert hasattr(InfoMixin, "info") + assert isinstance(InfoMixin.info, property) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test InfoMixin initialization through inheritance.""" + instance, constructor_args = valid_instances + assert isinstance(instance, InfoMixin) + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.smoke + def test_info_property(self, valid_instances): + """Test InfoMixin.info property.""" + instance, constructor_args = valid_instances + result = instance.info + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert "class" in result + assert "module" in result + assert "attributes" in result + assert result["type"] == "TestClass" + assert result["class"] == "TestClass" + assert isinstance(result["attributes"], dict) + for key, value in constructor_args.items(): + assert key in result["attributes"] + assert result["attributes"][key] == value + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("obj_data", "expected_attributes"), + [ + ({"name": "test", "value": 42}, {"name": "test", "value": 42}), + ({"data": [1, 2, 3], "flag": True}, {"data": [1, 2, 3], "flag": True}), + ({"nested": {"key": "value"}}, {"nested": {"key": "value"}}), + ], + ) + def test_create_info_dict(self, obj_data, expected_attributes): + """Test InfoMixin.create_info_dict class method.""" + + class SimpleObject: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + obj = SimpleObject(**obj_data) + result = InfoMixin.create_info_dict(obj) + + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert "class" in result + assert "module" in result + assert "attributes" in result + assert result["type"] == "SimpleObject" + assert result["class"] == "SimpleObject" + assert result["attributes"] == expected_attributes + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("obj_data", "expected_attributes"), + [ + ({"name": "test", "value": 42}, {"name": "test", "value": 42}), + ({"data": [1, 2, 3], "flag": True}, {"data": [1, 2, 3], "flag": True}), + ], + ) + def test_extract_from_obj_without_info(self, obj_data, expected_attributes): + """Test InfoMixin.extract_from_obj with objects without info method.""" + + class SimpleObject: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + obj = SimpleObject(**obj_data) + result = InfoMixin.extract_from_obj(obj) + + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert "class" in result + assert "module" in result + assert "attributes" in result + assert result["type"] == "SimpleObject" + assert result["class"] == "SimpleObject" + assert result["attributes"] == expected_attributes + + @pytest.mark.smoke + def test_extract_from_obj_with_info_method(self): + """Test InfoMixin.extract_from_obj with objects that have info method.""" + + class ObjectWithInfoMethod: + def info(self): + return {"custom": "info_method", "type": "custom_type"} + + obj = ObjectWithInfoMethod() + result = InfoMixin.extract_from_obj(obj) + + assert result == {"custom": "info_method", "type": "custom_type"} + + @pytest.mark.smoke + def test_extract_from_obj_with_info_property(self): + """Test InfoMixin.extract_from_obj with objects that have info property.""" + + class ObjectWithInfoProperty: + @property + def info(self): + return {"custom": "info_property", "type": "custom_type"} + + obj = ObjectWithInfoProperty() + result = InfoMixin.extract_from_obj(obj) + + assert result == {"custom": "info_property", "type": "custom_type"} + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("obj_type", "obj_value"), + [ + (str, "test_string"), + (int, 42), + (float, 3.14), + (list, [1, 2, 3]), + (dict, {"key": "value"}), + ], + ) + def test_extract_from_obj_builtin_types(self, obj_type, obj_value): + """Test InfoMixin.extract_from_obj with built-in types.""" + result = InfoMixin.extract_from_obj(obj_value) + + assert isinstance(result, dict) + assert "str" in result + assert "type" in result + assert result["type"] == obj_type.__name__ + assert result["str"] == str(obj_value) + + @pytest.mark.sanity + def test_extract_from_obj_without_dict(self): + """Test InfoMixin.extract_from_obj with objects without __dict__.""" + obj = 42 + result = InfoMixin.extract_from_obj(obj) + + assert isinstance(result, dict) + assert "attributes" in result + assert result["attributes"] == {} + assert result["type"] == "int" + assert result["str"] == "42" + + @pytest.mark.sanity + def test_extract_from_obj_with_private_attributes(self): + """Test InfoMixin.extract_from_obj filters private attributes.""" + + class ObjectWithPrivate: + def __init__(self): + self.public_attr = "public" + self._private_attr = "private" + self.__very_private = "very_private" + + obj = ObjectWithPrivate() + result = InfoMixin.extract_from_obj(obj) + + assert "public_attr" in result["attributes"] + assert result["attributes"]["public_attr"] == "public" + assert "_private_attr" not in result["attributes"] + assert "__very_private" not in result["attributes"] + + @pytest.mark.sanity + def test_extract_from_obj_complex_attributes(self): + """Test InfoMixin.extract_from_obj with complex attribute types.""" + + class ComplexObject: + def __init__(self): + self.simple_str = "test" + self.simple_int = 42 + self.simple_list = [1, 2, 3] + self.simple_dict = {"key": "value"} + self.complex_object = object() + + obj = ComplexObject() + result = InfoMixin.extract_from_obj(obj) + + attributes = result["attributes"] + assert attributes["simple_str"] == "test" + assert attributes["simple_int"] == 42 + assert attributes["simple_list"] == [1, 2, 3] + assert attributes["simple_dict"] == {"key": "value"} + assert isinstance(attributes["complex_object"], str) + + @pytest.mark.regression + def test_create_info_dict_consistency(self, valid_instances): + """Test InfoMixin.create_info_dict produces consistent results.""" + instance, _ = valid_instances + + result1 = InfoMixin.create_info_dict(instance) + result2 = InfoMixin.create_info_dict(instance) + + assert result1 == result2 + assert result1 is not result2 + + @pytest.mark.regression + def test_info_property_uses_create_info_dict(self, valid_instances): + """Test InfoMixin.info property uses create_info_dict method.""" + instance, _ = valid_instances + + info_result = instance.info + create_result = InfoMixin.create_info_dict(instance) + + assert info_result == create_result diff --git a/tests/unit/utils/test_pydantic_utils.py b/tests/unit/utils/test_pydantic_utils.py new file mode 100644 index 00000000..8683604b --- /dev/null +++ b/tests/unit/utils/test_pydantic_utils.py @@ -0,0 +1,710 @@ +""" +Unit tests for the pydantic_utils module. +""" + +from __future__ import annotations + +from typing import ClassVar +from unittest import mock + +import pytest +from pydantic import BaseModel, Field, ValidationError + +from guidellm.utils.pydantic_utils import ( + PydanticClassRegistryMixin, + ReloadableBaseModel, + StandardBaseDict, + StandardBaseModel, + StatusBreakdown, +) + + +class TestReloadableBaseModel: + """Test suite for ReloadableBaseModel.""" + + @pytest.fixture( + params=[ + {"name": "test_value"}, + {"name": "hello_world"}, + {"name": "another_test"}, + ], + ids=["basic_string", "multi_word", "underscore"], + ) + def valid_instances(self, request) -> tuple[ReloadableBaseModel, dict[str, str]]: + """Fixture providing test data for ReloadableBaseModel.""" + + class TestModel(ReloadableBaseModel): + name: str + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ReloadableBaseModel inheritance and class variables.""" + assert issubclass(ReloadableBaseModel, BaseModel) + assert hasattr(ReloadableBaseModel, "model_config") + assert hasattr(ReloadableBaseModel, "reload_schema") + + # Check model configuration + config = ReloadableBaseModel.model_config + assert config["extra"] == "ignore" + assert config["use_enum_values"] is True + assert config["validate_assignment"] is True + assert config["from_attributes"] is True + assert config["arbitrary_types_allowed"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test ReloadableBaseModel initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ReloadableBaseModel) + assert instance.name == constructor_args["name"] # type: ignore[attr-defined] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("name", None), + ("name", 123), + ("name", []), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test ReloadableBaseModel with invalid field values.""" + + class TestModel(ReloadableBaseModel): + name: str + + data = {field: value} + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test ReloadableBaseModel initialization without required field.""" + + class TestModel(ReloadableBaseModel): + name: str + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_reload_schema(self): + """Test ReloadableBaseModel.reload_schema method.""" + + class TestModel(ReloadableBaseModel): + name: str + + # Mock the model_rebuild method to simulate schema reload + with mock.patch.object(TestModel, "model_rebuild") as mock_rebuild: + TestModel.reload_schema() + mock_rebuild.assert_called_once_with(force=True) + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test ReloadableBaseModel serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["name"] == constructor_args["name"] + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.name == constructor_args["name"] + + +class TestStandardBaseModel: + """Test suite for StandardBaseModel.""" + + @pytest.fixture( + params=[ + {"field_str": "test_value", "field_int": 42}, + {"field_str": "hello_world", "field_int": 100}, + {"field_str": "another_test", "field_int": 0}, + ], + ids=["basic_values", "positive_values", "zero_value"], + ) + def valid_instances( + self, request + ) -> tuple[StandardBaseModel, dict[str, int | str]]: + """Fixture providing test data for StandardBaseModel.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StandardBaseModel inheritance and class variables.""" + assert issubclass(StandardBaseModel, BaseModel) + assert hasattr(StandardBaseModel, "model_config") + assert hasattr(StandardBaseModel, "get_default") + + # Check model configuration + config = StandardBaseModel.model_config + assert config["extra"] == "ignore" + assert config["use_enum_values"] is True + assert config["validate_assignment"] is True + assert config["from_attributes"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StandardBaseModel initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StandardBaseModel) + assert instance.field_str == constructor_args["field_str"] # type: ignore[attr-defined] + assert instance.field_int == constructor_args["field_int"] # type: ignore[attr-defined] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("field_str", None), + ("field_str", 123), + ("field_int", "not_int"), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test StandardBaseModel with invalid field values.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + data = {field: value} + if field == "field_str": + data["field_int"] = 42 + else: + data["field_str"] = "test" + + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test StandardBaseModel initialization without required field.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_get_default(self): + """Test StandardBaseModel.get_default method.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=42, description="Test integer field") + + default_value = TestModel.get_default("field_int") + assert default_value == 42 + + @pytest.mark.sanity + def test_get_default_invalid(self): + """Test StandardBaseModel.get_default with invalid field.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + + with pytest.raises(KeyError): + TestModel.get_default("nonexistent_field") + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StandardBaseModel serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["field_str"] == constructor_args["field_str"] + assert data_dict["field_int"] == constructor_args["field_int"] + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.field_str == constructor_args["field_str"] + assert recreated.field_int == constructor_args["field_int"] + + +class TestStandardBaseDict: + """Test suite for StandardBaseDict.""" + + @pytest.fixture( + params=[ + {"field_str": "test_value", "extra_field": "extra_value"}, + {"field_str": "hello_world", "another_extra": 123}, + {"field_str": "another_test", "complex_extra": {"nested": "value"}}, + ], + ids=["string_extra", "int_extra", "dict_extra"], + ) + def valid_instances( + self, request + ) -> tuple[StandardBaseDict, dict[str, str | int | dict[str, str]]]: + """Fixture providing test data for StandardBaseDict.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StandardBaseDict inheritance and class variables.""" + assert issubclass(StandardBaseDict, StandardBaseModel) + assert hasattr(StandardBaseDict, "model_config") + + # Check model configuration + config = StandardBaseDict.model_config + assert config["extra"] == "allow" + assert config["use_enum_values"] is True + assert config["validate_assignment"] is True + assert config["from_attributes"] is True + assert config["arbitrary_types_allowed"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StandardBaseDict initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StandardBaseDict) + assert instance.field_str == constructor_args["field_str"] # type: ignore[attr-defined] + + # Check extra fields are preserved + for key, value in constructor_args.items(): + if key != "field_str": + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("field_str", None), + ("field_str", 123), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test StandardBaseDict with invalid field values.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + data = {field: value} + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test StandardBaseDict initialization without required field.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StandardBaseDict serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["field_str"] == constructor_args["field_str"] + + # Check extra fields are in the serialized data + for key, value in constructor_args.items(): + if key != "field_str": + assert key in data_dict + assert data_dict[key] == value + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.field_str == constructor_args["field_str"] + + # Check extra fields are preserved after deserialization + for key, value in constructor_args.items(): + if key != "field_str": + assert hasattr(recreated, key) + assert getattr(recreated, key) == value + + +class TestStatusBreakdown: + """Test suite for StatusBreakdown.""" + + @pytest.fixture( + params=[ + {"successful": 100, "errored": 5, "incomplete": 10, "total": 115}, + { + "successful": "success_data", + "errored": "error_data", + "incomplete": "incomplete_data", + "total": "total_data", + }, + { + "successful": [1, 2, 3], + "errored": [4, 5], + "incomplete": [6], + "total": [1, 2, 3, 4, 5, 6], + }, + ], + ids=["int_values", "string_values", "list_values"], + ) + def valid_instances(self, request) -> tuple[StatusBreakdown, dict]: + """Fixture providing test data for StatusBreakdown.""" + constructor_args = request.param + instance = StatusBreakdown(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StatusBreakdown inheritance and type relationships.""" + assert issubclass(StatusBreakdown, BaseModel) + # Check if Generic is in the MRO (method resolution order) + assert any(cls.__name__ == "Generic" for cls in StatusBreakdown.__mro__) + assert "successful" in StatusBreakdown.model_fields + assert "errored" in StatusBreakdown.model_fields + assert "incomplete" in StatusBreakdown.model_fields + assert "total" in StatusBreakdown.model_fields + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StatusBreakdown initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StatusBreakdown) + assert instance.successful == constructor_args["successful"] + assert instance.errored == constructor_args["errored"] + assert instance.incomplete == constructor_args["incomplete"] + assert instance.total == constructor_args["total"] + + @pytest.mark.smoke + def test_initialization_defaults(self): + """Test StatusBreakdown initialization with default values.""" + instance: StatusBreakdown = StatusBreakdown() + assert instance.successful is None + assert instance.errored is None + assert instance.incomplete is None + assert instance.total is None + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StatusBreakdown serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["successful"] == constructor_args["successful"] + assert data_dict["errored"] == constructor_args["errored"] + assert data_dict["incomplete"] == constructor_args["incomplete"] + assert data_dict["total"] == constructor_args["total"] + + recreated: StatusBreakdown = StatusBreakdown.model_validate(data_dict) + assert isinstance(recreated, StatusBreakdown) + assert recreated.successful == constructor_args["successful"] + assert recreated.errored == constructor_args["errored"] + assert recreated.incomplete == constructor_args["incomplete"] + assert recreated.total == constructor_args["total"] + + +class TestPydanticClassRegistryMixin: + """Test suite for PydanticClassRegistryMixin.""" + + @pytest.fixture( + params=[ + {"test_type": "test_sub", "value": "test_value"}, + {"test_type": "test_sub", "value": "hello_world"}, + ], + ids=["basic_value", "multi_word"], + ) + def valid_instances( + self, request + ) -> tuple[PydanticClassRegistryMixin, dict, type, type]: + """Fixture providing test data for PydanticClassRegistryMixin.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + TestBaseModel.reload_schema() + + constructor_args = request.param + instance = TestSubModel(value=constructor_args["value"]) + return instance, constructor_args, TestBaseModel, TestSubModel + + @pytest.mark.smoke + def test_class_signatures(self): + """Test PydanticClassRegistryMixin inheritance and class variables.""" + assert issubclass(PydanticClassRegistryMixin, ReloadableBaseModel) + assert hasattr(PydanticClassRegistryMixin, "schema_discriminator") + assert PydanticClassRegistryMixin.schema_discriminator == "model_type" + assert hasattr(PydanticClassRegistryMixin, "register_decorator") + assert hasattr(PydanticClassRegistryMixin, "__get_pydantic_core_schema__") + assert hasattr(PydanticClassRegistryMixin, "__pydantic_generate_base_schema__") + assert hasattr(PydanticClassRegistryMixin, "auto_populate_registry") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test PydanticClassRegistryMixin initialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + assert isinstance(instance, sub_class) + assert isinstance(instance, base_class) + assert instance.test_type == constructor_args["test_type"] + assert instance.value == constructor_args["value"] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("test_type", None), + ("test_type", 123), + ("value", None), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test PydanticClassRegistryMixin with invalid field values.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + data = {field: value} + if field == "test_type": + data["value"] = "test" + else: + data["test_type"] = "test_sub" + + with pytest.raises(ValidationError): + TestSubModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test PydanticClassRegistryMixin initialization without required field.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + with pytest.raises(ValidationError): + TestSubModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_register_decorator(self): + """Test PydanticClassRegistryMixin.register_decorator method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register() + class TestSubModel(TestBaseModel): + test_type: str = "TestSubModel" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "testsubmodel" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["testsubmodel"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_with_name(self): + """Test PydanticClassRegistryMixin.register_decorator with custom name.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("custom_name") + class TestSubModel(TestBaseModel): + test_type: str = "custom_name" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "custom_name" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["custom_name"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_invalid_type(self): + """Test PydanticClassRegistryMixin.register_decorator with invalid type.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + class RegularClass: + pass + + with pytest.raises(TypeError) as exc_info: + TestBaseModel.register_decorator(RegularClass) # type: ignore[arg-type] + + assert "not a subclass of Pydantic BaseModel" in str(exc_info.value) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test PydanticClassRegistryMixin.auto_populate_registry method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = True + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + with ( + mock.patch.object(TestBaseModel, "reload_schema") as mock_reload, + mock.patch( + "guidellm.utils.registry.RegistryMixin.auto_populate_registry", + return_value=True, + ), + ): + result = TestBaseModel.auto_populate_registry() + assert result is True + mock_reload.assert_called_once() + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test PydanticClassRegistryMixin serialization and deserialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + + # Test serialization with model_dump + dump_data = instance.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["test_type"] == constructor_args["test_type"] + assert dump_data["value"] == constructor_args["value"] + + # Test deserialization via subclass + recreated = sub_class.model_validate(dump_data) + assert isinstance(recreated, sub_class) + assert recreated.test_type == constructor_args["test_type"] + assert recreated.value == constructor_args["value"] + + # Test polymorphic deserialization via base class + recreated_base = base_class.model_validate(dump_data) # type: ignore[assignment] + assert isinstance(recreated_base, sub_class) + assert recreated_base.test_type == constructor_args["test_type"] + assert recreated_base.value == constructor_args["value"] + + @pytest.mark.regression + def test_polymorphic_container_marshalling(self): + """Test PydanticClassRegistryMixin in container models.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @classmethod + def __pydantic_generate_base_schema__(cls, handler): + return handler(cls) + + @TestBaseModel.register("sub_a") + class TestSubModelA(TestBaseModel): + test_type: str = "sub_a" + value_a: str + + @TestBaseModel.register("sub_b") + class TestSubModelB(TestBaseModel): + test_type: str = "sub_b" + value_b: int + + class ContainerModel(BaseModel): + name: str + model: TestBaseModel + models: list[TestBaseModel] + + sub_a = TestSubModelA(value_a="test") + sub_b = TestSubModelB(value_b=123) + + container = ContainerModel(name="container", model=sub_a, models=[sub_a, sub_b]) + + # Verify container construction + assert isinstance(container.model, TestSubModelA) + assert container.model.test_type == "sub_a" + assert container.model.value_a == "test" + assert len(container.models) == 2 + assert isinstance(container.models[0], TestSubModelA) + assert isinstance(container.models[1], TestSubModelB) + + # Test serialization + dump_data = container.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["name"] == "container" + assert dump_data["model"]["test_type"] == "sub_a" + assert dump_data["model"]["value_a"] == "test" + assert len(dump_data["models"]) == 2 + assert dump_data["models"][0]["test_type"] == "sub_a" + assert dump_data["models"][1]["test_type"] == "sub_b" + + # Test deserialization + recreated = ContainerModel.model_validate(dump_data) + assert isinstance(recreated, ContainerModel) + assert recreated.name == "container" + assert isinstance(recreated.model, TestSubModelA) + assert len(recreated.models) == 2 + assert isinstance(recreated.models[0], TestSubModelA) + assert isinstance(recreated.models[1], TestSubModelB) diff --git a/tests/unit/utils/test_registry.py b/tests/unit/utils/test_registry.py new file mode 100644 index 00000000..b5c17975 --- /dev/null +++ b/tests/unit/utils/test_registry.py @@ -0,0 +1,533 @@ +""" +Unit tests for the registry module. +""" + +from __future__ import annotations + +from typing import TypeVar +from unittest import mock + +import pytest + +from guidellm.utils.registry import RegistryMixin, RegistryObjT + + +def test_registry_obj_type(): + """Test that RegistryObjT is configured correctly as a TypeVar.""" + assert isinstance(RegistryObjT, type(TypeVar("test"))) + assert RegistryObjT.__name__ == "RegistryObjT" + assert RegistryObjT.__bound__ is not None # bound to Any + assert RegistryObjT.__constraints__ == () + + +class TestRegistryMixin: + """Test suite for RegistryMixin class.""" + + @pytest.fixture( + params=[ + {"registry_auto_discovery": False, "auto_package": None}, + {"registry_auto_discovery": True, "auto_package": "test.package"}, + ], + ids=["manual_registry", "auto_discovery"], + ) + def valid_instances(self, request): + """Fixture providing test data for RegistryMixin subclasses.""" + config = request.param + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = config["registry_auto_discovery"] + if config["auto_package"]: + auto_package = config["auto_package"] + + return TestRegistryClass, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test RegistryMixin inheritance and exposed methods.""" + assert hasattr(RegistryMixin, "registry") + assert hasattr(RegistryMixin, "registry_auto_discovery") + assert hasattr(RegistryMixin, "registry_populated") + assert hasattr(RegistryMixin, "register") + assert hasattr(RegistryMixin, "register_decorator") + assert hasattr(RegistryMixin, "auto_populate_registry") + assert hasattr(RegistryMixin, "registered_objects") + assert hasattr(RegistryMixin, "is_registered") + assert hasattr(RegistryMixin, "get_registered_object") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test RegistryMixin initialization.""" + registry_class, config = valid_instances + + assert registry_class.registry is None + assert ( + registry_class.registry_auto_discovery == config["registry_auto_discovery"] + ) + assert registry_class.registry_populated is False + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test RegistryMixin with missing auto_package when auto_discovery enabled.""" + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = True + + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestRegistryClass.auto_import_package_modules() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, None), # Uses class name + ], + ) + def test_register(self, valid_instances, name, expected_key): + """Test register method with various name configurations.""" + registry_class, _ = valid_instances + + if name is None: + + @registry_class.register() + class TestClass: + pass + + expected_key = "testclass" + else: + + @registry_class.register(name) + class TestClass: + pass + + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_invalid(self, valid_instances, invalid_name): + """Test register method with invalid name types.""" + registry_class, _ = valid_instances + + with pytest.raises(ValueError, match="name must be a string, list of strings"): + registry_class.register(invalid_name) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, "testclass"), + ], + ) + def test_register_decorator(self, valid_instances, name, expected_key): + """Test register_decorator method with various name configurations.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + registry_class.register_decorator(TestClass, name=name) + + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_decorator_invalid(self, valid_instances, invalid_name): + """Test register_decorator with invalid name types.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + with pytest.raises( + ValueError, match="name must be a string or an iterable of strings" + ): + registry_class.register_decorator(TestClass, name=invalid_name) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test auto_populate_registry method with valid configuration.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test.package" + + with mock.patch.object( + TestAutoRegistry, "auto_import_package_modules" + ) as mock_import: + result = TestAutoRegistry.auto_populate_registry() + assert result is True + mock_import.assert_called_once() + assert TestAutoRegistry.registry_populated is True + + # Second call should return False + result = TestAutoRegistry.auto_populate_registry() + assert result is False + mock_import.assert_called_once() # Should not be called again + + @pytest.mark.sanity + def test_auto_populate_registry_invalid(self): + """Test auto_populate_registry when auto-discovery is disabled.""" + + class TestDisabledRegistry(RegistryMixin): + registry_auto_discovery = False + + with pytest.raises(ValueError, match="registry_auto_discovery is set to False"): + TestDisabledRegistry.auto_populate_registry() + + @pytest.mark.smoke + def test_registered_objects(self, valid_instances): + """Test registered_objects method with manual registration.""" + registry_class, config = valid_instances + + @registry_class.register("class1") + class TestClass1: + pass + + @registry_class.register("class2") + class TestClass2: + pass + + if config["registry_auto_discovery"]: + with mock.patch.object(registry_class, "auto_import_package_modules"): + objects = registry_class.registered_objects() + else: + objects = registry_class.registered_objects() + + assert isinstance(objects, tuple) + assert len(objects) == 2 + assert TestClass1 in objects + assert TestClass2 in objects + + @pytest.mark.sanity + def test_registered_objects_invalid(self): + """Test registered_objects when no objects are registered.""" + + class TestRegistryClass(RegistryMixin): + pass + + with pytest.raises( + ValueError, match="must be called after registering objects" + ): + TestRegistryClass.registered_objects() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "check_name", "expected"), + [ + ("test_name", "test_name", True), + ("TestName", "testname", True), + ("UPPERCASE", "uppercase", True), + ("test_name", "nonexistent", False), + ], + ) + def test_is_registered(self, valid_instances, register_name, check_name, expected): + """Test is_registered with various name combinations.""" + registry_class, _ = valid_instances + + @registry_class.register(register_name) + class TestClass: + pass + + result = registry_class.is_registered(check_name) + assert result == expected + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "lookup_name"), + [ + ("test_name", "test_name"), + ("TestName", "testname"), + ("UPPERCASE", "uppercase"), + ], + ) + def test_get_registered_object(self, valid_instances, register_name, lookup_name): + """Test get_registered_object with valid names.""" + registry_class, _ = valid_instances + + @registry_class.register(register_name) + class TestClass: + pass + + result = registry_class.get_registered_object(lookup_name) + assert result is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "lookup_name", + ["nonexistent", "wrong_name", "DIFFERENT_CASE"], + ) + def test_get_registered_object_invalid(self, valid_instances, lookup_name): + """Test get_registered_object with invalid names.""" + registry_class, _ = valid_instances + + @registry_class.register("valid_name") + class TestClass: + pass + + result = registry_class.get_registered_object(lookup_name) + assert result is None + + @pytest.mark.regression + def test_multiple_registries_isolation(self): + """Test that different registry classes maintain separate registries.""" + + class Registry1(RegistryMixin): + pass + + class Registry2(RegistryMixin): + pass + + @Registry1.register() + class TestClass1: + pass + + @Registry2.register() + class TestClass2: + pass + + assert Registry1.registry is not None + assert Registry2.registry is not None + assert Registry1.registry != Registry2.registry + assert "testclass1" in Registry1.registry + assert "testclass2" in Registry2.registry + assert "testclass1" not in Registry2.registry + assert "testclass2" not in Registry1.registry + + @pytest.mark.regression + def test_inheritance_registry_sharing(self): + """Test that inherited registry classes share the same registry.""" + + class BaseRegistry(RegistryMixin): + pass + + class ChildRegistry(BaseRegistry): + pass + + @BaseRegistry.register() + class BaseClass: + pass + + @ChildRegistry.register() + class ChildClass: + pass + + # Child classes share the same registry as their parent + assert BaseRegistry.registry is ChildRegistry.registry + + # Both classes can see all registered objects + base_objects = BaseRegistry.registered_objects() + child_objects = ChildRegistry.registered_objects() + + assert len(base_objects) == 2 + assert len(child_objects) == 2 + assert base_objects == child_objects + assert BaseClass in base_objects + assert ChildClass in base_objects + + @pytest.mark.smoke + def test_auto_discovery_initialization(self): + """Test initialization of auto-discovery enabled registry.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + assert TestAutoRegistry.registry is None + assert TestAutoRegistry.registry_populated is False + assert TestAutoRegistry.auto_package == "test_package.modules" + assert TestAutoRegistry.registry_auto_discovery is True + + @pytest.mark.smoke + def test_auto_discovery_registered_objects(self): + """Test automatic population during registered_objects call.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with mock.patch.object( + TestAutoRegistry, "auto_populate_registry" + ) as mock_populate: + TestAutoRegistry.registry = {"class1": "obj1", "class2": "obj2"} + objects = TestAutoRegistry.registered_objects() + mock_populate.assert_called_once() + assert objects == ("obj1", "obj2") + + @pytest.mark.sanity + def test_register_duplicate_registration(self, valid_instances): + """Test register method with duplicate names.""" + registry_class, _ = valid_instances + + @registry_class.register("duplicate_name") + class TestClass1: + pass + + with pytest.raises(ValueError, match="already registered"): + + @registry_class.register("duplicate_name") + class TestClass2: + pass + + @pytest.mark.sanity + def test_register_decorator_duplicate_registration(self, valid_instances): + """Test register_decorator with duplicate names.""" + registry_class, _ = valid_instances + + class TestClass1: + pass + + class TestClass2: + pass + + registry_class.register_decorator(TestClass1, name="duplicate_name") + with pytest.raises(ValueError, match="already registered"): + registry_class.register_decorator(TestClass2, name="duplicate_name") + + @pytest.mark.sanity + def test_register_decorator_invalid_list_element(self, valid_instances): + """Test register_decorator with invalid elements in name list.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + with pytest.raises( + ValueError, match="name must be a string or a list of strings" + ): + registry_class.register_decorator(TestClass, name=["valid", 123]) + + @pytest.mark.sanity + def test_register_decorator_invalid_object(self, valid_instances): + """Test register_decorator with object lacking __name__ attribute.""" + registry_class, _ = valid_instances + + with pytest.raises(AttributeError): + registry_class.register_decorator("not_a_class") + + @pytest.mark.smoke + def test_is_registered_empty_registry(self, valid_instances): + """Test is_registered with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.is_registered("any_name") + assert result is False + + @pytest.mark.smoke + def test_get_registered_object_empty_registry(self, valid_instances): + """Test get_registered_object with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.get_registered_object("any_name") + assert result is None + + @pytest.mark.regression + def test_auto_registry_integration(self): + """Test complete auto-discovery workflow with mocked imports.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with ( + mock.patch("pkgutil.walk_packages") as walk_mock, + mock.patch("importlib.import_module") as import_mock, + ): + # Setup mock package + package_mock = mock.MagicMock() + package_mock.__path__ = ["test_package/modules"] + package_mock.__name__ = "test_package.modules" + + # Setup mock module with test class + module_mock = mock.MagicMock() + module_mock.__name__ = "test_package.modules.module1" + + class Module1Class: + pass + + TestAutoRegistry.register_decorator(Module1Class, "Module1Class") + + # Setup import behavior + import_mock.side_effect = lambda name: ( + package_mock + if name == "test_package.modules" + else module_mock + if name == "test_package.modules.module1" + else (_ for _ in ()).throw(ImportError(f"No module named {name}")) + ) + + # Setup package walking behavior + walk_mock.side_effect = lambda path, prefix: ( + [(None, "test_package.modules.module1", False)] + if prefix == "test_package.modules." + else (_ for _ in ()).throw(ValueError(f"Unknown package: {prefix}")) + ) + + objects = TestAutoRegistry.registered_objects() + assert len(objects) == 1 + assert TestAutoRegistry.registry_populated is True + assert TestAutoRegistry.registry is not None + assert "module1class" in TestAutoRegistry.registry + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with ( + mock.patch("pkgutil.walk_packages") as mock_walk, + mock.patch("importlib.import_module") as mock_import, + ): + mock_package = mock.MagicMock() + mock_package.__path__ = ["test_package/modules"] + mock_package.__name__ = "test_package.modules" + + def import_module(name: str): + if name == "test_package.modules": + return mock_package + elif name == "test_package.modules.module1": + module = mock.MagicMock() + module.__name__ = "test_package.modules.module1" + + class Module1Class: + pass + + TestAutoRegistry.register_decorator(Module1Class, "Module1Class") + return module + else: + raise ImportError(f"No module named {name}") + + def walk_packages(package_path, package_name): + if package_name == "test_package.modules.": + return [(None, "test_package.modules.module1", False)] + else: + raise ValueError(f"Unknown package: {package_name}") + + mock_walk.side_effect = walk_packages + mock_import.side_effect = import_module + + objects = TestAutoRegistry.registered_objects() + assert len(objects) == 1 + assert TestAutoRegistry.registry_populated is True + assert TestAutoRegistry.registry is not None diff --git a/tests/unit/utils/test_singleton.py b/tests/unit/utils/test_singleton.py new file mode 100644 index 00000000..ee01ead1 --- /dev/null +++ b/tests/unit/utils/test_singleton.py @@ -0,0 +1,371 @@ +from __future__ import annotations + +import threading +import time + +import pytest + +from guidellm.utils.singleton import SingletonMixin, ThreadSafeSingletonMixin + + +class TestSingletonMixin: + """Test suite for SingletonMixin class.""" + + @pytest.fixture( + params=[ + {"init_value": "test_value"}, + {"init_value": "another_value"}, + ], + ids=["basic_singleton", "different_value"], + ) + def valid_instances(self, request): + """Provide parameterized test configurations for singleton testing.""" + config = request.param + + class TestSingleton(SingletonMixin): + def __init__(self): + # Check if we need to initialize before calling super().__init__() + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = config["init_value"] + + return TestSingleton, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test SingletonMixin inheritance and exposed attributes.""" + assert hasattr(SingletonMixin, "__new__") + assert hasattr(SingletonMixin, "__init__") + assert hasattr(SingletonMixin, "initialized") + assert isinstance(SingletonMixin.initialized, property) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test SingletonMixin initialization.""" + singleton_class, config = valid_instances + + # Create first instance + instance1 = singleton_class() + + assert isinstance(instance1, singleton_class) + assert instance1.initialized is True + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + + # Check that the class has the singleton instance stored + instance_attr = f"_singleton_instance_{singleton_class.__name__}" + assert hasattr(singleton_class, instance_attr) + assert getattr(singleton_class, instance_attr) is instance1 + + @pytest.mark.smoke + def test_singleton_behavior(self, valid_instances): + """Test that multiple instantiations return the same instance.""" + singleton_class, config = valid_instances + + # Create multiple instances + instance1 = singleton_class() + instance2 = singleton_class() + instance3 = singleton_class() + + # All should be the same instance + assert instance1 is instance2 + assert instance2 is instance3 + assert instance1 is instance3 + + # Value should remain from first initialization + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + assert instance2.value == config["init_value"] + assert instance3.value == config["init_value"] + + @pytest.mark.sanity + def test_initialization_called_once(self, valid_instances): + """Test that __init__ is only called once despite multiple instantiations.""" + singleton_class, config = valid_instances + + class TestSingletonWithCounter(SingletonMixin): + init_count = 0 + + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + TestSingletonWithCounter.init_count += 1 + self.value = config["init_value"] + + # Create multiple instances + instance1 = TestSingletonWithCounter() + instance2 = TestSingletonWithCounter() + + assert TestSingletonWithCounter.init_count == 1 + assert instance1 is instance2 + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + + @pytest.mark.regression + def test_multiple_singleton_classes_isolation(self): + """Test that different singleton classes maintain separate instances.""" + + class Singleton1(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value1" + + class Singleton2(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value2" + + instance1a = Singleton1() + instance2a = Singleton2() + instance1b = Singleton1() + instance2b = Singleton2() + + # Each class has its own singleton instance + assert instance1a is instance1b + assert instance2a is instance2b + assert instance1a is not instance2a + + # Each maintains its own value + assert hasattr(instance1a, "value") + assert hasattr(instance2a, "value") + assert instance1a.value == "value1" + assert instance2a.value == "value2" + + @pytest.mark.regression + def test_inheritance_singleton_sharing(self): + """Test that inherited singleton classes share the same singleton_instance.""" + + class BaseSingleton(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "base_value" + + class ChildSingleton(BaseSingleton): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.extra = "extra_value" + + # Child classes now have separate singleton instances + base_instance = BaseSingleton() + child_instance = ChildSingleton() + + # They should be different instances now (fixed inheritance behavior) + assert base_instance is not child_instance + assert hasattr(base_instance, "value") + assert base_instance.value == "base_value" + assert hasattr(child_instance, "value") + assert child_instance.value == "base_value" + assert hasattr(child_instance, "extra") + assert child_instance.extra == "extra_value" + + @pytest.mark.sanity + def test_without_super_init_call(self): + """Test singleton behavior when subclass doesn't call super().__init__().""" + + class BadSingleton(SingletonMixin): + def __init__(self): + # Not calling super().__init__() + self.value = "bad_value" + + instance1 = BadSingleton() + instance2 = BadSingleton() + + assert instance1 is instance2 + assert hasattr(instance1, "initialized") + assert instance1.initialized is False + + +class TestThreadSafeSingletonMixin: + """Test suite for ThreadSafeSingletonMixin class.""" + + @pytest.fixture( + params=[ + {"init_value": "thread_safe_value"}, + {"init_value": "concurrent_value"}, + ], + ids=["basic_thread_safe", "concurrent_test"], + ) + def valid_instances(self, request): + """Fixture providing test data for ThreadSafeSingletonMixin subclasses.""" + config = request.param + + class TestThreadSafeSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = config["init_value"] + + return TestThreadSafeSingleton, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ThreadSafeSingletonMixin inheritance and exposed attributes.""" + assert issubclass(ThreadSafeSingletonMixin, SingletonMixin) + assert hasattr(ThreadSafeSingletonMixin, "get_singleton_lock") + assert hasattr(ThreadSafeSingletonMixin, "__new__") + assert hasattr(ThreadSafeSingletonMixin, "__init__") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test ThreadSafeSingletonMixin initialization.""" + singleton_class, config = valid_instances + + instance = singleton_class() + + assert isinstance(instance, singleton_class) + assert instance.initialized is True + assert hasattr(instance, "value") + assert instance.value == config["init_value"] + assert hasattr(instance, "thread_lock") + lock_type = type(threading.Lock()) + assert isinstance(instance.thread_lock, lock_type) + + @pytest.mark.smoke + def test_singleton_behavior(self, valid_instances): + """Test multiple instantiations return same instance with thread safety.""" + singleton_class, config = valid_instances + + instance1 = singleton_class() + instance2 = singleton_class() + + assert instance1 is instance2 + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + assert hasattr(instance1, "thread_lock") + + @pytest.mark.regression + def test_thread_safety_concurrent_creation(self, valid_instances): + """Test thread safety during concurrent instance creation.""" + singleton_class, config = valid_instances + + instances = [] + exceptions = [] + creation_count = 0 + lock = threading.Lock() + + class ThreadSafeTestSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + nonlocal creation_count + with lock: + creation_count += 1 + + time.sleep(0.01) + self.value = config["init_value"] + + def create_instance(): + try: + instance = ThreadSafeTestSingleton() + instances.append(instance) + except (TypeError, ValueError, AttributeError) as exc: + exceptions.append(exc) + + threads = [] + for _ in range(10): + thread = threading.Thread(target=create_instance) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + assert len(exceptions) == 0, f"Exceptions occurred: {exceptions}" + + assert len(instances) == 10 + for instance in instances: + assert instance is instances[0] + + assert creation_count == 1 + assert all(instance.value == config["init_value"] for instance in instances) + + @pytest.mark.sanity + def test_thread_lock_creation(self, valid_instances): + """Test that thread_lock is created during initialization.""" + singleton_class, config = valid_instances + + instance1 = singleton_class() + instance2 = singleton_class() + + assert hasattr(instance1, "thread_lock") + lock_type = type(threading.Lock()) + assert isinstance(instance1.thread_lock, lock_type) + assert instance1.thread_lock is instance2.thread_lock + + @pytest.mark.regression + def test_multiple_thread_safe_classes_isolation(self): + """Test thread-safe singleton classes behavior with separate locks.""" + + class ThreadSafeSingleton1(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value1" + + class ThreadSafeSingleton2(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value2" + + instance1 = ThreadSafeSingleton1() + instance2 = ThreadSafeSingleton2() + + lock1 = ThreadSafeSingleton1.get_singleton_lock() + lock2 = ThreadSafeSingleton2.get_singleton_lock() + + assert lock1 is not None + assert lock2 is not None + assert lock1 is not lock2 + + assert instance1 is not instance2 + assert hasattr(instance1, "value") + assert hasattr(instance2, "value") + assert instance1.value == "value1" + assert instance2.value == "value2" + + @pytest.mark.sanity + def test_inheritance_with_thread_safety(self): + """Test inheritance behavior with thread-safe singletons.""" + + class BaseThreadSafeSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "base_value" + + class ChildThreadSafeSingleton(BaseThreadSafeSingleton): + def __init__(self): + super().__init__() + + base_instance = BaseThreadSafeSingleton() + child_instance = ChildThreadSafeSingleton() + + base_lock = BaseThreadSafeSingleton.get_singleton_lock() + child_lock = ChildThreadSafeSingleton.get_singleton_lock() + + assert base_lock is not None + assert child_lock is not None + assert base_lock is not child_lock + + assert base_instance is not child_instance + assert hasattr(base_instance, "value") + assert base_instance.value == "base_value" + assert hasattr(base_instance, "thread_lock") diff --git a/tests/unit/objects/test_statistics.py b/tests/unit/utils/test_statistics.py similarity index 90% rename from tests/unit/objects/test_statistics.py rename to tests/unit/utils/test_statistics.py index ede77175..fa8cccd0 100644 --- a/tests/unit/objects/test_statistics.py +++ b/tests/unit/utils/test_statistics.py @@ -704,82 +704,3 @@ def test_time_running_stats_update(): assert time_running_stats.rate_ms == pytest.approx( 3000 / (time.time() - time_running_stats.start_time), abs=0.1 ) - - -@pytest.mark.regression -def test_distribution_summary_concurrency_double_counting_regression(): - """Specific regression test for the double-counting bug in concurrency calculation. - - Before the fix, when events were merged due to epsilon, the deltas were summed - but then the active count wasn't properly accumulated, causing incorrect results. - - ### WRITTEN BY AI ### - """ - epsilon = 1e-6 - - # Create a scenario where multiple requests start at exactly the same time - # This should result in events being merged, testing the accumulation logic - same_start_time = 1.0 - requests = [ - (same_start_time, 3.0), - (same_start_time, 4.0), - (same_start_time, 5.0), - (same_start_time + epsilon / 3, 6.0), # Very close start (within epsilon) - ] - - distribution_summary = DistributionSummary.from_request_times( - requests, distribution_type="concurrency", epsilon=epsilon - ) - - # All requests start at the same time (or within epsilon), so they should - # all be considered concurrent from the start - # Expected timeline: - # - t=1.0-3.0: 4 concurrent requests - # - t=3.0-4.0: 3 concurrent requests - # - t=4.0-5.0: 2 concurrent requests - # - t=5.0-6.0: 1 concurrent request - - assert distribution_summary.max == 4.0 # All 4 requests concurrent at start - assert distribution_summary.min == 1.0 # 1 request still running at the end - - -@pytest.mark.sanity -def test_distribution_summary_concurrency_epsilon_edge_case(): - """Test the exact epsilon boundary condition. - - ### WRITTEN BY AI ### - """ - epsilon = 1e-6 - - # Test requests that are exactly epsilon apart - should be merged - requests_exactly_epsilon = [ - (1.0, 2.0), - (1.0 + epsilon, 2.5), # Exactly epsilon apart - (2.0, 2.5), # Another close request - ] - - dist_epsilon = DistributionSummary.from_request_times( - requests_exactly_epsilon, distribution_type="concurrency", epsilon=epsilon - ) - - # Should be treated as concurrent (merged events) - assert dist_epsilon.max == 2.0 - assert dist_epsilon.min == 2.0 - - # Test requests that are just over epsilon apart - should NOT be merged - requests_over_epsilon = [ - (1.0, 2.0), - (1.0 + epsilon * 1.1, 2.5), # Just over epsilon apart - (2.0, 2.5), # Another close request - ] - - dist_over_epsilon = DistributionSummary.from_request_times( - requests_over_epsilon, distribution_type="concurrency", epsilon=epsilon - ) - - # These should be treated separately, so max concurrency depends on overlap - # At t=1.0 to 1.0+epsilon*1.1: 1 concurrent - # At t=1.0+epsilon*1.1 to 2.0: 2 concurrent - # At t=2.0 to 2.5: 1 concurrent - assert dist_over_epsilon.max == 2.0 - assert dist_over_epsilon.min == 1.0 diff --git a/tests/unit/utils/test_text.py b/tests/unit/utils/test_text.py new file mode 100644 index 00000000..2f363c46 --- /dev/null +++ b/tests/unit/utils/test_text.py @@ -0,0 +1,531 @@ +from __future__ import annotations + +import gzip +import tempfile +from pathlib import Path +from unittest.mock import Mock, patch + +import httpx +import pytest + +from guidellm.utils.text import ( + MAX_PATH_LENGTH, + EndlessTextCreator, + clean_text, + filter_text, + format_value_display, + is_punctuation, + load_text, + split_text, + split_text_list_by_length, +) + + +def test_max_path_length(): + """Test that MAX_PATH_LENGTH is correctly defined.""" + assert isinstance(MAX_PATH_LENGTH, int) + assert MAX_PATH_LENGTH == 4096 + + +class TestFormatValueDisplay: + """Test suite for format_value_display.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ( + "value", + "label", + "units", + "total_characters", + "digits_places", + "decimal_places", + "expected", + ), + [ + (42.0, "test", "", None, None, None, "42 [info]test[/info]"), + (42.5, "test", "ms", None, None, 1, "42.5ms [info]test[/info]"), + (42.123, "test", "", None, 5, 2, " 42.12 [info]test[/info]"), + ( + 42.0, + "test", + "ms", + 30, + None, + 0, + " 42ms [info]test[/info]", + ), + ], + ) + def test_invocation( + self, + value, + label, + units, + total_characters, + digits_places, + decimal_places, + expected, + ): + """Test format_value_display with various parameters.""" + result = format_value_display( + value=value, + label=label, + units=units, + total_characters=total_characters, + digits_places=digits_places, + decimal_places=decimal_places, + ) + assert label in result + assert units in result + value_check = ( + str(int(value)) + if decimal_places == 0 + else ( + f"{value:.{decimal_places}f}" + if decimal_places is not None + else str(value) + ) + ) + assert value_check in result or str(value) in result + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("value", "label"), + [ + (None, "test"), + (42.0, None), + ("not_number", "test"), + ], + ) + def test_invocation_with_none_values(self, value, label): + """Test format_value_display with None/invalid inputs still works.""" + result = format_value_display(value, label) + assert isinstance(result, str) + if label is not None: + assert str(label) in result + if value is not None: + assert str(value) in result + + +class TestSplitTextListByLength: + """Test suite for split_text_list_by_length.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ( + "text_list", + "max_characters", + "pad_horizontal", + "pad_vertical", + "expected_structure", + ), + [ + ( + ["hello world", "test"], + 5, + False, + False, + [["hello", "world"], ["test"]], + ), + ( + ["short", "longer text"], + [5, 10], + True, + True, + [[" short"], ["longer", "text"]], + ), + ( + ["a", "b", "c"], + 10, + True, + True, + [[" a"], [" b"], [" c"]], + ), + ], + ) + def test_invocation( + self, + text_list, + max_characters, + pad_horizontal, + pad_vertical, + expected_structure, + ): + """Test split_text_list_by_length with various parameters.""" + result = split_text_list_by_length( + text_list, max_characters, pad_horizontal, pad_vertical + ) + assert len(result) == len(text_list) + if pad_vertical: + max_lines = max(len(lines) for lines in result) + assert all(len(lines) == max_lines for lines in result) + + @pytest.mark.sanity + def test_invalid_max_characters_length(self): + """Test split_text_list_by_length with mismatched max_characters length.""" + error_msg = "max_characters must be a list of the same length" + with pytest.raises(ValueError, match=error_msg): + split_text_list_by_length(["a", "b"], [5, 10, 15]) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("text_list", "max_characters"), + [ + (None, 5), + (["test"], None), + (["test"], []), + ], + ) + def test_invalid_invocation(self, text_list, max_characters): + """Test split_text_list_by_length with invalid inputs.""" + with pytest.raises((TypeError, ValueError)): + split_text_list_by_length(text_list, max_characters) + + +class TestFilterText: + """Test suite for filter_text.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "filter_start", "filter_end", "expected"), + [ + ("hello world test", "world", None, "world test"), + ("hello world test", None, "world", "hello "), + ("hello world test", "hello", "test", "hello world "), + ("hello world test", 6, 11, "world test"), + ("hello world test", 0, 5, "hello"), + ("hello world test", None, None, "hello world test"), + ], + ) + def test_invocation(self, text, filter_start, filter_end, expected): + """Test filter_text with various start and end markers.""" + result = filter_text(text, filter_start, filter_end) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("text", "filter_start", "filter_end"), + [ + ("hello", "notfound", None), + ("hello", None, "notfound"), + ("hello", "invalid_type", None), + ("hello", None, "invalid_type"), + ], + ) + def test_invalid_invocation(self, text, filter_start, filter_end): + """Test filter_text with invalid markers.""" + with pytest.raises((ValueError, TypeError)): + filter_text(text, filter_start, filter_end) + + +class TestCleanText: + """Test suite for clean_text.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "expected"), + [ + ("hello world", "hello world"), + (" hello\n\nworld ", "hello world"), + ("hello\tworld\r\ntest", "hello world test"), + ("", ""), + (" ", ""), + ], + ) + def test_invocation(self, text, expected): + """Test clean_text with various whitespace scenarios.""" + result = clean_text(text) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + "text", + [ + None, + 123, + ], + ) + def test_invalid_invocation(self, text): + """Test clean_text with invalid inputs.""" + with pytest.raises((TypeError, AttributeError)): + clean_text(text) + + +class TestSplitText: + """Test suite for split_text.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "split_punctuation", "expected"), + [ + ("hello world", False, ["hello", "world"]), + ("hello, world!", True, ["hello", ",", "world", "!"]), + ("test.example", False, ["test.example"]), + ("test.example", True, ["test", ".", "example"]), + ("", False, []), + ], + ) + def test_invocation(self, text, split_punctuation, expected): + """Test split_text with various punctuation options.""" + result = split_text(text, split_punctuation) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + "text", + [ + None, + 123, + ], + ) + def test_invalid_invocation(self, text): + """Test split_text with invalid inputs.""" + with pytest.raises((TypeError, AttributeError)): + split_text(text) + + +class TestLoadText: + """Test suite for load_text.""" + + @pytest.mark.smoke + def test_empty_data(self): + """Test load_text with empty data.""" + result = load_text("") + assert result == "" + + @pytest.mark.smoke + def test_raw_text(self): + """Test load_text with raw text that's not a file.""" + long_text = "a" * (MAX_PATH_LENGTH + 1) + result = load_text(long_text) + assert result == long_text + + @pytest.mark.smoke + def test_local_file(self): + """Test load_text with local file.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as tmp: + test_content = "test file content" + tmp.write(test_content) + tmp.flush() + + result = load_text(tmp.name) + assert result == test_content + + Path(tmp.name).unlink() + + @pytest.mark.smoke + def test_gzipped_file(self): + """Test load_text with gzipped file.""" + with tempfile.NamedTemporaryFile(delete=False, suffix=".gz") as tmp: + test_content = "test gzipped content" + with gzip.open(tmp.name, "wt") as gzf: + gzf.write(test_content) + + result = load_text(tmp.name) + assert result == test_content + + Path(tmp.name).unlink() + + @pytest.mark.smoke + @patch("httpx.Client") + def test_url_loading(self, mock_client): + """Test load_text with HTTP URL.""" + mock_response = Mock() + mock_response.text = "url content" + mock_client.return_value.__enter__.return_value.get.return_value = mock_response + + result = load_text("http://example.com/test.txt") + assert result == "url content" + + @pytest.mark.smoke + @patch("guidellm.utils.text.files") + @patch("guidellm.utils.text.as_file") + def test_package_data_loading(self, mock_as_file, mock_files): + """Test load_text with package data.""" + mock_resource = Mock() + mock_files.return_value.joinpath.return_value = mock_resource + + mock_file = Mock() + mock_file.read.return_value = "package data content" + mock_as_file.return_value.__enter__.return_value = mock_file + + with patch("gzip.open") as mock_gzip: + mock_gzip.return_value.__enter__.return_value = mock_file + result = load_text("data:test.txt") + assert result == "package data content" + + @pytest.mark.sanity + def test_nonexistent_file(self): + """Test load_text with nonexistent file returns the path as raw text.""" + result = load_text("/nonexistent/path/file.txt") + assert result == "/nonexistent/path/file.txt" + + @pytest.mark.sanity + @patch("httpx.Client") + def test_url_error(self, mock_client): + """Test load_text with HTTP error.""" + mock_client.return_value.__enter__.return_value.get.side_effect = ( + httpx.HTTPStatusError("HTTP error", request=None, response=None) + ) + + with pytest.raises(httpx.HTTPStatusError): + load_text("http://example.com/error.txt") + + +class TestIsPunctuation: + """Test suite for is_puncutation.""" + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("text", "expected"), + [ + (".", True), + (",", True), + ("!", True), + ("?", True), + (";", True), + ("a", False), + ("1", False), + (" ", False), + ("ab", False), + ("", False), + ], + ) + def test_invocation(self, text, expected): + """Test is_punctuation with various characters.""" + result = is_punctuation(text) + assert result == expected + + @pytest.mark.sanity + @pytest.mark.parametrize( + "text", + [ + None, + 123, + ], + ) + def test_invalid_invocation(self, text): + """Test is_punctuation with invalid inputs.""" + with pytest.raises((TypeError, AttributeError)): + is_punctuation(text) + + +class TestEndlessTextCreator: + """Test suite for EndlessTextCreator.""" + + @pytest.fixture( + params=[ + { + "data": "hello world test", + "filter_start": None, + "filter_end": None, + }, + { + "data": "hello world test", + "filter_start": "world", + "filter_end": None, + }, + {"data": "one two three four", "filter_start": 0, "filter_end": 9}, + ], + ids=["no_filter", "string_filter", "index_filter"], + ) + def valid_instances(self, request): + """Fixture providing test data for EndlessTextCreator.""" + constructor_args = request.param + instance = EndlessTextCreator(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test EndlessTextCreator signatures and methods.""" + assert hasattr(EndlessTextCreator, "__init__") + assert hasattr(EndlessTextCreator, "create_text") + instance = EndlessTextCreator("test") + assert hasattr(instance, "data") + assert hasattr(instance, "text") + assert hasattr(instance, "filtered_text") + assert hasattr(instance, "words") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test EndlessTextCreator initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, EndlessTextCreator) + assert instance.data == constructor_args["data"] + assert isinstance(instance.text, str) + assert isinstance(instance.filtered_text, str) + assert isinstance(instance.words, list) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("data", "filter_start", "filter_end"), + [ + ("test", "notfound", None), + ], + ) + def test_invalid_initialization_values(self, data, filter_start, filter_end): + """Test EndlessTextCreator with invalid initialization values.""" + with pytest.raises((TypeError, ValueError)): + EndlessTextCreator(data, filter_start, filter_end) + + @pytest.mark.smoke + def test_initialization_with_none(self): + """Test EndlessTextCreator handles None data gracefully.""" + instance = EndlessTextCreator(None) + assert isinstance(instance, EndlessTextCreator) + assert instance.data is None + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("start", "length", "expected_length"), + [ + (0, 5, 5), + (2, 3, 3), + (0, 0, 0), + ], + ) + def test_create_text(self, valid_instances, start, length, expected_length): + """Test EndlessTextCreator.create_text.""" + instance, constructor_args = valid_instances + result = instance.create_text(start, length) + assert isinstance(result, str) + if length > 0 and instance.words: + assert len(result) > 0 + + @pytest.mark.smoke + def test_create_text_cycling(self): + """Test EndlessTextCreator.create_text cycling behavior.""" + instance = EndlessTextCreator("one two three") + result1 = instance.create_text(0, 3) + result2 = instance.create_text(3, 3) + assert isinstance(result1, str) + assert isinstance(result2, str) + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("start", "length"), + [ + ("invalid", 5), + (0, "invalid"), + ], + ) + def test_create_text_invalid(self, valid_instances, start, length): + """Test EndlessTextCreator.create_text with invalid inputs.""" + instance, constructor_args = valid_instances + with pytest.raises((TypeError, ValueError)): + instance.create_text(start, length) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("start", "length", "min_length"), + [ + (-1, 5, 0), + (0, -1, 0), + ], + ) + def test_create_text_edge_cases(self, valid_instances, start, length, min_length): + """Test EndlessTextCreator.create_text with edge cases.""" + instance, constructor_args = valid_instances + result = instance.create_text(start, length) + assert isinstance(result, str) + assert len(result) >= min_length