Skip to content

Commit e1d7824

Browse files
Return immediately on streaming inputs
1 parent 94e9f57 commit e1d7824

File tree

8 files changed

+184
-158
lines changed

8 files changed

+184
-158
lines changed

codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
*/
55
package software.amazon.smithy.python.aws.codegen;
66

7-
import java.util.Collections;
7+
import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION;
8+
89
import java.util.List;
910
import software.amazon.smithy.aws.traits.auth.SigV4Trait;
1011
import software.amazon.smithy.codegen.core.Symbol;
1112
import software.amazon.smithy.model.shapes.ShapeId;
12-
import software.amazon.smithy.model.traits.HttpApiKeyAuthTrait;
13-
import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION;
1413
import software.amazon.smithy.python.codegen.ApplicationProtocol;
1514
import software.amazon.smithy.python.codegen.CodegenUtils;
1615
import software.amazon.smithy.python.codegen.ConfigProperty;
@@ -61,8 +60,7 @@ public List<RuntimeClientPlugin> getClientPlugins(GenerationContext context) {
6160
.build())
6261
.addConfigProperty(REGION)
6362
.authScheme(new Sigv4AuthScheme())
64-
.build()
65-
);
63+
.build());
6664
}
6765

6866
@Override
@@ -129,11 +127,9 @@ public List<DerivedProperty> getAuthProperties() {
129127
.source(DerivedProperty.Source.CONFIG)
130128
.type(Symbol.builder().name("str").build())
131129
.sourcePropertyName("region")
132-
.build()
133-
);
130+
.build());
134131
}
135132

