Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
*/
package software.amazon.smithy.python.aws.codegen;

import java.util.Collections;
import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION;

import java.util.List;
import software.amazon.smithy.aws.traits.auth.SigV4Trait;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.traits.HttpApiKeyAuthTrait;
import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION;
import software.amazon.smithy.python.codegen.ApplicationProtocol;
import software.amazon.smithy.python.codegen.CodegenUtils;
import software.amazon.smithy.python.codegen.ConfigProperty;
Expand Down Expand Up @@ -61,8 +60,7 @@ public List<RuntimeClientPlugin> getClientPlugins(GenerationContext context) {
.build())
.addConfigProperty(REGION)
.authScheme(new Sigv4AuthScheme())
.build()
);
.build());
}

@Override
Expand Down Expand Up @@ -129,11 +127,9 @@ public List<DerivedProperty> getAuthProperties() {
.source(DerivedProperty.Source.CONFIG)
.type(Symbol.builder().name("str").build())
.sourcePropertyName("region")
.build()
);
.build());
}


@Override
public Symbol getAuthOptionGenerator(GenerationContext context) {
var resolver = CodegenUtils.getHttpAuthSchemeResolverSymbol(context.settings());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
*/
@SmithyUnstableApi
public final class AwsConfiguration {
private AwsConfiguration() {
}
private AwsConfiguration() {}

public static final ConfigProperty REGION = ConfigProperty.builder()
.name("region")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
@SmithyUnstableApi
public class AwsPythonDependency {

private AwsPythonDependency() {
}
private AwsPythonDependency() {}

/**
* The core aws smithy runtime python package.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
*/
package software.amazon.smithy.python.aws.codegen;

import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION;

import java.util.List;
import software.amazon.smithy.aws.traits.ServiceTrait;
import software.amazon.smithy.codegen.core.Symbol;
import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION;
import software.amazon.smithy.python.codegen.CodegenUtils;
import software.amazon.smithy.python.codegen.ConfigProperty;
import software.amazon.smithy.python.codegen.GenerationContext;
import software.amazon.smithy.python.codegen.integrations.PythonIntegration;
import software.amazon.smithy.python.codegen.integrations.RuntimeClientPlugin;
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -154,23 +154,9 @@ default void generateSharedDeserializerComponents(GenerationContext context) {}
*/
default void generateProtocolTests(GenerationContext context) {}

/**
* Generates the code to wrap an operation output into an event stream.
*
* <p>Important context variables are:
* <ul>
* <li>execution_context - Has the context, including the transport input and output.</li>
* <li>operation_output - The deserialized operation output.</li>
* <li>has_input_stream - Whether or not there is an input stream.</li>
* <li>event_deserializer - The deserialize method for output events, or None for no output stream.</li>
* <li>event_response_deserializer - A DeserializeableShape representing the operation's output shape,
* or None for no output stream. This is used when the operation sends the initial response over the
* event stream.
* </li>
* </ul>
*
* @param context Generation context.
* @param writer The writer to write to.
*/
default void wrapEventStream(GenerationContext context, PythonWriter writer) {}
default void wrapInputStream(GenerationContext context, PythonWriter writer) {}

default void wrapOutputStream(GenerationContext context, PythonWriter writer) {}

default void wrapDuplexStream(GenerationContext context, PythonWriter writer) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -390,48 +390,64 @@ protected void resolveErrorCodeAndMessage(
}

@Override
public void wrapEventStream(GenerationContext context, PythonWriter writer) {
public void wrapInputStream(GenerationContext context, PythonWriter writer) {
writer.addDependency(SmithyPythonDependency.SMITHY_JSON);
writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM);
writer.addDependency(SmithyPythonDependency.SMITHY_CORE);
writer.addImports("aws_event_stream.aio",
Set.of(
"AWSDuplexEventStream",
"AWSInputEventStream",
"AWSOutputEventStream"));
writer.addImport("smithy_json", "JSONCodec");
writer.addImport("smithy_core.aio.types", "AsyncBytesReader");
writer.addImport("smithy_core.types", "TimestampFormat");
writer.addStdlibImport("typing", "Any");
writer.addImport("aws_event_stream.aio", "AWSInputEventStream");
writer.write(
"""
codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS)
return AWSInputEventStream[Any, Any](
payload_codec=codec,
awaitable_output=awaitable_output,
async_writer=request_context.transport_request.body, # type: ignore
)
""");
}

writer.write("""
codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS)
if has_input_stream:
if event_deserializer is not None:
return AWSDuplexEventStream[Any, Any, Any](
@Override
public void wrapOutputStream(GenerationContext context, PythonWriter writer) {
writer.addDependency(SmithyPythonDependency.SMITHY_JSON);
writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM);
writer.addImport("smithy_json", "JSONCodec");
writer.addImport("smithy_core.aio.types", "AsyncBytesReader");
writer.addImport("smithy_core.types", "TimestampFormat");
writer.addImport("aws_event_stream.aio", "AWSOutputEventStream");
writer.write(
"""
codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS)
return AWSOutputEventStream[Any, Any](
payload_codec=codec,
initial_response=operation_output,
async_writer=execution_context.transport_request.body, # type: ignore
initial_response=output,
async_reader=AsyncBytesReader(
execution_context.transport_response.body # type: ignore
transport_response.body # type: ignore
),
deserializer=event_deserializer, # type: ignore
)
else:
return AWSInputEventStream[Any, Any](
""");
}

@Override
public void wrapDuplexStream(GenerationContext context, PythonWriter writer) {
writer.addDependency(SmithyPythonDependency.SMITHY_JSON);
writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM);
writer.addImport("smithy_json", "JSONCodec");
writer.addImport("smithy_core.aio.types", "AsyncBytesReader");
writer.addImport("smithy_core.types", "TimestampFormat");
writer.addImport("aws_event_stream.aio", "AWSDuplexEventStream");
writer.write(
"""
codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS)
return AWSDuplexEventStream[Any, Any, Any](
payload_codec=codec,
initial_response=operation_output,
async_writer=execution_context.transport_request.body, # type: ignore
async_writer=request_context.transport_request.body, # type: ignore
awaitable_output=awaitable_output,
awaitable_response=response_future,
deserializer=event_deserializer, # type: ignore
)
else:
return AWSOutputEventStream[Any, Any](
payload_codec=codec,
initial_response=operation_output,
async_reader=AsyncBytesReader(
execution_context.transport_response.body # type: ignore
),
deserializer=event_deserializer, # type: ignore
)
""");
""");
}
}
72 changes: 19 additions & 53 deletions packages/aws-event-stream/src/aws_event_stream/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
from collections.abc import Callable
from typing import Self
from typing import Self, Awaitable

