Skip to content

Commit 1c8bed7

Browse files
committed
Add error classification
1 parent cb012c0 commit 1c8bed7

File tree

8 files changed

+210
-30
lines changed

8 files changed

+210
-30
lines changed

packages/smithy-core/src/smithy_core/aio/client.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ..auth import AuthParams
1313
from ..deserializers import DeserializeableShape, ShapeDeserializer
1414
from ..endpoints import EndpointResolverParams
15-
from ..exceptions import RetryError, SmithyError
15+
from ..exceptions import ClientTimeoutError, RetryError, SmithyError
1616
from ..interceptors import (
1717
InputContext,
1818
Interceptor,
@@ -448,24 +448,32 @@ async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape](
448448

449449
_LOGGER.debug("Sending request %s", request_context.transport_request)
450450

451-
if request_future is not None:
452-
# If we have an input event stream (or duplex event stream) then we
453-
# need to let the client return ASAP so that it can start sending
454-
# events. So here we start the transport send in a background task
455-
# then set the result of the request future. It's important to sequence
456-
# it just like that so that the client gets a stream that's ready
457-
# to send.
458-
transport_task = asyncio.create_task(
459-
self.transport.send(request=request_context.transport_request)
460-
)
461-
request_future.set_result(request_context)
462-
transport_response = await transport_task
463-
else:
464-
# If we don't have an input stream, there's no point in creating a
465-
# task, so we just immediately await the coroutine.
466-
transport_response = await self.transport.send(
467-
request=request_context.transport_request
468-
)
451+
try:
452+
if request_future is not None:
453+
# If we have an input event stream (or duplex event stream) then we
454+
# need to let the client return ASAP so that it can start sending
455+
# events. So here we start the transport send in a background task
456+
# then set the result of the request future. It's important to sequence
457+
# it just like that so that the client gets a stream that's ready
458+
# to send.
459+
transport_task = asyncio.create_task(
460+
self.transport.send(request=request_context.transport_request)
461+
)
462+
request_future.set_result(request_context)
463+
transport_response = await transport_task
464+
else:
465+
# If we don't have an input stream, there's no point in creating a
466+
# task, so we just immediately await the coroutine.
467+
transport_response = await self.transport.send(
468+
request=request_context.transport_request
469+
)
470+
except Exception as e:
471+
error_info = self.transport.get_error_info(e)
472+
if error_info.is_timeout_error:
473+
raise ClientTimeoutError(
474+
message=f"Client timeout occurred: {e}", fault=error_info.fault
475+
) from e
476+
raise
469477

470478
_LOGGER.debug("Received response: %s", transport_response)
471479

packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
from collections.abc import AsyncIterable, Callable
4-
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable
56

67
from ...documents import TypeRegistry
78
from ...endpoints import EndpointResolverParams
@@ -10,6 +11,18 @@
1011
from ...interfaces import StreamingBlob as SyncStreamingBlob
1112
from .eventstream import EventPublisher, EventReceiver
1213

14+
15+
@dataclass(frozen=True)
16+
class ErrorInfo:
17+
"""Information about an error from a transport."""
18+
19+
is_timeout_error: bool
20+
"""Whether this error represents a timeout condition."""
21+
22+
fault: Literal["client", "server"] = "client"
23+
"""Whether the client or server is at fault."""
24+
25+
1326
if TYPE_CHECKING:
1427
from typing_extensions import TypeForm
1528

@@ -86,7 +99,23 @@ async def resolve_endpoint(self, params: EndpointResolverParams[Any]) -> Endpoin
8699

87100

88101
class ClientTransport[I: Request, O: Response](Protocol):
89-
"""Protocol-agnostic representation of a client tranport (e.g. an HTTP client)."""
102+
"""Protocol-agnostic representation of a client transport (e.g. an HTTP client).
103+
104+
Transport implementations must define the get_error_info method to determine which
105+
exceptions represent timeout conditions for that transport.
106+
"""
107+
108+
def get_error_info(self, exception: Exception, **kwargs) -> ErrorInfo:
109+
"""Get information about an exception.
110+
111+
Args:
112+
exception: The exception to analyze
113+
**kwargs: Additional context for analysis
114+
115+
Returns:
116+
ErrorInfo with timeout and fault information.
117+
"""
118+
...
90119

91120
async def send(self, request: I) -> O:
92121
"""Send a request over the transport and receive the response."""

packages/smithy-core/src/smithy_core/exceptions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ class CallError(SmithyError):
5050
is_throttling_error: bool = False
5151
"""Whether the error is a throttling error."""
5252

53+
is_timeout_error: bool = False
54+
"""Whether the error represents a timeout condition."""
55+
5356
def __post_init__(self):
5457
super().__init__(self.message)
5558

@@ -61,6 +64,20 @@ class ModeledError(CallError):
6164
fault: Fault = "client"
6265

6366

67+
@dataclass(kw_only=True)
68+
class ClientTimeoutError(CallError):
69+
"""Exception raised when a client-side timeout occurs.
70+
71+
This error indicates that the client transport layer encountered a timeout while
72+
attempting to communicate with the server. This typically occurs when network
73+
requests take longer than the configured timeout period.
74+
"""
75+
76+
fault: Fault = "client"
77+
is_timeout_error: bool = True
78+
is_retry_safe: bool = True
79+
80+
6481
class SerializationError(SmithyError):
6582
"""Base exception type for exceptions raised during serialization."""
6683

packages/smithy-http/src/smithy_http/aio/aiohttp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
except ImportError:
2121
HAS_AIOHTTP = False # type: ignore
2222

23-
from smithy_core.aio.interfaces import StreamingBlob
23+
from smithy_core.aio.interfaces import ErrorInfo, StreamingBlob
2424
from smithy_core.aio.types import AsyncBytesReader
2525
from smithy_core.aio.utils import async_list
2626
from smithy_core.exceptions import MissingDependencyError
@@ -52,6 +52,14 @@ def __post_init__(self) -> None:
5252
class AIOHTTPClient(HTTPClient):
5353
"""Implementation of :py:class:`.interfaces.HTTPClient` using aiohttp."""
5454

55+
def get_error_info(self, exception: Exception, **kwargs) -> ErrorInfo:
56+
"""Get information about aiohttp errors."""
57+
58+
if isinstance(exception, TimeoutError):
59+
return ErrorInfo(is_timeout_error=True)
60+
61+
return ErrorInfo(is_timeout_error=False)
62+
5563
def __init__(
5664
self,
5765
*,

packages/smithy-http/src/smithy_http/aio/crt.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from inspect import iscoroutinefunction
99
from typing import TYPE_CHECKING, Any
1010

11+
from awscrt.exceptions import AwsCrtError
12+
1113
if TYPE_CHECKING:
1214
# pyright doesn't like optional imports. This is reasonable because if we use these
1315
# in type hints then they'd result in runtime errors.
@@ -33,6 +35,7 @@
3335

3436
from smithy_core import interfaces as core_interfaces
3537
from smithy_core.aio import interfaces as core_aio_interfaces
38+
from smithy_core.aio.interfaces import ErrorInfo
3639
from smithy_core.aio.types import AsyncBytesReader
3740
from smithy_core.exceptions import MissingDependencyError
3841

@@ -133,6 +136,22 @@ class AWSCRTHTTPClient(http_aio_interfaces.HTTPClient):
133136
_HTTP_PORT = 80
134137
_HTTPS_PORT = 443
135138

139+
def get_error_info(self, exception: Exception, **kwargs) -> ErrorInfo:
140+
"""Get information about CRT errors."""
141+
142+
timeout_indicators = (
143+
"AWS_IO_SOCKET_TIMEOUT",
144+
"AWS_IO_CHANNEL_ERROR_SOCKET_TIMEOUT",
145+
"AWS_ERROR_HTTP_REQUEST_TIMEOUT",
146+
)
147+
if isinstance(exception, TimeoutError):
148+
return ErrorInfo(is_timeout_error=True, fault="client")
149+
150+
if isinstance(exception, AwsCrtError) and exception.name in timeout_indicators:
151+
return ErrorInfo(is_timeout_error=True, fault="client")
152+
153+
return ErrorInfo(is_timeout_error=False)
154+
136155
def __init__(
137156
self,
138157
eventloop: _AWSCRTEventLoop | None = None,

packages/smithy-http/src/smithy_http/aio/protocols.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,6 @@ async def _create_error(
215215
)
216216
return error_shape.deserialize(deserializer)
217217

218-
is_throttle = response.status == 429
219218
message = (
220219
f"Unknown error for operation {operation.schema.id} "
221220
f"- status: {response.status}"
@@ -224,11 +223,22 @@ async def _create_error(
224223
message += f" - id: {error_id}"
225224
if response.reason is not None:
226225
message += f" - reason: {response.status}"
226+
227+
if response.status == 408:
228+
is_timeout = True
229+
fault = "server"
230+
else:
231+
is_timeout = False
232+
fault = "client" if response.status < 500 else "server"
233+
234+
is_throttle = response.status == 429
235+
227236
return CallError(
228237
message=message,
229-
fault="client" if response.status < 500 else "server",
238+
fault=fault,
230239
is_throttling_error=is_throttle,
231-
is_retry_safe=is_throttle or None,
240+
is_timeout_error=is_timeout,
241+
is_retry_safe=is_throttle or is_timeout or None,
232242
)
233243

234244
def _matches_content_type(self, response: HTTPResponse) -> bool:

packages/smithy-http/tests/unit/aio/test_protocols.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,24 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from typing import Any
5+
from unittest.mock import Mock
56

67
import pytest
78
from smithy_core import URI
89
from smithy_core.documents import TypeRegistry
910
from smithy_core.endpoints import Endpoint
10-
from smithy_core.interfaces import TypedProperties
1111
from smithy_core.interfaces import URI as URIInterface
1212
from smithy_core.schemas import APIOperation
1313
from smithy_core.shapes import ShapeID
14+
from smithy_core.types import TypedProperties
1415
from smithy_http import Fields
15-
from smithy_http.aio import HTTPRequest
16+
from smithy_http.aio import HTTPRequest, HTTPResponse
1617
from smithy_http.aio.interfaces import HTTPRequest as HTTPRequestInterface
1718
from smithy_http.aio.interfaces import HTTPResponse as HTTPResponseInterface
18-
from smithy_http.aio.protocols import HttpClientProtocol
19+
from smithy_http.aio.protocols import HttpBindingClientProtocol, HttpClientProtocol
1920

2021

21-
class TestProtocol(HttpClientProtocol):
22+
class MockProtocol(HttpClientProtocol):
2223
_id = ShapeID("ns.foo#bar")
2324

2425
@property
@@ -125,7 +126,7 @@ def deserialize_response(
125126
def test_http_protocol_joins_uris(
126127
request_uri: URI, endpoint_uri: URI, expected: URI
127128
) -> None:
128-
protocol = TestProtocol()
129+
protocol = MockProtocol()
129130
request = HTTPRequest(
130131
destination=request_uri,
131132
method="GET",
@@ -135,3 +136,28 @@ def test_http_protocol_joins_uris(
135136
updated_request = protocol.set_service_endpoint(request=request, endpoint=endpoint)
136137
actual = updated_request.destination
137138
assert actual == expected
139+
140+
141+
@pytest.mark.asyncio
142+
async def test_http_408_creates_timeout_error() -> None:
143+
"""Test that HTTP 408 creates a timeout error with server fault."""
144+
protocol = Mock(spec=HttpBindingClientProtocol)
145+
protocol.error_identifier = Mock()
146+
protocol.error_identifier.identify.return_value = None
147+
148+
response = HTTPResponse(status=408, fields=Fields())
149+
150+
error = await HttpBindingClientProtocol._create_error(
151+
protocol,
152+
operation=Mock(),
153+
request=HTTPRequest(
154+
destination=URI(host="example.com"), method="POST", fields=Fields()
155+
),
156+
response=response,
157+
response_body=b"",
158+
error_registry=TypeRegistry({}),
159+
context=TypedProperties(),
160+
)
161+
162+
assert error.is_timeout_error is True
163+
assert error.fault == "server"
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pytest
5+
from smithy_core.aio.interfaces import ErrorInfo
6+
7+
try:
8+
from smithy_http.aio.aiohttp import AIOHTTPClient
9+
10+
HAS_AIOHTTP = True
11+
except ImportError:
12+
HAS_AIOHTTP = False
13+
14+
try:
15+
from smithy_http.aio.crt import AWSCRTHTTPClient
16+
17+
HAS_CRT = True
18+
except ImportError:
19+
HAS_CRT = False
20+
21+
22+
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not available")
23+
class TestAIOHTTPTimeoutErrorHandling:
24+
"""Test timeout error handling for AIOHTTPClient."""
25+
26+
@pytest.fixture
27+
async def client(self):
28+
return AIOHTTPClient()
29+
30+
@pytest.mark.asyncio
31+
async def test_timeout_error_detection(self, client):
32+
"""Test timeout error detection for standard TimeoutError."""
33+
timeout_err = TimeoutError("Connection timed out")
34+
result = client.get_error_info(timeout_err)
35+
assert result == ErrorInfo(is_timeout_error=True, fault="client")
36+
37+
@pytest.mark.asyncio
38+
async def test_non_timeout_error_detection(self, client):
39+
"""Test non-timeout error detection."""
40+
other_err = ValueError("Not a timeout")
41+
result = client.get_error_info(other_err)
42+
assert result == ErrorInfo(is_timeout_error=False, fault="client")
43+
44+
45+
@pytest.mark.skipif(not HAS_CRT, reason="AWS CRT not available")
46+
class TestAWSCRTTimeoutErrorHandling:
47+
"""Test timeout error handling for AWSCRTHTTPClient."""
48+
49+
@pytest.fixture
50+
def client(self):
51+
return AWSCRTHTTPClient()
52+
53+
def test_timeout_error_detection(self, client):
54+
"""Test timeout error detection for standard TimeoutError."""
55+
timeout_err = TimeoutError("Connection timed out")
56+
result = client.get_error_info(timeout_err)
57+
assert result == ErrorInfo(is_timeout_error=True, fault="client")
58+
59+
def test_non_timeout_error_detection(self, client):
60+
"""Test non-timeout error detection."""
61+
other_err = ValueError("Not a timeout")
62+
result = client.get_error_info(other_err)
63+
assert result == ErrorInfo(is_timeout_error=False, fault="client")

0 commit comments

Comments
 (0)