Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from collections.abc import AsyncIterable
from typing import Protocol, runtime_checkable, TYPE_CHECKING, Any
from typing import Protocol, runtime_checkable, TYPE_CHECKING

from ...interfaces import URI, Endpoint, TypedProperties
from ...interfaces import StreamingBlob as SyncStreamingBlob

from ...documents import TypeRegistry

if TYPE_CHECKING:
from ...schemas import APIOperation
Expand Down Expand Up @@ -126,7 +126,7 @@ async def deserialize_response[
operation: "APIOperation[OperationInput, OperationOutput]",
request: I,
response: O,
error_registry: Any, # TODO: add error registry
error_registry: TypeRegistry,
context: TypedProperties,
) -> OperationOutput:
"""Deserializes the output from the tranport response or throws an exception.
Expand Down
26 changes: 26 additions & 0 deletions packages/smithy-core/src/smithy_core/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def shape_type(self) -> ShapeType:
"""The Smithy data model type for the underlying contents of the document."""
return self._type

@property
def discriminator(self) -> ShapeID:
"""The shape ID that corresponds to the contents of the document."""
return self._schema.id

def is_none(self) -> bool:
"""Indicates whether the document contains a null value."""
return self._value is None and self._raw_value is None
Expand Down Expand Up @@ -633,3 +638,24 @@ def read_document(self, schema: "Schema") -> Document:
@override
def read_timestamp(self, schema: "Schema") -> datetime.datetime:
return self._value.as_timestamp()


# A registry for on-demand deserialization of types by using a mapping of shape IDs to their deserializers.
class TypeRegistry:
def __init__(
self,
types: dict[ShapeID, type[DeserializeableShape]],
sub_registry: "TypeRegistry | None" = None,
):
self._types = types
self._sub_registry = sub_registry

def get(self, shape: ShapeID) -> type[DeserializeableShape]:
if shape in self._types:
return self._types[shape]
if self._sub_registry is not None:
return self._sub_registry.get(shape)
raise KeyError(f"Unknown shape: {shape}")

def deserialize(self, document: Document) -> DeserializeableShape:
return document.as_shape(self.get(document.discriminator))
5 changes: 2 additions & 3 deletions packages/smithy-core/src/smithy_core/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from .shapes import ShapeID, ShapeType
from .traits import Trait, DynamicTrait, IdempotencyTokenTrait, StreamingTrait


if TYPE_CHECKING:
from .documents import TypeRegistry
from .serializers import SerializeableShape
from .deserializers import DeserializeableShape

Expand Down Expand Up @@ -289,8 +289,7 @@ class APIOperation[I: "SerializeableShape", O: "DeserializeableShape"]:
output_schema: Schema
"""The schema of the operation's output shape."""

# TODO: Add a type registry for errors
error_registry: Any
error_registry: "TypeRegistry"
"""A TypeRegistry used to create errors."""

effective_auth_schemes: Sequence[ShapeID]
Expand Down
49 changes: 49 additions & 0 deletions packages/smithy-core/tests/unit/test_type_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer
from smithy_core.documents import Document, TypeRegistry
from smithy_core.schemas import Schema
from smithy_core.shapes import ShapeID, ShapeType
import pytest


def test_get():
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})

result = registry.get(ShapeID("com.example#Test"))

assert result == TestShape


def test_get_sub_registry():
sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
registry = TypeRegistry({}, sub_registry)

result = registry.get(ShapeID("com.example#Test"))

assert result == TestShape


def test_get_no_match():
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})

with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"):
registry.get(ShapeID("com.example#Test2"))


def test_deserialize():
shape_id = ShapeID("com.example#Test")
registry = TypeRegistry({shape_id: TestShape})

result = registry.deserialize(Document("abc123", schema=TestShape.schema))

assert isinstance(result, TestShape) and result.value == "abc123"


class TestShape(DeserializeableShape):
schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING)

def __init__(self, value: str):
self.value = value

@classmethod
def deserialize(cls, deserializer: ShapeDeserializer) -> "TestShape":
return TestShape(deserializer.read_string(schema=TestShape.schema))