from smithy_core.aio.interfaces import AsyncByteStream, AsyncWriter
from smithy_core.aio.interfaces import AsyncByteStream, AsyncWriter, Response
from smithy_core.aio.types import AsyncBytesReader
from smithy_core.codecs import Codec
from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer
from smithy_core.serializers import SerializeableShape
Expand Down Expand Up @@ -33,8 +33,8 @@ def __init__(
payload_codec: Codec,
async_writer: AsyncWriter,
deserializer: Callable[[ShapeDeserializer], O],
async_reader: AsyncByteStream | None = None,
initial_response: R | None = None,
awaitable_response: Awaitable[Response],
awaitable_output: Awaitable[R],
deserializeable_response: type[R] | None = None,
signer: Signer | None = None,
is_client_mode: bool = True,
Expand Down Expand Up @@ -68,36 +68,14 @@ def __init__(
self._deserializer = deserializer
self._payload_codec = payload_codec
self._is_client_mode = is_client_mode
self._deserializeable_response = deserializeable_response

# Create a future to allow awaiting the reader
loop = asyncio.get_event_loop()
self._reader_future: asyncio.Future[AsyncByteStream] = loop.create_future()
if async_reader is not None:
self._reader_future.set_result(async_reader)

# Create a future to allow awaiting the initial response
self._response = initial_response
self._deserializerable_response = deserializeable_response
self._response_future: asyncio.Future[R] = loop.create_future()

@property
def response(self) -> R | None:
return self._response

@response.setter
def response(self, value: R) -> None:
self._response_future.set_result(value)
self._response = value

def set_reader(self, value: AsyncByteStream) -> None:
"""Sets the object to read events from.

:param value: An async readable object to read event bytes from.
"""
self._reader_future.set_result(value)
self._awaitable_response = awaitable_response
self._awaitable_output = awaitable_output
self.response: R | None = None

async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]:
async_reader = await self._reader_future
async_reader = AsyncBytesReader((await self._awaitable_response).body)
if self.output_stream is None:
self.output_stream = _AWSEventReceiver[O](
payload_codec=self._payload_codec,
Expand All @@ -107,13 +85,13 @@ async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]:
)

if self.response is None:
if self._deserializerable_response is None:
initial_response = await self._response_future
if self._deserializeable_response is None:
initial_response = await self._awaitable_output
else:
initial_response_stream = _AWSEventReceiver(
payload_codec=self._payload_codec,
source=async_reader,
deserializer=self._deserializerable_response.deserialize,
deserializer=self._deserializeable_response.deserialize,
is_client_mode=self._is_client_mode,
)
initial_response = await initial_response_stream.receive()
Expand All @@ -133,7 +111,7 @@ def __init__(
self,
payload_codec: Codec,
async_writer: AsyncWriter,
initial_response: R | None = None,
awaitable_output: Awaitable[R],
signer: Signer | None = None,
is_client_mode: bool = True,
) -> None:
Expand All @@ -147,13 +125,8 @@ def __init__(
:param is_client_mode: Whether the stream is being constructed for a client or
server implementation.
"""
self._response = initial_response

# Create a future to allow awaiting the initial response.
loop = asyncio.get_event_loop()
self._response_future: asyncio.Future[R] = loop.create_future()
if initial_response is not None:
self._response_future.set_result(initial_response)
self.response: R | None = None
self._awaitable_response = awaitable_output

self.input_stream = _AWSEventPublisher(
payload_codec=payload_codec,
Expand All @@ -162,17 +135,10 @@ def __init__(
is_client_mode=is_client_mode,
)

@property
def response(self) -> R | None:
return self._response

@response.setter
def response(self, value: R) -> None:
self._response_future.set_result(value)
self._response = value

async def await_output(self) -> R:
return await self._response_future
if self.response is None:
self.response = await self._awaitable_response
return self.response


class AWSOutputEventStream[O: DeserializeableShape, R: DeserializeableShape](
Expand Down