99
1010from smithy_core import URI
1111from smithy_core .codecs import Codec
12+ from smithy_core .exceptions import SerializationError
1213from smithy_core .schemas import Schema
1314from smithy_core .serializers import (
1415 InterceptingSerializer ,
2425 HTTPQueryTrait ,
2526 HTTPTrait ,
2627 MediaTypeTrait ,
28+ RequiresLengthTrait ,
2729 TimestampFormatTrait ,
2830)
2931from smithy_core .types import PathPattern , TimestampFormat
3032from smithy_core .utils import serialize_float
3133
32- from . import tuples_to_fields
34+ from . import Field , tuples_to_fields
3335from .aio import HTTPRequest as _HTTPRequest
3436from .aio import HTTPResponse as _HTTPResponse
3537from .aio .interfaces import HTTPRequest , HTTPResponse
4345__all__ = ["HTTPRequestSerializer" , "HTTPResponseSerializer" ]
4446
4547
48+ # TODO: refactor this to share code with response serializer
4649class HTTPRequestSerializer (SpecificShapeSerializer ):
4750 """Binds a serializable shape to an HTTP request.
4851
@@ -82,8 +85,12 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
8285 host_prefix = self ._endpoint_trait .host_prefix
8386
8487 content_type = self ._payload_codec .media_type
88+ content_length : int | None = None
89+ content_length_required = False
90+
8591 binding_matcher = RequestBindingMatcher (schema )
8692 if (payload_member := binding_matcher .payload_member ) is not None :
93+ content_length_required = RequiresLengthTrait in payload_member
8794 if payload_member .shape_type in (
8895 ShapeType .BLOB ,
8996 ShapeType .STRING ,
@@ -105,6 +112,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
105112 )
106113 yield binding_serializer
107114 payload = payload_serializer .payload
115+ try :
116+ content_length = len (payload )
117+ except TypeError :
118+ pass
108119 else :
109120 if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
110121 content_type = media_type .value
@@ -117,6 +128,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
117128 binding_matcher ,
118129 )
119130 yield binding_serializer
131+ content_length = payload .tell ()
132+ payload .seek (0 )
120133 else :
121134 payload = BytesIO ()
122135 payload_serializer = self ._payload_codec .create_serializer (payload )
@@ -131,25 +144,36 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
131144 binding_matcher ,
132145 )
133146 yield binding_serializer
147+ content_length = payload .tell ()
134148 else :
135149 content_type = None
150+ content_length = 0
136151 binding_serializer = HTTPRequestBindingSerializer (
137152 payload_serializer ,
138153 self ._http_trait .path ,
139154 host_prefix ,
140155 binding_matcher ,
141156 )
142157 yield binding_serializer
143-
144- if (
145- seek := getattr (payload , "seek" , None )
146- ) is not None and not iscoroutinefunction (seek ):
147- seek (0 )
158+ payload .seek (0 )
148159
149160 headers = binding_serializer .header_serializer .headers
150161 if content_type is not None :
151162 headers .append (("content-type" , content_type ))
152163
164+ if content_length is not None :
165+ headers .append (("content-length" , str (content_length )))
166+
167+ fields = tuples_to_fields (headers )
168+ if content_length_required and "content-length" not in fields :
169+ content_length = _compute_content_length (payload )
170+ if content_length is None :
171+ raise SerializationError (
172+ "This operation requires the the content length of the input "
173+ "stream, but it was not provided and was unable to be computed."
174+ )
175+ fields .set_field (Field (name = "content-length" , values = [str (content_length )]))
176+
153177 self .result = _HTTPRequest (
154178 method = self ._http_trait .method ,
155179 destination = URI (
@@ -160,11 +184,30 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
160184 prefix = self ._http_trait .query or "" ,
161185 ),
162186 ),
163- fields = tuples_to_fields ( headers ) ,
187+ fields = fields ,
164188 body = payload ,
165189 )
166190
167191
192+ def _compute_content_length (payload : Any ) -> int | None :
193+ if (tell := getattr (payload , "tell" , None )) is not None and not iscoroutinefunction (
194+ tell
195+ ):
196+ start : int = tell ()
197+ if (end := _seek (payload , 0 , 2 )) is not None :
198+ content_length : int = end - start
199+ _seek (payload , start , 0 )
200+ return content_length
201+ return None
202+
203+
204+ def _seek (payload : Any , pos : int , whence : int = 0 ) -> None :
205+ if (seek := getattr (payload , "seek" , None )) is not None and not iscoroutinefunction (
206+ seek
207+ ):
208+ seek (pos , whence )
209+
210+
168211class HTTPRequestBindingSerializer (InterceptingSerializer ):
169212 """Delegates HTTP request bindings to binding-location-specific serializers."""
170213
@@ -235,8 +278,12 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
235278 binding_serializer : HTTPResponseBindingSerializer
236279
237280 content_type : str | None = self ._payload_codec .media_type
281+ content_length : int | None = None
282+ content_length_required = False
283+
238284 binding_matcher = ResponseBindingMatcher (schema )
239285 if (payload_member := binding_matcher .payload_member ) is not None :
286+ content_length_required = RequiresLengthTrait in payload_member
240287 if payload_member .shape_type in (ShapeType .BLOB , ShapeType .STRING ):
241288 if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
242289 content_type = media_type .value
@@ -250,6 +297,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
250297 )
251298 yield binding_serializer
252299 payload = payload_serializer .payload
300+ try :
301+ content_length = len (payload )
302+ except TypeError :
303+ pass
253304 else :
254305 if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
255306 content_type = media_type .value
@@ -259,6 +310,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
259310 payload_serializer , binding_matcher
260311 )
261312 yield binding_serializer
313+ content_length = payload .tell ()
314+ payload .seek (0 )
262315 else :
263316 payload = BytesIO ()
264317 payload_serializer = self ._payload_codec .create_serializer (payload )
@@ -270,23 +323,34 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
270323 body_serializer , binding_matcher
271324 )
272325 yield binding_serializer
326+ content_length = payload .tell ()
273327 else :
274328 content_type = None
329+ content_length = 0
275330 binding_serializer = HTTPResponseBindingSerializer (
276331 payload_serializer ,
277332 binding_matcher ,
278333 )
279334 yield binding_serializer
280-
281- if (
282- seek := getattr (payload , "seek" , None )
283- ) is not None and not iscoroutinefunction (seek ):
284- seek (0 )
335+ payload .seek (0 )
285336
286337 headers = binding_serializer .header_serializer .headers
287338 if content_type is not None :
288339 headers .append (("content-type" , content_type ))
289340
341+ if content_length is not None :
342+ headers .append (("content-length" , str (content_length )))
343+
344+ fields = tuples_to_fields (headers )
345+ if content_length_required and "content-length" not in fields :
346+ content_length = _compute_content_length (payload )
347+ if content_length is None :
348+ raise SerializationError (
349+ "This operation requires the the content length of the input "
350+ "stream, but it was not provided and was unable to be computed."
351+ )
352+ fields .set_field (Field (name = "content-length" , values = [str (content_length )]))
353+
290354 status = binding_serializer .response_code_serializer .response_code
291355 if status is None :
292356 if binding_matcher .response_status > 0 :
0 commit comments