Skip to content

Commit cea2326

Browse files
Generate event stream operation method signatures
This updates operation generation to generate event stream operations with the EventStream type as its return value. It also updates union generation so that unions contain their own deserialize functions. This is needed to make the them pass the type check, but also it is best to have them own as much of that as possible so that the deserializer function can be left to only dispatch duty.
1 parent c61fbf6 commit cea2326

File tree

3 files changed

+105
-41
lines changed

3 files changed

+105
-41
lines changed

codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java

Lines changed: 77 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import java.util.LinkedHashSet;
2121
import java.util.Set;
2222
import software.amazon.smithy.codegen.core.SymbolReference;
23+
import software.amazon.smithy.model.knowledge.EventStreamIndex;
24+
import software.amazon.smithy.model.knowledge.EventStreamInfo;
2325
import software.amazon.smithy.model.knowledge.ServiceIndex;
2426
import software.amazon.smithy.model.knowledge.TopDownIndex;
2527
import software.amazon.smithy.model.shapes.OperationShape;
@@ -104,8 +106,14 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None):
104106
""", configSymbol, pluginSymbol, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins)));
105107

106108
var topDownIndex = TopDownIndex.of(context.model());
109+
var eventStreamIndex = EventStreamIndex.of(context.model());
107110
for (OperationShape operation : topDownIndex.getContainedOperations(service)) {
108-
generateOperation(writer, operation);
111+
if (eventStreamIndex.getInputInfo(operation).isPresent()
112+
|| eventStreamIndex.getOutputInfo(operation).isPresent()) {
113+
generateEventStreamOperation(writer, operation);
114+
} else {
115+
generateOperation(writer, operation);
116+
}
109117
}
110118
});
111119

@@ -348,7 +356,7 @@ async def _handle_attempt(
348356
)
349357
350358
""", CodegenUtils.getHttpAuthParamsSymbol(context.settings()),
351-
writer.consumer(this::initializeHttpAuthParameters));
359+
writer.consumer(this::initializeHttpAuthParameters));
352360
writer.popState();
353361

