diff --git a/src/guidellm/backend/response.py b/src/guidellm/backend/response.py index ee2101d7..bfa738d8 100644 --- a/src/guidellm/backend/response.py +++ b/src/guidellm/backend/response.py @@ -3,7 +3,7 @@ from pydantic import computed_field from guidellm.config import settings -from guidellm.objects.pydantic import StandardBaseModel +from guidellm.utils import StandardBaseModel __all__ = [ "RequestArgs", diff --git a/src/guidellm/benchmark/aggregator.py b/src/guidellm/benchmark/aggregator.py index af7f1a13..b322eadd 100644 --- a/src/guidellm/benchmark/aggregator.py +++ b/src/guidellm/benchmark/aggregator.py @@ -24,8 +24,6 @@ from guidellm.config import settings from guidellm.objects import ( RunningStats, - StandardBaseModel, - StatusBreakdown, TimeRunningStats, ) from guidellm.request import ( @@ -40,7 +38,7 @@ SchedulerRequestResult, WorkerDescription, ) -from guidellm.utils import check_load_processor +from guidellm.utils import StandardBaseModel, StatusBreakdown, check_load_processor __all__ = [ "AggregatorT", diff --git a/src/guidellm/benchmark/benchmark.py b/src/guidellm/benchmark/benchmark.py index 1e2a5f4b..7fd6ad1b 100644 --- a/src/guidellm/benchmark/benchmark.py +++ b/src/guidellm/benchmark/benchmark.py @@ -13,8 +13,6 @@ ThroughputProfile, ) from guidellm.objects import ( - StandardBaseModel, - StatusBreakdown, StatusDistributionSummary, ) from guidellm.request import ( @@ -32,6 +30,7 @@ ThroughputStrategy, WorkerDescription, ) +from guidellm.utils import StandardBaseModel, StatusBreakdown __all__ = [ "Benchmark", diff --git a/src/guidellm/benchmark/benchmarker.py b/src/guidellm/benchmark/benchmarker.py index 11b6d245..876e6f43 100644 --- a/src/guidellm/benchmark/benchmarker.py +++ b/src/guidellm/benchmark/benchmarker.py @@ -22,7 +22,6 @@ ) from guidellm.benchmark.benchmark import BenchmarkArgs, GenerativeBenchmark from guidellm.benchmark.profile import Profile -from guidellm.objects import StandardBaseModel from guidellm.request import ( GenerationRequest, GenerativeRequestLoaderDescription, @@ -37,6 +36,7 @@ SchedulerRequestResult, SchedulingStrategy, ) +from guidellm.utils import StandardBaseModel __all__ = ["Benchmarker", "BenchmarkerResult", "GenerativeBenchmarker"] diff --git a/src/guidellm/benchmark/output.py b/src/guidellm/benchmark/output.py index 8a113f72..dd94f899 100644 --- a/src/guidellm/benchmark/output.py +++ b/src/guidellm/benchmark/output.py @@ -23,13 +23,12 @@ from guidellm.config import settings from guidellm.objects import ( DistributionSummary, - StandardBaseModel, StatusDistributionSummary, ) from guidellm.presentation import UIDataBuilder from guidellm.presentation.injector import create_report from guidellm.scheduler import strategy_display_str -from guidellm.utils import Colors, split_text_list_by_length +from guidellm.utils import Colors, StandardBaseModel, split_text_list_by_length __all__ = [ "GenerativeBenchmarksConsole", diff --git a/src/guidellm/benchmark/profile.py b/src/guidellm/benchmark/profile.py index 642cb7a8..d46f2b16 100644 --- a/src/guidellm/benchmark/profile.py +++ b/src/guidellm/benchmark/profile.py @@ -5,7 +5,6 @@ from pydantic import Field, computed_field from guidellm.config import settings -from guidellm.objects import StandardBaseModel from guidellm.scheduler import ( AsyncConstantStrategy, AsyncPoissonStrategy, @@ -15,6 +14,7 @@ SynchronousStrategy, ThroughputStrategy, ) +from guidellm.utils import StandardBaseModel __all__ = [ "AsyncProfile", diff --git a/src/guidellm/benchmark/scenario.py b/src/guidellm/benchmark/scenario.py index af43e426..57dfa98b 100644 --- a/src/guidellm/benchmark/scenario.py +++ b/src/guidellm/benchmark/scenario.py @@ -11,8 +11,8 @@ from guidellm.backend.backend import BackendType from guidellm.benchmark.profile import ProfileType -from guidellm.objects.pydantic import StandardBaseModel from guidellm.scheduler.strategy import StrategyType +from guidellm.utils import StandardBaseModel __ALL__ = ["Scenario", "GenerativeTextScenario", "get_builtin_scenarios"] diff --git a/src/guidellm/objects/__init__.py b/src/guidellm/objects/__init__.py index 89e3c9b9..119ac6e7 100644 --- a/src/guidellm/objects/__init__.py +++ b/src/guidellm/objects/__init__.py @@ -1,4 +1,3 @@ -from .pydantic import StandardBaseModel, StatusBreakdown from .statistics import ( DistributionSummary, Percentiles, @@ -11,8 +10,6 @@ "DistributionSummary", "Percentiles", "RunningStats", - "StandardBaseModel", - "StatusBreakdown", "StatusDistributionSummary", "TimeRunningStats", ] diff --git a/src/guidellm/objects/pydantic.py b/src/guidellm/objects/pydantic.py deleted file mode 100644 index fcededcf..00000000 --- a/src/guidellm/objects/pydantic.py +++ /dev/null @@ -1,89 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Generic, Optional, TypeVar - -import yaml -from loguru import logger -from pydantic import BaseModel, ConfigDict, Field - -__all__ = ["StandardBaseModel", "StatusBreakdown"] - -T = TypeVar("T", bound="StandardBaseModel") - - -class StandardBaseModel(BaseModel): - """ - A base class for Pydantic models throughout GuideLLM enabling standard - configuration and logging. - """ - - model_config = ConfigDict( - extra="ignore", - use_enum_values=True, - validate_assignment=True, - from_attributes=True, - ) - - def __init__(self, /, **data: Any) -> None: - super().__init__(**data) - logger.debug( - "Initialized new instance of {} with data: {}", - self.__class__.__name__, - data, - ) - - @classmethod - def get_default(cls: type[T], field: str) -> Any: - """Get default values for model fields""" - return cls.model_fields[field].default - - @classmethod - def from_file(cls: type[T], filename: Path, overrides: Optional[dict] = None) -> T: - """ - Attempt to create a new instance of the model using - data loaded from json or yaml file. - """ - try: - with filename.open() as f: - if str(filename).endswith(".json"): - data = json.load(f) - else: # Assume everything else is yaml - data = yaml.safe_load(f) - except (json.JSONDecodeError, yaml.YAMLError) as e: - logger.error(f"Failed to parse {filename} as type {cls.__name__}") - raise ValueError(f"Error when parsing file: {filename}") from e - - data.update(overrides) - return cls.model_validate(data) - - -SuccessfulT = TypeVar("SuccessfulT") -ErroredT = TypeVar("ErroredT") -IncompleteT = TypeVar("IncompleteT") -TotalT = TypeVar("TotalT") - - -class StatusBreakdown(BaseModel, Generic[SuccessfulT, ErroredT, IncompleteT, TotalT]): - """ - A base class for Pydantic models that are separated by statuses including - successful, incomplete, and errored. It additionally enables the inclusion - of total, which is intended as the combination of all statuses. - Total may or may not be used depending on if it duplicates information. - """ - - successful: SuccessfulT = Field( - description="The results with a successful status.", - default=None, # type: ignore[assignment] - ) - errored: ErroredT = Field( - description="The results with an errored status.", - default=None, # type: ignore[assignment] - ) - incomplete: IncompleteT = Field( - description="The results with an incomplete status.", - default=None, # type: ignore[assignment] - ) - total: TotalT = Field( - description="The combination of all statuses.", - default=None, # type: ignore[assignment] - ) diff --git a/src/guidellm/objects/statistics.py b/src/guidellm/objects/statistics.py index 7831b2cf..7989f7c4 100644 --- a/src/guidellm/objects/statistics.py +++ b/src/guidellm/objects/statistics.py @@ -6,7 +6,7 @@ import numpy as np from pydantic import Field, computed_field -from guidellm.objects.pydantic import StandardBaseModel, StatusBreakdown +from guidellm.utils import StandardBaseModel, StatusBreakdown __all__ = [ "DistributionSummary", diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index 50ab3cca..2eff87d5 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -13,8 +13,8 @@ from guidellm.config import settings from guidellm.dataset import ColumnInputTypes, load_dataset -from guidellm.objects import StandardBaseModel from guidellm.request.request import GenerationRequest +from guidellm.utils import StandardBaseModel __all__ = [ "GenerativeRequestLoader", diff --git a/src/guidellm/request/request.py b/src/guidellm/request/request.py index 81c8cabd..bf4e59fb 100644 --- a/src/guidellm/request/request.py +++ b/src/guidellm/request/request.py @@ -3,7 +3,7 @@ from pydantic import Field -from guidellm.objects.pydantic import StandardBaseModel +from guidellm.utils import StandardBaseModel __all__ = ["GenerationRequest"] diff --git a/src/guidellm/scheduler/result.py b/src/guidellm/scheduler/result.py index 0f12687f..0cca530b 100644 --- a/src/guidellm/scheduler/result.py +++ b/src/guidellm/scheduler/result.py @@ -4,9 +4,9 @@ Optional, ) -from guidellm.objects import StandardBaseModel from guidellm.scheduler.strategy import SchedulingStrategy from guidellm.scheduler.types import RequestT, ResponseT +from guidellm.utils import StandardBaseModel __all__ = [ "SchedulerRequestInfo", diff --git a/src/guidellm/scheduler/strategy.py b/src/guidellm/scheduler/strategy.py index 200c799e..d4c065da 100644 --- a/src/guidellm/scheduler/strategy.py +++ b/src/guidellm/scheduler/strategy.py @@ -12,7 +12,7 @@ from pydantic import Field from guidellm.config import settings -from guidellm.objects import StandardBaseModel +from guidellm.utils import StandardBaseModel __all__ = [ "AsyncConstantStrategy", diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index a53b14c2..ab16e4db 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -24,10 +24,10 @@ ResponseSummary, StreamingTextResponse, ) -from guidellm.objects import StandardBaseModel from guidellm.request import GenerationRequest from guidellm.scheduler.result import SchedulerRequestInfo from guidellm.scheduler.types import RequestT, ResponseT +from guidellm.utils import StandardBaseModel __all__ = [ "GenerativeRequestsWorker", diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index fb9262c3..98ac1c36 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -1,3 +1,4 @@ +from .auto_importer import AutoImporterMixin from .colors import Colors from .default_group import DefaultGroupHandler from .hf_datasets import ( @@ -7,7 +8,16 @@ from .hf_transformers import ( check_load_processor, ) +from .pydantic_utils import ( + PydanticClassRegistryMixin, + ReloadableBaseModel, + StandardBaseDict, + StandardBaseModel, + StatusBreakdown, +) from .random import IntegerRangeSampler +from .registry import RegistryMixin +from .singleton import SingletonMixin, ThreadSafeSingletonMixin from .text import ( EndlessTextCreator, clean_text, @@ -20,10 +30,19 @@ __all__ = [ "SUPPORTED_TYPES", + "AutoImporterMixin", "Colors", "DefaultGroupHandler", "EndlessTextCreator", "IntegerRangeSampler", + "PydanticClassRegistryMixin", + "RegistryMixin", + "ReloadableBaseModel", + "SingletonMixin", + "StandardBaseDict", + "StandardBaseModel", + "StatusBreakdown", + "ThreadSafeSingletonMixin", "check_load_processor", "clean_text", "filter_text", diff --git a/src/guidellm/utils/auto_importer.py b/src/guidellm/utils/auto_importer.py new file mode 100644 index 00000000..5b939014 --- /dev/null +++ b/src/guidellm/utils/auto_importer.py @@ -0,0 +1,98 @@ +""" +Automatic module importing utilities for dynamic class discovery. + +This module provides a mixin class for automatic module importing within a package, +enabling dynamic discovery of classes and implementations without explicit imports. +It is particularly useful for auto-registering classes in a registry pattern where +subclasses need to be discoverable at runtime. + +The AutoImporterMixin can be combined with registration mechanisms to create +extensible systems where new implementations are automatically discovered and +registered when they are placed in the correct package structure. +""" + +from __future__ import annotations + +import importlib +import pkgutil +import sys +from typing import ClassVar + +__all__ = ["AutoImporterMixin"] + + +class AutoImporterMixin: + """ + Mixin class for automatic module importing within packages. + + This mixin enables dynamic discovery of classes and implementations without + explicit imports by automatically importing all modules within specified + packages. It is designed for use with class registration mechanisms to enable + automatic discovery and registration of classes when they are placed in the + correct package structure. + + Example: + :: + from guidellm.utils import AutoImporterMixin + + class MyRegistry(AutoImporterMixin): + auto_package = "my_package.implementations" + + MyRegistry.auto_import_package_modules() + + :cvar auto_package: Package name or tuple of package names to import modules from + :cvar auto_ignore_modules: Module names to ignore during import + :cvar auto_imported_modules: List tracking which modules have been imported + """ + + auto_package: ClassVar[str | tuple[str, ...] | None] = None + auto_ignore_modules: ClassVar[tuple[str, ...] | None] = None + auto_imported_modules: ClassVar[list[str] | None] = None + + @classmethod + def auto_import_package_modules(cls) -> None: + """ + Automatically import all modules within the specified package(s). + + Scans the package(s) defined in the `auto_package` class variable and imports + all modules found, tracking them in `auto_imported_modules`. Skips packages + (directories) and any modules listed in `auto_ignore_modules`. + + :raises ValueError: If the `auto_package` class variable is not set + """ + if cls.auto_package is None: + raise ValueError( + "The class variable 'auto_package' must be set to the package name to " + "import modules from." + ) + + cls.auto_imported_modules = [] + packages = ( + cls.auto_package + if isinstance(cls.auto_package, tuple) + else (cls.auto_package,) + ) + + for package_name in packages: + package = importlib.import_module(package_name) + + for _, module_name, is_pkg in pkgutil.walk_packages( + package.__path__, package.__name__ + "." + ): + if ( + is_pkg + or ( + cls.auto_ignore_modules is not None + and module_name in cls.auto_ignore_modules + ) + or module_name in cls.auto_imported_modules + ): + # Skip packages and ignored modules + continue + + if module_name in sys.modules: + # Avoid circular imports + cls.auto_imported_modules.append(module_name) + else: + importlib.import_module(module_name) + cls.auto_imported_modules.append(module_name) diff --git a/src/guidellm/utils/pydantic_utils.py b/src/guidellm/utils/pydantic_utils.py new file mode 100644 index 00000000..52bf6564 --- /dev/null +++ b/src/guidellm/utils/pydantic_utils.py @@ -0,0 +1,302 @@ +""" +Pydantic utilities for polymorphic model serialization and registry integration. + +Provides integration between Pydantic and the registry system, enabling +polymorphic serialization and deserialization of Pydantic models using +a discriminator field and dynamic class registry. Includes base model classes +with standardized configurations and generic status breakdown models for +structured result organization. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Generic, TypeVar + +from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema + +from guidellm.utils.registry import RegistryMixin + +__all__ = [ + "PydanticClassRegistryMixin", + "ReloadableBaseModel", + "StandardBaseDict", + "StandardBaseModel", + "StatusBreakdown", +] + + +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) +SuccessfulT = TypeVar("SuccessfulT") +ErroredT = TypeVar("ErroredT") +IncompleteT = TypeVar("IncompleteT") +TotalT = TypeVar("TotalT") + + +class ReloadableBaseModel(BaseModel): + """ + Base Pydantic model with schema reloading capabilities. + + Provides dynamic schema rebuilding functionality for models that need to + update their validation schemas at runtime, particularly useful when + working with registry-based polymorphic models where new types are + registered after initial class definition. + """ + + model_config = ConfigDict( + extra="ignore", + use_enum_values=True, + validate_assignment=True, + from_attributes=True, + arbitrary_types_allowed=True, + ) + + @classmethod + def reload_schema(cls) -> None: + """ + Reload the class schema with updated registry information. + + Forces a complete rebuild of the Pydantic model schema to incorporate + any changes made to associated registries or validation rules. + """ + cls.model_rebuild(force=True) + + +class StandardBaseModel(BaseModel): + """ + Base Pydantic model with standardized configuration for GuideLLM. + + Provides consistent validation behavior and configuration settings across + all Pydantic models in the application, including field validation, + attribute conversion, and default value handling. + + Example: + :: + class MyModel(StandardBaseModel): + name: str + value: int = 42 + + # Access default values + default_value = MyModel.get_default("value") # Returns 42 + """ + + model_config = ConfigDict( + extra="ignore", + use_enum_values=True, + validate_assignment=True, + from_attributes=True, + ) + + @classmethod + def get_default(cls: type[BaseModelT], field: str) -> Any: + """ + Get default value for a model field. + + :param field: Name of the field to get the default value for + :return: Default value of the specified field + :raises KeyError: If the field does not exist in the model + """ + return cls.model_fields[field].default + + +class StandardBaseDict(StandardBaseModel): + """ + Base Pydantic model allowing arbitrary additional fields. + + Extends StandardBaseModel to accept extra fields beyond those explicitly + defined in the model schema. Useful for flexible data structures that + need to accommodate varying or unknown field sets while maintaining + type safety for known fields. + """ + + model_config = ConfigDict( + extra="allow", + use_enum_values=True, + validate_assignment=True, + from_attributes=True, + arbitrary_types_allowed=True, + ) + + +class StatusBreakdown(BaseModel, Generic[SuccessfulT, ErroredT, IncompleteT, TotalT]): + """ + Generic model for organizing results by processing status. + + Provides structured categorization of results into successful, errored, + incomplete, and total status groups. Supports flexible typing for each + status category to accommodate different result types while maintaining + consistent organization patterns across the application. + + Example: + :: + from guidellm.utils.pydantic_utils import StatusBreakdown + + # Define a breakdown for request counts + breakdown = StatusBreakdown[int, int, int, int]( + successful=150, + errored=5, + incomplete=10, + total=165 + ) + """ + + successful: SuccessfulT = Field( + description="Results or metrics for requests with successful completion status", + default=None, # type: ignore[assignment] + ) + errored: ErroredT = Field( + description="Results or metrics for requests with error completion status", + default=None, # type: ignore[assignment] + ) + incomplete: IncompleteT = Field( + description="Results or metrics for requests with incomplete processing status", + default=None, # type: ignore[assignment] + ) + total: TotalT = Field( + description="Aggregated results or metrics combining all status categories", + default=None, # type: ignore[assignment] + ) + + +class PydanticClassRegistryMixin( + ReloadableBaseModel, RegistryMixin[type[BaseModelT]], ABC, Generic[BaseModelT] +): + """ + Polymorphic Pydantic model mixin enabling registry-based dynamic instantiation. + + Integrates Pydantic validation with the registry system to enable polymorphic + serialization and deserialization based on a discriminator field. Automatically + instantiates the correct subclass during validation based on registry mappings, + providing a foundation for extensible plugin-style architectures. + + Example: + :: + from guidellm.utils.pydantic_utils import PydanticClassRegistryMixin + + class BaseConfig(PydanticClassRegistryMixin["BaseConfig"]): + schema_discriminator: ClassVar[str] = "config_type" + config_type: str = Field(description="Configuration type identifier") + + @classmethod + def __pydantic_schema_base_type__(cls) -> type["BaseConfig"]: + return BaseConfig + + @BaseConfig.register("database") + class DatabaseConfig(BaseConfig): + config_type: str = "database" + connection_string: str = Field(description="Database connection string") + + # Dynamic instantiation based on discriminator + config = BaseConfig.model_validate({ + "config_type": "database", + "connection_string": "postgresql://localhost:5432/db" + }) + + :cvar schema_discriminator: Field name used for polymorphic type discrimination + """ + + schema_discriminator: ClassVar[str] = "model_type" + + @classmethod + def register_decorator( + cls, clazz: type[BaseModelT], name: str | list[str] | None = None + ) -> type[BaseModelT]: + """ + Register a Pydantic model class with type validation and schema reload. + + Validates that the class is a proper Pydantic BaseModel subclass before + registering it in the class registry. Automatically triggers schema + reload to incorporate the new type into polymorphic validation. + + :param clazz: Pydantic model class to register in the polymorphic hierarchy + :param name: Registry identifier for the class. Uses class name if None + :return: The registered class unchanged for decorator chaining + :raises TypeError: If clazz is not a Pydantic BaseModel subclass + """ + if not issubclass(clazz, BaseModel): + raise TypeError( + f"Cannot register {clazz.__name__} as it is not a subclass of " + "Pydantic BaseModel" + ) + + dec_clazz = super().register_decorator(clazz, name=name) + cls.reload_schema() + + return dec_clazz + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """ + Generate polymorphic validation schema for dynamic type instantiation. + + Creates a tagged union schema that enables Pydantic to automatically + instantiate the correct subclass based on the discriminator field value. + Falls back to base schema generation when no registry is available. + + :param source_type: Type being processed for schema generation + :param handler: Pydantic core schema generation handler + :return: Tagged union schema for polymorphic validation or base schema + """ + if source_type == cls.__pydantic_schema_base_type__(): + if not cls.registry: + return cls.__pydantic_generate_base_schema__(handler) + + choices = { + name: handler(model_class) for name, model_class in cls.registry.items() + } + + return core_schema.tagged_union_schema( + choices=choices, + discriminator=cls.schema_discriminator, + ) + + return handler(cls) + + @classmethod + @abstractmethod + def __pydantic_schema_base_type__(cls) -> type[BaseModelT]: + """ + Define the base type for polymorphic validation hierarchy. + + Must be implemented by subclasses to specify which type serves as the + root of the polymorphic hierarchy for schema generation and validation. + + :return: Base class type for the polymorphic model hierarchy + """ + ... + + @classmethod + def __pydantic_generate_base_schema__( + cls, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """ + Generate fallback schema for polymorphic models without registry. + + Provides a base schema that accepts any valid input when no registry + is available for polymorphic validation. Used as fallback during + schema generation when the registry has not been populated. + + :param handler: Pydantic core schema generation handler + :return: Base CoreSchema that accepts any valid input + """ + return core_schema.any_schema() + + @classmethod + def auto_populate_registry(cls) -> bool: + """ + Initialize registry with auto-discovery and reload validation schema. + + Triggers automatic population of the class registry through the parent + RegistryMixin functionality and ensures the Pydantic validation schema + is updated to include all discovered types for polymorphic validation. + + :return: True if registry was populated, False if already populated + :raises ValueError: If called when registry_auto_discovery is disabled + """ + populated = super().auto_populate_registry() + cls.reload_schema() + + return populated diff --git a/src/guidellm/utils/registry.py b/src/guidellm/utils/registry.py new file mode 100644 index 00000000..5d4bc055 --- /dev/null +++ b/src/guidellm/utils/registry.py @@ -0,0 +1,206 @@ +""" +Registry system for dynamic object registration and discovery. + +Provides a flexible object registration system with optional auto-discovery +capabilities through decorators and module imports. Enables dynamic discovery +and instantiation of implementations based on configuration parameters, supporting +both manual registration and automatic package-based discovery for extensible +plugin architectures. +""" + +from __future__ import annotations + +from typing import Any, Callable, ClassVar, Generic, TypeVar + +from guidellm.utils.auto_importer import AutoImporterMixin + +__all__ = ["RegistryMixin", "RegistryObjT"] + + +RegistryObjT = TypeVar("RegistryObjT", bound=Any) +""" +Generic type variable for objects managed by the registry system. +""" + + +class RegistryMixin(Generic[RegistryObjT], AutoImporterMixin): + """ + Generic mixin for creating object registries with optional auto-discovery. + + Enables classes to maintain separate registries of objects that can be + dynamically discovered and instantiated through decorators and module imports. + Supports both manual registration via decorators and automatic discovery + through package scanning for extensible plugin architectures. + + Example: + :: + class BaseAlgorithm(RegistryMixin): + pass + + @BaseAlgorithm.register() + class ConcreteAlgorithm(BaseAlgorithm): + pass + + @BaseAlgorithm.register("custom_name") + class AnotherAlgorithm(BaseAlgorithm): + pass + + # Get all registered implementations + algorithms = BaseAlgorithm.registered_objects() + + Example with auto-discovery: + :: + class TokenProposal(RegistryMixin): + registry_auto_discovery = True + auto_package = "mypackage.proposals" + + # Automatically imports and registers decorated objects + proposals = TokenProposal.registered_objects() + + :cvar registry: Dictionary mapping names to registered objects + :cvar registry_auto_discovery: Enable automatic package-based discovery + :cvar registry_populated: Track whether auto-discovery has completed + """ + + registry: ClassVar[dict[str, RegistryObjT] | None] = None + registry_auto_discovery: ClassVar[bool] = False + registry_populated: ClassVar[bool] = False + + @classmethod + def register( + cls, name: str | list[str] | None = None + ) -> Callable[[RegistryObjT], RegistryObjT]: + """ + Decorator that registers an object with the registry. + + :param name: Optional name(s) to register the object under. + If None, the object name is used as the registry key. + :return: A decorator function that registers the decorated object. + :raises ValueError: If name is provided but is not a string or list of strings. + """ + if name is not None and not isinstance(name, (str, list)): + raise ValueError( + "RegistryMixin.register() name must be a string, list of strings, " + f"or None. Got {name}." + ) + + return lambda obj: cls.register_decorator(obj, name=name) + + @classmethod + def register_decorator( + cls, obj: RegistryObjT, name: str | list[str] | None = None + ) -> RegistryObjT: + """ + Direct decorator that registers an object with the registry. + + :param obj: The object to register. + :param name: Optional name(s) to register the object under. + If None, the object name is used as the registry key. + :return: The registered object. + :raises ValueError: If the object is already registered or if name is invalid. + """ + + if not name: + name = obj.__name__ + elif not isinstance(name, (str, list)): + raise ValueError( + "RegistryMixin.register_decorator name must be a string or " + f"an iterable of strings. Got {name}." + ) + + if cls.registry is None: + cls.registry = {} + + names = [name] if isinstance(name, str) else list(name) + + for register_name in names: + if not isinstance(register_name, str): + raise ValueError( + "RegistryMixin.register_decorator name must be a string or " + f"a list of strings. Got {register_name}." + ) + + if register_name in cls.registry: + raise ValueError( + f"RegistryMixin.register_decorator cannot register an object " + f"{obj} with the name {register_name} because it is already " + "registered." + ) + + cls.registry[register_name.lower()] = obj + + return obj + + @classmethod + def auto_populate_registry(cls) -> bool: + """ + Import and register all modules from the specified auto_package. + + Automatically called by registered_objects when registry_auto_discovery is True + to ensure all available implementations are discovered before returning results. + + :return: True if the registry was populated, False if already populated. + :raises ValueError: If called when registry_auto_discovery is False. + """ + if not cls.registry_auto_discovery: + raise ValueError( + "RegistryMixin.auto_populate_registry() cannot be called " + "because registry_auto_discovery is set to False. " + "Set registry_auto_discovery to True to enable auto-discovery." + ) + + if cls.registry_populated: + return False + + cls.auto_import_package_modules() + cls.registry_populated = True + + return True + + @classmethod + def registered_objects(cls) -> tuple[RegistryObjT, ...]: + """ + Get all registered objects from the registry. + + Automatically triggers auto-discovery if registry_auto_discovery is enabled + to ensure all available implementations are included. + + :return: Tuple of all registered objects including auto-discovered ones. + :raises ValueError: If called before any objects have been registered. + """ + if cls.registry_auto_discovery: + cls.auto_populate_registry() + + if cls.registry is None: + raise ValueError( + "RegistryMixin.registered_objects() must be called after " + "registering objects with RegistryMixin.register()." + ) + + return tuple(cls.registry.values()) + + @classmethod + def is_registered(cls, name: str) -> bool: + """ + Check if an object is registered under the given name. + + :param name: The name to check for registration. + :return: True if the object is registered, False otherwise. + """ + if cls.registry is None: + return False + + return name.lower() in cls.registry + + @classmethod + def get_registered_object(cls, name: str) -> RegistryObjT | None: + """ + Get a registered object by its name. + + :param name: The name of the registered object. + :return: The registered object if found, None otherwise. + """ + if cls.registry is None: + return None + + return cls.registry.get(name.lower()) diff --git a/src/guidellm/utils/singleton.py b/src/guidellm/utils/singleton.py new file mode 100644 index 00000000..3ec10f79 --- /dev/null +++ b/src/guidellm/utils/singleton.py @@ -0,0 +1,130 @@ +""" +Singleton pattern implementations for ensuring single instance classes. + +Provides singleton mixins for creating classes that maintain a single instance +throughout the application lifecycle, with support for both basic and thread-safe +implementations. These mixins integrate with the scheduler and other system components +to ensure consistent state management and prevent duplicate resource allocation. +""" + +from __future__ import annotations + +import threading + +__all__ = ["SingletonMixin", "ThreadSafeSingletonMixin"] + + +class SingletonMixin: + """ + Basic singleton mixin ensuring single instance per class. + + Implements the singleton pattern using class variables to control instance + creation. Subclasses must call super().__init__() for proper initialization + state management. Suitable for single-threaded environments or when external + synchronization is provided. + + Example: + :: + class ConfigManager(SingletonMixin): + def __init__(self, config_path: str): + super().__init__() + if not self.initialized: + self.config = load_config(config_path) + + manager1 = ConfigManager("config.json") + manager2 = ConfigManager("config.json") + assert manager1 is manager2 + """ + + def __new__(cls, *args, **kwargs): # noqa: ARG004 + """ + Create or return the singleton instance. + + :param args: Positional arguments passed to the constructor + :param kwargs: Keyword arguments passed to the constructor + :return: The singleton instance of the class + """ + # Use class-specific attribute name to avoid inheritance issues + attr_name = f"_singleton_instance_{cls.__name__}" + + if not hasattr(cls, attr_name) or getattr(cls, attr_name) is None: + instance = super().__new__(cls) + setattr(cls, attr_name, instance) + instance._singleton_initialized = False + return getattr(cls, attr_name) + + def __init__(self): + """Initialize the singleton instance exactly once.""" + if hasattr(self, "_singleton_initialized") and self._singleton_initialized: + return + self._singleton_initialized = True + + @property + def initialized(self): + """Return True if the singleton has been initialized.""" + return getattr(self, "_singleton_initialized", False) + + +class ThreadSafeSingletonMixin(SingletonMixin): + """ + Thread-safe singleton mixin with locking mechanisms. + + Extends SingletonMixin with thread safety using locks to prevent race + conditions during instance creation in multi-threaded environments. Essential + for scheduler components and other shared resources accessed concurrently. + + Example: + :: + class SchedulerResource(ThreadSafeSingletonMixin): + def __init__(self): + super().__init__() + if not self.initialized: + self.resource_pool = initialize_resources() + """ + + def __new__(cls, *args, **kwargs): # noqa: ARG004 + """ + Create or return the singleton instance with thread safety. + + :param args: Positional arguments passed to the constructor + :param kwargs: Keyword arguments passed to the constructor + :return: The singleton instance of the class + """ + # Use class-specific lock and instance names to avoid inheritance issues + lock_attr_name = f"_singleton_lock_{cls.__name__}" + instance_attr_name = f"_singleton_instance_{cls.__name__}" + + with getattr(cls, lock_attr_name): + instance_exists = ( + hasattr(cls, instance_attr_name) + and getattr(cls, instance_attr_name) is not None + ) + if not instance_exists: + instance = super(SingletonMixin, cls).__new__(cls) + setattr(cls, instance_attr_name, instance) + instance._singleton_initialized = False + instance._init_lock = threading.Lock() + return getattr(cls, instance_attr_name) + + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) + lock_attr_name = f"_singleton_lock_{cls.__name__}" + setattr(cls, lock_attr_name, threading.Lock()) + + def __init__(self): + """Initialize the singleton instance with thread-safe initialization.""" + with self._init_lock: + if hasattr(self, "_singleton_initialized") and self._singleton_initialized: + return + self._singleton_initialized = True + + @property + def thread_lock(self): + """Return the thread lock for this singleton instance.""" + return getattr(self, "_init_lock", None) + + @classmethod + def get_singleton_lock(cls): + """Get the class-specific singleton creation lock.""" + lock_attr_name = f"_singleton_lock_{cls.__name__}" + return getattr(cls, lock_attr_name, None) diff --git a/tests/unit/objects/test_pydantic.py b/tests/unit/objects/test_pydantic.py deleted file mode 100644 index cb7f438f..00000000 --- a/tests/unit/objects/test_pydantic.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest -from pydantic import computed_field - -from guidellm.objects.pydantic import StandardBaseModel - - -class ExampleModel(StandardBaseModel): - name: str - age: int - - @computed_field # type: ignore[misc] - @property - def computed(self) -> str: - return self.name + " " + str(self.age) - - -@pytest.mark.smoke -def test_standard_base_model_initialization(): - example = ExampleModel(name="John Doe", age=30) - assert example.name == "John Doe" - assert example.age == 30 - assert example.computed == "John Doe 30" - - -@pytest.mark.smoke -def test_standard_base_model_invalid_initialization(): - with pytest.raises(ValueError): - ExampleModel(name="John Doe", age="thirty") # type: ignore[arg-type] - - -@pytest.mark.smoke -def test_standard_base_model_marshalling(): - example = ExampleModel(name="John Doe", age=30) - serialized = example.model_dump() - assert serialized["name"] == "John Doe" - assert serialized["age"] == 30 - assert serialized["computed"] == "John Doe 30" - - serialized["computed"] = "Jane Doe 40" - deserialized = ExampleModel.model_validate(serialized) - assert deserialized.name == "John Doe" - assert deserialized.age == 30 - assert deserialized.computed == "John Doe 30" diff --git a/tests/unit/utils/test_auto_importer.py b/tests/unit/utils/test_auto_importer.py new file mode 100644 index 00000000..cc71bce3 --- /dev/null +++ b/tests/unit/utils/test_auto_importer.py @@ -0,0 +1,269 @@ +""" +Unit tests for the auto_importer module. +""" + +from __future__ import annotations + +from unittest import mock + +import pytest + +from guidellm.utils import AutoImporterMixin + + +class TestAutoImporterMixin: + """Test suite for AutoImporterMixin functionality.""" + + @pytest.fixture( + params=[ + { + "auto_package": "test.package", + "auto_ignore_modules": None, + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module1", "test.package.module2"], + }, + { + "auto_package": ("test.package1", "test.package2"), + "auto_ignore_modules": None, + "modules": [ + ("test.package1.moduleA", False), + ("test.package2.moduleB", False), + ], + "expected_imports": ["test.package1.moduleA", "test.package2.moduleB"], + }, + { + "auto_package": "test.package", + "auto_ignore_modules": ("test.package.module1",), + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module2"], + }, + ], + ids=["single_package", "multiple_packages", "ignored_modules"], + ) + def valid_instances(self, request): + """Fixture providing test data for AutoImporterMixin subclasses.""" + config = request.param + + class TestClass(AutoImporterMixin): + auto_package = config["auto_package"] + auto_ignore_modules = config["auto_ignore_modules"] + + return TestClass, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test AutoImporterMixin class signatures and attributes.""" + assert hasattr(AutoImporterMixin, "auto_package") + assert hasattr(AutoImporterMixin, "auto_ignore_modules") + assert hasattr(AutoImporterMixin, "auto_imported_modules") + assert hasattr(AutoImporterMixin, "auto_import_package_modules") + assert callable(AutoImporterMixin.auto_import_package_modules) + + # Test default class variables + assert AutoImporterMixin.auto_package is None + assert AutoImporterMixin.auto_ignore_modules is None + assert AutoImporterMixin.auto_imported_modules is None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test AutoImporterMixin subclass initialization.""" + test_class, config = valid_instances + assert issubclass(test_class, AutoImporterMixin) + assert test_class.auto_package == config["auto_package"] + assert test_class.auto_ignore_modules == config["auto_ignore_modules"] + assert test_class.auto_imported_modules is None + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test AutoImporterMixin with missing auto_package.""" + + class TestClass(AutoImporterMixin): + pass + + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestClass.auto_import_package_modules() + + @pytest.mark.smoke + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_auto_import_package_modules(self, mock_walk, mock_import, valid_instances): + """Test auto_import_package_modules core functionality.""" + test_class, config = valid_instances + + # Setup mocks based on config + packages = {} + modules = {} + + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + pkg_path = pkg.replace(".", "/") + packages[pkg] = MockHelper.create_mock_package(pkg, pkg_path) + else: + pkg = config["auto_package"] + packages[pkg] = MockHelper.create_mock_package(pkg, pkg.replace(".", "/")) + + for module_name, is_pkg in config["modules"]: + if not is_pkg: + modules[module_name] = MockHelper.create_mock_module(module_name) + + mock_import.side_effect = lambda name: {**packages, **modules}.get( + name, mock.MagicMock() + ) + + def walk_side_effect(path, prefix): + return [ + (None, module_name, is_pkg) + for module_name, is_pkg in config["modules"] + if module_name.startswith(prefix) + ] + + mock_walk.side_effect = walk_side_effect + + # Execute + test_class.auto_import_package_modules() + + # Verify + assert test_class.auto_imported_modules == config["expected_imports"] + + # Verify package imports + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + mock_import.assert_any_call(pkg) + else: + mock_import.assert_any_call(config["auto_package"]) + + # Verify expected module imports + for expected_module in config["expected_imports"]: + mock_import.assert_any_call(expected_module) + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_auto_import_package_modules_invalid(self, mock_walk, mock_import): + """Test auto_import_package_modules with invalid configurations.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Test import error handling + mock_import.side_effect = ImportError("Module not found") + + with pytest.raises(ImportError): + TestClass.auto_import_package_modules() + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_skip_packages(self, mock_walk, mock_import): + """Test that packages (is_pkg=True) are skipped.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_module = MockHelper.create_mock_module("test.package.module") + + mock_import.side_effect = lambda name: { + "test.package": mock_package, + "test.package.module": mock_module, + }[name] + + mock_walk.return_value = [ + (None, "test.package.subpackage", True), + (None, "test.package.module", False), + ] + + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.module"] + mock_import.assert_any_call("test.package.module") + # subpackage should not be imported + with pytest.raises(AssertionError): + mock_import.assert_any_call("test.package.subpackage") + + @pytest.mark.sanity + @mock.patch("sys.modules", {"test.package.existing": mock.MagicMock()}) + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_skip_already_imported_modules(self, mock_walk, mock_import): + """Test that modules already in sys.modules are tracked but not re-imported.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_import.side_effect = lambda name: { + "test.package": mock_package, + }.get(name, mock.MagicMock()) + + mock_walk.return_value = [ + (None, "test.package.existing", False), + ] + + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.existing"] + mock_import.assert_called_once_with("test.package") + with pytest.raises(AssertionError): + mock_import.assert_any_call("test.package.existing") + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_prevent_duplicate_module_imports(self, mock_walk, mock_import): + """Test that modules already in auto_imported_modules are not re-imported.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_module = MockHelper.create_mock_module("test.package.module") + + mock_import.side_effect = lambda name: { + "test.package": mock_package, + "test.package.module": mock_module, + }[name] + + mock_walk.return_value = [ + (None, "test.package.module", False), + (None, "test.package.module", False), + ] + + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.module"] + assert mock_import.call_count == 2 # Package + module (not duplicate) + + +class MockHelper: + """Helper class to create consistent mock objects for testing.""" + + @staticmethod + def create_mock_package(name: str, path: str): + """Create a mock package with required attributes.""" + package = mock.MagicMock() + package.__name__ = name + package.__path__ = [path] + return package + + @staticmethod + def create_mock_module(name: str): + """Create a mock module with required attributes.""" + module = mock.MagicMock() + module.__name__ = name + return module diff --git a/tests/unit/utils/test_pydantic_utils.py b/tests/unit/utils/test_pydantic_utils.py new file mode 100644 index 00000000..8683604b --- /dev/null +++ b/tests/unit/utils/test_pydantic_utils.py @@ -0,0 +1,710 @@ +""" +Unit tests for the pydantic_utils module. +""" + +from __future__ import annotations + +from typing import ClassVar +from unittest import mock + +import pytest +from pydantic import BaseModel, Field, ValidationError + +from guidellm.utils.pydantic_utils import ( + PydanticClassRegistryMixin, + ReloadableBaseModel, + StandardBaseDict, + StandardBaseModel, + StatusBreakdown, +) + + +class TestReloadableBaseModel: + """Test suite for ReloadableBaseModel.""" + + @pytest.fixture( + params=[ + {"name": "test_value"}, + {"name": "hello_world"}, + {"name": "another_test"}, + ], + ids=["basic_string", "multi_word", "underscore"], + ) + def valid_instances(self, request) -> tuple[ReloadableBaseModel, dict[str, str]]: + """Fixture providing test data for ReloadableBaseModel.""" + + class TestModel(ReloadableBaseModel): + name: str + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ReloadableBaseModel inheritance and class variables.""" + assert issubclass(ReloadableBaseModel, BaseModel) + assert hasattr(ReloadableBaseModel, "model_config") + assert hasattr(ReloadableBaseModel, "reload_schema") + + # Check model configuration + config = ReloadableBaseModel.model_config + assert config["extra"] == "ignore" + assert config["use_enum_values"] is True + assert config["validate_assignment"] is True + assert config["from_attributes"] is True + assert config["arbitrary_types_allowed"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test ReloadableBaseModel initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ReloadableBaseModel) + assert instance.name == constructor_args["name"] # type: ignore[attr-defined] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("name", None), + ("name", 123), + ("name", []), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test ReloadableBaseModel with invalid field values.""" + + class TestModel(ReloadableBaseModel): + name: str + + data = {field: value} + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test ReloadableBaseModel initialization without required field.""" + + class TestModel(ReloadableBaseModel): + name: str + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_reload_schema(self): + """Test ReloadableBaseModel.reload_schema method.""" + + class TestModel(ReloadableBaseModel): + name: str + + # Mock the model_rebuild method to simulate schema reload + with mock.patch.object(TestModel, "model_rebuild") as mock_rebuild: + TestModel.reload_schema() + mock_rebuild.assert_called_once_with(force=True) + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test ReloadableBaseModel serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["name"] == constructor_args["name"] + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.name == constructor_args["name"] + + +class TestStandardBaseModel: + """Test suite for StandardBaseModel.""" + + @pytest.fixture( + params=[ + {"field_str": "test_value", "field_int": 42}, + {"field_str": "hello_world", "field_int": 100}, + {"field_str": "another_test", "field_int": 0}, + ], + ids=["basic_values", "positive_values", "zero_value"], + ) + def valid_instances( + self, request + ) -> tuple[StandardBaseModel, dict[str, int | str]]: + """Fixture providing test data for StandardBaseModel.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StandardBaseModel inheritance and class variables.""" + assert issubclass(StandardBaseModel, BaseModel) + assert hasattr(StandardBaseModel, "model_config") + assert hasattr(StandardBaseModel, "get_default") + + # Check model configuration + config = StandardBaseModel.model_config + assert config["extra"] == "ignore" + assert config["use_enum_values"] is True + assert config["validate_assignment"] is True + assert config["from_attributes"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StandardBaseModel initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StandardBaseModel) + assert instance.field_str == constructor_args["field_str"] # type: ignore[attr-defined] + assert instance.field_int == constructor_args["field_int"] # type: ignore[attr-defined] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("field_str", None), + ("field_str", 123), + ("field_int", "not_int"), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test StandardBaseModel with invalid field values.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + data = {field: value} + if field == "field_str": + data["field_int"] = 42 + else: + data["field_str"] = "test" + + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test StandardBaseModel initialization without required field.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=10, description="Test integer field") + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_get_default(self): + """Test StandardBaseModel.get_default method.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + field_int: int = Field(default=42, description="Test integer field") + + default_value = TestModel.get_default("field_int") + assert default_value == 42 + + @pytest.mark.sanity + def test_get_default_invalid(self): + """Test StandardBaseModel.get_default with invalid field.""" + + class TestModel(StandardBaseModel): + field_str: str = Field(description="Test string field") + + with pytest.raises(KeyError): + TestModel.get_default("nonexistent_field") + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StandardBaseModel serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["field_str"] == constructor_args["field_str"] + assert data_dict["field_int"] == constructor_args["field_int"] + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.field_str == constructor_args["field_str"] + assert recreated.field_int == constructor_args["field_int"] + + +class TestStandardBaseDict: + """Test suite for StandardBaseDict.""" + + @pytest.fixture( + params=[ + {"field_str": "test_value", "extra_field": "extra_value"}, + {"field_str": "hello_world", "another_extra": 123}, + {"field_str": "another_test", "complex_extra": {"nested": "value"}}, + ], + ids=["string_extra", "int_extra", "dict_extra"], + ) + def valid_instances( + self, request + ) -> tuple[StandardBaseDict, dict[str, str | int | dict[str, str]]]: + """Fixture providing test data for StandardBaseDict.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StandardBaseDict inheritance and class variables.""" + assert issubclass(StandardBaseDict, StandardBaseModel) + assert hasattr(StandardBaseDict, "model_config") + + # Check model configuration + config = StandardBaseDict.model_config + assert config["extra"] == "allow" + assert config["use_enum_values"] is True + assert config["validate_assignment"] is True + assert config["from_attributes"] is True + assert config["arbitrary_types_allowed"] is True + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StandardBaseDict initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StandardBaseDict) + assert instance.field_str == constructor_args["field_str"] # type: ignore[attr-defined] + + # Check extra fields are preserved + for key, value in constructor_args.items(): + if key != "field_str": + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("field_str", None), + ("field_str", 123), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test StandardBaseDict with invalid field values.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + data = {field: value} + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test StandardBaseDict initialization without required field.""" + + class TestModel(StandardBaseDict): + field_str: str = Field(description="Test string field") + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StandardBaseDict serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["field_str"] == constructor_args["field_str"] + + # Check extra fields are in the serialized data + for key, value in constructor_args.items(): + if key != "field_str": + assert key in data_dict + assert data_dict[key] == value + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.field_str == constructor_args["field_str"] + + # Check extra fields are preserved after deserialization + for key, value in constructor_args.items(): + if key != "field_str": + assert hasattr(recreated, key) + assert getattr(recreated, key) == value + + +class TestStatusBreakdown: + """Test suite for StatusBreakdown.""" + + @pytest.fixture( + params=[ + {"successful": 100, "errored": 5, "incomplete": 10, "total": 115}, + { + "successful": "success_data", + "errored": "error_data", + "incomplete": "incomplete_data", + "total": "total_data", + }, + { + "successful": [1, 2, 3], + "errored": [4, 5], + "incomplete": [6], + "total": [1, 2, 3, 4, 5, 6], + }, + ], + ids=["int_values", "string_values", "list_values"], + ) + def valid_instances(self, request) -> tuple[StatusBreakdown, dict]: + """Fixture providing test data for StatusBreakdown.""" + constructor_args = request.param + instance = StatusBreakdown(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test StatusBreakdown inheritance and type relationships.""" + assert issubclass(StatusBreakdown, BaseModel) + # Check if Generic is in the MRO (method resolution order) + assert any(cls.__name__ == "Generic" for cls in StatusBreakdown.__mro__) + assert "successful" in StatusBreakdown.model_fields + assert "errored" in StatusBreakdown.model_fields + assert "incomplete" in StatusBreakdown.model_fields + assert "total" in StatusBreakdown.model_fields + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test StatusBreakdown initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, StatusBreakdown) + assert instance.successful == constructor_args["successful"] + assert instance.errored == constructor_args["errored"] + assert instance.incomplete == constructor_args["incomplete"] + assert instance.total == constructor_args["total"] + + @pytest.mark.smoke + def test_initialization_defaults(self): + """Test StatusBreakdown initialization with default values.""" + instance: StatusBreakdown = StatusBreakdown() + assert instance.successful is None + assert instance.errored is None + assert instance.incomplete is None + assert instance.total is None + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test StatusBreakdown serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["successful"] == constructor_args["successful"] + assert data_dict["errored"] == constructor_args["errored"] + assert data_dict["incomplete"] == constructor_args["incomplete"] + assert data_dict["total"] == constructor_args["total"] + + recreated: StatusBreakdown = StatusBreakdown.model_validate(data_dict) + assert isinstance(recreated, StatusBreakdown) + assert recreated.successful == constructor_args["successful"] + assert recreated.errored == constructor_args["errored"] + assert recreated.incomplete == constructor_args["incomplete"] + assert recreated.total == constructor_args["total"] + + +class TestPydanticClassRegistryMixin: + """Test suite for PydanticClassRegistryMixin.""" + + @pytest.fixture( + params=[ + {"test_type": "test_sub", "value": "test_value"}, + {"test_type": "test_sub", "value": "hello_world"}, + ], + ids=["basic_value", "multi_word"], + ) + def valid_instances( + self, request + ) -> tuple[PydanticClassRegistryMixin, dict, type, type]: + """Fixture providing test data for PydanticClassRegistryMixin.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + TestBaseModel.reload_schema() + + constructor_args = request.param + instance = TestSubModel(value=constructor_args["value"]) + return instance, constructor_args, TestBaseModel, TestSubModel + + @pytest.mark.smoke + def test_class_signatures(self): + """Test PydanticClassRegistryMixin inheritance and class variables.""" + assert issubclass(PydanticClassRegistryMixin, ReloadableBaseModel) + assert hasattr(PydanticClassRegistryMixin, "schema_discriminator") + assert PydanticClassRegistryMixin.schema_discriminator == "model_type" + assert hasattr(PydanticClassRegistryMixin, "register_decorator") + assert hasattr(PydanticClassRegistryMixin, "__get_pydantic_core_schema__") + assert hasattr(PydanticClassRegistryMixin, "__pydantic_generate_base_schema__") + assert hasattr(PydanticClassRegistryMixin, "auto_populate_registry") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test PydanticClassRegistryMixin initialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + assert isinstance(instance, sub_class) + assert isinstance(instance, base_class) + assert instance.test_type == constructor_args["test_type"] + assert instance.value == constructor_args["value"] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("test_type", None), + ("test_type", 123), + ("value", None), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test PydanticClassRegistryMixin with invalid field values.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + data = {field: value} + if field == "test_type": + data["value"] = "test" + else: + data["test_type"] = "test_sub" + + with pytest.raises(ValidationError): + TestSubModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test PydanticClassRegistryMixin initialization without required field.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + with pytest.raises(ValidationError): + TestSubModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_register_decorator(self): + """Test PydanticClassRegistryMixin.register_decorator method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register() + class TestSubModel(TestBaseModel): + test_type: str = "TestSubModel" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "testsubmodel" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["testsubmodel"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_with_name(self): + """Test PydanticClassRegistryMixin.register_decorator with custom name.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("custom_name") + class TestSubModel(TestBaseModel): + test_type: str = "custom_name" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "custom_name" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["custom_name"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_invalid_type(self): + """Test PydanticClassRegistryMixin.register_decorator with invalid type.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + class RegularClass: + pass + + with pytest.raises(TypeError) as exc_info: + TestBaseModel.register_decorator(RegularClass) # type: ignore[arg-type] + + assert "not a subclass of Pydantic BaseModel" in str(exc_info.value) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test PydanticClassRegistryMixin.auto_populate_registry method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = True + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + with ( + mock.patch.object(TestBaseModel, "reload_schema") as mock_reload, + mock.patch( + "guidellm.utils.registry.RegistryMixin.auto_populate_registry", + return_value=True, + ), + ): + result = TestBaseModel.auto_populate_registry() + assert result is True + mock_reload.assert_called_once() + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test PydanticClassRegistryMixin serialization and deserialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + + # Test serialization with model_dump + dump_data = instance.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["test_type"] == constructor_args["test_type"] + assert dump_data["value"] == constructor_args["value"] + + # Test deserialization via subclass + recreated = sub_class.model_validate(dump_data) + assert isinstance(recreated, sub_class) + assert recreated.test_type == constructor_args["test_type"] + assert recreated.value == constructor_args["value"] + + # Test polymorphic deserialization via base class + recreated_base = base_class.model_validate(dump_data) # type: ignore[assignment] + assert isinstance(recreated_base, sub_class) + assert recreated_base.test_type == constructor_args["test_type"] + assert recreated_base.value == constructor_args["value"] + + @pytest.mark.regression + def test_polymorphic_container_marshalling(self): + """Test PydanticClassRegistryMixin in container models.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @classmethod + def __pydantic_generate_base_schema__(cls, handler): + return handler(cls) + + @TestBaseModel.register("sub_a") + class TestSubModelA(TestBaseModel): + test_type: str = "sub_a" + value_a: str + + @TestBaseModel.register("sub_b") + class TestSubModelB(TestBaseModel): + test_type: str = "sub_b" + value_b: int + + class ContainerModel(BaseModel): + name: str + model: TestBaseModel + models: list[TestBaseModel] + + sub_a = TestSubModelA(value_a="test") + sub_b = TestSubModelB(value_b=123) + + container = ContainerModel(name="container", model=sub_a, models=[sub_a, sub_b]) + + # Verify container construction + assert isinstance(container.model, TestSubModelA) + assert container.model.test_type == "sub_a" + assert container.model.value_a == "test" + assert len(container.models) == 2 + assert isinstance(container.models[0], TestSubModelA) + assert isinstance(container.models[1], TestSubModelB) + + # Test serialization + dump_data = container.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["name"] == "container" + assert dump_data["model"]["test_type"] == "sub_a" + assert dump_data["model"]["value_a"] == "test" + assert len(dump_data["models"]) == 2 + assert dump_data["models"][0]["test_type"] == "sub_a" + assert dump_data["models"][1]["test_type"] == "sub_b" + + # Test deserialization + recreated = ContainerModel.model_validate(dump_data) + assert isinstance(recreated, ContainerModel) + assert recreated.name == "container" + assert isinstance(recreated.model, TestSubModelA) + assert len(recreated.models) == 2 + assert isinstance(recreated.models[0], TestSubModelA) + assert isinstance(recreated.models[1], TestSubModelB) diff --git a/tests/unit/utils/test_registry.py b/tests/unit/utils/test_registry.py new file mode 100644 index 00000000..b5c17975 --- /dev/null +++ b/tests/unit/utils/test_registry.py @@ -0,0 +1,533 @@ +""" +Unit tests for the registry module. +""" + +from __future__ import annotations + +from typing import TypeVar +from unittest import mock + +import pytest + +from guidellm.utils.registry import RegistryMixin, RegistryObjT + + +def test_registry_obj_type(): + """Test that RegistryObjT is configured correctly as a TypeVar.""" + assert isinstance(RegistryObjT, type(TypeVar("test"))) + assert RegistryObjT.__name__ == "RegistryObjT" + assert RegistryObjT.__bound__ is not None # bound to Any + assert RegistryObjT.__constraints__ == () + + +class TestRegistryMixin: + """Test suite for RegistryMixin class.""" + + @pytest.fixture( + params=[ + {"registry_auto_discovery": False, "auto_package": None}, + {"registry_auto_discovery": True, "auto_package": "test.package"}, + ], + ids=["manual_registry", "auto_discovery"], + ) + def valid_instances(self, request): + """Fixture providing test data for RegistryMixin subclasses.""" + config = request.param + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = config["registry_auto_discovery"] + if config["auto_package"]: + auto_package = config["auto_package"] + + return TestRegistryClass, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test RegistryMixin inheritance and exposed methods.""" + assert hasattr(RegistryMixin, "registry") + assert hasattr(RegistryMixin, "registry_auto_discovery") + assert hasattr(RegistryMixin, "registry_populated") + assert hasattr(RegistryMixin, "register") + assert hasattr(RegistryMixin, "register_decorator") + assert hasattr(RegistryMixin, "auto_populate_registry") + assert hasattr(RegistryMixin, "registered_objects") + assert hasattr(RegistryMixin, "is_registered") + assert hasattr(RegistryMixin, "get_registered_object") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test RegistryMixin initialization.""" + registry_class, config = valid_instances + + assert registry_class.registry is None + assert ( + registry_class.registry_auto_discovery == config["registry_auto_discovery"] + ) + assert registry_class.registry_populated is False + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test RegistryMixin with missing auto_package when auto_discovery enabled.""" + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = True + + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestRegistryClass.auto_import_package_modules() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, None), # Uses class name + ], + ) + def test_register(self, valid_instances, name, expected_key): + """Test register method with various name configurations.""" + registry_class, _ = valid_instances + + if name is None: + + @registry_class.register() + class TestClass: + pass + + expected_key = "testclass" + else: + + @registry_class.register(name) + class TestClass: + pass + + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_invalid(self, valid_instances, invalid_name): + """Test register method with invalid name types.""" + registry_class, _ = valid_instances + + with pytest.raises(ValueError, match="name must be a string, list of strings"): + registry_class.register(invalid_name) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, "testclass"), + ], + ) + def test_register_decorator(self, valid_instances, name, expected_key): + """Test register_decorator method with various name configurations.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + registry_class.register_decorator(TestClass, name=name) + + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_decorator_invalid(self, valid_instances, invalid_name): + """Test register_decorator with invalid name types.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + with pytest.raises( + ValueError, match="name must be a string or an iterable of strings" + ): + registry_class.register_decorator(TestClass, name=invalid_name) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test auto_populate_registry method with valid configuration.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test.package" + + with mock.patch.object( + TestAutoRegistry, "auto_import_package_modules" + ) as mock_import: + result = TestAutoRegistry.auto_populate_registry() + assert result is True + mock_import.assert_called_once() + assert TestAutoRegistry.registry_populated is True + + # Second call should return False + result = TestAutoRegistry.auto_populate_registry() + assert result is False + mock_import.assert_called_once() # Should not be called again + + @pytest.mark.sanity + def test_auto_populate_registry_invalid(self): + """Test auto_populate_registry when auto-discovery is disabled.""" + + class TestDisabledRegistry(RegistryMixin): + registry_auto_discovery = False + + with pytest.raises(ValueError, match="registry_auto_discovery is set to False"): + TestDisabledRegistry.auto_populate_registry() + + @pytest.mark.smoke + def test_registered_objects(self, valid_instances): + """Test registered_objects method with manual registration.""" + registry_class, config = valid_instances + + @registry_class.register("class1") + class TestClass1: + pass + + @registry_class.register("class2") + class TestClass2: + pass + + if config["registry_auto_discovery"]: + with mock.patch.object(registry_class, "auto_import_package_modules"): + objects = registry_class.registered_objects() + else: + objects = registry_class.registered_objects() + + assert isinstance(objects, tuple) + assert len(objects) == 2 + assert TestClass1 in objects + assert TestClass2 in objects + + @pytest.mark.sanity + def test_registered_objects_invalid(self): + """Test registered_objects when no objects are registered.""" + + class TestRegistryClass(RegistryMixin): + pass + + with pytest.raises( + ValueError, match="must be called after registering objects" + ): + TestRegistryClass.registered_objects() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "check_name", "expected"), + [ + ("test_name", "test_name", True), + ("TestName", "testname", True), + ("UPPERCASE", "uppercase", True), + ("test_name", "nonexistent", False), + ], + ) + def test_is_registered(self, valid_instances, register_name, check_name, expected): + """Test is_registered with various name combinations.""" + registry_class, _ = valid_instances + + @registry_class.register(register_name) + class TestClass: + pass + + result = registry_class.is_registered(check_name) + assert result == expected + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "lookup_name"), + [ + ("test_name", "test_name"), + ("TestName", "testname"), + ("UPPERCASE", "uppercase"), + ], + ) + def test_get_registered_object(self, valid_instances, register_name, lookup_name): + """Test get_registered_object with valid names.""" + registry_class, _ = valid_instances + + @registry_class.register(register_name) + class TestClass: + pass + + result = registry_class.get_registered_object(lookup_name) + assert result is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "lookup_name", + ["nonexistent", "wrong_name", "DIFFERENT_CASE"], + ) + def test_get_registered_object_invalid(self, valid_instances, lookup_name): + """Test get_registered_object with invalid names.""" + registry_class, _ = valid_instances + + @registry_class.register("valid_name") + class TestClass: + pass + + result = registry_class.get_registered_object(lookup_name) + assert result is None + + @pytest.mark.regression + def test_multiple_registries_isolation(self): + """Test that different registry classes maintain separate registries.""" + + class Registry1(RegistryMixin): + pass + + class Registry2(RegistryMixin): + pass + + @Registry1.register() + class TestClass1: + pass + + @Registry2.register() + class TestClass2: + pass + + assert Registry1.registry is not None + assert Registry2.registry is not None + assert Registry1.registry != Registry2.registry + assert "testclass1" in Registry1.registry + assert "testclass2" in Registry2.registry + assert "testclass1" not in Registry2.registry + assert "testclass2" not in Registry1.registry + + @pytest.mark.regression + def test_inheritance_registry_sharing(self): + """Test that inherited registry classes share the same registry.""" + + class BaseRegistry(RegistryMixin): + pass + + class ChildRegistry(BaseRegistry): + pass + + @BaseRegistry.register() + class BaseClass: + pass + + @ChildRegistry.register() + class ChildClass: + pass + + # Child classes share the same registry as their parent + assert BaseRegistry.registry is ChildRegistry.registry + + # Both classes can see all registered objects + base_objects = BaseRegistry.registered_objects() + child_objects = ChildRegistry.registered_objects() + + assert len(base_objects) == 2 + assert len(child_objects) == 2 + assert base_objects == child_objects + assert BaseClass in base_objects + assert ChildClass in base_objects + + @pytest.mark.smoke + def test_auto_discovery_initialization(self): + """Test initialization of auto-discovery enabled registry.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + assert TestAutoRegistry.registry is None + assert TestAutoRegistry.registry_populated is False + assert TestAutoRegistry.auto_package == "test_package.modules" + assert TestAutoRegistry.registry_auto_discovery is True + + @pytest.mark.smoke + def test_auto_discovery_registered_objects(self): + """Test automatic population during registered_objects call.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with mock.patch.object( + TestAutoRegistry, "auto_populate_registry" + ) as mock_populate: + TestAutoRegistry.registry = {"class1": "obj1", "class2": "obj2"} + objects = TestAutoRegistry.registered_objects() + mock_populate.assert_called_once() + assert objects == ("obj1", "obj2") + + @pytest.mark.sanity + def test_register_duplicate_registration(self, valid_instances): + """Test register method with duplicate names.""" + registry_class, _ = valid_instances + + @registry_class.register("duplicate_name") + class TestClass1: + pass + + with pytest.raises(ValueError, match="already registered"): + + @registry_class.register("duplicate_name") + class TestClass2: + pass + + @pytest.mark.sanity + def test_register_decorator_duplicate_registration(self, valid_instances): + """Test register_decorator with duplicate names.""" + registry_class, _ = valid_instances + + class TestClass1: + pass + + class TestClass2: + pass + + registry_class.register_decorator(TestClass1, name="duplicate_name") + with pytest.raises(ValueError, match="already registered"): + registry_class.register_decorator(TestClass2, name="duplicate_name") + + @pytest.mark.sanity + def test_register_decorator_invalid_list_element(self, valid_instances): + """Test register_decorator with invalid elements in name list.""" + registry_class, _ = valid_instances + + class TestClass: + pass + + with pytest.raises( + ValueError, match="name must be a string or a list of strings" + ): + registry_class.register_decorator(TestClass, name=["valid", 123]) + + @pytest.mark.sanity + def test_register_decorator_invalid_object(self, valid_instances): + """Test register_decorator with object lacking __name__ attribute.""" + registry_class, _ = valid_instances + + with pytest.raises(AttributeError): + registry_class.register_decorator("not_a_class") + + @pytest.mark.smoke + def test_is_registered_empty_registry(self, valid_instances): + """Test is_registered with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.is_registered("any_name") + assert result is False + + @pytest.mark.smoke + def test_get_registered_object_empty_registry(self, valid_instances): + """Test get_registered_object with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.get_registered_object("any_name") + assert result is None + + @pytest.mark.regression + def test_auto_registry_integration(self): + """Test complete auto-discovery workflow with mocked imports.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with ( + mock.patch("pkgutil.walk_packages") as walk_mock, + mock.patch("importlib.import_module") as import_mock, + ): + # Setup mock package + package_mock = mock.MagicMock() + package_mock.__path__ = ["test_package/modules"] + package_mock.__name__ = "test_package.modules" + + # Setup mock module with test class + module_mock = mock.MagicMock() + module_mock.__name__ = "test_package.modules.module1" + + class Module1Class: + pass + + TestAutoRegistry.register_decorator(Module1Class, "Module1Class") + + # Setup import behavior + import_mock.side_effect = lambda name: ( + package_mock + if name == "test_package.modules" + else module_mock + if name == "test_package.modules.module1" + else (_ for _ in ()).throw(ImportError(f"No module named {name}")) + ) + + # Setup package walking behavior + walk_mock.side_effect = lambda path, prefix: ( + [(None, "test_package.modules.module1", False)] + if prefix == "test_package.modules." + else (_ for _ in ()).throw(ValueError(f"Unknown package: {prefix}")) + ) + + objects = TestAutoRegistry.registered_objects() + assert len(objects) == 1 + assert TestAutoRegistry.registry_populated is True + assert TestAutoRegistry.registry is not None + assert "module1class" in TestAutoRegistry.registry + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with ( + mock.patch("pkgutil.walk_packages") as mock_walk, + mock.patch("importlib.import_module") as mock_import, + ): + mock_package = mock.MagicMock() + mock_package.__path__ = ["test_package/modules"] + mock_package.__name__ = "test_package.modules" + + def import_module(name: str): + if name == "test_package.modules": + return mock_package + elif name == "test_package.modules.module1": + module = mock.MagicMock() + module.__name__ = "test_package.modules.module1" + + class Module1Class: + pass + + TestAutoRegistry.register_decorator(Module1Class, "Module1Class") + return module + else: + raise ImportError(f"No module named {name}") + + def walk_packages(package_path, package_name): + if package_name == "test_package.modules.": + return [(None, "test_package.modules.module1", False)] + else: + raise ValueError(f"Unknown package: {package_name}") + + mock_walk.side_effect = walk_packages + mock_import.side_effect = import_module + + objects = TestAutoRegistry.registered_objects() + assert len(objects) == 1 + assert TestAutoRegistry.registry_populated is True + assert TestAutoRegistry.registry is not None diff --git a/tests/unit/utils/test_singleton.py b/tests/unit/utils/test_singleton.py new file mode 100644 index 00000000..ee01ead1 --- /dev/null +++ b/tests/unit/utils/test_singleton.py @@ -0,0 +1,371 @@ +from __future__ import annotations + +import threading +import time + +import pytest + +from guidellm.utils.singleton import SingletonMixin, ThreadSafeSingletonMixin + + +class TestSingletonMixin: + """Test suite for SingletonMixin class.""" + + @pytest.fixture( + params=[ + {"init_value": "test_value"}, + {"init_value": "another_value"}, + ], + ids=["basic_singleton", "different_value"], + ) + def valid_instances(self, request): + """Provide parameterized test configurations for singleton testing.""" + config = request.param + + class TestSingleton(SingletonMixin): + def __init__(self): + # Check if we need to initialize before calling super().__init__() + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = config["init_value"] + + return TestSingleton, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test SingletonMixin inheritance and exposed attributes.""" + assert hasattr(SingletonMixin, "__new__") + assert hasattr(SingletonMixin, "__init__") + assert hasattr(SingletonMixin, "initialized") + assert isinstance(SingletonMixin.initialized, property) + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test SingletonMixin initialization.""" + singleton_class, config = valid_instances + + # Create first instance + instance1 = singleton_class() + + assert isinstance(instance1, singleton_class) + assert instance1.initialized is True + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + + # Check that the class has the singleton instance stored + instance_attr = f"_singleton_instance_{singleton_class.__name__}" + assert hasattr(singleton_class, instance_attr) + assert getattr(singleton_class, instance_attr) is instance1 + + @pytest.mark.smoke + def test_singleton_behavior(self, valid_instances): + """Test that multiple instantiations return the same instance.""" + singleton_class, config = valid_instances + + # Create multiple instances + instance1 = singleton_class() + instance2 = singleton_class() + instance3 = singleton_class() + + # All should be the same instance + assert instance1 is instance2 + assert instance2 is instance3 + assert instance1 is instance3 + + # Value should remain from first initialization + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + assert instance2.value == config["init_value"] + assert instance3.value == config["init_value"] + + @pytest.mark.sanity + def test_initialization_called_once(self, valid_instances): + """Test that __init__ is only called once despite multiple instantiations.""" + singleton_class, config = valid_instances + + class TestSingletonWithCounter(SingletonMixin): + init_count = 0 + + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + TestSingletonWithCounter.init_count += 1 + self.value = config["init_value"] + + # Create multiple instances + instance1 = TestSingletonWithCounter() + instance2 = TestSingletonWithCounter() + + assert TestSingletonWithCounter.init_count == 1 + assert instance1 is instance2 + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + + @pytest.mark.regression + def test_multiple_singleton_classes_isolation(self): + """Test that different singleton classes maintain separate instances.""" + + class Singleton1(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value1" + + class Singleton2(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value2" + + instance1a = Singleton1() + instance2a = Singleton2() + instance1b = Singleton1() + instance2b = Singleton2() + + # Each class has its own singleton instance + assert instance1a is instance1b + assert instance2a is instance2b + assert instance1a is not instance2a + + # Each maintains its own value + assert hasattr(instance1a, "value") + assert hasattr(instance2a, "value") + assert instance1a.value == "value1" + assert instance2a.value == "value2" + + @pytest.mark.regression + def test_inheritance_singleton_sharing(self): + """Test that inherited singleton classes share the same singleton_instance.""" + + class BaseSingleton(SingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "base_value" + + class ChildSingleton(BaseSingleton): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.extra = "extra_value" + + # Child classes now have separate singleton instances + base_instance = BaseSingleton() + child_instance = ChildSingleton() + + # They should be different instances now (fixed inheritance behavior) + assert base_instance is not child_instance + assert hasattr(base_instance, "value") + assert base_instance.value == "base_value" + assert hasattr(child_instance, "value") + assert child_instance.value == "base_value" + assert hasattr(child_instance, "extra") + assert child_instance.extra == "extra_value" + + @pytest.mark.sanity + def test_without_super_init_call(self): + """Test singleton behavior when subclass doesn't call super().__init__().""" + + class BadSingleton(SingletonMixin): + def __init__(self): + # Not calling super().__init__() + self.value = "bad_value" + + instance1 = BadSingleton() + instance2 = BadSingleton() + + assert instance1 is instance2 + assert hasattr(instance1, "initialized") + assert instance1.initialized is False + + +class TestThreadSafeSingletonMixin: + """Test suite for ThreadSafeSingletonMixin class.""" + + @pytest.fixture( + params=[ + {"init_value": "thread_safe_value"}, + {"init_value": "concurrent_value"}, + ], + ids=["basic_thread_safe", "concurrent_test"], + ) + def valid_instances(self, request): + """Fixture providing test data for ThreadSafeSingletonMixin subclasses.""" + config = request.param + + class TestThreadSafeSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = config["init_value"] + + return TestThreadSafeSingleton, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ThreadSafeSingletonMixin inheritance and exposed attributes.""" + assert issubclass(ThreadSafeSingletonMixin, SingletonMixin) + assert hasattr(ThreadSafeSingletonMixin, "get_singleton_lock") + assert hasattr(ThreadSafeSingletonMixin, "__new__") + assert hasattr(ThreadSafeSingletonMixin, "__init__") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test ThreadSafeSingletonMixin initialization.""" + singleton_class, config = valid_instances + + instance = singleton_class() + + assert isinstance(instance, singleton_class) + assert instance.initialized is True + assert hasattr(instance, "value") + assert instance.value == config["init_value"] + assert hasattr(instance, "thread_lock") + lock_type = type(threading.Lock()) + assert isinstance(instance.thread_lock, lock_type) + + @pytest.mark.smoke + def test_singleton_behavior(self, valid_instances): + """Test multiple instantiations return same instance with thread safety.""" + singleton_class, config = valid_instances + + instance1 = singleton_class() + instance2 = singleton_class() + + assert instance1 is instance2 + assert hasattr(instance1, "value") + assert instance1.value == config["init_value"] + assert hasattr(instance1, "thread_lock") + + @pytest.mark.regression + def test_thread_safety_concurrent_creation(self, valid_instances): + """Test thread safety during concurrent instance creation.""" + singleton_class, config = valid_instances + + instances = [] + exceptions = [] + creation_count = 0 + lock = threading.Lock() + + class ThreadSafeTestSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + nonlocal creation_count + with lock: + creation_count += 1 + + time.sleep(0.01) + self.value = config["init_value"] + + def create_instance(): + try: + instance = ThreadSafeTestSingleton() + instances.append(instance) + except (TypeError, ValueError, AttributeError) as exc: + exceptions.append(exc) + + threads = [] + for _ in range(10): + thread = threading.Thread(target=create_instance) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + assert len(exceptions) == 0, f"Exceptions occurred: {exceptions}" + + assert len(instances) == 10 + for instance in instances: + assert instance is instances[0] + + assert creation_count == 1 + assert all(instance.value == config["init_value"] for instance in instances) + + @pytest.mark.sanity + def test_thread_lock_creation(self, valid_instances): + """Test that thread_lock is created during initialization.""" + singleton_class, config = valid_instances + + instance1 = singleton_class() + instance2 = singleton_class() + + assert hasattr(instance1, "thread_lock") + lock_type = type(threading.Lock()) + assert isinstance(instance1.thread_lock, lock_type) + assert instance1.thread_lock is instance2.thread_lock + + @pytest.mark.regression + def test_multiple_thread_safe_classes_isolation(self): + """Test thread-safe singleton classes behavior with separate locks.""" + + class ThreadSafeSingleton1(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value1" + + class ThreadSafeSingleton2(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "value2" + + instance1 = ThreadSafeSingleton1() + instance2 = ThreadSafeSingleton2() + + lock1 = ThreadSafeSingleton1.get_singleton_lock() + lock2 = ThreadSafeSingleton2.get_singleton_lock() + + assert lock1 is not None + assert lock2 is not None + assert lock1 is not lock2 + + assert instance1 is not instance2 + assert hasattr(instance1, "value") + assert hasattr(instance2, "value") + assert instance1.value == "value1" + assert instance2.value == "value2" + + @pytest.mark.sanity + def test_inheritance_with_thread_safety(self): + """Test inheritance behavior with thread-safe singletons.""" + + class BaseThreadSafeSingleton(ThreadSafeSingletonMixin): + def __init__(self): + should_initialize = not getattr(self, "_singleton_initialized", False) + super().__init__() + if should_initialize: + self.value = "base_value" + + class ChildThreadSafeSingleton(BaseThreadSafeSingleton): + def __init__(self): + super().__init__() + + base_instance = BaseThreadSafeSingleton() + child_instance = ChildThreadSafeSingleton() + + base_lock = BaseThreadSafeSingleton.get_singleton_lock() + child_lock = ChildThreadSafeSingleton.get_singleton_lock() + + assert base_lock is not None + assert child_lock is not None + assert base_lock is not child_lock + + assert base_instance is not child_instance + assert hasattr(base_instance, "value") + assert base_instance.value == "base_value" + assert hasattr(base_instance, "thread_lock")