Skip to content

Commit b2f4b4c

Browse files
committed
Add a type registry
1 parent cd6641f commit b2f4b4c

File tree

5 files changed

+88
-6
lines changed

5 files changed

+88
-6
lines changed

packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
from collections.abc import AsyncIterable
4-
from typing import Protocol, runtime_checkable, TYPE_CHECKING, Any
4+
from typing import Protocol, runtime_checkable, TYPE_CHECKING
55

66
from ...interfaces import URI, Endpoint, TypedProperties
77
from ...interfaces import StreamingBlob as SyncStreamingBlob
8-
8+
from ...type_registry import TypeRegistry
99

1010
if TYPE_CHECKING:
1111
from ...schemas import APIOperation
@@ -126,7 +126,7 @@ async def deserialize_response[
126126
operation: "APIOperation[OperationInput, OperationOutput]",
127127
request: I,
128128
response: O,
129-
error_registry: Any, # TODO: add error registry
129+
error_registry: TypeRegistry,
130130
context: TypedProperties,
131131
) -> OperationOutput:
132132
"""Deserializes the output from the tranport response or throws an exception.

packages/smithy-core/src/smithy_core/documents.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def shape_type(self) -> ShapeType:
143143
"""The Smithy data model type for the underlying contents of the document."""
144144
return self._type
145145

146+
@property
147+
def discriminator(self) -> ShapeID:
148+
"""The shape ID that corresponds to the contents of the document."""
149+
return self._schema.id
150+
146151
def is_none(self) -> bool:
147152
"""Indicates whether the document contains a null value."""
148153
return self._value is None and self._raw_value is None

packages/smithy-core/src/smithy_core/schemas.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .exceptions import ExpectationNotMetException, SmithyException
88
from .shapes import ShapeID, ShapeType
99
from .traits import Trait, DynamicTrait, IdempotencyTokenTrait, StreamingTrait
10-
10+
from .type_registry import TypeRegistry
1111

1212
if TYPE_CHECKING:
1313
from .serializers import SerializeableShape
@@ -289,8 +289,7 @@ class APIOperation[I: "SerializeableShape", O: "DeserializeableShape"]:
289289
output_schema: Schema
290290
"""The schema of the operation's output shape."""
291291

292-
# TODO: Add a type registry for errors
293-
error_registry: Any
292+
error_registry: TypeRegistry
294293
"""A TypeRegistry used to create errors."""
295294

296295
effective_auth_schemes: Sequence[ShapeID]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from smithy_core.deserializers import (
5+
DeserializeableShape,
6+
) # TODO: fix typo in deserializable
7+
from smithy_core.documents import Document
8+
from smithy_core.shapes import ShapeID
9+
10+
11+
# A registry for on-demand deserialization of types by using a mapping of shape IDs to their deserializers.
12+
# TODO: protocol? Also, move into documents.py?
13+
class TypeRegistry:
14+
def __init__(
15+
self,
16+
types: dict[ShapeID, type[DeserializeableShape]],
17+
sub_registry: "TypeRegistry | None" = None,
18+
):
19+
self._types = types
20+
self._sub_registry = sub_registry
21+
22+
def get(self, shape: ShapeID) -> type[DeserializeableShape]:
23+
if shape in self._types:
24+
return self._types[shape]
25+
if self._sub_registry is not None:
26+
return self._sub_registry.get(shape)
27+
raise KeyError(f"Unknown shape: {shape}")
28+
29+
def deserialize(self, document: Document) -> DeserializeableShape:
30+
return document.as_shape(self.get(document.discriminator))
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer
2+
from smithy_core.documents import Document
3+
from smithy_core.schemas import Schema
4+
from smithy_core.shapes import ShapeID, ShapeType
5+
from smithy_core.type_registry import TypeRegistry
6+
import pytest
7+
8+
9+
class TestTypeRegistry:
10+
def test_get(self):
11+
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
12+
13+
result = registry.get(ShapeID("com.example#Test"))
14+
15+
assert result == TestShape
16+
17+
def test_get_sub_registry(self):
18+
sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
19+
registry = TypeRegistry({}, sub_registry)
20+
21+
result = registry.get(ShapeID("com.example#Test"))
22+
23+
assert result == TestShape
24+
25+
def test_get_no_match(self):
26+
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
27+
28+
with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"):
29+
registry.get(ShapeID("com.example#Test2"))
30+
31+
def test_deserialize(self):
32+
shape_id = ShapeID("com.example#Test")
33+
registry = TypeRegistry({shape_id: TestShape})
34+
35+
result = registry.deserialize(Document("abc123", schema=TestShape.schema))
36+
37+
assert isinstance(result, TestShape) and result.value == "abc123"
38+
39+
40+
class TestShape(DeserializeableShape):
41+
schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING)
42+
43+
def __init__(self, value: str):
44+
self.value = value
45+
46+
@classmethod
47+
def deserialize(cls, deserializer: ShapeDeserializer) -> "TestShape":
48+
return TestShape(deserializer.read_string(schema=TestShape.schema))

0 commit comments

Comments
 (0)