Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions packages/smithy-aws-core/src/smithy_aws_core/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from smithy_core.shapes import ShapeID
from smithy_http.aio.interfaces import ErrorExtractor, HTTPResponse


class AmznErrorExtractor(ErrorExtractor):
"""Attempts to extract the Amazon-specific 'X-Amzn-Errortype' error header from a
response."""

def get_error(self, response: HTTPResponse):
if "x-amzn-errortype" in response.fields:
val = response.fields["x-amzn-errortype"].values[0]
return ShapeID(val)
Comment on lines +11 to +14
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this will be enough. This can show up in a few other headers and has several formats which may or may not include a namespace. I could also be in the body.

11 changes: 9 additions & 2 deletions packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@

from smithy_core.codecs import Codec
from smithy_core.shapes import ShapeID
from smithy_http.aio.interfaces import ErrorExtractor
from smithy_http.aio.protocols import HttpBindingClientProtocol
from smithy_json import JSONCodec

from ..traits import RestJson1Trait
from . import AmznErrorExtractor


class RestJsonClientProtocol(HttpBindingClientProtocol):
"""An implementation of the aws.protocols#restJson1 protocol."""

_id: ShapeID = RestJson1Trait.id
_codec: JSONCodec = JSONCodec()
_contentType: Final = "application/json"
_content_type: Final = "application/json"
_error_extractor: ErrorExtractor = AmznErrorExtractor()

@property
def id(self) -> ShapeID:
Expand All @@ -25,4 +28,8 @@ def payload_codec(self) -> Codec:

@property
def content_type(self) -> str:
return self._contentType
return self._content_type

@property
def error_extractor(self) -> ErrorExtractor:
return self._error_extractor
32 changes: 32 additions & 0 deletions packages/smithy-core/src/smithy_core/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import dataclass
from typing import Literal

from smithy_core.deserializers import DeserializeableShape

