Skip to content

Commit fcb036e

Browse files
Ensure streams get closed
1 parent 6db4183 commit fcb036e

File tree

8 files changed

+179
-39
lines changed

8 files changed

+179
-39
lines changed

packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,18 @@ def __init__(
4040
self._source = source
4141
self._is_client_mode = is_client_mode
4242
self._deserializer = deserializer
43+
self._closed = False
4344

4445
async def receive(self) -> E | None:
45-
event = await Event.decode_async(self._source)
46+
if self._closed:
47+
return None
48+
49+
try:
50+
event = await Event.decode_async(self._source)
51+
except Exception as e:
52+
await self.close()
53+
raise IOError("Failed to read from stream.") from e
54+
4655
if event is None:
4756
return None
4857

@@ -57,10 +66,18 @@ async def receive(self) -> E | None:
5766
return result
5867

5968
async def close(self) -> None:
69+
if self._closed:
70+
return
71+
self._closed = True
72+
6073
if (close := getattr(self._source, "close", None)) is not None:
6174
if asyncio.iscoroutine(result := close()):
6275
await result
6376

77+
@property
78+
def closed(self) -> bool:
79+
return self._closed
80+
6481

6582
class EventDeserializer(SpecificShapeDeserializer):
6683
def __init__(

packages/aws-event-stream/src/aws_event_stream/_private/serializers.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,11 @@ def __init__(
5050
self._serializer = EventSerializer(
5151
payload_codec=payload_codec, is_client_mode=is_client_mode
5252
)
53+
self._closed = False
5354

5455
async def send(self, event: E) -> None:
56+
if self._closed:
57+
raise IOError("Attempted to write to closed stream.")
5558
event.serialize(self._serializer)
5659
result = self._serializer.get_result()
5760
if result is None:
@@ -60,13 +63,27 @@ async def send(self, event: E) -> None:
6063
)
6164
if self._signer is not None:
6265
result = self._signer(result)
63-
await self._writer.write(result.encode())
66+
67+
encoded_result = result.encode()
68+
try:
69+
await self._writer.write(encoded_result)
70+
except Exception as e:
71+
await self.close()
72+
raise IOError("Failed to write to stream.") from e
6473

6574
async def close(self) -> None:
75+
if self._closed:
76+
return
77+
self._closed = True
78+
6679
if (close := getattr(self._writer, "close", None)) is not None:
6780
if asyncio.iscoroutine(result := close()):
6881
await result
6982

83+
@property
84+
def closed(self) -> bool:
85+
return self._closed
86+
7087

7188
class EventSerializer(SpecificShapeSerializer):
7289
def __init__(

packages/aws-event-stream/src/aws_event_stream/aio/__init__.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,31 +75,35 @@ def __init__(
7575
self.response: R | None = None
7676

7777
async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]:
78-
async_reader = AsyncBytesReader((await self._awaitable_response).body)
79-
if self.output_stream is None:
80-
self.output_stream = _AWSEventReceiver[O](
81-
payload_codec=self._payload_codec,
82-
source=async_reader,
83-
deserializer=self._deserializer,
84-
is_client_mode=self._is_client_mode,
85-
)
86-
87-
if self.response is None:
88-
if self._deserializeable_response is None:
89-
initial_response = await self._awaitable_output
90-
else:
91-
initial_response_stream = _AWSEventReceiver(
78+
try:
79+
async_reader = AsyncBytesReader((await self._awaitable_response).body)
80+
if self.output_stream is None:
81+
self.output_stream = _AWSEventReceiver[O](
9282
payload_codec=self._payload_codec,
9383
source=async_reader,
94-
deserializer=self._deserializeable_response.deserialize,
84+
deserializer=self._deserializer,
9585
is_client_mode=self._is_client_mode,
9686
)
97-
initial_response = await initial_response_stream.receive()
98-
if initial_response is None:
99-
raise MissingInitialResponse()
100-
self.response = initial_response
101-
else:
102-
initial_response = self.response
87+
88+
if self.response is None:
89+
if self._deserializeable_response is None:
90+
initial_response = await self._awaitable_output
91+
else:
92+
initial_response_stream = _AWSEventReceiver(
93+
payload_codec=self._payload_codec,
94+
source=async_reader,
95+
deserializer=self._deserializeable_response.deserialize,
96+
is_client_mode=self._is_client_mode,
97+
)
98+
initial_response = await initial_response_stream.receive()
99+
if initial_response is None:
100+
raise MissingInitialResponse()
101+
self.response = initial_response
102+
else:
103+
initial_response = self.response
104+
except Exception:
105+
await self.input_stream.close()
106+
raise
103107

104108
return initial_response, self.output_stream
105109

@@ -137,7 +141,11 @@ def __init__(
137141

138142
async def await_output(self) -> R:
139143
if self.response is None:
140-
self.response = await self._awaitable_response
144+
try:
145+
self.response = await self._awaitable_response
146+
except Exception:
147+
await self.input_stream.close()
148+
raise
141149
return self.response
142150

143151

packages/aws-event-stream/tests/unit/_private/test_deserializers.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,43 @@ def test_deserialize_unmodeled_error():
8989

9090
with pytest.raises(UnmodeledEventError, match="InternalError"):
9191
EventStreamOperationInputOutput.deserialize(deserializer)
92+
93+
94+
async def test_receiver_closes_source() -> None:
95+
source = AsyncBytesReader(b"")
96+
deserializer = EventStreamDeserializer()
97+
receiver = AWSAsyncEventReceiver[Any](
98+
payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize
99+
)
100+
assert not receiver.closed
101+
assert not source.closed
102+
await receiver.close()
103+
assert receiver.closed
104+
assert source.closed
105+
106+
107+
async def test_read_closed_receiver() -> None:
108+
source = AsyncBytesReader(b"")
109+
deserializer = EventStreamDeserializer()
110+
receiver = AWSAsyncEventReceiver[Any](
111+
payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize
112+
)
113+
114+
await receiver.close()
115+
assert receiver.closed
116+
assert await receiver.receive() is None
117+
118+
119+
async def test_read_closed_receiver_source() -> None:
120+
source = AsyncBytesReader(b"")
121+
deserializer = EventStreamDeserializer()
122+
receiver = AWSAsyncEventReceiver[Any](
123+
payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize
124+
)
125+
126+
await source.close()
127+
assert source.closed
128+
assert not receiver.closed
129+
with pytest.raises(IOError):
130+
await receiver.receive()
131+
assert receiver.closed

packages/aws-event-stream/tests/unit/_private/test_serializers.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
from typing import Any
4+
35
import pytest
46
from smithy_core.serializers import SerializeableShape
7+
from smithy_core.aio.types import AsyncBytesProvider
58
from smithy_json import JSONCodec
69

7-
from aws_event_stream._private.serializers import EventSerializer
10+
from aws_event_stream._private.serializers import (
11+
EventSerializer,
12+
AWSAsyncEventPublisher,
13+
)
814
from aws_event_stream.events import EventMessage
915

1016
from . import EVENT_STREAM_SERDE_CASES, INITIAL_REQUEST_CASE, INITIAL_RESPONSE_CASE
@@ -36,3 +42,41 @@ def test_serialize_initial_request():
3642

3743
def test_serialize_initial_response():
3844
test_event_serializer_server_mode(*INITIAL_RESPONSE_CASE)
45+
46+
47+
async def test_publisher_closes_reader():
48+
writer = AsyncBytesProvider()
49+
publisher: AWSAsyncEventPublisher[Any] = AWSAsyncEventPublisher(
50+
payload_codec=JSONCodec(), async_writer=writer
51+
)
52+
53+
assert not publisher.closed
54+
assert not writer.closed
55+
await publisher.close()
56+
assert publisher.closed
57+
assert writer.closed
58+
59+
60+
async def test_send_after_close():
61+
writer = AsyncBytesProvider()
62+
publisher: AWSAsyncEventPublisher[Any] = AWSAsyncEventPublisher(
63+
payload_codec=JSONCodec(), async_writer=writer
64+
)
65+
66+
await publisher.close()
67+
assert publisher.closed
68+
with pytest.raises(IOError):
69+
await publisher.send(EVENT_STREAM_SERDE_CASES[0][0])
70+
71+
72+
async def test_send_to_closed_writer():
73+
writer = AsyncBytesProvider()
74+
publisher: AWSAsyncEventPublisher[Any] = AWSAsyncEventPublisher(
75+
payload_codec=JSONCodec(), async_writer=writer
76+
)
77+
78+
await writer.close()
79+
with pytest.raises(IOError):
80+
await publisher.send(EVENT_STREAM_SERDE_CASES[0][0])
81+
82+
assert publisher.closed

packages/smithy-core/src/smithy_core/aio/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,12 @@ def closed(self) -> bool:
114114

115115
async def close(self) -> None:
116116
"""Closes the stream, as well as the underlying stream where possible."""
117+
self._closed = True
117118
if (close := getattr(self._data, "close", None)) is not None:
118119
if asyncio.iscoroutine(result := close()):
119120
await result
120121

121122
self._data = None
122-
self._closed = True
123123

124124

125125
class SeekableAsyncBytesReader:

packages/smithy-http/src/smithy_http/aio/crt.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections import deque
88
from collections.abc import AsyncGenerator, AsyncIterable
99
from copy import deepcopy
10+
from functools import partial
1011
from io import BytesIO, BufferedIOBase
1112
from typing import TYPE_CHECKING, Any
1213

@@ -231,7 +232,7 @@ async def send(
231232
:param request: The request including destination URI, fields, payload.
232233
:param request_config: Configuration specific to this request.
233234
"""
234-
crt_request = await self._marshal_request(request)
235+
crt_request, crt_body = await self._marshal_request(request)
235236
connection = await self._get_connection(request.destination)
236237
response_body = CRTResponseBody()
237238
response_factory = CRTResponseFactory(response_body)
@@ -242,8 +243,17 @@ async def send(
242243
)
243244
response_factory.set_done_callback(crt_stream)
244245
response_body.set_stream(crt_stream)
246+
crt_stream.completion_future.add_done_callback(
247+
partial(self._close_input_body, body=crt_body)
248+
)
245249
return await response_factory.await_response()
246250

251+
def _close_input_body(
252+
self, future: ConcurrentFuture[int], *, body: "BufferableByteStream | BytesIO"
253+
) -> None:
254+
if future.exception(timeout=0):
255+
body.close()
256+
247257
async def _create_connection(
248258
self, url: core_interfaces.URI
249259
) -> "crt_http.HttpClientConnection":
@@ -314,7 +324,7 @@ def _render_path(self, url: core_interfaces.URI) -> str:
314324

315325
async def _marshal_request(
316326
self, request: http_aio_interfaces.HTTPRequest
317-
) -> "crt_http.HttpRequest":
327+
) -> tuple["crt_http.HttpRequest", "BufferableByteStream | BytesIO"]:
318328
"""Create :py:class:`awscrt.http.HttpRequest` from
319329
:py:class:`smithy_http.aio.HTTPRequest`"""
320330
headers_list = []
@@ -343,13 +353,11 @@ async def _marshal_request(
343353
crt_body = BytesIO(body)
344354
else:
345355
# If the body is async, or potentially very large, start up a task to read
346-
# it into the BytesIO object that CRT needs. By using asyncio.create_task
347-
# we'll start the coroutine without having to explicitly await it.
356+
# it into the intermediate object that CRT needs. By using
357+
# asyncio.create_task we'll start the coroutine without having to
358+
# explicitly await it.
348359
crt_body = BufferableByteStream()
349-
if not isinstance(body, AsyncIterable):
350-
# If the body isn't already an async iterable, wrap it in one. Objects
351-
# with read methods will be read in chunks so as not to exhaust memory.
352-
body = AsyncBytesReader(body)
360+
body = AsyncBytesReader(body)
353361

354362
# Start the read task in the background.
355363
read_task = asyncio.create_task(self._consume_body_async(body, crt_body))
@@ -365,13 +373,19 @@ async def _marshal_request(
365373
headers=headers,
366374
body_stream=crt_body,
367375
)
368-
return crt_request
376+
return crt_request, crt_body
369377

370378
async def _consume_body_async(
371-
self, source: AsyncIterable[bytes], dest: "BufferableByteStream"
379+
self, source: AsyncBytesReader, dest: "BufferableByteStream"
372380
) -> None:
373-
async for chunk in source:
374-
dest.write(chunk)
381+
try:
382+
async for chunk in source:
383+
dest.write(chunk)
384+
except Exception:
385+
dest.close()
386+
raise
387+
finally:
388+
await source.close()
375389
dest.end_stream()
376390

377391
def __deepcopy__(self, memo: Any) -> "AWSCRTHTTPClient":

packages/smithy-http/tests/unit/aio/test_crt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ async def test_client_marshal_request() -> None:
2626
body=BytesIO(),
2727
fields=Fields(),
2828
)
29-
crt_request = await client._marshal_request(request) # type: ignore
29+
crt_request, _ = await client._marshal_request(request) # type: ignore
3030
assert crt_request.headers.get("host") == "example.com" # type: ignore
3131
assert crt_request.headers.get("accept") == "*/*" # type: ignore
3232
assert crt_request.method == "GET" # type: ignore

0 commit comments

Comments
 (0)