11# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22# SPDX-License-Identifier: Apache-2.0
3+ import os .path
4+ from asyncio import iscoroutinefunction
35from collections .abc import AsyncIterable
6+ from io import BytesIO
47from typing import Protocol , runtime_checkable , TYPE_CHECKING , Any
58
9+ from smithy_http .aio .interfaces import HTTPRequest , HTTPResponse
10+ from smithy_http .deserializers import HTTPResponseDeserializer
11+ from smithy_http .serializers import HTTPRequestSerializer
12+ from smithy_json import JSONCodec
13+ from ...codecs import Codec
614from ...interfaces import URI , Endpoint
715from ...interfaces import StreamingBlob as SyncStreamingBlob
8-
16+ from ...traits import HTTPTrait , EndpointTrait , RestJson1Trait
17+ from ...type_registry import TypeRegistry
918
1019if TYPE_CHECKING :
1120 from ...schemas import APIOperation
@@ -126,7 +135,7 @@ async def deserialize_response[
126135 operation : APIOperation [OperationInput , OperationOutput ],
127136 request : I ,
128137 response : O ,
129- error_registry : Any , # TODO: add error registry
138+ error_registry : TypeRegistry ,
130139 context : dict [str , Any ], # TODO: replace with a typed context bag
131140 ) -> OperationOutput :
132141 """Deserializes the output from the tranport response or throws an exception.
@@ -138,3 +147,121 @@ async def deserialize_response[
138147 :param context: A context bag for the request.
139148 """
140149 ...
150+
151+
152+ class HttpClientProtocol (ClientProtocol [HTTPRequest , HTTPResponse ]):
153+ def set_service_endpoint (
154+ self ,
155+ * ,
156+ request : HTTPRequest ,
157+ endpoint : Endpoint ,
158+ ) -> HTTPRequest :
159+ """Update the endpoint of a transport request.
160+
161+ :param request: The request whose endpoint should be updated.
162+ :param endpoint: The endpoint to set on the request.
163+ """
164+ uri = endpoint .uri
165+ uri_builder = request .destination
166+
167+ if uri .scheme :
168+ uri_builder .scheme = uri .scheme
169+ if uri .host :
170+ uri_builder .host = uri .host
171+ if uri .port and uri .port > - 1 :
172+ uri_builder .port = uri .port
173+ if uri .path :
174+ # TODO: verify, uri helper?
175+ uri_builder .path = os .path .join (uri .path , uri_builder .path or "" )
176+ # TODO: merge headers from the endpoint properties bag
177+ return request
178+
179+
180+ class HttpBindingClientProtocol (HttpClientProtocol ):
181+ @property
182+ def codec (self ) -> Codec :
183+ """The codec used for the serde of input and output shapes."""
184+ ...
185+
186+ @property
187+ def content_type (self ) -> str :
188+ """The media type of the http payload."""
189+ ...
190+
191+ def serialize_request [
192+ OperationInput : "SerializeableShape" ,
193+ OperationOutput : "DeserializeableShape" ,
194+ ](
195+ self ,
196+ * ,
197+ operation : APIOperation [OperationInput , OperationOutput ],
198+ input : OperationInput ,
199+ endpoint : URI ,
200+ context : dict [str , Any ],
201+ ) -> HTTPRequest :
202+ # TODO: request binding cache like done in SJ
203+ serializer = HTTPRequestSerializer (
204+ payload_codec = self .codec ,
205+ http_trait = operation .schema .expect_trait (HTTPTrait ), # TODO
206+ endpoint_trait = operation .schema .get_trait (EndpointTrait ),
207+ )
208+
209+ input .serialize (serializer = serializer )
210+ request = serializer .result
211+
212+ if request is None :
213+ raise ValueError ("Request is None" ) # TODO
214+
215+ request .fields ["content-type" ].add (self .content_type )
216+ return request
217+
218+ async def deserialize_response [
219+ OperationInput : "SerializeableShape" ,
220+ OperationOutput : "DeserializeableShape" ,
221+ ](
222+ self ,
223+ * ,
224+ operation : APIOperation [OperationInput , OperationOutput ],
225+ request : HTTPRequest ,
226+ response : HTTPResponse ,
227+ error_registry : TypeRegistry ,
228+ context : dict [str , Any ], # TODO: replace with a typed context bag
229+ ) -> OperationOutput :
230+ if not (200 <= response .status <= 299 ): # TODO: extract to utility
231+ # TODO: implement error serde from type registry
232+ raise NotImplementedError
233+
234+ body = response .body
235+ # TODO: extract to utility, seems common
236+ if (read := getattr (body , "read" , None )) is not None and iscoroutinefunction (
237+ read
238+ ):
239+ body = BytesIO (await read ())
240+
241+ # TODO: response binding cache like done in SJ
242+ deserializer = HTTPResponseDeserializer (
243+ payload_codec = self .codec ,
244+ http_trait = operation .schema .expect_trait (HTTPTrait ),
245+ response = response ,
246+ body = body , # type: ignore
247+ )
248+
249+ return operation .output .deserialize (deserializer )
250+
251+
252+ class RestJsonClientProtocol (HttpBindingClientProtocol ):
253+ _id : ShapeID = RestJson1Trait .id
254+ _codec : JSONCodec = JSONCodec ()
255+ _contentType : str = "application/json"
256+
257+ @property
258+ def id (self ) -> ShapeID :
259+ return self ._id
260+
261+ @property
262+ def codec (self ) -> Codec :
263+ return self ._codec
264+
265+ @property
266+ def content_type (self ) -> str :
267+ return self ._contentType
0 commit comments