diff --git a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py index 4f58b9d3e..519046979 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py @@ -1,19 +1,45 @@ -from typing import Final +from typing import Any, Final from smithy_core.codecs import Codec +from smithy_core.schemas import APIOperation from smithy_core.shapes import ShapeID +from smithy_http.aio.interfaces import HTTPErrorIdentifier, HTTPResponse from smithy_http.aio.protocols import HttpBindingClientProtocol from smithy_json import JSONCodec from ..traits import RestJson1Trait +class AWSErrorIdentifier(HTTPErrorIdentifier): + _HEADER_KEY: Final = "x-amzn-errortype" + + def identify( + self, + *, + operation: APIOperation[Any, Any], + response: HTTPResponse, + ) -> ShapeID | None: + if self._HEADER_KEY not in response.fields: + return None + + error_field = response.fields[self._HEADER_KEY] + code = error_field.values[0] if len(error_field.values) > 0 else None + if not code: + return None + + code = code.split(":")[0] + if "#" in code: + return ShapeID(code) + return ShapeID.from_parts(name=code, namespace=operation.schema.id.namespace) + + class RestJsonClientProtocol(HttpBindingClientProtocol): """An implementation of the aws.protocols#restJson1 protocol.""" - _id: ShapeID = RestJson1Trait.id - _codec: JSONCodec = JSONCodec() + _id: Final = RestJson1Trait.id + _codec: Final = JSONCodec() _contentType: Final = "application/json" + _error_identifier: Final = AWSErrorIdentifier() @property def id(self) -> ShapeID: @@ -26,3 +52,7 @@ def payload_codec(self) -> Codec: @property def content_type(self) -> str: return self._contentType + + @property + def error_identifier(self) -> HTTPErrorIdentifier: + return self._error_identifier diff --git a/packages/smithy-aws-core/src/smithy_aws_core/traits.py b/packages/smithy-aws-core/src/smithy_aws_core/traits.py index 180818774..1f5e0bca7 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/traits.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/traits.py @@ -10,8 +10,9 @@ from collections.abc import Mapping, Sequence from dataclasses import dataclass, field +from smithy_core.documents import DocumentValue from smithy_core.shapes import ShapeID -from smithy_core.traits import DocumentValue, DynamicTrait, Trait +from smithy_core.traits import DynamicTrait, Trait @dataclass(init=False, frozen=True) diff --git a/packages/smithy-aws-core/tests/unit/aio/__init__.py b/packages/smithy-aws-core/tests/unit/aio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/smithy-aws-core/tests/unit/aio/test_protocols.py b/packages/smithy-aws-core/tests/unit/aio/test_protocols.py new file mode 100644 index 000000000..82cf7d1e8 --- /dev/null +++ b/packages/smithy-aws-core/tests/unit/aio/test_protocols.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import Mock + +import pytest +from smithy_aws_core.aio.protocols import AWSErrorIdentifier +from smithy_core.schemas import APIOperation, Schema +from smithy_core.shapes import ShapeID, ShapeType +from smithy_http import Fields, tuples_to_fields +from smithy_http.aio import HTTPResponse + + +@pytest.mark.parametrize( + "header, expected", + [ + ("FooError", "com.test#FooError"), + ( + "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/", + "com.test#FooError", + ), + ( + "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate", + "com.test#FooError", + ), + ("", None), + (None, None), + ], +) +def test_aws_error_identifier(header: str | None, expected: ShapeID | None) -> None: + fields = Fields() + if header is not None: + fields = tuples_to_fields([("x-amzn-errortype", header)]) + http_response = HTTPResponse(status=500, fields=fields) + + operation = Mock(spec=APIOperation) + operation.schema = Schema( + id=ShapeID("com.test#TestOperation"), shape_type=ShapeType.OPERATION + ) + + error_identifier = AWSErrorIdentifier() + actual = error_identifier.identify(operation=operation, response=http_response) + + assert actual == expected diff --git a/packages/smithy-core/src/smithy_core/interfaces/__init__.py b/packages/smithy-core/src/smithy_core/interfaces/__init__.py index 3b1936a00..0123d2d2d 100644 --- a/packages/smithy-core/src/smithy_core/interfaces/__init__.py +++ b/packages/smithy-core/src/smithy_core/interfaces/__init__.py @@ -79,6 +79,14 @@ def is_bytes_reader(obj: Any) -> TypeGuard[BytesReader]: ) +@runtime_checkable +class SeekableBytesReader(BytesReader, Protocol): + """A synchronous bytes reader with seek and tell methods.""" + + def tell(self) -> int: ... + def seek(self, offset: int, whence: int = 0, /) -> int: ... + + # A union of all acceptable streaming blob types. Deserialized payloads will # always return a ByteStream, or AsyncByteStream if async is enabled. type StreamingBlob = BytesReader | bytes | bytearray diff --git a/packages/smithy-http/src/smithy_http/aio/interfaces/__init__.py b/packages/smithy-http/src/smithy_http/aio/interfaces/__init__.py index c3aaa390b..9c169c7c7 100644 --- a/packages/smithy-http/src/smithy_http/aio/interfaces/__init__.py +++ b/packages/smithy-http/src/smithy_http/aio/interfaces/__init__.py @@ -1,9 +1,11 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Protocol +from typing import Any, Protocol from smithy_core.aio.interfaces import ClientTransport, Request, Response from smithy_core.aio.utils import read_streaming_blob, read_streaming_blob_async +from smithy_core.schemas import APIOperation +from smithy_core.shapes import ShapeID from ...interfaces import ( Fields, @@ -83,3 +85,19 @@ async def send( :param request_config: Configuration specific to this request. """ ... + + +class HTTPErrorIdentifier: + """A class that uses HTTP response metadata to identify errors. + + The body of the response SHOULD NOT be touched by this. The payload codec will be + used instead to check for an ID in the body. + """ + + def identify( + self, + *, + operation: APIOperation[Any, Any], + response: HTTPResponse, + ) -> ShapeID | None: + """Idenitfy the ShapeID of an error from an HTTP response.""" diff --git a/packages/smithy-http/src/smithy_http/aio/protocols.py b/packages/smithy-http/src/smithy_http/aio/protocols.py index e038b6f01..e5591923a 100644 --- a/packages/smithy-http/src/smithy_http/aio/protocols.py +++ b/packages/smithy-http/src/smithy_http/aio/protocols.py @@ -1,20 +1,29 @@ import os from inspect import iscoroutinefunction from io import BytesIO +from typing import Any from smithy_core.aio.interfaces import ClientProtocol from smithy_core.codecs import Codec from smithy_core.deserializers import DeserializeableShape from smithy_core.documents import TypeRegistry -from smithy_core.exceptions import ExpectationNotMetError -from smithy_core.interfaces import Endpoint, TypedProperties, URI +from smithy_core.exceptions import CallError, ExpectationNotMetError, ModeledError +from smithy_core.interfaces import ( + Endpoint, + SeekableBytesReader, + TypedProperties, + URI, + is_streaming_blob, +) +from smithy_core.interfaces import StreamingBlob as SyncStreamingBlob +from smithy_core.prelude import DOCUMENT from smithy_core.schemas import APIOperation from smithy_core.serializers import SerializeableShape from smithy_core.traits import EndpointTrait, HTTPTrait -from smithy_http.aio.interfaces import HTTPRequest, HTTPResponse -from smithy_http.deserializers import HTTPResponseDeserializer -from smithy_http.serializers import HTTPRequestSerializer +from ..deserializers import HTTPResponseDeserializer +from ..serializers import HTTPRequestSerializer +from .interfaces import HTTPErrorIdentifier, HTTPRequest, HTTPResponse class HttpClientProtocol(ClientProtocol[HTTPRequest, HTTPResponse]): @@ -54,6 +63,12 @@ def content_type(self) -> str: """The media type of the http payload.""" raise NotImplementedError() + @property + def error_identifier(self) -> HTTPErrorIdentifier: + """The class used to identify the shape IDs of errors based on fields or other + response information.""" + raise NotImplementedError() + def serialize_request[ OperationInput: "SerializeableShape", OperationOutput: "DeserializeableShape", @@ -94,19 +109,25 @@ async def deserialize_response[ error_registry: TypeRegistry, context: TypedProperties, ) -> OperationOutput: - if not (200 <= response.status <= 299): - # TODO: implement error serde from type registry - raise NotImplementedError - body = response.body # if body is not streaming and is async, we have to buffer it - if not operation.output_stream_member: + if not operation.output_stream_member and not is_streaming_blob(body): if ( read := getattr(body, "read", None) ) is not None and iscoroutinefunction(read): body = BytesIO(await read()) + if not self._is_success(operation, context, response): + raise await self._create_error( + operation=operation, + request=request, + response=response, + response_body=body, # type: ignore + error_registry=error_registry, + context=context, + ) + # TODO(optimization): response binding cache like done in SJ deserializer = HTTPResponseDeserializer( payload_codec=self.payload_codec, @@ -116,3 +137,69 @@ async def deserialize_response[ ) return operation.output.deserialize(deserializer) + + def _is_success( + self, + operation: APIOperation[Any, Any], + context: TypedProperties, + response: HTTPResponse, + ) -> bool: + return 200 <= response.status < 300 + + async def _create_error( + self, + operation: APIOperation[Any, Any], + request: HTTPRequest, + response: HTTPResponse, + response_body: SyncStreamingBlob, + error_registry: TypeRegistry, + context: TypedProperties, + ) -> CallError: + error_id = self.error_identifier.identify( + operation=operation, response=response + ) + + if error_id is None: + if isinstance(response_body, bytearray): + response_body = bytes(response_body) + deserializer = self.payload_codec.create_deserializer(source=response_body) + document = deserializer.read_document(schema=DOCUMENT) + + if document.discriminator in error_registry: + error_id = document.discriminator + if isinstance(response_body, SeekableBytesReader): + response_body.seek(0) + + if error_id is not None and error_id in error_registry: + error_shape = error_registry.get(error_id) + + # make sure the error shape is derived from modeled exception + if not issubclass(error_shape, ModeledError): + raise ExpectationNotMetError( + f"Modeled errors must be derived from 'ModeledError', " + f"but got {error_shape}" + ) + + deserializer = HTTPResponseDeserializer( + payload_codec=self.payload_codec, + http_trait=operation.schema.expect_trait(HTTPTrait), + response=response, + body=response_body, + ) + return error_shape.deserialize(deserializer) + + is_throttle = response.status == 429 + message = ( + f"Unknown error for operation {operation.schema.id} " + f"- status: {response.status}" + ) + if error_id is not None: + message += f" - id: {error_id}" + if response.reason is not None: + message += f" - reason: {response.status}" + return CallError( + message=message, + fault="client" if response.status < 500 else "server", + is_throttling_error=is_throttle, + is_retry_safe=is_throttle or None, + ) diff --git a/packages/smithy-http/src/smithy_http/deserializers.py b/packages/smithy-http/src/smithy_http/deserializers.py index 0ddd63cbf..f48136db1 100644 --- a/packages/smithy-http/src/smithy_http/deserializers.py +++ b/packages/smithy-http/src/smithy_http/deserializers.py @@ -39,16 +39,17 @@ class HTTPResponseDeserializer(SpecificShapeDeserializer): # Note: caller will have to read the body if it's async and not streaming def __init__( self, + *, payload_codec: Codec, - http_trait: HTTPTrait, response: HTTPResponse, + http_trait: HTTPTrait | None = None, body: "SyncStreamingBlob | None" = None, ) -> None: """Initialize an HTTPResponseDeserializer. :param payload_codec: The Codec to use to deserialize the payload, if present. - :param http_trait: The HTTP trait of the operation being handled. :param response: The HTTP response to read from. + :param http_trait: The HTTP trait of the operation being handled. :param body: The HTTP response body in a synchronously readable form. This is necessary for async response bodies when there is no streaming member. """