Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,8 @@ private void writeUtilStubs(Symbol serviceSymbol) {
writer.addImport("smithy_http", "tuples_to_fields");
writer.addImport("smithy_http.aio", "HTTPResponse", "_HTTPResponse");
writer.addImport("smithy_core.aio.utils", "async_list");
writer.addImport("smithy_core.aio.interfaces", "ErrorInfo");
writer.addStdlibImport("typing", "Any");

writer.write("""
class $1L($2T):
Expand All @@ -634,6 +636,10 @@ class $3L:
def __init__(self, *, client_config: HTTPClientConfiguration | None = None):
self._client_config = client_config

def get_error_info(self, exception: Exception, **kwargs: Any) -> ErrorInfo:
\"\"\"Get information about an exception.\"\"\"
return ErrorInfo(is_timeout_error=False, fault="client")

async def send(
self, request: HTTPRequest, *, request_config: HTTPRequestConfiguration | None = None
) -> HTTPResponse:
Expand All @@ -657,6 +663,10 @@ def __init__(
self.fields = tuples_to_fields(headers or [])
self.body = body

def get_error_info(self, exception: Exception, **kwargs: Any) -> ErrorInfo:
\"\"\"Get information about an exception.\"\"\"
return ErrorInfo(is_timeout_error=False, fault="client")

async def send(
self, request: HTTPRequest, *, request_config: HTTPRequestConfiguration | None = None
) -> _HTTPResponse:
Expand Down
46 changes: 27 additions & 19 deletions packages/smithy-core/src/smithy_core/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..auth import AuthParams
from ..deserializers import DeserializeableShape, ShapeDeserializer
from ..endpoints import EndpointResolverParams
from ..exceptions import RetryError, SmithyError
from ..exceptions import ClientTimeoutError, RetryError, SmithyError
from ..interceptors import (
InputContext,
Interceptor,
Expand Down Expand Up @@ -448,24 +448,32 @@ async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape](

_LOGGER.debug("Sending request %s", request_context.transport_request)

if request_future is not None:
# If we have an input event stream (or duplex event stream) then we
# need to let the client return ASAP so that it can start sending
# events. So here we start the transport send in a background task
# then set the result of the request future. It's important to sequence
# it just like that so that the client gets a stream that's ready
# to send.
transport_task = asyncio.create_task(
self.transport.send(request=request_context.transport_request)
)
request_future.set_result(request_context)
transport_response = await transport_task
else:
# If we don't have an input stream, there's no point in creating a
# task, so we just immediately await the coroutine.
transport_response = await self.transport.send(
request=request_context.transport_request
)
try:
if request_future is not None:
# If we have an input event stream (or duplex event stream) then we
# need to let the client return ASAP so that it can start sending
# events. So here we start the transport send in a background task
# then set the result of the request future. It's important to sequence
# it just like that so that the client gets a stream that's ready
# to send.
transport_task = asyncio.create_task(
self.transport.send(request=request_context.transport_request)
)
request_future.set_result(request_context)
transport_response = await transport_task
else:
# If we don't have an input stream, there's no point in creating a
# task, so we just immediately await the coroutine.
transport_response = await self.transport.send(
request=request_context.transport_request
)
except Exception as e:
error_info = self.transport.get_error_info(e)
if error_info.is_timeout_error:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based off the docstring for ClientTimeoutError, shouldn't we be checking the fault value? What if the fault is set to server?

Exception raised when a client-side timeout occurs.

raise ClientTimeoutError(
message=f"Client timeout occurred: {e}", fault=error_info.fault
) from e
raise

_LOGGER.debug("Received response: %s", transport_response)

Expand Down
33 changes: 31 additions & 2 deletions packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from collections.abc import AsyncIterable, Callable
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable

from ...documents import TypeRegistry
from ...endpoints import EndpointResolverParams
Expand All @@ -10,6 +11,18 @@
from ...interfaces import StreamingBlob as SyncStreamingBlob
from .eventstream import EventPublisher, EventReceiver


@dataclass(frozen=True)
class ErrorInfo:
"""Information about an error from a transport."""

is_timeout_error: bool
"""Whether this error represents a timeout condition."""

fault: Literal["client", "server"] = "client"
"""Whether the client or server is at fault."""


if TYPE_CHECKING:
from typing_extensions import TypeForm

Expand Down Expand Up @@ -86,7 +99,23 @@ async def resolve_endpoint(self, params: EndpointResolverParams[Any]) -> Endpoin


class ClientTransport[I: Request, O: Response](Protocol):
"""Protocol-agnostic representation of a client tranport (e.g. an HTTP client)."""
"""Protocol-agnostic representation of a client transport (e.g. an HTTP client).

Transport implementations must define the get_error_info method to determine which
exceptions represent timeout conditions for that transport.
"""

def get_error_info(self, exception: Exception, **kwargs: Any) -> ErrorInfo:
"""Get information about an exception.

Args:
exception: The exception to analyze
**kwargs: Additional context for analysis

Returns:
ErrorInfo with timeout and fault information.
"""
...

async def send(self, request: I) -> O:
"""Send a request over the transport and receive the response."""
Expand Down
17 changes: 17 additions & 0 deletions packages/smithy-core/src/smithy_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class CallError(SmithyError):
is_throttling_error: bool = False
"""Whether the error is a throttling error."""

is_timeout_error: bool = False
"""Whether the error represents a timeout condition."""

def __post_init__(self):
super().__init__(self.message)

Expand All @@ -61,6 +64,20 @@ class ModeledError(CallError):
fault: Fault = "client"


@dataclass(kw_only=True)
class ClientTimeoutError(CallError):
"""Exception raised when a client-side timeout occurs.

This error indicates that the client transport layer encountered a timeout while
attempting to communicate with the server. This typically occurs when network
requests take longer than the configured timeout period.
"""

fault: Fault = "client"
is_timeout_error: bool = True
is_retry_safe: bool | None = True


class SerializationError(SmithyError):
"""Base exception type for exceptions raised during serialization."""

Expand Down
10 changes: 9 additions & 1 deletion packages/smithy-http/src/smithy_http/aio/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
except ImportError:
HAS_AIOHTTP = False # type: ignore

from smithy_core.aio.interfaces import StreamingBlob
from smithy_core.aio.interfaces import ErrorInfo, StreamingBlob
from smithy_core.aio.types import AsyncBytesReader
from smithy_core.aio.utils import async_list
from smithy_core.exceptions import MissingDependencyError
Expand Down Expand Up @@ -52,6 +52,14 @@ def __post_init__(self) -> None:
class AIOHTTPClient(HTTPClient):
"""Implementation of :py:class:`.interfaces.HTTPClient` using aiohttp."""

def get_error_info(self, exception: Exception, **kwargs: Any) -> ErrorInfo:
"""Get information about aiohttp errors."""

if isinstance(exception, TimeoutError):
return ErrorInfo(is_timeout_error=True)

return ErrorInfo(is_timeout_error=False)

def __init__(
self,
*,
Expand Down
19 changes: 19 additions & 0 deletions packages/smithy-http/src/smithy_http/aio/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING, Any

from awscrt.exceptions import AwsCrtError

if TYPE_CHECKING:
# pyright doesn't like optional imports. This is reasonable because if we use these
# in type hints then they'd result in runtime errors.
Expand All @@ -33,6 +35,7 @@

from smithy_core import interfaces as core_interfaces
from smithy_core.aio import interfaces as core_aio_interfaces
from smithy_core.aio.interfaces import ErrorInfo
from smithy_core.aio.types import AsyncBytesReader
from smithy_core.exceptions import MissingDependencyError

Expand Down Expand Up @@ -133,6 +136,22 @@ class AWSCRTHTTPClient(http_aio_interfaces.HTTPClient):
_HTTP_PORT = 80
_HTTPS_PORT = 443

def get_error_info(self, exception: Exception, **kwargs: Any) -> ErrorInfo:
"""Get information about CRT errors."""

timeout_indicators = (
"AWS_IO_SOCKET_TIMEOUT",
"AWS_IO_CHANNEL_ERROR_SOCKET_TIMEOUT",
"AWS_ERROR_HTTP_REQUEST_TIMEOUT",
Comment on lines +143 to +145
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are these coming from? Is this an exhaustive list for CRT errors?

)
if isinstance(exception, TimeoutError):
return ErrorInfo(is_timeout_error=True, fault="client")

if isinstance(exception, AwsCrtError) and exception.name in timeout_indicators:
return ErrorInfo(is_timeout_error=True, fault="client")

return ErrorInfo(is_timeout_error=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why set fault="client" above and not here?


def __init__(
self,
eventloop: _AWSCRTEventLoop | None = None,
Expand Down
16 changes: 13 additions & 3 deletions packages/smithy-http/src/smithy_http/aio/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ async def _create_error(
)
return error_shape.deserialize(deserializer)

is_throttle = response.status == 429
message = (
f"Unknown error for operation {operation.schema.id} "
f"- status: {response.status}"
Expand All @@ -224,11 +223,22 @@ async def _create_error(
message += f" - id: {error_id}"
if response.reason is not None:
message += f" - reason: {response.status}"

if response.status == 408:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why special case 408?

is_timeout = True
fault = "server"
else:
is_timeout = False
fault = "client" if response.status < 500 else "server"

is_throttle = response.status == 429

return CallError(
message=message,
fault="client" if response.status < 500 else "server",
fault=fault,
is_throttling_error=is_throttle,
is_retry_safe=is_throttle or None,
is_timeout_error=is_timeout,
is_retry_safe=is_throttle or is_timeout or None,
)

def _matches_content_type(self, response: HTTPResponse) -> bool:
Expand Down
35 changes: 31 additions & 4 deletions packages/smithy-http/tests/unit/aio/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any
from unittest.mock import Mock

import pytest
from smithy_core import URI
Expand All @@ -11,14 +12,15 @@
from smithy_core.interfaces import URI as URIInterface
from smithy_core.schemas import APIOperation
from smithy_core.shapes import ShapeID
from smithy_core.types import TypedProperties as ConcreteTypedProperties
from smithy_http import Fields
from smithy_http.aio import HTTPRequest
from smithy_http.aio import HTTPRequest, HTTPResponse
from smithy_http.aio.interfaces import HTTPRequest as HTTPRequestInterface
from smithy_http.aio.interfaces import HTTPResponse as HTTPResponseInterface
from smithy_http.aio.protocols import HttpClientProtocol
from smithy_http.aio.protocols import HttpBindingClientProtocol, HttpClientProtocol


class TestProtocol(HttpClientProtocol):
class MockProtocol(HttpClientProtocol):
_id = ShapeID("ns.foo#bar")

@property
Expand Down Expand Up @@ -125,7 +127,7 @@ def deserialize_response(
def test_http_protocol_joins_uris(
request_uri: URI, endpoint_uri: URI, expected: URI
) -> None:
protocol = TestProtocol()
protocol = MockProtocol()
request = HTTPRequest(
destination=request_uri,
method="GET",
Expand All @@ -135,3 +137,28 @@ def test_http_protocol_joins_uris(
updated_request = protocol.set_service_endpoint(request=request, endpoint=endpoint)
actual = updated_request.destination
assert actual == expected


@pytest.mark.asyncio
async def test_http_408_creates_timeout_error() -> None:
"""Test that HTTP 408 creates a timeout error with server fault."""
protocol = Mock(spec=HttpBindingClientProtocol)
protocol.error_identifier = Mock()
protocol.error_identifier.identify.return_value = None

response = HTTPResponse(status=408, fields=Fields())

error = await HttpBindingClientProtocol._create_error( # type: ignore[reportPrivateUsage]
protocol,
operation=Mock(),
request=HTTPRequest(
destination=URI(host="example.com"), method="POST", fields=Fields()
),
response=response,
response_body=b"",
error_registry=TypeRegistry({}),
context=ConcreteTypedProperties(),
)

assert error.is_timeout_error is True
assert error.fault == "server"
Loading
Loading