type Fault = Literal["client", "server", "other"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other?



@dataclass(kw_only=True, frozen=True)
class CallException(RuntimeError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All exceptions we define should inherit from SmithyException

"""The top-level exception that should be used to throw application-level errors
from clients and servers.
This should be used in protocol error deserialization, throwing errors based on
protocol-hints, network errors, and shape validation errors. It should not be used
for illegal arguments, null argument validation, or other kinds of logic errors
sufficiently covered by the Java standard library.
"""

fault: Fault = "other"
"""The party that is at fault for the error, if any."""

message: str = ""
"""The error message."""

# TODO: retry-ability and associated information (throttling, duration, etc.), perhaps 'Retryability' dataclass?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah having a RetryableException protocol, mixin, or something would be ideal. Classifying is too complex without embedding the information in the exception.



@dataclass(kw_only=True, frozen=True)
class ModeledException(CallException, DeserializeableShape):
"""The top-level exception that should be used to throw modeled errors from clients
and servers."""
16 changes: 16 additions & 0 deletions packages/smithy-http/src/smithy_http/aio/interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from smithy_core.aio.interfaces import ClientTransport, Request, Response
from smithy_core.aio.utils import read_streaming_blob, read_streaming_blob_async
from smithy_core.shapes import ShapeID

from ...interfaces import (
Fields,
Expand Down Expand Up @@ -83,3 +84,18 @@ async def send(
:param request_config: Configuration specific to this request.
"""
...


class ErrorExtractor(Protocol):
"""Extract error shape IDs from an HTTP response."""

def get_error(
self,
response: HTTPResponse,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be tied to HTTP

) -> ShapeID | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might actually be the place to get the retry info

"""Get the shape id for an error by using information (such as headers) from a
response.

:param response: The response object to derive an error shape from.
"""
...
31 changes: 21 additions & 10 deletions packages/smithy-http/src/smithy_http/aio/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from smithy_core.serializers import SerializeableShape
from smithy_core.traits import EndpointTrait, HTTPTrait

from smithy_http.aio.interfaces import HTTPRequest, HTTPResponse
from smithy_http.deserializers import HTTPResponseDeserializer
from smithy_http.aio.interfaces import ErrorExtractor, HTTPRequest, HTTPResponse
from smithy_http.deserializers import HTTPErrorDeserializer, HTTPResponseDeserializer
from smithy_http.serializers import HTTPRequestSerializer


Expand Down Expand Up @@ -47,12 +47,17 @@ class HttpBindingClientProtocol(HttpClientProtocol):
@property
def payload_codec(self) -> Codec:
"""The codec used for the serde of input and output payloads."""
raise NotImplementedError()
raise NotImplementedError

@property
def content_type(self) -> str:
"""The media type of the http payload."""
raise NotImplementedError()
raise NotImplementedError

@property
def error_extractor(self) -> ErrorExtractor:
"""The error extractor used to extract errors from the response."""
raise NotImplementedError

def serialize_request[
OperationInput: "SerializeableShape",
Expand Down Expand Up @@ -94,19 +99,25 @@ async def deserialize_response[
error_registry: TypeRegistry,
context: TypedProperties,
) -> OperationOutput:
if not (200 <= response.status <= 299):
# TODO: implement error serde from type registry
raise NotImplementedError

body = response.body

# if body is not streaming and is async, we have to buffer it
body = response.body
if not operation.output_stream_member:
if (
read := getattr(body, "read", None)
) is not None and iscoroutinefunction(read):
body = BytesIO(await read())

# handle error response
if not (200 <= response.status <= 299):
error_deserializer = HTTPErrorDeserializer(
payload_codec=self.payload_codec,
extractor=self.error_extractor,
response=response,
body=body, # type: ignore
)

raise error_deserializer.read_error(operation, error_registry, context)

# TODO(optimization): response binding cache like done in SJ
deserializer = HTTPResponseDeserializer(
payload_codec=self.payload_codec,
Expand Down
89 changes: 82 additions & 7 deletions packages/smithy-http/src/smithy_http/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import datetime
from collections.abc import Callable
from decimal import Decimal
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from smithy_core.codecs import Codec
from smithy_core.deserializers import ShapeDeserializer, SpecificShapeDeserializer
from smithy_core.exceptions import UnsupportedStreamException
from smithy_core.interfaces import is_bytes_reader, is_streaming_blob
from smithy_core.schemas import Schema
from smithy_core.documents import TypeRegistry
from smithy_core.errors import CallException, ModeledException
from smithy_core.exceptions import (
ExpectationNotMetException,
UnsupportedStreamException,
)
from smithy_core.interfaces import TypedProperties, is_bytes_reader, is_streaming_blob
from smithy_core.prelude import DOCUMENT
from smithy_core.schemas import APIOperation, Schema
from smithy_core.shapes import ShapeType
from smithy_core.traits import (
HTTPHeaderTrait,
Expand All @@ -22,15 +28,15 @@
from smithy_core.types import TimestampFormat
from smithy_core.utils import ensure_utc, strict_parse_bool, strict_parse_float

from .aio.interfaces import HTTPResponse
from .interfaces import Field, Fields
from smithy_http.aio.interfaces import ErrorExtractor, HTTPResponse
from smithy_http.interfaces import Field, Fields

if TYPE_CHECKING:
from smithy_core.aio.interfaces import StreamingBlob as AsyncStreamingBlob
from smithy_core.interfaces import StreamingBlob as SyncStreamingBlob


__all__ = ["HTTPResponseDeserializer"]
__all__ = ["HTTPErrorDeserializer", "HTTPResponseDeserializer"]


class HTTPResponseDeserializer(SpecificShapeDeserializer):
Expand Down Expand Up @@ -257,3 +263,72 @@ def _consume_payload(self) -> bytes:
"Unable to read async stream. This stream must be buffered prior "
"to creating the deserializer."
)


class HTTPErrorDeserializer:
"""Binds an error response to a modelled or unknown exception."""

def __init__(
self,
payload_codec: Codec,
extractor: ErrorExtractor,
response: HTTPResponse,
body: "SyncStreamingBlob",
) -> None:
"""Initialize an HTTPErrorDeserializer.

:param payload_codec: The Codec to use to deserialize the payload, if present.
:param extractor: The error extractor to get error shape id from the response.
:param response: The HTTP response to read from.
:param body: The HTTP response body in a synchronously readable form. This is
necessary for async response bodies when there is no streaming member.
"""
self._payload_codec = payload_codec
self._response = response
self._body = body
self._extractor = extractor
self._codec = payload_codec

def read_error(
self,
operation: APIOperation[Any, Any],
error_registry: TypeRegistry,
context: TypedProperties,
) -> CallException:
body = self._body
if isinstance(body, bytearray):
body = bytes(body)
deserializer = self._payload_codec.create_deserializer(body)
document = deserializer.read_document(DOCUMENT)
Comment on lines +301 to +302
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can errors have streaming payloads? (blob streams that is)


# try to get the error shape-id from the extractor
error_id = self._extractor.get_error(self._response)

# if none, get it from the parsed document (e.g. '__type')
if error_id is None:
error_id = document.discriminator
Comment on lines +307 to +309
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're getting this document from the payload codec, but that's not necessarily gonna give you what you want. Different JSON protocols could embed this differently.


if error_id is not None:
error_shape = error_registry.get(error_id)
# make sure the error shape is derived from modeled exception
if not isinstance(error_shape, ModeledException):
raise ExpectationNotMetException(
f"Modeled errors must be derived from 'ModeledException', but got {error_shape}"
)

# return the deserialized error
return error_shape.deserialize(deserializer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to rewind the body since you've already read from it


# unknown error (no header, no type/unrecognized type)
fault = "other"
if 400 <= self._response.status < 500:
fault = "client"
elif self._response.status >= 500:
fault = "server"
message = (
f"Unknown error: {operation.output_schema.id} "
f"- code: {self._response.status} "
f"- reason: {self._response.reason}"
)

return CallException(message=message, fault=fault)
4 changes: 3 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.