Skip to content

Commit 4b56b21

Browse files
committed
Add http and restjson1 client protocols
1 parent c11f1de commit 4b56b21

File tree

3 files changed

+181
-0
lines changed

3 files changed

+181
-0
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Literal
2+
3+
from smithy_aws_core.traits import RestJson1Trait
4+
from smithy_core.protocols import HttpBindingClientProtocol
5+
from smithy_core.codecs import Codec
6+
from smithy_core.shapes import ShapeID
7+
from smithy_json import JSONCodec
8+
9+
10+
class RestJsonClientProtocol(HttpBindingClientProtocol):
11+
"""An implementation of the aws.protocols#restJson1 protocol."""
12+
13+
_id: ShapeID = RestJson1Trait.id
14+
_codec: JSONCodec = JSONCodec()
15+
_contentType: Literal["application/json"] = "application/json"
16+
17+
@property
18+
def id(self) -> ShapeID:
19+
return self._id
20+
21+
@property
22+
def codec(self) -> Codec:
23+
return self._codec
24+
25+
@property
26+
def content_type(self) -> str:
27+
return self._contentType
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# This ruff check warns against using the assert statement, which can be stripped out
5+
# when running Python with certain (common) optimization settings. Assert is used here
6+
# for trait values. Since these are always generated, we can be fairly confident that
7+
# they're correct regardless, so it's okay if the checks are stripped out.
8+
# ruff: noqa: S101
9+
10+
from dataclasses import dataclass, field
11+
from typing import Mapping, Sequence
12+
13+
from smithy_core.shapes import ShapeID
14+
from smithy_core.traits import Trait, DocumentValue, DynamicTrait
15+
16+
17+
@dataclass(init=False, frozen=True)
18+
class RestJson1Trait(Trait, id=ShapeID("aws.protocols#restJson1")):
19+
http: set[str] = field(repr=False, hash=False, compare=False, default_factory=set)
20+
eventStreamHttp: set[str] = field(
21+
repr=False, hash=False, compare=False, default_factory=set
22+
)
23+
24+
def __init__(self, value: DocumentValue | DynamicTrait = None):
25+
super().__init__(value)
26+
assert isinstance(self.document_value, Mapping)
27+
28+
assert isinstance(self.document_value["http"], Sequence)
29+
for val in self.document_value["http"]:
30+
assert isinstance(val, str)
31+
self.http.add(val)
32+
33+
if vals := self.document_value.get("eventStreamHttp") is None:
34+
object.__setattr__(self, "eventStreamHttp", self.http)
35+
else:
36+
# check that eventStreamHttp is a subset of http
37+
assert isinstance(vals, Sequence)
38+
for val in self.document_value["http"]:
39+
assert val in self.http
40+
assert isinstance(val, str)
41+
self.eventStreamHttp.add(val)
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import os
2+
from inspect import iscoroutinefunction
3+
from io import BytesIO
4+
5+
from smithy_core.aio.interfaces import ClientProtocol
6+
from smithy_core.codecs import Codec
7+
from smithy_core.deserializers import DeserializeableShape
8+
from smithy_core.documents import TypeRegistry
9+
from smithy_core.interfaces import Endpoint, TypedProperties, URI
10+
from smithy_core.schemas import APIOperation
11+
from smithy_core.serializers import SerializeableShape
12+
from smithy_core.traits import HTTPTrait, EndpointTrait
13+
from smithy_http.aio.interfaces import HTTPRequest, HTTPResponse
14+
from smithy_http.deserializers import HTTPResponseDeserializer
15+
from smithy_http.serializers import HTTPRequestSerializer
16+
17+
18+
class HttpClientProtocol(ClientProtocol[HTTPRequest, HTTPResponse]):
19+
"""An HTTP-based protocol."""
20+
21+
def set_service_endpoint(
22+
self,
23+
*,
24+
request: HTTPRequest,
25+
endpoint: Endpoint,
26+
) -> HTTPRequest:
27+
uri = endpoint.uri
28+
uri_builder = request.destination
29+
30+
if uri.scheme:
31+
uri_builder.scheme = uri.scheme
32+
if uri.host:
33+
uri_builder.host = uri.host
34+
if uri.port and uri.port > -1:
35+
uri_builder.port = uri.port
36+
if uri.path:
37+
uri_builder.path = os.path.join(uri.path, uri_builder.path or "")
38+
# TODO: merge headers from the endpoint properties bag
39+
return request
40+
41+
42+
class HttpBindingClientProtocol(HttpClientProtocol):
43+
"""An HTTP-based protocol that uses HTTP binding traits."""
44+
45+
@property
46+
def codec(self) -> Codec:
47+
"""The codec used for the serde of input and output shapes."""
48+
...
49+
50+
@property
51+
def content_type(self) -> str:
52+
"""The media type of the http payload."""
53+
...
54+
55+
def serialize_request[
56+
OperationInput: "SerializeableShape",
57+
OperationOutput: "DeserializeableShape",
58+
](
59+
self,
60+
*,
61+
operation: APIOperation[OperationInput, OperationOutput],
62+
input: OperationInput,
63+
endpoint: URI,
64+
context: TypedProperties,
65+
) -> HTTPRequest:
66+
# TODO: request binding cache like done in SJ
67+
serializer = HTTPRequestSerializer(
68+
payload_codec=self.codec,
69+
http_trait=operation.schema.expect_trait(HTTPTrait), # TODO
70+
endpoint_trait=operation.schema.get_trait(EndpointTrait),
71+
)
72+
73+
input.serialize(serializer=serializer)
74+
request = serializer.result
75+
76+
if request is None:
77+
raise ValueError("Request is None") # TODO
78+
79+
request.fields["content-type"].add(self.content_type)
80+
return request
81+
82+
async def deserialize_response[
83+
OperationInput: "SerializeableShape",
84+
OperationOutput: "DeserializeableShape",
85+
](
86+
self,
87+
*,
88+
operation: APIOperation[OperationInput, OperationOutput],
89+
request: HTTPRequest,
90+
response: HTTPResponse,
91+
error_registry: TypeRegistry,
92+
context: TypedProperties,
93+
) -> OperationOutput:
94+
if not (200 <= response.status <= 299): # TODO: extract to utility
95+
# TODO: implement error serde from type registry
96+
raise NotImplementedError
97+
98+
body = response.body
99+
# TODO: extract to utility, seems common
100+
if (read := getattr(body, "read", None)) is not None and iscoroutinefunction(
101+
read
102+
):
103+
body = BytesIO(await read())
104+
105+
# TODO: response binding cache like done in SJ
106+
deserializer = HTTPResponseDeserializer(
107+
payload_codec=self.codec,
108+
http_trait=operation.schema.expect_trait(HTTPTrait),
109+
response=response,
110+
body=body, # type: ignore
111+
)
112+
113+
return operation.output.deserialize(deserializer)

0 commit comments

Comments
 (0)