Skip to content

Commit 80beb67

Browse files
committed
quick utils enhancements
1 parent 7360368 commit 80beb67

File tree

4 files changed

+200
-137
lines changed

4 files changed

+200
-137
lines changed

src/guidellm/utils/pydantic_utils.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929

3030
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
31+
RegisterClassT = TypeVar("RegisterClassT", bound=type[BaseModelT])
3132
SuccessfulT = TypeVar("SuccessfulT")
3233
ErroredT = TypeVar("ErroredT")
3334
IncompleteT = TypeVar("IncompleteT")
@@ -130,7 +131,7 @@ class StatusBreakdown(BaseModel, Generic[SuccessfulT, ErroredT, IncompleteT, Tot
130131
131132
Example:
132133
::
133-
from guidellm.utils.pydantic_utils import StatusBreakdown
134+
from guidellm.utils import StatusBreakdown
134135
135136
# Define a breakdown for request counts
136137
breakdown = StatusBreakdown[int, int, int, int](
@@ -172,7 +173,7 @@ class PydanticClassRegistryMixin(
172173
173174
Example:
174175
::
175-
from guidellm.utils.pydantic_utils import PydanticClassRegistryMixin
176+
from speculators.utils import PydanticClassRegistryMixin
176177
177178
class BaseConfig(PydanticClassRegistryMixin["BaseConfig"]):
178179
schema_discriminator: ClassVar[str] = "config_type"
@@ -200,8 +201,8 @@ class DatabaseConfig(BaseConfig):
200201

201202
@classmethod
202203
def register_decorator(
203-
cls, clazz: type[BaseModelT], name: str | list[str] | None = None
204-
) -> type[BaseModelT]:
204+
cls, clazz: RegisterClassT, name: str | list[str] | None = None
205+
) -> RegisterClassT:
205206
"""
206207
Register a Pydantic model class with type validation and schema reload.
207208
@@ -300,3 +301,25 @@ def auto_populate_registry(cls) -> bool:
300301
cls.reload_schema()
301302

302303
return populated
304+
305+
@classmethod
306+
def registered_classes(cls) -> tuple[type[Any], ...]:
307+
"""
308+
Get all registered pydantic classes from the registry.
309+
310+
Automatically triggers auto-discovery if registry_auto_discovery is enabled
311+
to ensure all available implementations are included.
312+
313+
:return: Tuple of all registered classes including auto-discovered ones
314+
:raises ValueError: If called before any objects have been registered
315+
"""
316+
if cls.registry_auto_discovery:
317+
cls.auto_populate_registry()
318+
319+
if cls.registry is None:
320+
raise ValueError(
321+
"ClassRegistryMixin.registered_classes() must be called after "
322+
"registering classes with ClassRegistryMixin.register()."
323+
)
324+
325+
return tuple(cls.registry.values())

src/guidellm/utils/registry.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,23 @@
1414

1515
from guidellm.utils.auto_importer import AutoImporterMixin
1616

17-
__all__ = ["RegistryMixin", "RegistryObjT"]
17+
__all__ = ["RegisterT", "RegistryMixin", "RegistryObjT"]
1818

1919

2020
RegistryObjT = TypeVar("RegistryObjT", bound=Any)
21-
"""
22-
Generic type variable for objects managed by the registry system.
23-
"""
21+
"""Generic type variable for objects managed by the registry system."""
22+
RegisterT = TypeVar("RegisterT", bound=RegistryObjT)
23+
"""Generic type variable for the args and return values within the registry."""
2424

2525

2626
class RegistryMixin(Generic[RegistryObjT], AutoImporterMixin):
2727
"""
2828
Generic mixin for creating object registries with optional auto-discovery.
2929
30-
Enables classes to maintain separate registries of objects that can be
31-
dynamically discovered and instantiated through decorators and module imports.
32-
Supports both manual registration via decorators and automatic discovery
33-
through package scanning for extensible plugin architectures.
30+
Enables classes to maintain separate registries of objects that can be dynamically
31+
discovered and instantiated through decorators and module imports. Supports both
32+
manual registration via decorators and automatic discovery through package scanning
33+
for extensible plugin architectures.
3434
3535
Example:
3636
::
@@ -69,14 +69,14 @@ class TokenProposal(RegistryMixin):
6969
@classmethod
7070
def register(
7171
cls, name: str | list[str] | None = None
72-
) -> Callable[[RegistryObjT], RegistryObjT]:
72+
) -> Callable[[RegisterT], RegisterT]:
7373
"""
74-
Decorator that registers an object with the registry.
74+
Decorator for registering objects with the registry.
7575
7676
:param name: Optional name(s) to register the object under.
77-
If None, the object name is used as the registry key.
78-
:return: A decorator function that registers the decorated object.
79-
:raises ValueError: If name is provided but is not a string or list of strings.
77+
If None, uses the object's __name__ attribute
78+
:return: Decorator function that registers the decorated object
79+
:raises ValueError: If name is not a string, list of strings, or None
8080
"""
8181
if name is not None and not isinstance(name, (str, list)):
8282
raise ValueError(
@@ -88,19 +88,19 @@ def register(
8888

8989
@classmethod
9090
def register_decorator(
91-
cls, obj: RegistryObjT, name: str | list[str] | None = None
92-
) -> RegistryObjT:
91+
cls, obj: RegisterT, name: str | list[str] | None = None
92+
) -> RegisterT:
9393
"""
94-
Direct decorator that registers an object with the registry.
94+
Register an object directly with the registry.
9595
96-
:param obj: The object to register.
96+
:param obj: The object to register
9797
:param name: Optional name(s) to register the object under.
98-
If None, the object name is used as the registry key.
99-
:return: The registered object.
100-
:raises ValueError: If the object is already registered or if name is invalid.
98+
If None, uses the object's __name__ attribute
99+
:return: The registered object
100+
:raises ValueError: If the object is already registered or name is invalid
101101
"""
102102

103-
if not name:
103+
if name is None:
104104
name = obj.__name__
105105
elif not isinstance(name, (str, list)):
106106
raise ValueError(
@@ -127,20 +127,20 @@ def register_decorator(
127127
"registered."
128128
)
129129

130-
cls.registry[register_name.lower()] = obj
130+
cls.registry[register_name] = obj
131131

132132
return obj
133133

134134
@classmethod
135135
def auto_populate_registry(cls) -> bool:
136136
"""
137-
Import and register all modules from the specified auto_package.
137+
Import and register all modules from the auto_package.
138138
139139
Automatically called by registered_objects when registry_auto_discovery is True
140-
to ensure all available implementations are discovered before returning results.
140+
to ensure all available implementations are discovered.
141141
142-
:return: True if the registry was populated, False if already populated.
143-
:raises ValueError: If called when registry_auto_discovery is False.
142+
:return: True if registry was populated, False if already populated
143+
:raises ValueError: If called when registry_auto_discovery is False
144144
"""
145145
if not cls.registry_auto_discovery:
146146
raise ValueError(
@@ -165,8 +165,8 @@ def registered_objects(cls) -> tuple[RegistryObjT, ...]:
165165
Automatically triggers auto-discovery if registry_auto_discovery is enabled
166166
to ensure all available implementations are included.
167167
168-
:return: Tuple of all registered objects including auto-discovered ones.
169-
:raises ValueError: If called before any objects have been registered.
168+
:return: Tuple of all registered objects including auto-discovered ones
169+
:raises ValueError: If called before any objects have been registered
170170
"""
171171
if cls.registry_auto_discovery:
172172
cls.auto_populate_registry()
@@ -183,24 +183,33 @@ def registered_objects(cls) -> tuple[RegistryObjT, ...]:
183183
def is_registered(cls, name: str) -> bool:
184184
"""
185185
Check if an object is registered under the given name.
186+
It matches first by exact name, then by str.lower().
186187
187188
:param name: The name to check for registration.
188189
:return: True if the object is registered, False otherwise.
189190
"""
190191
if cls.registry is None:
191192
return False
192193

193-
return name.lower() in cls.registry
194+
return name in cls.registry or name.lower() in [
195+
key.lower() for key in cls.registry
196+
]
194197

195198
@classmethod
196199
def get_registered_object(cls, name: str) -> RegistryObjT | None:
197200
"""
198-
Get a registered object by its name.
201+
Get a registered object by its name. It matches first by exact name,
202+
then by str.lower().
199203
200204
:param name: The name of the registered object.
201205
:return: The registered object if found, None otherwise.
202206
"""
203207
if cls.registry is None:
204208
return None
205209

206-
return cls.registry.get(name.lower())
210+
if name in cls.registry:
211+
return cls.registry[name]
212+
213+
lower_key_map = {key.lower(): key for key in cls.registry}
214+
215+
return cls.registry.get(lower_key_map.get(name.lower()))

tests/unit/utils/test_pydantic_utils.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pytest
1111
from pydantic import BaseModel, Field, ValidationError
1212

13-
from guidellm.utils.pydantic_utils import (
13+
from guidellm.utils import (
1414
PydanticClassRegistryMixin,
1515
ReloadableBaseModel,
1616
StandardBaseDict,
@@ -459,6 +459,7 @@ def test_class_signatures(self):
459459
assert hasattr(PydanticClassRegistryMixin, "__get_pydantic_core_schema__")
460460
assert hasattr(PydanticClassRegistryMixin, "__pydantic_generate_base_schema__")
461461
assert hasattr(PydanticClassRegistryMixin, "auto_populate_registry")
462+
assert hasattr(PydanticClassRegistryMixin, "registered_classes")
462463

463464
@pytest.mark.smoke
464465
def test_initialization(self, valid_instances):
@@ -547,8 +548,8 @@ class TestSubModel(TestBaseModel):
547548
value: str
548549

549550
assert TestBaseModel.registry is not None # type: ignore[misc]
550-
assert "testsubmodel" in TestBaseModel.registry # type: ignore[misc]
551-
assert TestBaseModel.registry["testsubmodel"] is TestSubModel # type: ignore[misc]
551+
assert "TestSubModel" in TestBaseModel.registry # type: ignore[misc]
552+
assert TestBaseModel.registry["TestSubModel"] is TestSubModel # type: ignore[misc]
552553

553554
@pytest.mark.sanity
554555
def test_register_decorator_with_name(self):
@@ -621,6 +622,87 @@ def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]:
621622
assert result is True
622623
mock_reload.assert_called_once()
623624

625+
@pytest.mark.smoke
626+
def test_registered_classes(self):
627+
"""Test PydanticClassRegistryMixin.registered_classes method."""
628+
629+
class TestBaseModel(PydanticClassRegistryMixin):
630+
schema_discriminator: ClassVar[str] = "test_type"
631+
test_type: str
632+
registry_auto_discovery: ClassVar[bool] = False
633+
634+
@classmethod
635+
def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]:
636+
if cls.__name__ == "TestBaseModel":
637+
return cls
638+
return TestBaseModel
639+
640+
@TestBaseModel.register("test_sub_a")
641+
class TestSubModelA(TestBaseModel):
642+
test_type: str = "test_sub_a"
643+
value_a: str
644+
645+
@TestBaseModel.register("test_sub_b")
646+
class TestSubModelB(TestBaseModel):
647+
test_type: str = "test_sub_b"
648+
value_b: int
649+
650+
# Test normal case with registered classes
651+
registered = TestBaseModel.registered_classes()
652+
assert isinstance(registered, tuple)
653+
assert len(registered) == 2
654+
assert TestSubModelA in registered
655+
assert TestSubModelB in registered
656+
657+
@pytest.mark.sanity
658+
def test_registered_classes_with_auto_discovery(self):
659+
"""Test PydanticClassRegistryMixin.registered_classes with auto discovery."""
660+
661+
class TestBaseModel(PydanticClassRegistryMixin):
662+
schema_discriminator: ClassVar[str] = "test_type"
663+
test_type: str
664+
registry_auto_discovery: ClassVar[bool] = True
665+
666+
@classmethod
667+
def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]:
668+
if cls.__name__ == "TestBaseModel":
669+
return cls
670+
return TestBaseModel
671+
672+
with mock.patch.object(
673+
TestBaseModel, "auto_populate_registry"
674+
) as mock_auto_populate:
675+
# Mock the registry to simulate registered classes
676+
TestBaseModel.registry = {"test_class": type("TestClass", (), {})}
677+
mock_auto_populate.return_value = False
678+
679+
registered = TestBaseModel.registered_classes()
680+
mock_auto_populate.assert_called_once()
681+
assert isinstance(registered, tuple)
682+
assert len(registered) == 1
683+
684+
@pytest.mark.sanity
685+
def test_registered_classes_no_registry(self):
686+
"""Test PydanticClassRegistryMixin.registered_classes with no registry."""
687+
688+
class TestBaseModel(PydanticClassRegistryMixin):
689+
schema_discriminator: ClassVar[str] = "test_type"
690+
test_type: str
691+
692+
@classmethod
693+
def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]:
694+
if cls.__name__ == "TestBaseModel":
695+
return cls
696+
return TestBaseModel
697+
698+
# Ensure registry is None
699+
TestBaseModel.registry = None
700+
701+
with pytest.raises(ValueError) as exc_info:
702+
TestBaseModel.registered_classes()
703+
704+
assert "must be called after registering classes" in str(exc_info.value)
705+
624706
@pytest.mark.sanity
625707
def test_marshalling(self, valid_instances):
626708
"""Test PydanticClassRegistryMixin serialization and deserialization."""
@@ -707,4 +789,4 @@ class ContainerModel(BaseModel):
707789
assert isinstance(recreated.model, TestSubModelA)
708790
assert len(recreated.models) == 2
709791
assert isinstance(recreated.models[0], TestSubModelA)
710-
assert isinstance(recreated.models[1], TestSubModelB)
792+
assert isinstance(recreated.models[1], TestSubModelB)

0 commit comments

Comments
 (0)