Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
*/
package software.amazon.smithy.python.aws.codegen;

import java.util.Collections;
import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION;

import java.util.List;
import software.amazon.smithy.aws.traits.auth.SigV4Trait;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.traits.HttpApiKeyAuthTrait;
import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION;
import software.amazon.smithy.python.codegen.ApplicationProtocol;
import software.amazon.smithy.python.codegen.CodegenUtils;
import software.amazon.smithy.python.codegen.ConfigProperty;
Expand Down Expand Up @@ -61,8 +60,7 @@ public List<RuntimeClientPlugin> getClientPlugins(GenerationContext context) {
.build())
.addConfigProperty(REGION)
.authScheme(new Sigv4AuthScheme())
.build()
);
.build());
}

@Override
Expand Down Expand Up @@ -129,11 +127,9 @@ public List<DerivedProperty> getAuthProperties() {
.source(DerivedProperty.Source.CONFIG)
.type(Symbol.builder().name("str").build())
.sourcePropertyName("region")
.build()
);
.build());
}


@Override
public Symbol getAuthOptionGenerator(GenerationContext context) {
var resolver = CodegenUtils.getHttpAuthSchemeResolverSymbol(context.settings());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
*/
@SmithyUnstableApi
public final class AwsConfiguration {
private AwsConfiguration() {
}
private AwsConfiguration() {}

public static final ConfigProperty REGION = ConfigProperty.builder()
.name("region")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
@SmithyUnstableApi
public class AwsPythonDependency {

private AwsPythonDependency() {
}
private AwsPythonDependency() {}

/**
* The core aws smithy runtime python package.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
*/
package software.amazon.smithy.python.aws.codegen;

import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION;

import java.util.List;
import software.amazon.smithy.aws.traits.ServiceTrait;
import software.amazon.smithy.codegen.core.Symbol;
import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION;
import software.amazon.smithy.python.codegen.CodegenUtils;
import software.amazon.smithy.python.codegen.ConfigProperty;
import software.amazon.smithy.python.codegen.GenerationContext;
import software.amazon.smithy.python.codegen.integrations.PythonIntegration;
import software.amazon.smithy.python.codegen.integrations.RuntimeClientPlugin;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,7 @@ private void generateOperationExecutor(PythonWriter writer) {
var hasStreaming = hasEventStream();
writer.putContext("hasEventStream", hasStreaming);
if (hasStreaming) {
writer.addImports("smithy_core.deserializers",
Set.of(
"ShapeDeserializer",
"DeserializeableShape"));
writer.addImport("smithy_core.deserializers", "ShapeDeserializer");
writer.addStdlibImport("typing", "Any");
}

Expand All @@ -137,7 +134,8 @@ private void generateOperationExecutor(PythonWriter writer) {
writer.addStdlibImport("typing", "Awaitable");
writer.addStdlibImport("typing", "cast");
writer.addStdlibImport("copy", "deepcopy");
writer.addStdlibImport("asyncio", "sleep");
writer.addStdlibImport("asyncio");
writer.addStdlibImports("asyncio", Set.of("sleep", "Future"));

writer.addDependency(SmithyPythonDependency.SMITHY_CORE);
writer.addImport("smithy_core.exceptions", "SmithyRetryException");
Expand Down Expand Up @@ -187,6 +185,75 @@ def _classify_error(
""");
writer.dedent();

if (hasStreaming) {
writer.addStdlibImports("typing", Set.of("Any", "Awaitable"));
writer.addStdlibImport("asyncio");
writer.write(
"""
async def _input_stream(
self,
input: Input,
plugins: list[$1T],
serialize: Callable[[Input, $4T], Awaitable[$2T]],
deserialize: Callable[[$3T, $4T], Awaitable[Output]],
config: $4T,
operation_name: str,
) -> Any:
request_future = Future[$2T]()
awaitable_output = asyncio.create_task(self._execute_operation(
input, plugins, serialize, deserialize, config, operation_name,
request_future=request_future
))
transport_request = await request_future
${5C|}

async def _output_stream(
self,
input: Input,
plugins: list[$1T],
serialize: Callable[[Input, $4T], Awaitable[$2T]],
deserialize: Callable[[$3T, $4T], Awaitable[Output]],
config: $4T,
operation_name: str,
event_deserializer: Callable[[ShapeDeserializer], Any],
) -> Any:
response_future = Future[$3T]()
output = await self._execute_operation(
input, plugins, serialize, deserialize, config, operation_name,
response_future=response_future
)
transport_response = await response_future
${6C|}

async def _duplex_stream(
self,
input: Input,
plugins: list[$1T],
serialize: Callable[[Input, $4T], Awaitable[$2T]],
deserialize: Callable[[$3T, $4T], Awaitable[Output]],
config: $4T,
operation_name: str,
event_deserializer: Callable[[ShapeDeserializer], Any],
) -> Any:
request_future = Future[$2T]()
response_future = Future[$3T]()
awaitable_output = asyncio.create_task(self._execute_operation(
input, plugins, serialize, deserialize, config, operation_name,
request_future=request_future,
response_future=response_future
))
transport_request = await request_future
${7C|}
""",
pluginSymbol,
transportRequest,
transportResponse,
configSymbol,
writer.consumer(w -> context.protocolGenerator().wrapInputStream(context, w)),
writer.consumer(w -> context.protocolGenerator().wrapOutputStream(context, w)),
writer.consumer(w -> context.protocolGenerator().wrapDuplexStream(context, w)));
}

writer.write(
"""
async def _execute_operation(
Expand All @@ -197,25 +264,25 @@ async def _execute_operation(
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
config: $5T,
operation_name: str,
${?hasEventStream}
has_input_stream: bool = False,
event_deserializer: Callable[[ShapeDeserializer], Any] | None = None,
event_response_deserializer: DeserializeableShape | None = None,
${/hasEventStream}
request_future: Future[$2T] | None = None,
response_future: Future[$3T] | None = None,
) -> Output:
try:
return await self._handle_execution(
input, plugins, serialize, deserialize, config, operation_name,
${?hasEventStream}
has_input_stream, event_deserializer, event_response_deserializer,
${/hasEventStream}
request_future, response_future,
)
except Exception as e:
if request_future is not None and not request_future.done:
request_future.set_exception($4T(e))
if response_future is not None and not response_future.done:
response_future.set_exception($4T(e))

# Make sure every exception that we throw is an instance of $4T so
# customers can reliably catch everything we throw.
if not isinstance(e, $4T):
raise $4T(e) from e
raise e
raise

async def _handle_execution(
self,
Expand All @@ -225,11 +292,8 @@ async def _handle_execution(
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
config: $5T,
operation_name: str,
${?hasEventStream}
has_input_stream: bool = False,
event_deserializer: Callable[[ShapeDeserializer], Any] | None = None,
event_response_deserializer: DeserializeableShape | None = None,
${/hasEventStream}
request_future: Future[$2T] | None,
response_future: Future[$3T] | None,
) -> Output:
logger.debug('Making request for operation "%s" with parameters: %s', operation_name, input)
context: InterceptorContext[Input, None, None, None] = InterceptorContext(
Expand Down Expand Up @@ -307,6 +371,7 @@ async def _handle_execution(
context_with_transport_request.copy(),
config,
operation_name,
request_future,
)

# We perform this type-ignored re-assignment because `context` needs
Expand Down Expand Up @@ -342,6 +407,10 @@ await seek(0)
else:
# Step 8: Invoke record_success
retry_strategy.record_success(token=retry_token)
if response_future is not None:
response_future.set_result(
context_with_response.response, # type: ignore
)
break
except Exception as e:
if context.response is not None:
Expand All @@ -355,16 +424,7 @@ await seek(0)
execution_context = cast(
InterceptorContext[Input, Output, $2T | None, $3T | None], context
)
${^hasEventStream}
return await self._finalize_execution(interceptors, execution_context)
${/hasEventStream}
${?hasEventStream}
operation_output = await self._finalize_execution(interceptors, execution_context)
if has_input_stream or event_deserializer is not None:
${6C|}
else:
return operation_output
${/hasEventStream}

async def _handle_attempt(
self,
Expand All @@ -373,6 +433,7 @@ async def _handle_attempt(
context: InterceptorContext[Input, None, $2T, None],
config: $5T,
operation_name: str,
request_future: Future[$2T] | None,
) -> InterceptorContext[Input, Output, $2T, $3T | None]:
try:
# assert config.interceptors is not None
Expand All @@ -385,8 +446,7 @@ async def _handle_attempt(
transportRequest,
transportResponse,
errorSymbol,
configSymbol,
writer.consumer(w -> context.protocolGenerator().wrapEventStream(context, w)));
configSymbol);

boolean supportsAuth = !ServiceIndex.of(context.model()).getAuthSchemes(service).isEmpty();
writer.pushState(new ResolveIdentitySection());
Expand Down Expand Up @@ -533,10 +593,19 @@ async def _handle_attempt(
)
logger.debug("HTTP request config: %s", request_config)
logger.debug("Sending HTTP request: %s", context_with_response.transport_request)
context_with_response._transport_response = await config.http_client.send(
request=context_with_response.transport_request,
request_config=request_config,
)

if request_future is not None:
response_task = asyncio.create_task(config.http_client.send(
request=context_with_response.transport_request,
request_config=request_config,
))
request_future.set_result(context_with_response.transport_request)
context_with_response._transport_response = await response_task
else:
context_with_response._transport_response = await config.http_client.send(
request=context_with_response.transport_request,
request_config=request_config,
)
logger.debug("Received HTTP response: %s", context_with_response.transport_response)

""", transportRequest, transportResponse);
Expand Down Expand Up @@ -834,16 +903,14 @@ private void generateEventStreamOperation(PythonWriter writer, OperationShape op
raise NotImplementedError()
${/hasProtocol}
${?hasProtocol}
return await self._execute_operation(
return await self._duplex_stream(
input=input,
plugins=operation_plugins,
serialize=${serSymbol:T},
deserialize=${deserSymbol:T},
config=self._config,
operation_name=${operationName:S},
has_input_stream=True,
event_deserializer=$T().deserialize,
event_response_deserializer=${output:T},
) # type: ignore
${/hasProtocol}
""",
Expand All @@ -862,14 +929,13 @@ raise NotImplementedError()
raise NotImplementedError()
${/hasProtocol}
${?hasProtocol}
return await self._execute_operation(
return await self._input_stream(
input=input,
plugins=operation_plugins,
serialize=${serSymbol:T},
deserialize=${deserSymbol:T},
config=self._config,
operation_name=${operationName:S},
has_input_stream=True,
) # type: ignore
${/hasProtocol}
""", writer.consumer(w -> writeSharedOperationInit(w, operation, input)));
Expand All @@ -887,15 +953,14 @@ raise NotImplementedError()
raise NotImplementedError()
${/hasProtocol}
${?hasProtocol}
return await self._execute_operation(
return await self._output_stream(
input=input,
plugins=operation_plugins,
serialize=${serSymbol:T},
deserialize=${deserSymbol:T},
config=self._config,
operation_name=${operationName:S},
event_deserializer=$T().deserialize,
event_response_deserializer=${output:T},
) # type: ignore
${/hasProtocol}
""",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,23 +154,9 @@ default void generateSharedDeserializerComponents(GenerationContext context) {}
*/
default void generateProtocolTests(GenerationContext context) {}

/**
* Generates the code to wrap an operation output into an event stream.
*
* <p>Important context variables are:
* <ul>
* <li>execution_context - Has the context, including the transport input and output.</li>
* <li>operation_output - The deserialized operation output.</li>
* <li>has_input_stream - Whether or not there is an input stream.</li>
* <li>event_deserializer - The deserialize method for output events, or None for no output stream.</li>
* <li>event_response_deserializer - A DeserializeableShape representing the operation's output shape,
* or None for no output stream. This is used when the operation sends the initial response over the
* event stream.
* </li>
* </ul>
*
* @param context Generation context.
* @param writer The writer to write to.
*/
default void wrapEventStream(GenerationContext context, PythonWriter writer) {}
default void wrapInputStream(GenerationContext context, PythonWriter writer) {}

default void wrapOutputStream(GenerationContext context, PythonWriter writer) {}

default void wrapDuplexStream(GenerationContext context, PythonWriter writer) {}
}
Loading