Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import datetime
from collections.abc import Callable

from smithy_core.aio.interfaces import AsyncByteStream, AsyncCloseable
from smithy_core.codecs import Codec
from smithy_core.deserializers import ShapeDeserializer, SpecificShapeDeserializer
from smithy_core.interfaces import BytesReader
from smithy_core.deserializers import (
DeserializeableShape,
ShapeDeserializer,
SpecificShapeDeserializer,
)
from smithy_core.schemas import Schema
from smithy_core.utils import expect_type
from smithy_event_stream.aio.interfaces import AsyncEventReceiver

from ..events import HEADERS_DICT, Event
from ..exceptions import EventError, UnmodeledEventError
Expand All @@ -17,11 +22,38 @@
INITIAL_MESSAGE_TYPES = (INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE)


class EventDeserializer(SpecificShapeDeserializer):
class AWSAsyncEventReceiver[E: DeserializeableShape](AsyncEventReceiver[E]):
def __init__(
self, source: BytesReader, payload_codec: Codec, is_client_mode: bool = True
self,
payload_codec: Codec,
source: AsyncByteStream,
deserializer: Callable[[ShapeDeserializer], E],
is_client_mode: bool = True,
) -> None:
self._payload_codec = payload_codec
self._source = source
self._is_client_mode = is_client_mode
self._deserializer = deserializer

async def receive(self) -> E | None:
event = await Event.decode_async(self._source)
deserializer = EventDeserializer(
event=event,
payload_codec=self._payload_codec,
is_client_mode=self._is_client_mode,
)
return self._deserializer(deserializer)

async def close(self) -> None:
if isinstance(self._source, AsyncCloseable):
await self._source.close()


class EventDeserializer(SpecificShapeDeserializer):
def __init__(
self, event: Event, payload_codec: Codec, is_client_mode: bool = True
) -> None:
self._event = event
self._payload_codec = payload_codec
self._is_client_mode = is_client_mode

Expand All @@ -30,13 +62,12 @@ def read_struct(
schema: Schema,
consumer: Callable[[Schema, ShapeDeserializer], None],
) -> None:
event = Event.decode(self._source)
headers = event.message.headers
headers = self._event.message.headers

payload_deserializer = None
if event.message.payload:
if self._event.message.payload:
payload_deserializer = self._payload_codec.create_deserializer(
event.message.payload
self._event.message.payload
)

message_deserializer = EventMessageDeserializer(headers, payload_deserializer)
Expand All @@ -61,7 +92,7 @@ def read_struct(
expect_type(str, headers[":error-message"]),
)
case _:
raise EventError(f"Unknown event structure: {event}")
raise EventError(f"Unknown event structure: {self._event}")


class EventMessageDeserializer(SpecificShapeDeserializer):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import datetime
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from io import BytesIO
from typing import Never

from smithy_core.aio.interfaces import AsyncCloseable, AsyncWriter
from smithy_core.codecs import Codec
from smithy_core.exceptions import ExpectationNotMetException
from smithy_core.schemas import Schema
from smithy_core.serializers import (
InterceptingSerializer,
SerializeableShape,
ShapeSerializer,
SpecificShapeSerializer,
)
from smithy_core.shapes import ShapeType
from smithy_core.utils import expect_type
from smithy_event_stream.aio.interfaces import AsyncEventPublisher

from ..events import EventHeaderEncoder, EventMessage
from ..exceptions import InvalidHeaderValue
Expand All @@ -30,6 +34,40 @@
_DEFAULT_BLOB_CONTENT_TYPE = "application/octet-stream"


type Signer = Callable[[EventMessage], EventMessage]
"""A function that takes an event message and signs it, and returns it signed."""


class AWSAsyncEventPublisher[E: SerializeableShape](AsyncEventPublisher[E]):
def __init__(
self,
payload_codec: Codec,
async_writer: AsyncWriter,
signer: Signer | None = None,
is_client_mode: bool = True,
):
self._writer = async_writer
self._signer = signer
self._serializer = EventSerializer(
payload_codec=payload_codec, is_client_mode=is_client_mode
)

async def send(self, event: E) -> None:
event.serialize(self._serializer)
result = self._serializer.get_result()
if result is None:
raise ExpectationNotMetException(
"Expected an event message to be serialized, but was None."
)
if self._signer is not None:
result = self._signer(result)
await self._writer.write(result.encode())

async def close(self) -> None:
if isinstance(self._writer, AsyncCloseable):
await self._writer.close()


class EventSerializer(SpecificShapeSerializer):
def __init__(
self,
Expand Down
Loading
Loading