Skip to content

Commit 037c68c

Browse files
committed
Move type registry into documents, fix circular import
1 parent 7e01290 commit 037c68c

File tree

5 files changed

+46
-53
lines changed

5 files changed

+46
-53
lines changed

packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from ...interfaces import URI, Endpoint, TypedProperties
77
from ...interfaces import StreamingBlob as SyncStreamingBlob
8-
from ...type_registry import TypeRegistry
8+
from ...documents import TypeRegistry
99

1010
if TYPE_CHECKING:
1111
from ...schemas import APIOperation

packages/smithy-core/src/smithy_core/documents.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,3 +638,24 @@ def read_document(self, schema: "Schema") -> Document:
638638
@override
639639
def read_timestamp(self, schema: "Schema") -> datetime.datetime:
640640
return self._value.as_timestamp()
641+
642+
643+
# A registry for on-demand deserialization of types by using a mapping of shape IDs to their deserializers.
644+
class TypeRegistry:
645+
def __init__(
646+
self,
647+
types: dict[ShapeID, type[DeserializeableShape]],
648+
sub_registry: "TypeRegistry | None" = None,
649+
):
650+
self._types = types
651+
self._sub_registry = sub_registry
652+
653+
def get(self, shape: ShapeID) -> type[DeserializeableShape]:
654+
if shape in self._types:
655+
return self._types[shape]
656+
if self._sub_registry is not None:
657+
return self._sub_registry.get(shape)
658+
raise KeyError(f"Unknown shape: {shape}")
659+
660+
def deserialize(self, document: Document) -> DeserializeableShape:
661+
return document.as_shape(self.get(document.discriminator))

packages/smithy-core/src/smithy_core/schemas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from .exceptions import ExpectationNotMetException, SmithyException
88
from .shapes import ShapeID, ShapeType
99
from .traits import Trait, DynamicTrait, IdempotencyTokenTrait, StreamingTrait
10-
from .type_registry import TypeRegistry
1110

1211
if TYPE_CHECKING:
12+
from .documents import TypeRegistry
1313
from .serializers import SerializeableShape
1414
from .deserializers import DeserializeableShape
1515

@@ -289,7 +289,7 @@ class APIOperation[I: "SerializeableShape", O: "DeserializeableShape"]:
289289
output_schema: Schema
290290
"""The schema of the operation's output shape."""
291291

292-
error_registry: TypeRegistry
292+
error_registry: "TypeRegistry"
293293
"""A TypeRegistry used to create errors."""
294294

295295
effective_auth_schemes: Sequence[ShapeID]

packages/smithy-core/src/smithy_core/type_registry.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

packages/smithy-core/tests/unit/test_type_registry.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,41 @@
11
from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer
2-
from smithy_core.documents import Document
2+
from smithy_core.documents import Document, TypeRegistry
33
from smithy_core.schemas import Schema
44
from smithy_core.shapes import ShapeID, ShapeType
5-
from smithy_core.type_registry import TypeRegistry
65
import pytest
76

87

9-
class TestTypeRegistry:
10-
def test_get(self):
11-
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
8+
def test_get():
9+
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
1210

13-
result = registry.get(ShapeID("com.example#Test"))
11+
result = registry.get(ShapeID("com.example#Test"))
1412

15-
assert result == TestShape
13+
assert result == TestShape
1614

17-
def test_get_sub_registry(self):
18-
sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
19-
registry = TypeRegistry({}, sub_registry)
2015

21-
result = registry.get(ShapeID("com.example#Test"))
16+
def test_get_sub_registry():
17+
sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
18+
registry = TypeRegistry({}, sub_registry)
2219

23-
assert result == TestShape
20+
result = registry.get(ShapeID("com.example#Test"))
2421

25-
def test_get_no_match(self):
26-
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
22+
assert result == TestShape
2723

28-
with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"):
29-
registry.get(ShapeID("com.example#Test2"))
3024

31-
def test_deserialize(self):
32-
shape_id = ShapeID("com.example#Test")
33-
registry = TypeRegistry({shape_id: TestShape})
25+
def test_get_no_match():
26+
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
3427

35-
result = registry.deserialize(Document("abc123", schema=TestShape.schema))
28+
with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"):
29+
registry.get(ShapeID("com.example#Test2"))
3630

37-
assert isinstance(result, TestShape) and result.value == "abc123"
31+
32+
def test_deserialize():
33+
shape_id = ShapeID("com.example#Test")
34+
registry = TypeRegistry({shape_id: TestShape})
35+
36+
result = registry.deserialize(Document("abc123", schema=TestShape.schema))
37+
38+
assert isinstance(result, TestShape) and result.value == "abc123"
3839

3940

4041
class TestShape(DeserializeableShape):

0 commit comments

Comments
 (0)