136-
137133
@Override
138134
public Symbol getAuthOptionGenerator(GenerationContext context) {
139135
var resolver = CodegenUtils.getHttpAuthSchemeResolverSymbol(context.settings());

codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsConfiguration.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
*/
1414
@SmithyUnstableApi
1515
public final class AwsConfiguration {
16-
private AwsConfiguration() {
17-
}
16+
private AwsConfiguration() {}
1817

1918
public static final ConfigProperty REGION = ConfigProperty.builder()
2019
.name("region")

codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsPythonDependency.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
@SmithyUnstableApi
1414
public class AwsPythonDependency {
1515

16-
private AwsPythonDependency() {
17-
}
16+
private AwsPythonDependency() {}
1817

1918
/**
2019
* The core aws smithy runtime python package.

codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsStandardRegionalEndpointsIntegration.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
*/
55
package software.amazon.smithy.python.aws.codegen;
66

7+
import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION;
8+
79
import java.util.List;
810
import software.amazon.smithy.aws.traits.ServiceTrait;
9-
import software.amazon.smithy.codegen.core.Symbol;
10-
import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION;
1111
import software.amazon.smithy.python.codegen.CodegenUtils;
12-
import software.amazon.smithy.python.codegen.ConfigProperty;
1312
import software.amazon.smithy.python.codegen.GenerationContext;
1413
import software.amazon.smithy.python.codegen.integrations.PythonIntegration;
1514
import software.amazon.smithy.python.codegen.integrations.RuntimeClientPlugin;

codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java

Lines changed: 106 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,7 @@ private void generateOperationExecutor(PythonWriter writer) {
120120
var hasStreaming = hasEventStream();
121121
writer.putContext("hasEventStream", hasStreaming);
122122
if (hasStreaming) {
123-
writer.addImports("smithy_core.deserializers",
124-
Set.of(
125-
"ShapeDeserializer",
126-
"DeserializeableShape"));
123+
writer.addImport("smithy_core.deserializers", "ShapeDeserializer");
127124
writer.addStdlibImport("typing", "Any");
128125
}
129126

@@ -137,7 +134,8 @@ private void generateOperationExecutor(PythonWriter writer) {
137134
writer.addStdlibImport("typing", "Awaitable");
138135
writer.addStdlibImport("typing", "cast");
139136
writer.addStdlibImport("copy", "deepcopy");
140-
writer.addStdlibImport("asyncio", "sleep");
137+
writer.addStdlibImport("asyncio");
138+
writer.addStdlibImports("asyncio", Set.of("sleep", "Future"));
141139

142140
writer.addDependency(SmithyPythonDependency.SMITHY_CORE);
143141
writer.addImport("smithy_core.exceptions", "SmithyRetryException");
@@ -187,6 +185,75 @@ def _classify_error(
187185
""");
188186
writer.dedent();
189187

188+
if (hasStreaming) {
189+
writer.addStdlibImports("typing", Set.of("Any", "Awaitable"));
190+
writer.addStdlibImport("asyncio");
191+
writer.write(
192+
"""
193+
async def _input_stream(
194+
self,
195+
input: Input,
196+
plugins: list[$1T],
197+
serialize: Callable[[Input, $4T], Awaitable[$2T]],
198+
deserialize: Callable[[$3T, $4T], Awaitable[Output]],
199+
config: $4T,
200+
operation_name: str,
201+
) -> Any:
202+
request_future = Future[$2T]()
203+
awaitable_output = asyncio.create_task(self._execute_operation(
204+
input, plugins, serialize, deserialize, config, operation_name,
205+
request_future=request_future
206+
))
207+
transport_request = await request_future
208+
${5C|}
209+
210+
async def _output_stream(
211+
self,
212+
input: Input,
213+
plugins: list[$1T],
214+
serialize: Callable[[Input, $4T], Awaitable[$2T]],
215+
deserialize: Callable[[$3T, $4T], Awaitable[Output]],
216+
config: $4T,
217+
operation_name: str,
218+
event_deserializer: Callable[[ShapeDeserializer], Any],
219+
) -> Any:
220+
response_future = Future[$3T]()
221+
output = await self._execute_operation(
222+
input, plugins, serialize, deserialize, config, operation_name,
223+
response_future=response_future
224+
)
225+
transport_response = await response_future
226+
${6C|}
227+
228+
async def _duplex_stream(
229+
self,
230+
input: Input,
231+
plugins: list[$1T],
232+
serialize: Callable[[Input, $4T], Awaitable[$2T]],
233+
deserialize: Callable[[$3T, $4T], Awaitable[Output]],
234+
config: $4T,
235+
operation_name: str,
236+
event_deserializer: Callable[[ShapeDeserializer], Any],
237+
) -> Any:
238+
request_future = Future[$2T]()
239+
response_future = Future[$3T]()
240+
awaitable_output = asyncio.create_task(self._execute_operation(
241+
input, plugins, serialize, deserialize, config, operation_name,
242+
request_future=request_future,
243+
response_future=response_future
244+
))
245+
transport_request = await request_future
246+
${7C|}
247+
""",
248+
pluginSymbol,
249+
transportRequest,
250+
transportResponse,
251+
configSymbol,
252+
writer.consumer(w -> context.protocolGenerator().wrapInputStream(context, w)),
253+
writer.consumer(w -> context.protocolGenerator().wrapOutputStream(context, w)),
254+
writer.consumer(w -> context.protocolGenerator().wrapDuplexStream(context, w)));
255+
}
256+
190257
writer.write(
191258
"""
192259
async def _execute_operation(
@@ -197,25 +264,25 @@ async def _execute_operation(
197264
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
198265
config: $5T,
199266
operation_name: str,
200-
${?hasEventStream}
201-
has_input_stream: bool = False,
202-
event_deserializer: Callable[[ShapeDeserializer], Any] | None = None,
203-
event_response_deserializer: DeserializeableShape | None = None,
204-
${/hasEventStream}
267+
request_future: Future[$2T] | None = None,
268+
response_future: Future[$3T] | None = None,
205269
) -> Output:
206270
try:
207271
return await self._handle_execution(
208272
input, plugins, serialize, deserialize, config, operation_name,
209-
${?hasEventStream}
210-
has_input_stream, event_deserializer, event_response_deserializer,
211-
${/hasEventStream}
273+
request_future, response_future,
212274
)
213275
except Exception as e:
276+
if request_future is not None and not request_future.done:
277+
request_future.set_exception($4T(e))
278+
if response_future is not None and not response_future.done:
279+
response_future.set_exception($4T(e))
280+
214281
# Make sure every exception that we throw is an instance of $4T so
215282
# customers can reliably catch everything we throw.
216283
if not isinstance(e, $4T):
217284
raise $4T(e) from e
218-
raise e
285+
raise
219286
220287
async def _handle_execution(
221288
self,
@@ -225,11 +292,8 @@ async def _handle_execution(
225292
deserialize: Callable[[$3T, $5T], Awaitable[Output]],
226293
config: $5T,
227294
operation_name: str,
228-
${?hasEventStream}
229-
has_input_stream: bool = False,
230-
event_deserializer: Callable[[ShapeDeserializer], Any] | None = None,
231-
event_response_deserializer: DeserializeableShape | None = None,
232-
${/hasEventStream}
295+
request_future: Future[$2T] | None,
296+
response_future: Future[$3T] | None,
233297
) -> Output:
234298
logger.debug('Making request for operation "%s" with parameters: %s', operation_name, input)
235299
context: InterceptorContext[Input, None, None, None] = InterceptorContext(
@@ -307,6 +371,7 @@ async def _handle_execution(
307371
context_with_transport_request.copy(),
308372
config,
309373
operation_name,
374+
request_future,
310375
)
311376
312377
# We perform this type-ignored re-assignment because `context` needs
@@ -342,6 +407,10 @@ await seek(0)
342407
else:
343408
# Step 8: Invoke record_success
344409
retry_strategy.record_success(token=retry_token)
410+
if response_future is not None:
411+
response_future.set_result(
412+
context_with_response.response, # type: ignore
413+
)
345414
break
346415
except Exception as e:
347416
if context.response is not None:
@@ -355,16 +424,7 @@ await seek(0)
355424
execution_context = cast(
356425
InterceptorContext[Input, Output, $2T | None, $3T | None], context
357426
)
358-
${^hasEventStream}
359427
return await self._finalize_execution(interceptors, execution_context)
360-
${/hasEventStream}
361-
${?hasEventStream}
362-
operation_output = await self._finalize_execution(interceptors, execution_context)
363-
if has_input_stream or event_deserializer is not None:
364-
${6C|}
365-
else:
366-
return operation_output
367-
${/hasEventStream}
368428
369429
async def _handle_attempt(
370430
self,
@@ -373,6 +433,7 @@ async def _handle_attempt(
373433
context: InterceptorContext[Input, None, $2T, None],
374434
config: $5T,
375435
operation_name: str,
436+
request_future: Future[$2T] | None,
376437
) -> InterceptorContext[Input, Output, $2T, $3T | None]:
377438
try:
378439
# assert config.interceptors is not None
@@ -385,8 +446,7 @@ async def _handle_attempt(
385446
transportRequest,
386447
transportResponse,
387448
errorSymbol,
388-
configSymbol,
389-
writer.consumer(w -> context.protocolGenerator().wrapEventStream(context, w)));
449+
configSymbol);
390450

391451
boolean supportsAuth = !ServiceIndex.of(context.model()).getAuthSchemes(service).isEmpty();
392452
writer.pushState(new ResolveIdentitySection());
@@ -533,10 +593,19 @@ async def _handle_attempt(
533593
)
534594
logger.debug("HTTP request config: %s", request_config)
535595
logger.debug("Sending HTTP request: %s", context_with_response.transport_request)
536-
context_with_response._transport_response = await config.http_client.send(
537-
request=context_with_response.transport_request,
538-
request_config=request_config,
539-
)
596+
597+
if request_future is not None:
598+
response_task = asyncio.create_task(config.http_client.send(
599+
request=context_with_response.transport_request,
600+
request_config=request_config,
601+
))
602+
request_future.set_result(context_with_response.transport_request)
603+
context_with_response._transport_response = await response_task
604+
else:
605+
context_with_response._transport_response = await config.http_client.send(
606+
request=context_with_response.transport_request,
607+
request_config=request_config,
608+
)
540609
logger.debug("Received HTTP response: %s", context_with_response.transport_response)
541610
542611
""", transportRequest, transportResponse);
@@ -834,16 +903,14 @@ private void generateEventStreamOperation(PythonWriter writer, OperationShape op
834903
raise NotImplementedError()
835904
${/hasProtocol}
836905
${?hasProtocol}
837-
return await self._execute_operation(
906+
return await self._duplex_stream(
838907
input=input,
839908
plugins=operation_plugins,
840909
serialize=${serSymbol:T},
841910
deserialize=${deserSymbol:T},
842911
config=self._config,
843912
operation_name=${operationName:S},
844-
has_input_stream=True,
845913
event_deserializer=$T().deserialize,
846-
event_response_deserializer=${output:T},
847914
) # type: ignore
848915
${/hasProtocol}
849916
""",
@@ -862,14 +929,13 @@ raise NotImplementedError()
862929
raise NotImplementedError()
863930
${/hasProtocol}
864931
${?hasProtocol}
865-
return await self._execute_operation(
932+
return await self._input_stream(
866933
input=input,
867934
plugins=operation_plugins,
868935
serialize=${serSymbol:T},
869936
deserialize=${deserSymbol:T},
870937
config=self._config,
871938
operation_name=${operationName:S},
872-
has_input_stream=True,
873939
) # type: ignore
874940
${/hasProtocol}
875941
""", writer.consumer(w -> writeSharedOperationInit(w, operation, input)));
@@ -887,15 +953,14 @@ raise NotImplementedError()
887953
raise NotImplementedError()
888954
${/hasProtocol}
889955
${?hasProtocol}
890-
return await self._execute_operation(
956+
return await self._output_stream(
891957
input=input,
892958
plugins=operation_plugins,
893959
serialize=${serSymbol:T},
894960
deserialize=${deserSymbol:T},
895961
config=self._config,
896962
operation_name=${operationName:S},
897963
event_deserializer=$T().deserialize,
898-
event_response_deserializer=${output:T},
899964
) # type: ignore
900965
${/hasProtocol}
901966
""",

codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ProtocolGenerator.java

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -154,23 +154,9 @@ default void generateSharedDeserializerComponents(GenerationContext context) {}
154154
*/
155155
default void generateProtocolTests(GenerationContext context) {}
156156

157-
/**
158-
* Generates the code to wrap an operation output into an event stream.
159-
*
160-
* <p>Important context variables are:
161-
* <ul>
162-
* <li>execution_context - Has the context, including the transport input and output.</li>
163-
* <li>operation_output - The deserialized operation output.</li>
164-
* <li>has_input_stream - Whether or not there is an input stream.</li>
165-
* <li>event_deserializer - The deserialize method for output events, or None for no output stream.</li>
166-
* <li>event_response_deserializer - A DeserializeableShape representing the operation's output shape,
167-
* or None for no output stream. This is used when the operation sends the initial response over the
168-
* event stream.
169-
* </li>
170-
* </ul>
171-
*
172-
* @param context Generation context.
173-
* @param writer The writer to write to.
174-
*/
175-
default void wrapEventStream(GenerationContext context, PythonWriter writer) {}
157+
default void wrapInputStream(GenerationContext context, PythonWriter writer) {}
158+
159+
default void wrapOutputStream(GenerationContext context, PythonWriter writer) {}
160+
161+
default void wrapDuplexStream(GenerationContext context, PythonWriter writer) {}
176162
}

0 commit comments

Comments
 (0)