Skip to content

Commit 31ec8e2

Browse files
authored
Add a type registry (#444)
1 parent d65bf2e commit 31ec8e2

File tree

4 files changed

+131
-6
lines changed

4 files changed

+131
-6
lines changed

packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
from collections.abc import AsyncIterable
4-
from typing import Protocol, runtime_checkable, TYPE_CHECKING, Any
4+
from typing import Protocol, runtime_checkable, TYPE_CHECKING
55

66
from ...interfaces import URI, Endpoint, TypedProperties
77
from ...interfaces import StreamingBlob as SyncStreamingBlob
8-
8+
from ...documents import TypeRegistry
99

1010
if TYPE_CHECKING:
1111
from ...schemas import APIOperation
@@ -126,7 +126,7 @@ async def deserialize_response[
126126
operation: "APIOperation[OperationInput, OperationOutput]",
127127
request: I,
128128
response: O,
129-
error_registry: Any, # TODO: add error registry
129+
error_registry: TypeRegistry,
130130
context: TypedProperties,
131131
) -> OperationOutput:
132132
"""Deserializes the output from the tranport response or throws an exception.

packages/smithy-core/src/smithy_core/documents.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def shape_type(self) -> ShapeType:
143143
"""The Smithy data model type for the underlying contents of the document."""
144144
return self._type
145145

146+
@property
147+
def discriminator(self) -> ShapeID:
148+
"""The shape ID that corresponds to the contents of the document."""
149+
return self._schema.id
150+
146151
def is_none(self) -> bool:
147152
"""Indicates whether the document contains a null value."""
148153
return self._value is None and self._raw_value is None
@@ -633,3 +638,55 @@ def read_document(self, schema: "Schema") -> Document:
633638
@override
634639
def read_timestamp(self, schema: "Schema") -> datetime.datetime:
635640
return self._value.as_timestamp()
641+
642+
643+
class TypeRegistry:
644+
"""A registry for on-demand deserialization of types by using a mapping of shape IDs
645+
to their deserializers."""
646+
647+
def __init__(
648+
self,
649+
types: dict[ShapeID, type[DeserializeableShape]],
650+
sub_registry: "TypeRegistry | None" = None,
651+
):
652+
"""Initialize a TypeRegistry.
653+
654+
:param types: A mapping of ShapeID to the shapes they deserialize to.
655+
:param sub_registry: A registry to delegate to if an ID is not found in types.
656+
"""
657+
self._types = types
658+
self._sub_registry = sub_registry
659+
660+
def get(self, shape: ShapeID) -> type[DeserializeableShape]:
661+
"""Get the deserializable shape for the given shape ID.
662+
663+
:param shape: The shape ID to get from the registry.
664+
:returns: The corresponding deserializable shape.
665+
:raises KeyError: If the shape ID is not found in the registry.
666+
"""
667+
if shape in self._types:
668+
return self._types[shape]
669+
if self._sub_registry is not None:
670+
return self._sub_registry.get(shape)
671+
raise KeyError(f"Unknown shape: {shape}")
672+
673+
def __getitem__(self, shape: ShapeID):
674+
"""Get the deserializable shape for the given shape ID.
675+
676+
:param shape: The shape ID to get from the registry.
677+
:returns: The corresponding deserializable shape.
678+
:raises KeyError: If the shape ID is not found in the registry.
679+
"""
680+
return self.get(shape)
681+
682+
def __contains__(self, item: object, /):
683+
"""Check if the registry contains the given shape.
684+
685+
:param shape: The shape ID to check for.
686+
"""
687+
return item in self._types or (
688+
self._sub_registry is not None and item in self._sub_registry
689+
)
690+
691+
def deserialize(self, document: Document) -> DeserializeableShape:
692+
return document.as_shape(self.get(document.discriminator))

packages/smithy-core/src/smithy_core/schemas.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from .shapes import ShapeID, ShapeType
99
from .traits import Trait, DynamicTrait, IdempotencyTokenTrait, StreamingTrait
1010

11-
1211
if TYPE_CHECKING:
12+
from .documents import TypeRegistry
1313
from .serializers import SerializeableShape
1414
from .deserializers import DeserializeableShape
1515

@@ -289,8 +289,7 @@ class APIOperation[I: "SerializeableShape", O: "DeserializeableShape"]:
289289
output_schema: Schema
290290
"""The schema of the operation's output shape."""
291291

292-
# TODO: Add a type registry for errors
293-
error_registry: Any
292+
error_registry: "TypeRegistry"
294293
"""A TypeRegistry used to create errors."""
295294

296295
effective_auth_schemes: Sequence[ShapeID]
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer
2+
from smithy_core.documents import Document, TypeRegistry
3+
from smithy_core.schemas import Schema
4+
from smithy_core.shapes import ShapeID, ShapeType
5+
import pytest
6+
7+
8+
def test_get():
9+
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
10+
11+
result = registry[ShapeID("com.example#Test")]
12+
13+
assert result == TestShape
14+
15+
16+
def test_contains():
17+
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
18+
19+
assert ShapeID("com.example#Test") in registry
20+
21+
22+
def test_get_sub_registry():
23+
sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
24+
registry = TypeRegistry({}, sub_registry)
25+
26+
result = registry[ShapeID("com.example#Test")]
27+
28+
assert result == TestShape
29+
30+
31+
def test_contains_sub_registry():
32+
sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
33+
registry = TypeRegistry({}, sub_registry)
34+
35+
assert ShapeID("com.example#Test") in registry
36+
37+
38+
def test_get_no_match():
39+
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
40+
41+
with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"):
42+
registry[ShapeID("com.example#Test2")]
43+
44+
45+
def test_contains_no_match():
46+
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
47+
48+
assert ShapeID("com.example#Test2") not in registry
49+
50+
51+
def test_deserialize():
52+
shape_id = ShapeID("com.example#Test")
53+
registry = TypeRegistry({shape_id: TestShape})
54+
55+
result = registry.deserialize(Document("abc123", schema=TestShape.schema))
56+
57+
assert isinstance(result, TestShape) and result.value == "abc123"
58+
59+
60+
class TestShape(DeserializeableShape):
61+
__test__ = False
62+
schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING)
63+
64+
def __init__(self, value: str):
65+
self.value = value
66+
67+
@classmethod
68+
def deserialize(cls, deserializer: ShapeDeserializer) -> "TestShape":
69+
return TestShape(deserializer.read_string(schema=TestShape.schema))

0 commit comments

Comments
 (0)