Skip to content

Commit f2b4fb4

Browse files
committed
checkpoint
1 parent db1d84a commit f2b4fb4

File tree

5 files changed

+188
-2
lines changed

5 files changed

+188
-2
lines changed

examples/weather/model/weather.smithy

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,21 @@ resource City {
2626
resource Forecast {
2727
identifiers: { cityId: CityId }
2828
read: GetForecast,
29+
update: PutForecast
30+
}
31+
32+
@http(method: "PUT", uri: "/city/{cityId}/forecast", code: 200)
33+
@idempotent
34+
operation PutForecast {
35+
input := for Forecast {
36+
@required
37+
@httpLabel
38+
$cityId
39+
40+
chanceOfRain: Float
41+
}
42+
43+
output := {}
2944
}
3045

3146
// "pattern" is a trait.
@@ -154,3 +169,5 @@ structure GetForecastInput {
154169
structure GetForecastOutput {
155170
chanceOfRain: Float
156171
}
172+
173+

packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
import os.path
4+
from asyncio import iscoroutinefunction
35
from collections.abc import AsyncIterable
6+
from io import BytesIO
47
from typing import Protocol, runtime_checkable, TYPE_CHECKING, Any
58

9+
from smithy_http.aio.interfaces import HTTPRequest, HTTPResponse
10+
from smithy_http.deserializers import HTTPResponseDeserializer
11+
from smithy_http.serializers import HTTPRequestSerializer
12+
from smithy_json import JSONCodec
13+
from ...codecs import Codec
614
from ...interfaces import URI, Endpoint
715
from ...interfaces import StreamingBlob as SyncStreamingBlob
8-
16+
from ...traits import HTTPTrait, EndpointTrait, RestJson1Trait
17+
from ...type_registry import TypeRegistry
918

1019
if TYPE_CHECKING:
1120
from ...schemas import APIOperation
@@ -126,7 +135,7 @@ async def deserialize_response[
126135
operation: APIOperation[OperationInput, OperationOutput],
127136
request: I,
128137
response: O,
129-
error_registry: Any, # TODO: add error registry
138+
error_registry: TypeRegistry,
130139
context: dict[str, Any], # TODO: replace with a typed context bag
131140
) -> OperationOutput:
132141
"""Deserializes the output from the tranport response or throws an exception.
@@ -138,3 +147,121 @@ async def deserialize_response[
138147
:param context: A context bag for the request.
139148
"""
140149
...
150+
151+
152+
class HttpClientProtocol(ClientProtocol[HTTPRequest, HTTPResponse]):
153+
def set_service_endpoint(
154+
self,
155+
*,
156+
request: HTTPRequest,
157+
endpoint: Endpoint,
158+
) -> HTTPRequest:
159+
"""Update the endpoint of a transport request.
160+
161+
:param request: The request whose endpoint should be updated.
162+
:param endpoint: The endpoint to set on the request.
163+
"""
164+
uri = endpoint.uri
165+
uri_builder = request.destination
166+
167+
if uri.scheme:
168+
uri_builder.scheme = uri.scheme
169+
if uri.host:
170+
uri_builder.host = uri.host
171+
if uri.port and uri.port > -1:
172+
uri_builder.port = uri.port
173+
if uri.path:
174+
# TODO: verify, uri helper?
175+
uri_builder.path = os.path.join(uri.path, uri_builder.path or "")
176+
# TODO: merge headers from the endpoint properties bag
177+
return request
178+
179+
180+
class HttpBindingClientProtocol(HttpClientProtocol):
181+
@property
182+
def codec(self) -> Codec:
183+
"""The codec used for the serde of input and output shapes."""
184+
...
185+
186+
@property
187+
def content_type(self) -> str:
188+
"""The media type of the http payload."""
189+
...
190+
191+
def serialize_request[
192+
OperationInput: "SerializeableShape",
193+
OperationOutput: "DeserializeableShape",
194+
](
195+
self,
196+
*,
197+
operation: APIOperation[OperationInput, OperationOutput],
198+
input: OperationInput,
199+
endpoint: URI,
200+
context: dict[str, Any],
201+
) -> HTTPRequest:
202+
# TODO: request binding cache like done in SJ
203+
serializer = HTTPRequestSerializer(
204+
payload_codec=self.codec,
205+
http_trait=operation.schema.expect_trait(HTTPTrait), # TODO
206+
endpoint_trait=operation.schema.get_trait(EndpointTrait),
207+
)
208+
209+
input.serialize(serializer=serializer)
210+
request = serializer.result
211+
212+
if request is None:
213+
raise ValueError("Request is None") # TODO
214+
215+
request.fields["content-type"].add(self.content_type)
216+
return request
217+
218+
async def deserialize_response[
219+
OperationInput: "SerializeableShape",
220+
OperationOutput: "DeserializeableShape",
221+
](
222+
self,
223+
*,
224+
operation: APIOperation[OperationInput, OperationOutput],
225+
request: HTTPRequest,
226+
response: HTTPResponse,
227+
error_registry: TypeRegistry,
228+
context: dict[str, Any], # TODO: replace with a typed context bag
229+
) -> OperationOutput:
230+
if not (200 <= response.status <= 299): # TODO: extract to utility
231+
# TODO: implement error serde from type registry
232+
raise NotImplementedError
233+
234+
body = response.body
235+
# TODO: extract to utility, seems common
236+
if (read := getattr(body, "read", None)) is not None and iscoroutinefunction(
237+
read
238+
):
239+
body = BytesIO(await read())
240+
241+
# TODO: response binding cache like done in SJ
242+
deserializer = HTTPResponseDeserializer(
243+
payload_codec=self.codec,
244+
http_trait=operation.schema.expect_trait(HTTPTrait),
245+
response=response,
246+
body=body, # type: ignore
247+
)
248+
249+
return operation.output.deserialize(deserializer)
250+
251+
252+
class RestJsonClientProtocol(HttpBindingClientProtocol):
253+
_id: ShapeID = RestJson1Trait.id
254+
_codec: JSONCodec = JSONCodec()
255+
_contentType: str = "application/json"
256+
257+
@property
258+
def id(self) -> ShapeID:
259+
return self._id
260+
261+
@property
262+
def codec(self) -> Codec:
263+
return self._codec
264+
265+
@property
266+
def content_type(self) -> str:
267+
return self._contentType

packages/smithy-core/src/smithy_core/documents.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ def shape_type(self) -> ShapeType:
143143
"""The Smithy data model type for the underlying contents of the document."""
144144
return self._type
145145

146+
@property
147+
def discriminator(self) -> ShapeID:
148+
"""The shape ID that corresponds to the contents of the document."""
149+
# TODO: custom exception?
150+
raise NotImplementedError(f"{self} document has no discriminator.")
151+
146152
def is_none(self) -> bool:
147153
"""Indicates whether the document contains a null value."""
148154
return self._value is None and self._raw_value is None

packages/smithy-core/src/smithy_core/traits.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,9 @@ def host_prefix(self) -> str:
313313
class HostLabelTrait(Trait, id=ShapeID("smithy.api#hostLabel")):
314314
def __post_init__(self):
315315
assert self.document_value is None
316+
317+
318+
@dataclass(init=False, frozen=True)
319+
class RestJson1Trait(Trait, id=ShapeID("aws.protocols#restJson1")):
320+
def __post_init__(self):
321+
assert self.document_value is None
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from smithy_core.deserializers import (
5+
DeserializeableShape,
6+
) # TODO: fix typo in deserializable
7+
from smithy_core.documents import Document
8+
from smithy_core.shapes import ShapeID
9+
10+
11+
# A registry for on-demand deserialization of types by using a mapping of shape IDs to their deserializers.
12+
# TODO: protocol
13+
class TypeRegistry:
14+
def __init__(
15+
self,
16+
types: dict[ShapeID, DeserializeableShape],
17+
sub_registry: "TypeRegistry | None" = None,
18+
):
19+
self._types = types
20+
self._sub_registry = sub_registry
21+
22+
def get(self, shape: ShapeID) -> type[DeserializeableShape]:
23+
if shape in self._types:
24+
return type(self._types[shape])
25+
if self._sub_registry is not None:
26+
return self._sub_registry.get(shape)
27+
raise ValueError(f"Unknown shape: {shape}") # TODO: real exception?
28+
29+
def deserialize(self, document: Document) -> DeserializeableShape:
30+
return document.as_shape(self.get(document.discriminator))

0 commit comments

Comments
 (0)