Skip to content

Commit 6db4183

Browse files
authored
Add operation shape generation (#445)
1 parent 31ec8e2 commit 6db4183

File tree

8 files changed

+188
-44
lines changed

8 files changed

+188
-44
lines changed

codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
*/
55
package software.amazon.smithy.python.codegen;
66

7+
import static software.amazon.smithy.python.codegen.SymbolProperties.OPERATION_METHOD;
8+
79
import java.util.Collection;
810
import java.util.LinkedHashSet;
911
import java.util.Set;
12+
import software.amazon.smithy.codegen.core.SymbolProvider;
1013
import software.amazon.smithy.codegen.core.SymbolReference;
14+
import software.amazon.smithy.model.Model;
1115
import software.amazon.smithy.model.knowledge.EventStreamIndex;
1216
import software.amazon.smithy.model.knowledge.EventStreamInfo;
1317
import software.amazon.smithy.model.knowledge.ServiceIndex;
@@ -34,10 +38,14 @@
3438
final class ClientGenerator implements Runnable {
3539

3640
private final GenerationContext context;
41+
private final Model model;
3742
private final ServiceShape service;
43+
private final SymbolProvider symbolProvider;
3844

3945
ClientGenerator(GenerationContext context, ServiceShape service) {
4046
this.context = context;
47+
this.symbolProvider = context.symbolProvider();
48+
this.model = context.model();
4149
this.service = service;
4250
}
4351

@@ -47,7 +55,7 @@ public void run() {
4755
}
4856

4957
private void generateService(PythonWriter writer) {
50-
var serviceSymbol = context.symbolProvider().toSymbol(service);
58+
var serviceSymbol = symbolProvider.toSymbol(service);
5159
var configSymbol = CodegenUtils.getConfigSymbol(context.settings());
5260
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());
5361
writer.addLogger();
@@ -77,7 +85,7 @@ private void generateService(PythonWriter writer) {
7785

7886
for (PythonIntegration integration : context.integrations()) {
7987
for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins(context)) {
80-
if (runtimeClientPlugin.matchesService(context.model(), service)) {
88+
if (runtimeClientPlugin.matchesService(model, service)) {
8189
runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add);
8290
}
8391
}
@@ -97,8 +105,8 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None):
97105
plugin(self._config)
98106
""", configSymbol, pluginSymbol, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins)));
99107

100-
var topDownIndex = TopDownIndex.of(context.model());
101-
var eventStreamIndex = EventStreamIndex.of(context.model());
108+
var topDownIndex = TopDownIndex.of(model);
109+
var eventStreamIndex = EventStreamIndex.of(model);
102110
for (OperationShape operation : topDownIndex.getContainedOperations(service)) {
103111
if (eventStreamIndex.getInputInfo(operation).isPresent()
104112
|| eventStreamIndex.getOutputInfo(operation).isPresent()) {
@@ -253,7 +261,7 @@ async def _duplex_stream(
253261
writer.consumer(w -> context.protocolGenerator().wrapOutputStream(context, w)),
254262
writer.consumer(w -> context.protocolGenerator().wrapDuplexStream(context, w)));
255263
}
256-
264+
writer.addStdlibImport("typing", "Any");
257265
writer.write(
258266
"""
259267
async def _execute_operation(
@@ -448,7 +456,7 @@ async def _handle_attempt(
448456
errorSymbol,
449457
configSymbol);
450458

451-
boolean supportsAuth = !ServiceIndex.of(context.model()).getAuthSchemes(service).isEmpty();
459+
boolean supportsAuth = !ServiceIndex.of(model).getAuthSchemes(service).isEmpty();
452460
writer.pushState(new ResolveIdentitySection());
453461
if (context.applicationProtocol().isHttpProtocol() && supportsAuth) {
454462
writer.pushState(new InitializeHttpAuthParametersSection());
@@ -731,8 +739,8 @@ async def _finalize_execution(
731739
}
732740

733741
private boolean hasEventStream() {
734-
var streamIndex = EventStreamIndex.of(context.model());
735-
var topDownIndex = TopDownIndex.of(context.model());
742+
var streamIndex = EventStreamIndex.of(model);
743+
var topDownIndex = TopDownIndex.of(model);
736744
for (OperationShape operation : topDownIndex.getContainedOperations(context.settings().service())) {
737745
if (streamIndex.getInputInfo(operation).isPresent() || streamIndex.getOutputInfo(operation).isPresent()) {
738746
return true;
@@ -745,7 +753,7 @@ private void initializeHttpAuthParameters(PythonWriter writer) {
745753
var derived = new LinkedHashSet<DerivedProperty>();
746754
for (PythonIntegration integration : context.integrations()) {
747755
for (RuntimeClientPlugin plugin : integration.getClientPlugins(context)) {
748-
if (plugin.matchesService(context.model(), service)
756+
if (plugin.matchesService(model, service)
749757
&& plugin.getAuthScheme().isPresent()
750758
&& plugin.getAuthScheme().get().getApplicationProtocol().isHttpProtocol()) {
751759
derived.addAll(plugin.getAuthScheme().get().getAuthProperties());
@@ -773,18 +781,18 @@ private void writeDefaultPlugins(PythonWriter writer, Collection<SymbolReference
773781
* Generates the function for a single operation.
774782
*/
775783
private void generateOperation(PythonWriter writer, OperationShape operation) {
776-
var operationSymbol = context.symbolProvider().toSymbol(operation);
784+
var operationMethodSymbol = symbolProvider.toSymbol(operation).expectProperty(OPERATION_METHOD);
777785
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());
778786

779-
var input = context.model().expectShape(operation.getInputShape());
780-
var inputSymbol = context.symbolProvider().toSymbol(input);
787+
var input = model.expectShape(operation.getInputShape());
788+
var inputSymbol = symbolProvider.toSymbol(input);
781789

782-
var output = context.model().expectShape(operation.getOutputShape());
783-
var outputSymbol = context.symbolProvider().toSymbol(output);
790+
var output = model.expectShape(operation.getOutputShape());
791+
var outputSymbol = symbolProvider.toSymbol(output);
784792

785793
writer.openBlock("async def $L(self, input: $T, plugins: list[$T] | None = None) -> $T:",
786794
"",
787-
operationSymbol.getName(),
795+
operationMethodSymbol.getName(),
788796
inputSymbol,
789797
pluginSymbol,
790798
outputSymbol,
@@ -834,7 +842,7 @@ private void writeSharedOperationInit(PythonWriter writer, OperationShape operat
834842
var defaultPlugins = new LinkedHashSet<SymbolReference>();
835843
for (PythonIntegration integration : context.integrations()) {
836844
for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins(context)) {
837-
if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) {
845+
if (runtimeClientPlugin.matchesOperation(model, service, operation)) {
838846
runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add);
839847
}
840848
}
@@ -852,29 +860,29 @@ private void writeSharedOperationInit(PythonWriter writer, OperationShape operat
852860
private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) {
853861
writer.pushState();
854862
writer.addDependency(SmithyPythonDependency.SMITHY_EVENT_STREAM);
855-
var operationSymbol = context.symbolProvider().toSymbol(operation);
856-
writer.putContext("operationName", operationSymbol.getName());
863+
var operationMethodSymbol = symbolProvider.toSymbol(operation).expectProperty(OPERATION_METHOD);
864+
writer.putContext("operationName", operationMethodSymbol.getName());
857865
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());
858866
writer.putContext("plugin", pluginSymbol);
859867

860-
var input = context.model().expectShape(operation.getInputShape());
861-
var inputSymbol = context.symbolProvider().toSymbol(input);
868+
var input = model.expectShape(operation.getInputShape());
869+
var inputSymbol = symbolProvider.toSymbol(input);
862870
writer.putContext("input", inputSymbol);
863871

864-
var eventStreamIndex = EventStreamIndex.of(context.model());
872+
var eventStreamIndex = EventStreamIndex.of(model);
865873
var inputStreamSymbol = eventStreamIndex.getInputInfo(operation)
866874
.map(EventStreamInfo::getEventStreamTarget)
867-
.map(target -> context.symbolProvider().toSymbol(target))
875+
.map(symbolProvider::toSymbol)
868876
.orElse(null);
869877
writer.putContext("inputStream", inputStreamSymbol);
870878

871-
var output = context.model().expectShape(operation.getOutputShape());
872-
var outputSymbol = context.symbolProvider().toSymbol(output);
879+
var output = model.expectShape(operation.getOutputShape());
880+
var outputSymbol = symbolProvider.toSymbol(output);
873881
writer.putContext("output", outputSymbol);
874882

875883
var outputStreamSymbol = eventStreamIndex.getOutputInfo(operation)
876884
.map(EventStreamInfo::getEventStreamTarget)
877-
.map(target -> context.symbolProvider().toSymbol(target))
885+
.map(symbolProvider::toSymbol)
878886
.orElse(null);
879887
writer.putContext("outputStream", outputStreamSymbol);
880888

codegen/core/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonClientCodegen.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import software.amazon.smithy.codegen.core.directed.GenerateIntEnumDirective;
2222
import software.amazon.smithy.codegen.core.directed.GenerateListDirective;
2323
import software.amazon.smithy.codegen.core.directed.GenerateMapDirective;
24+
import software.amazon.smithy.codegen.core.directed.GenerateOperationDirective;
2425
import software.amazon.smithy.codegen.core.directed.GenerateServiceDirective;
2526
import software.amazon.smithy.codegen.core.directed.GenerateStructureDirective;
2627
import software.amazon.smithy.codegen.core.directed.GenerateUnionDirective;
@@ -35,6 +36,7 @@
3536
import software.amazon.smithy.python.codegen.generators.IntEnumGenerator;
3637
import software.amazon.smithy.python.codegen.generators.ListGenerator;
3738
import software.amazon.smithy.python.codegen.generators.MapGenerator;
39+
import software.amazon.smithy.python.codegen.generators.OperationGenerator;
3840
import software.amazon.smithy.python.codegen.generators.ProtocolGenerator;
3941
import software.amazon.smithy.python.codegen.generators.SchemaGenerator;
4042
import software.amazon.smithy.python.codegen.generators.ServiceErrorGenerator;
@@ -134,6 +136,19 @@ public void generateService(GenerateServiceDirective<GenerationContext, PythonSe
134136
protocolGenerator.generateProtocolTests(directive.context());
135137
}
136138

139+
@Override
140+
public void generateOperation(GenerateOperationDirective<GenerationContext, PythonSettings> directive) {
141+
DirectedCodegen.super.generateOperation(directive);
142+
143+
directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> {
144+
OperationGenerator generator = new OperationGenerator(
145+
directive.context(),
146+
writer,
147+
directive.shape());
148+
generator.run();
149+
});
150+
}
151+
137152
@Override
138153
public void generateStructure(GenerateStructureDirective<GenerationContext, PythonSettings> directive) {
139154
directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> {

codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,9 @@ with raises(KeyError):
255255
except Exception as err:
256256
fail(f"Expected '$2L' exception to be thrown, but received {type(err).__name__}: {err}")
257257
""",
258-
context.symbolProvider().toSymbol(operation),
258+
context.symbolProvider()
259+
.toSymbol(operation)
260+
.expectProperty(SymbolProperties.OPERATION_METHOD),
259261
TEST_HTTP_SERVICE_ERR_SYMBOL,
260262
testCase.getMethod(),
261263
testCase.getUri(),
@@ -459,7 +461,9 @@ private void generateResponseTest(OperationShape operation, HttpResponseTestCase
459461
460462
${C|}
461463
""",
462-
context.symbolProvider().toSymbol(operation),
464+
context.symbolProvider()
465+
.toSymbol(operation)
466+
.expectProperty(SymbolProperties.OPERATION_METHOD),
463467
(Runnable) () -> testCase.getParams().accept(new ValueNodeVisitor(outputShape)),
464468
(Runnable) () -> assertResponseEqual(testCase, operation));
465469
});
@@ -507,7 +511,9 @@ private void generateErrorResponseTest(
507511
if type(err).__name__ != $2S:
508512
fail(f"Expected '$2L' exception to be thrown, but received {type(err).__name__}: {err}")
509513
""",
510-
context.symbolProvider().toSymbol(operation),
514+
context.symbolProvider()
515+
.toSymbol(operation)
516+
.expectProperty(SymbolProperties.OPERATION_METHOD),
511517
error.getId().getName());
512518
// TODO: Correctly assert the status code and other values
513519
});

codegen/core/src/main/java/software/amazon/smithy/python/codegen/PythonSymbolProvider.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,17 @@ public Symbol bigDecimalShape(BigDecimalShape shape) {
276276
public Symbol operationShape(OperationShape shape) {
277277
// Operation names are escaped like members because ultimately they're
278278
// properties on an object too.
279-
var name = escaper.escapeMemberName(CaseUtils.toSnakeCase(shape.getId().getName(service)));
280-
return createGeneratedSymbolBuilder(shape, name, "client").build();
279+
var methodName = escaper.escapeMemberName(CaseUtils.toSnakeCase(shape.getId().getName(service)));
280+
var methodSymbol = createGeneratedSymbolBuilder(shape, methodName, "client", false)
281+
.putProperty(SymbolProperties.IMPORTABLE, false)
282+
.build();
283+
284+
// We add a symbol for the method in the client as a property, whereas the actual
285+
// operation symbol points to the generated type for it
286+
var name = CaseUtils.toSnakeCase(getDefaultShapeName(shape)).toUpperCase(Locale.ENGLISH);
287+
return createGeneratedSymbolBuilder(shape, name, SHAPES_FILE)
288+
.putProperty(SymbolProperties.OPERATION_METHOD, methodSymbol)
289+
.build();
281290
}
282291

283292
@Override

codegen/core/src/main/java/software/amazon/smithy/python/codegen/SymbolProperties.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,16 @@ public final class SymbolProperties {
7373
*/
7474
public static final Property<Symbol> DESERIALIZER = Property.named("deserializer");
7575

76+
/**
77+
* Contains a symbol pointing to an operation shape's method in the client. This is
78+
* only used for operations.
79+
*/
80+
public static final Property<Symbol> OPERATION_METHOD = Property.named("operationMethod");
81+
82+
/**
83+
* Whether a symbol is importable (i.e. an instance method is not "importable")
84+
*/
85+
public static final Property<Boolean> IMPORTABLE = Property.named("nonImportable");
86+
7687
private SymbolProperties() {}
7788
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package software.amazon.smithy.python.codegen.generators;
6+
7+
import java.util.List;
8+
import java.util.logging.Logger;
9+
import software.amazon.smithy.codegen.core.SymbolProvider;
10+
import software.amazon.smithy.model.Model;
11+
import software.amazon.smithy.model.knowledge.ServiceIndex;
12+
import software.amazon.smithy.model.shapes.OperationShape;
13+
import software.amazon.smithy.model.shapes.ShapeId;
14+
import software.amazon.smithy.python.codegen.GenerationContext;
15+
import software.amazon.smithy.python.codegen.SmithyPythonDependency;
16+
import software.amazon.smithy.python.codegen.SymbolProperties;
17+
import software.amazon.smithy.python.codegen.writer.PythonWriter;
18+
import software.amazon.smithy.utils.SmithyInternalApi;
19+
20+
@SmithyInternalApi
21+
public final class OperationGenerator implements Runnable {
22+
private static final Logger LOGGER = Logger.getLogger(OperationGenerator.class.getName());
23+
24+
private final GenerationContext context;
25+
private final PythonWriter writer;
26+
private final OperationShape shape;
27+
private final SymbolProvider symbolProvider;
28+
private final Model model;
29+
30+
public OperationGenerator(GenerationContext context, PythonWriter writer, OperationShape shape) {
31+
this.context = context;
32+
this.writer = writer;
33+
this.shape = shape;
34+
this.symbolProvider = context.symbolProvider();
35+
this.model = context.model();
36+
}
37+
38+
@Override
39+
public void run() {
40+
var opSymbol = symbolProvider.toSymbol(shape);
41+
var inSymbol = symbolProvider.toSymbol(model.expectShape(shape.getInputShape()));
42+
var outSymbol = symbolProvider.toSymbol(model.expectShape(shape.getOutputShape()));
43+
44+
writer.addStdlibImport("dataclasses", "dataclass");
45+
writer.addDependency(SmithyPythonDependency.SMITHY_CORE);
46+
writer.addImport("smithy_core.schemas", "APIOperation");
47+
writer.addImport("smithy_core.documents", "TypeRegistry");
48+
49+
writer.write("""
50+
$1L = APIOperation(
51+
input = $2T,
52+
output = $3T,
53+
schema = $4T,
54+
input_schema = $5T,
55+
output_schema = $6T,
56+
error_registry = TypeRegistry({
57+
$7C
58+
}),
59+
effective_auth_schemes = [
60+
$8C
61+
]
62+
)
63+
""",
64+
opSymbol.getName(),
65+
inSymbol,
66+
outSymbol,
67+
opSymbol.expectProperty(SymbolProperties.SCHEMA),
68+
inSymbol.expectProperty(SymbolProperties.SCHEMA),
69+
outSymbol.expectProperty(SymbolProperties.SCHEMA),
70+
writer.consumer(this::writeErrorTypeRegistry),
71+
writer.consumer(this::writeAuthSchemes)
72+
);
73+
}
74+
75+
private void writeErrorTypeRegistry(PythonWriter writer) {
76+
List<ShapeId> errors = shape.getErrors();
77+
if (!errors.isEmpty()) {
78+
writer.addImport("smithy_core.shapes", "ShapeID");
79+
}
80+
for (var error : errors) {
81+
var errSymbol = symbolProvider.toSymbol(model.expectShape(error));
82+
writer.write("ShapeID($S): $T,", error, errSymbol);
83+
}
84+
}
85+
86+
private void writeAuthSchemes(PythonWriter writer) {
87+
var authSchemes = ServiceIndex.of(model)
88+
.getEffectiveAuthSchemes(context.settings().service(),
89+
shape.getId(),
90+
ServiceIndex.AuthSchemeMode.NO_AUTH_AWARE);
91+
92+
if (!authSchemes.isEmpty()) {
93+
writer.addImport("smithy_core.shapes", "ShapeID");
94+
}
95+
96+
for (var authSchemeId : authSchemes.keySet()) {
97+
writer.write("ShapeID($S)", authSchemeId);
98+
}
99+
100+
}
101+
}

0 commit comments

Comments
 (0)