354362
writer.addDependency(SmithyPythonDependency.SMITHY_CORE);
@@ -641,48 +649,48 @@ private void generateOperation(PythonWriter writer, OperationShape operation) {
641649

642650
writer.openBlock("async def $L(self, input: $T, plugins: list[$T] | None = None) -> $T:", "",
643651
operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol, () -> {
644-
writer.writeDocs(() -> {
645-
var docs = operation.getTrait(DocumentationTrait.class)
646-
.map(StringTrait::getValue)
647-
.orElse(String.format("Invokes the %s operation.", operation.getId().getName()));
652+
writer.writeDocs(() -> {
653+
var docs = operation.getTrait(DocumentationTrait.class)
654+
.map(StringTrait::getValue)
655+
.orElse(String.format("Invokes the %s operation.", operation.getId().getName()));
648656

649-
var inputDocs = input.getTrait(DocumentationTrait.class)
650-
.map(StringTrait::getValue)
651-
.orElse("The operation's input.");
657+
var inputDocs = input.getTrait(DocumentationTrait.class)
658+
.map(StringTrait::getValue)
659+
.orElse("The operation's input.");
652660

653-
writer.write("""
661+
writer.write("""
654662
$L
655663
656664
:param input: $L
657665
658666
:param plugins: A list of callables that modify the configuration dynamically.
659667
Changes made by these plugins only apply for the duration of the operation
660668
execution and will not affect any other operation invocations.""", docs, inputDocs);
661-
});
662-
663-
var defaultPlugins = new LinkedHashSet<SymbolReference>();
664-
for (PythonIntegration integration : context.integrations()) {
665-
for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) {
666-
if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) {
667-
runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add);
669+
});
670+
671+
var defaultPlugins = new LinkedHashSet<SymbolReference>();
672+
for (PythonIntegration integration : context.integrations()) {
673+
for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins()) {
674+
if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) {
675+
runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add);
676+
}
677+
}
668678
}
669-
}
670-
}
671-
writer.write("""
679+
writer.write("""
672680
operation_plugins: list[Plugin] = [
673681
$C
674682
]
675683
if plugins:
676684
operation_plugins.extend(plugins)
677685
""", writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins)));
678686

679-
if (context.protocolGenerator() == null) {
680-
writer.write("raise NotImplementedError()");
681-
} else {
682-
var protocolGenerator = context.protocolGenerator();
683-
var serSymbol = protocolGenerator.getSerializationFunction(context, operation);
684-
var deserSymbol = protocolGenerator.getDeserializationFunction(context, operation);
685-
writer.write("""
687+
if (context.protocolGenerator() == null) {
688+
writer.write("raise NotImplementedError()");
689+
} else {
690+
var protocolGenerator = context.protocolGenerator();
691+
var serSymbol = protocolGenerator.getSerializationFunction(context, operation);
692+
var deserSymbol = protocolGenerator.getDeserializationFunction(context, operation);
693+
writer.write("""
686694
return await self._execute_operation(
687695
input=input,
688696
plugins=operation_plugins,
@@ -692,7 +700,47 @@ private void generateOperation(PythonWriter writer, OperationShape operation) {
692700
operation_name=$S,
693701
)
694702
""", serSymbol, deserSymbol, operation.getId().getName());
695-
}
696-
});
703+
}
704+
});
705+
}
706+
707+
private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) {
708+
writer.pushState();
709+
writer.addDependency(SmithyPythonDependency.SMITHY_EVENT_STREAM);
710+
writer.addImports("smithy_event_stream.aio.interfaces", Set.of(
711+
"EventStream", "InputEventStream", "OutputEventStream"));
712+
var operationSymbol = context.symbolProvider().toSymbol(operation);
713+
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());
714+
715+
var input = context.model().expectShape(operation.getInputShape());
716+
var inputSymbol = context.symbolProvider().toSymbol(input);
717+
718+
var eventStreamIndex = EventStreamIndex.of(context.model());
719+
var inputStreamSymbol = eventStreamIndex.getInputInfo(operation)
720+
.map(EventStreamInfo::getEventStreamTarget)
721+
.map(target -> context.symbolProvider().toSymbol(target))
722+
.orElse(null);
723+
writer.putContext("inputStream", inputStreamSymbol);
724+
725+
var output = context.model().expectShape(operation.getOutputShape());
726+
var outputSymbol = context.symbolProvider().toSymbol(output);
727+
var outputStreamSymbol = eventStreamIndex.getOutputInfo(operation)
728+
.map(EventStreamInfo::getEventStreamTarget)
729+
.map(target -> context.symbolProvider().toSymbol(target))
730+
.orElse(null);
731+
writer.putContext("outputStream", outputStreamSymbol);
732+
733+
writer.write("""
734+
async def $L(self, input: $T, plugins: list[$T] | None = None) -> EventStream[
735+
${?inputStream}InputEventStream[${inputStream:T}]${/inputStream}\
736+
${^inputStream}None${/inputStream},
737+
${?outputStream}OutputEventStream[${outputStream:T}]${/outputStream}\
738+
${^outputStream}None${/outputStream},
739+
$T
740+
]:
741+
raise NotImplementedError()
742+
""", operationSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol);
743+
744+
writer.popState();
697745
}
698746
}

codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/SmithyPythonDependency.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ public final class SmithyPythonDependency {
6464
false
6565
);
6666

67+
public static final PythonDependency SMITHY_EVENT_STREAM = new PythonDependency(
68+
"smithy_event_stream",
69+
"==0.0.1",
70+
Type.DEPENDENCY,
71+
false
72+
);
73+
6774
/**
6875
* testing framework used in generated functional tests.
6976
*/

codegen/smithy-python-codegen/src/main/java/software/amazon/smithy/python/codegen/UnionGenerator.java

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,20 @@ public void run() {
7272

7373
writer.write("""
7474
@dataclass
75-
class $L:
76-
${C|}
75+
class $1L:
76+
${2C|}
7777
78-
value: $T
78+
value: $3T
7979
8080
def serialize(self, serializer: ShapeSerializer):
81-
serializer.write_struct($T, self)
81+
serializer.write_struct($4T, self)
8282
8383
def serialize_members(self, serializer: ShapeSerializer):
84-
${C|}
84+
${5C|}
85+
86+
@classmethod
87+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
88+
return cls(value=${6C|})
8589
8690
""",
8791
memberSymbol.getName(),
@@ -90,7 +94,11 @@ def serialize_members(self, serializer: ShapeSerializer):
9094
targetSymbol,
9195
schemaSymbol,
9296
writer.consumer(w -> target.accept(
93-
new MemberSerializerGenerator(context, w, member, "serializer"))));
97+
new MemberSerializerGenerator(context, w, member, "serializer"))),
98+
writer.consumer(w -> target.accept(
99+
new MemberDeserializerGenerator(context, w, member, "deserializer")))
100+
101+
);
94102
}
95103

96104
// Note that the unknown variant doesn't implement __eq__. This is because
@@ -118,11 +126,15 @@ raise SmithyException("Unknown union variants may not be serialized.")
118126
def serialize_members(self, serializer: ShapeSerializer):
119127
raise SmithyException("Unknown union variants may not be serialized.")
120128
129+
@classmethod
130+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
131+
raise NotImplementedError()
132+
121133
""", unknownSymbol.getName());
122134
memberNames.add(unknownSymbol.getName());
123135

124-
shape.getTrait(DocumentationTrait.class).ifPresent(trait -> writer.writeComment(trait.getValue()));
125136
writer.write("type $L = $L\n", parentName, String.join(" | ", memberNames));
137+
shape.getTrait(DocumentationTrait.class).ifPresent(trait -> writer.writeDocs(trait.getValue()));
126138

127139
generateDeserializer();
128140
writer.popState();
@@ -173,13 +185,10 @@ raise SmithyException("Unions must have exactly one value, but found more than o
173185
private void deserializeMembers() {
174186
int index = 0;
175187
for (MemberShape member : shape.members()) {
176-
var target = model.expectShape(member.getTarget());
177188
writer.write("""
178189
case $L:
179-
self._set_result($T(${C|}))
180-
""", index++, symbolProvider.toSymbol(member), writer.consumer(w ->
181-
target.accept(new MemberDeserializerGenerator(context, writer, member, "de"))
182-
));
190+
self._set_result($T.deserialize(de))
191+
""", index++, symbolProvider.toSymbol(member));
183192
}
184193
}
185194
}

0 commit comments

Comments
 (0)