Skip to content

Commit 6af3305

Browse files
committed
Address comments
1 parent 037c68c commit 6af3305

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

packages/smithy-core/src/smithy_core/documents.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,22 +640,53 @@ def read_timestamp(self, schema: "Schema") -> datetime.datetime:
640640
return self._value.as_timestamp()
641641

642642

643-
# A registry for on-demand deserialization of types by using a mapping of shape IDs to their deserializers.
644643
class TypeRegistry:
644+
"""A registry for on-demand deserialization of types by using a mapping of shape IDs
645+
to their deserializers."""
646+
645647
def __init__(
646648
self,
647649
types: dict[ShapeID, type[DeserializeableShape]],
648650
sub_registry: "TypeRegistry | None" = None,
649651
):
652+
"""Initialize a TypeRegistry.
653+
654+
:param types: A mapping of ShapeID to the shapes they deserialize to.
655+
:param sub_registry: A registry to delegate to if an ID is not found in types.
656+
"""
650657
self._types = types
651658
self._sub_registry = sub_registry
652659

653660
def get(self, shape: ShapeID) -> type[DeserializeableShape]:
661+
"""Get the deserializable shape for the given shape ID.
662+
663+
:param shape: The shape ID to get from the registry.
664+
:returns: The corresponding deserializable shape.
665+
:raises KeyError: If the shape ID is not found in the registry.
666+
"""
654667
if shape in self._types:
655668
return self._types[shape]
656669
if self._sub_registry is not None:
657670
return self._sub_registry.get(shape)
658671
raise KeyError(f"Unknown shape: {shape}")
659672

673+
def __getitem__(self, shape: ShapeID):
674+
"""Get the deserializable shape for the given shape ID.
675+
676+
:param shape: The shape ID to get from the registry.
677+
:returns: The corresponding deserializable shape.
678+
:raises KeyError: If the shape ID is not found in the registry.
679+
"""
680+
return self.get(shape)
681+
682+
def __contains__(self, shape: ShapeID):
683+
"""Check if the registry contains the given shape.
684+
685+
:param shape: The shape ID to check for.
686+
"""
687+
return shape in self._types or (
688+
self._sub_registry is not None and shape in self._sub_registry
689+
)
690+
660691
def deserialize(self, document: Document) -> DeserializeableShape:
661692
return document.as_shape(self.get(document.discriminator))

packages/smithy-core/tests/unit/test_type_registry.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,44 @@
88
def test_get():
99
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
1010

11-
result = registry.get(ShapeID("com.example#Test"))
11+
result = registry[ShapeID("com.example#Test")]
1212

1313
assert result == TestShape
1414

1515

16+
def test_contains():
17+
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
18+
19+
assert ShapeID("com.example#Test") in registry
20+
21+
1622
def test_get_sub_registry():
1723
sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
1824
registry = TypeRegistry({}, sub_registry)
1925

20-
result = registry.get(ShapeID("com.example#Test"))
26+
result = registry[ShapeID("com.example#Test")]
2127

2228
assert result == TestShape
2329

2430

31+
def test_contains_sub_registry():
32+
sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
33+
registry = TypeRegistry({}, sub_registry)
34+
35+
assert ShapeID("com.example#Test") in registry
36+
37+
2538
def test_get_no_match():
2639
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
2740

2841
with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"):
29-
registry.get(ShapeID("com.example#Test2"))
42+
registry[ShapeID("com.example#Test2")]
43+
44+
45+
def test_contains_no_match():
46+
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
47+
48+
assert ShapeID("com.example#Test2") not in registry
3049

3150

3251
def test_deserialize():
@@ -39,6 +58,7 @@ def test_deserialize():
3958

4059

4160
class TestShape(DeserializeableShape):
61+
__test__ = False
4262
schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING)
4363

4464
def __init__(self, value: str):

0 commit comments

Comments
 (0)