2020import java .util .LinkedHashSet ;
2121import java .util .Set ;
2222import software .amazon .smithy .codegen .core .SymbolReference ;
23+ import software .amazon .smithy .model .knowledge .EventStreamIndex ;
24+ import software .amazon .smithy .model .knowledge .EventStreamInfo ;
2325import software .amazon .smithy .model .knowledge .ServiceIndex ;
2426import software .amazon .smithy .model .knowledge .TopDownIndex ;
2527import software .amazon .smithy .model .shapes .OperationShape ;
@@ -104,8 +106,14 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None):
104106 """ , configSymbol , pluginSymbol , writer .consumer (w -> writeDefaultPlugins (w , defaultPlugins )));
105107
106108 var topDownIndex = TopDownIndex .of (context .model ());
109+ var eventStreamIndex = EventStreamIndex .of (context .model ());
107110 for (OperationShape operation : topDownIndex .getContainedOperations (service )) {
108- generateOperation (writer , operation );
111+ if (eventStreamIndex .getInputInfo (operation ).isPresent ()
112+ || eventStreamIndex .getOutputInfo (operation ).isPresent ()) {
113+ generateEventStreamOperation (writer , operation );
114+ } else {
115+ generateOperation (writer , operation );
116+ }
109117 }
110118 });
111119
@@ -348,7 +356,7 @@ async def _handle_attempt(
348356 )
349357
350358 """ , CodegenUtils .getHttpAuthParamsSymbol (context .settings ()),
351- writer .consumer (this ::initializeHttpAuthParameters ));
359+ writer .consumer (this ::initializeHttpAuthParameters ));
352360 writer .popState ();
353361
354362 writer .addDependency (SmithyPythonDependency .SMITHY_CORE );
@@ -641,48 +649,48 @@ private void generateOperation(PythonWriter writer, OperationShape operation) {
641649
642650 writer .openBlock ("async def $L(self, input: $T, plugins: list[$T] | None = None) -> $T:" , "" ,
643651 operationSymbol .getName (), inputSymbol , pluginSymbol , outputSymbol , () -> {
644- writer .writeDocs (() -> {
645- var docs = operation .getTrait (DocumentationTrait .class )
646- .map (StringTrait ::getValue )
647- .orElse (String .format ("Invokes the %s operation." , operation .getId ().getName ()));
652+ writer .writeDocs (() -> {
653+ var docs = operation .getTrait (DocumentationTrait .class )
654+ .map (StringTrait ::getValue )
655+ .orElse (String .format ("Invokes the %s operation." , operation .getId ().getName ()));
648656
649- var inputDocs = input .getTrait (DocumentationTrait .class )
650- .map (StringTrait ::getValue )
651- .orElse ("The operation's input." );
657+ var inputDocs = input .getTrait (DocumentationTrait .class )
658+ .map (StringTrait ::getValue )
659+ .orElse ("The operation's input." );
652660
653- writer .write ("""
661+ writer .write ("""
654662 $L
655663
656664 :param input: $L
657665
658666 :param plugins: A list of callables that modify the configuration dynamically.
659667 Changes made by these plugins only apply for the duration of the operation
660668 execution and will not affect any other operation invocations.""" , docs , inputDocs );
661- });
662-
663- var defaultPlugins = new LinkedHashSet <SymbolReference >();
664- for (PythonIntegration integration : context .integrations ()) {
665- for (RuntimeClientPlugin runtimeClientPlugin : integration .getClientPlugins ()) {
666- if (runtimeClientPlugin .matchesOperation (context .model (), service , operation )) {
667- runtimeClientPlugin .getPythonPlugin ().ifPresent (defaultPlugins ::add );
669+ });
670+
671+ var defaultPlugins = new LinkedHashSet <SymbolReference >();
672+ for (PythonIntegration integration : context .integrations ()) {
673+ for (RuntimeClientPlugin runtimeClientPlugin : integration .getClientPlugins ()) {
674+ if (runtimeClientPlugin .matchesOperation (context .model (), service , operation )) {
675+ runtimeClientPlugin .getPythonPlugin ().ifPresent (defaultPlugins ::add );
676+ }
677+ }
668678 }
669- }
670- }
671- writer .write ("""
679+ writer .write ("""
672680 operation_plugins: list[Plugin] = [
673681 $C
674682 ]
675683 if plugins:
676684 operation_plugins.extend(plugins)
677685 """ , writer .consumer (w -> writeDefaultPlugins (w , defaultPlugins )));
678686
679- if (context .protocolGenerator () == null ) {
680- writer .write ("raise NotImplementedError()" );
681- } else {
682- var protocolGenerator = context .protocolGenerator ();
683- var serSymbol = protocolGenerator .getSerializationFunction (context , operation );
684- var deserSymbol = protocolGenerator .getDeserializationFunction (context , operation );
685- writer .write ("""
687+ if (context .protocolGenerator () == null ) {
688+ writer .write ("raise NotImplementedError()" );
689+ } else {
690+ var protocolGenerator = context .protocolGenerator ();
691+ var serSymbol = protocolGenerator .getSerializationFunction (context , operation );
692+ var deserSymbol = protocolGenerator .getDeserializationFunction (context , operation );
693+ writer .write ("""
686694 return await self._execute_operation(
687695 input=input,
688696 plugins=operation_plugins,
@@ -692,7 +700,47 @@ private void generateOperation(PythonWriter writer, OperationShape operation) {
692700 operation_name=$S,
693701 )
694702 """ , serSymbol , deserSymbol , operation .getId ().getName ());
695- }
696- });
703+ }
704+ });
705+ }
706+
707+ private void generateEventStreamOperation (PythonWriter writer , OperationShape operation ) {
708+ writer .pushState ();
709+ writer .addDependency (SmithyPythonDependency .SMITHY_EVENT_STREAM );
710+ writer .addImports ("smithy_event_stream.aio.interfaces" , Set .of (
711+ "EventStream" , "InputEventStream" , "OutputEventStream" ));
712+ var operationSymbol = context .symbolProvider ().toSymbol (operation );
713+ var pluginSymbol = CodegenUtils .getPluginSymbol (context .settings ());
714+
715+ var input = context .model ().expectShape (operation .getInputShape ());
716+ var inputSymbol = context .symbolProvider ().toSymbol (input );
717+
718+ var eventStreamIndex = EventStreamIndex .of (context .model ());
719+ var inputStreamSymbol = eventStreamIndex .getInputInfo (operation )
720+ .map (EventStreamInfo ::getEventStreamTarget )
721+ .map (target -> context .symbolProvider ().toSymbol (target ))
722+ .orElse (null );
723+ writer .putContext ("inputStream" , inputStreamSymbol );
724+
725+ var output = context .model ().expectShape (operation .getOutputShape ());
726+ var outputSymbol = context .symbolProvider ().toSymbol (output );
727+ var outputStreamSymbol = eventStreamIndex .getOutputInfo (operation )
728+ .map (EventStreamInfo ::getEventStreamTarget )
729+ .map (target -> context .symbolProvider ().toSymbol (target ))
730+ .orElse (null );
731+ writer .putContext ("outputStream" , outputStreamSymbol );
732+
733+ writer .write ("""
734+ async def $L(self, input: $T, plugins: list[$T] | None = None) -> EventStream[
735+ ${?inputStream}InputEventStream[${inputStream:T}]${/inputStream}\
736+ ${^inputStream}None${/inputStream},
737+ ${?outputStream}OutputEventStream[${outputStream:T}]${/outputStream}\
738+ ${^outputStream}None${/outputStream},
739+ $T
740+ ]:
741+ raise NotImplementedError()
742+ """ , operationSymbol .getName (), inputSymbol , pluginSymbol , outputSymbol );
743+
744+ writer .popState ();
697745 }
698746}
0 commit comments