|
1 | 1 | import os |
| 2 | +from collections.abc import AsyncIterable |
2 | 3 | from inspect import iscoroutinefunction |
3 | | -from io import BytesIO |
4 | 4 | from typing import Any |
5 | 5 |
|
6 | | -from smithy_core.aio.interfaces import ClientProtocol |
| 6 | +from smithy_core.aio.interfaces import AsyncByteStream, ClientProtocol |
| 7 | +from smithy_core.aio.interfaces import StreamingBlob as AsyncStreamingBlob |
7 | 8 | from smithy_core.codecs import Codec |
8 | 9 | from smithy_core.deserializers import DeserializeableShape |
9 | 10 | from smithy_core.documents import TypeRegistry |
@@ -109,35 +110,45 @@ async def deserialize_response[ |
109 | 110 | error_registry: TypeRegistry, |
110 | 111 | context: TypedProperties, |
111 | 112 | ) -> OperationOutput: |
112 | | - body = response.body |
113 | | - |
114 | | - # if body is not streaming and is async, we have to buffer it |
115 | | - if not operation.output_stream_member and not is_streaming_blob(body): |
116 | | - if ( |
117 | | - read := getattr(body, "read", None) |
118 | | - ) is not None and iscoroutinefunction(read): |
119 | | - body = BytesIO(await read()) |
120 | | - |
121 | 113 | if not self._is_success(operation, context, response): |
122 | 114 | raise await self._create_error( |
123 | 115 | operation=operation, |
124 | 116 | request=request, |
125 | 117 | response=response, |
126 | | - response_body=body, # type: ignore |
| 118 | + response_body=await self._buffer_async_body(response.body), |
127 | 119 | error_registry=error_registry, |
128 | 120 | context=context, |
129 | 121 | ) |
130 | 122 |
|
| 123 | + # if body is not streaming and is async, we have to buffer it |
| 124 | + body: SyncStreamingBlob | None = None |
| 125 | + if not operation.output_stream_member and not is_streaming_blob(body): |
| 126 | + body = await self._buffer_async_body(response.body) |
| 127 | + |
131 | 128 | # TODO(optimization): response binding cache like done in SJ |
132 | 129 | deserializer = HTTPResponseDeserializer( |
133 | 130 | payload_codec=self.payload_codec, |
134 | 131 | http_trait=operation.schema.expect_trait(HTTPTrait), |
135 | 132 | response=response, |
136 | | - body=body, # type: ignore |
| 133 | + body=body, |
137 | 134 | ) |
138 | 135 |
|
139 | 136 | return operation.output.deserialize(deserializer) |
140 | 137 |
|
| 138 | + async def _buffer_async_body(self, stream: AsyncStreamingBlob) -> SyncStreamingBlob: |
| 139 | + match stream: |
| 140 | + case AsyncByteStream(): |
| 141 | + if not iscoroutinefunction(stream.read): |
| 142 | + return stream # type: ignore |
| 143 | + return await stream.read() |
| 144 | + case AsyncIterable(): |
| 145 | + full = b"" |
| 146 | + async for chunk in stream: |
| 147 | + full += chunk |
| 148 | + return full |
| 149 | + case _: |
| 150 | + return stream |
| 151 | + |
141 | 152 | def _is_success( |
142 | 153 | self, |
143 | 154 | operation: APIOperation[Any, Any], |
|
0 commit comments