Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
*/
package software.amazon.smithy.python.codegen;

import static software.amazon.smithy.python.codegen.SymbolProperties.OPERATION_METHOD;

import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.Set;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.codegen.core.SymbolReference;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.knowledge.EventStreamIndex;
import software.amazon.smithy.model.knowledge.EventStreamInfo;
import software.amazon.smithy.model.knowledge.ServiceIndex;
Expand All @@ -34,10 +38,14 @@
final class ClientGenerator implements Runnable {

private final GenerationContext context;
private final Model model;
private final ServiceShape service;
private final SymbolProvider symbolProvider;

ClientGenerator(GenerationContext context, ServiceShape service) {
this.context = context;
this.symbolProvider = context.symbolProvider();
this.model = context.model();
this.service = service;
}

Expand All @@ -47,7 +55,7 @@ public void run() {
}

private void generateService(PythonWriter writer) {
var serviceSymbol = context.symbolProvider().toSymbol(service);
var serviceSymbol = symbolProvider.toSymbol(service);
var configSymbol = CodegenUtils.getConfigSymbol(context.settings());
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());
writer.addLogger();
Expand Down Expand Up @@ -77,7 +85,7 @@ private void generateService(PythonWriter writer) {

for (PythonIntegration integration : context.integrations()) {
for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins(context)) {
if (runtimeClientPlugin.matchesService(context.model(), service)) {
if (runtimeClientPlugin.matchesService(model, service)) {
runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add);
}
}
Expand All @@ -97,8 +105,8 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None):
plugin(self._config)
""", configSymbol, pluginSymbol, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins)));

var topDownIndex = TopDownIndex.of(context.model());
var eventStreamIndex = EventStreamIndex.of(context.model());
var topDownIndex = TopDownIndex.of(model);
var eventStreamIndex = EventStreamIndex.of(model);
for (OperationShape operation : topDownIndex.getContainedOperations(service)) {
if (eventStreamIndex.getInputInfo(operation).isPresent()
|| eventStreamIndex.getOutputInfo(operation).isPresent()) {
Expand Down Expand Up @@ -253,7 +261,7 @@ async def _duplex_stream(
writer.consumer(w -> context.protocolGenerator().wrapOutputStream(context, w)),
writer.consumer(w -> context.protocolGenerator().wrapDuplexStream(context, w)));
}

writer.addStdlibImport("typing", "Any");
writer.write(
"""
async def _execute_operation(
Expand Down Expand Up @@ -448,7 +456,7 @@ async def _handle_attempt(
errorSymbol,
configSymbol);

boolean supportsAuth = !ServiceIndex.of(context.model()).getAuthSchemes(service).isEmpty();
boolean supportsAuth = !ServiceIndex.of(model).getAuthSchemes(service).isEmpty();
writer.pushState(new ResolveIdentitySection());
if (context.applicationProtocol().isHttpProtocol() && supportsAuth) {
writer.pushState(new InitializeHttpAuthParametersSection());
Expand Down Expand Up @@ -731,8 +739,8 @@ async def _finalize_execution(
}

private boolean hasEventStream() {
var streamIndex = EventStreamIndex.of(context.model());
var topDownIndex = TopDownIndex.of(context.model());
var streamIndex = EventStreamIndex.of(model);
var topDownIndex = TopDownIndex.of(model);
for (OperationShape operation : topDownIndex.getContainedOperations(context.settings().service())) {
if (streamIndex.getInputInfo(operation).isPresent() || streamIndex.getOutputInfo(operation).isPresent()) {
return true;
Expand All @@ -745,7 +753,7 @@ private void initializeHttpAuthParameters(PythonWriter writer) {
var derived = new LinkedHashSet<DerivedProperty>();
for (PythonIntegration integration : context.integrations()) {
for (RuntimeClientPlugin plugin : integration.getClientPlugins(context)) {
if (plugin.matchesService(context.model(), service)
if (plugin.matchesService(model, service)
&& plugin.getAuthScheme().isPresent()
&& plugin.getAuthScheme().get().getApplicationProtocol().isHttpProtocol()) {
derived.addAll(plugin.getAuthScheme().get().getAuthProperties());
Expand Down Expand Up @@ -773,18 +781,18 @@ private void writeDefaultPlugins(PythonWriter writer, Collection<SymbolReference
* Generates the function for a single operation.
*/
private void generateOperation(PythonWriter writer, OperationShape operation) {
var operationSymbol = context.symbolProvider().toSymbol(operation);
var operationMethodSymbol = symbolProvider.toSymbol(operation).expectProperty(OPERATION_METHOD);
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());

var input = context.model().expectShape(operation.getInputShape());
var inputSymbol = context.symbolProvider().toSymbol(input);
var input = model.expectShape(operation.getInputShape());
var inputSymbol = symbolProvider.toSymbol(input);

var output = context.model().expectShape(operation.getOutputShape());
var outputSymbol = context.symbolProvider().toSymbol(output);
var output = model.expectShape(operation.getOutputShape());
var outputSymbol = symbolProvider.toSymbol(output);

writer.openBlock("async def $L(self, input: $T, plugins: list[$T] | None = None) -> $T:",
"",
operationSymbol.getName(),
operationMethodSymbol.getName(),
inputSymbol,
pluginSymbol,
outputSymbol,
Expand Down Expand Up @@ -834,7 +842,7 @@ private void writeSharedOperationInit(PythonWriter writer, OperationShape operat
var defaultPlugins = new LinkedHashSet<SymbolReference>();
for (PythonIntegration integration : context.integrations()) {
for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins(context)) {
if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) {
if (runtimeClientPlugin.matchesOperation(model, service, operation)) {
runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add);
}
}
Expand All @@ -852,29 +860,29 @@ private void writeSharedOperationInit(PythonWriter writer, OperationShape operat
private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) {
writer.pushState();
writer.addDependency(SmithyPythonDependency.SMITHY_EVENT_STREAM);
var operationSymbol = context.symbolProvider().toSymbol(operation);
writer.putContext("operationName", operationSymbol.getName());
var operationMethodSymbol = symbolProvider.toSymbol(operation).expectProperty(OPERATION_METHOD);
writer.putContext("operationName", operationMethodSymbol.getName());
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());
writer.putContext("plugin", pluginSymbol);

var input = context.model().expectShape(operation.getInputShape());
var inputSymbol = context.symbolProvider().toSymbol(input);
var input = model.expectShape(operation.getInputShape());
var inputSymbol = symbolProvider.toSymbol(input);
writer.putContext("input", inputSymbol);

var eventStreamIndex = EventStreamIndex.of(context.model());
var eventStreamIndex = EventStreamIndex.of(model);
var inputStreamSymbol = eventStreamIndex.getInputInfo(operation)
.map(EventStreamInfo::getEventStreamTarget)
.map(target -> context.symbolProvider().toSymbol(target))
.map(symbolProvider::toSymbol)
.orElse(null);
writer.putContext("inputStream", inputStreamSymbol);

var output = context.model().expectShape(operation.getOutputShape());
var outputSymbol = context.symbolProvider().toSymbol(output);
var output = model.expectShape(operation.getOutputShape());
var outputSymbol = symbolProvider.toSymbol(output);
writer.putContext("output", outputSymbol);

var outputStreamSymbol = eventStreamIndex.getOutputInfo(operation)
.map(EventStreamInfo::getEventStreamTarget)
.map(target -> context.symbolProvider().toSymbol(target))
.map(symbolProvider::toSymbol)
.orElse(null);
writer.putContext("outputStream", outputStreamSymbol);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import software.amazon.smithy.codegen.core.directed.GenerateIntEnumDirective;
import software.amazon.smithy.codegen.core.directed.GenerateListDirective;
import software.amazon.smithy.codegen.core.directed.GenerateMapDirective;
import software.amazon.smithy.codegen.core.directed.GenerateOperationDirective;
import software.amazon.smithy.codegen.core.directed.GenerateServiceDirective;
import software.amazon.smithy.codegen.core.directed.GenerateStructureDirective;
import software.amazon.smithy.codegen.core.directed.GenerateUnionDirective;
Expand All @@ -35,6 +36,7 @@
import software.amazon.smithy.python.codegen.generators.IntEnumGenerator;
import software.amazon.smithy.python.codegen.generators.ListGenerator;
import software.amazon.smithy.python.codegen.generators.MapGenerator;
import software.amazon.smithy.python.codegen.generators.OperationGenerator;
import software.amazon.smithy.python.codegen.generators.ProtocolGenerator;
import software.amazon.smithy.python.codegen.generators.SchemaGenerator;
import software.amazon.smithy.python.codegen.generators.ServiceErrorGenerator;
Expand Down Expand Up @@ -134,6 +136,19 @@ public void generateService(GenerateServiceDirective<GenerationContext, PythonSe
protocolGenerator.generateProtocolTests(directive.context());
}

@Override
public void generateOperation(GenerateOperationDirective<GenerationContext, PythonSettings> directive) {
DirectedCodegen.super.generateOperation(directive);

directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> {
OperationGenerator generator = new OperationGenerator(
directive.context(),
writer,
directive.shape());
generator.run();
});
}

@Override
public void generateStructure(GenerateStructureDirective<GenerationContext, PythonSettings> directive) {
directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ with raises(KeyError):
except Exception as err:
fail(f"Expected '$2L' exception to be thrown, but received {type(err).__name__}: {err}")
""",
context.symbolProvider().toSymbol(operation),
context.symbolProvider()
.toSymbol(operation)
.expectProperty(SymbolProperties.OPERATION_METHOD),
TEST_HTTP_SERVICE_ERR_SYMBOL,
testCase.getMethod(),
testCase.getUri(),
Expand Down Expand Up @@ -459,7 +461,9 @@ private void generateResponseTest(OperationShape operation, HttpResponseTestCase

${C|}
""",
context.symbolProvider().toSymbol(operation),
context.symbolProvider()
.toSymbol(operation)
.expectProperty(SymbolProperties.OPERATION_METHOD),
(Runnable) () -> testCase.getParams().accept(new ValueNodeVisitor(outputShape)),
(Runnable) () -> assertResponseEqual(testCase, operation));
});
Expand Down Expand Up @@ -507,7 +511,9 @@ private void generateErrorResponseTest(
if type(err).__name__ != $2S:
fail(f"Expected '$2L' exception to be thrown, but received {type(err).__name__}: {err}")
""",
context.symbolProvider().toSymbol(operation),
context.symbolProvider()
.toSymbol(operation)
.expectProperty(SymbolProperties.OPERATION_METHOD),
error.getId().getName());
// TODO: Correctly assert the status code and other values
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,17 @@ public Symbol bigDecimalShape(BigDecimalShape shape) {
public Symbol operationShape(OperationShape shape) {
// Operation names are escaped like members because ultimately they're
// properties on an object too.
var name = escaper.escapeMemberName(CaseUtils.toSnakeCase(shape.getId().getName(service)));
return createGeneratedSymbolBuilder(shape, name, "client").build();
var methodName = escaper.escapeMemberName(CaseUtils.toSnakeCase(shape.getId().getName(service)));
var methodSymbol = createGeneratedSymbolBuilder(shape, methodName, "client", false)
.putProperty(SymbolProperties.IMPORTABLE, false)
.build();

// We add a symbol for the method in the client as a property, whereas the actual
// operation symbol points to the generated type for it
var name = CaseUtils.toSnakeCase(getDefaultShapeName(shape)).toUpperCase(Locale.ENGLISH);
return createGeneratedSymbolBuilder(shape, name, SHAPES_FILE)
.putProperty(SymbolProperties.OPERATION_METHOD, methodSymbol)
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,16 @@ public final class SymbolProperties {
*/
public static final Property<Symbol> DESERIALIZER = Property.named("deserializer");

/**
* Contains a symbol pointing to an operation shape's method in the client. This is
* only used for operations.
*/
public static final Property<Symbol> OPERATION_METHOD = Property.named("operationMethod");

/**
* Whether a symbol is importable (i.e. an instance method is not "importable")
*/
public static final Property<Boolean> IMPORTABLE = Property.named("nonImportable");

private SymbolProperties() {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.python.codegen.generators;

import java.util.List;
import java.util.logging.Logger;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.knowledge.ServiceIndex;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.python.codegen.GenerationContext;
import software.amazon.smithy.python.codegen.SmithyPythonDependency;
import software.amazon.smithy.python.codegen.SymbolProperties;
import software.amazon.smithy.python.codegen.writer.PythonWriter;
import software.amazon.smithy.utils.SmithyInternalApi;

@SmithyInternalApi
public final class OperationGenerator implements Runnable {
private static final Logger LOGGER = Logger.getLogger(OperationGenerator.class.getName());

private final GenerationContext context;
private final PythonWriter writer;
private final OperationShape shape;
private final SymbolProvider symbolProvider;
private final Model model;

public OperationGenerator(GenerationContext context, PythonWriter writer, OperationShape shape) {
this.context = context;
this.writer = writer;
this.shape = shape;
this.symbolProvider = context.symbolProvider();
this.model = context.model();
}

@Override
public void run() {
var opSymbol = symbolProvider.toSymbol(shape);
var inSymbol = symbolProvider.toSymbol(model.expectShape(shape.getInputShape()));
var outSymbol = symbolProvider.toSymbol(model.expectShape(shape.getOutputShape()));

writer.addStdlibImport("dataclasses", "dataclass");
writer.addDependency(SmithyPythonDependency.SMITHY_CORE);
writer.addImport("smithy_core.schemas", "APIOperation");
writer.addImport("smithy_core.documents", "TypeRegistry");

writer.write("""
$1L = APIOperation(
input = $2T,
output = $3T,
schema = $4T,
input_schema = $5T,
output_schema = $6T,
error_registry = TypeRegistry({
$7C
}),
effective_auth_schemes = [
$8C
]
)
""",
opSymbol.getName(),
inSymbol,
outSymbol,
opSymbol.expectProperty(SymbolProperties.SCHEMA),
inSymbol.expectProperty(SymbolProperties.SCHEMA),
outSymbol.expectProperty(SymbolProperties.SCHEMA),
writer.consumer(this::writeErrorTypeRegistry),
writer.consumer(this::writeAuthSchemes)
);
}

private void writeErrorTypeRegistry(PythonWriter writer) {
List<ShapeId> errors = shape.getErrors();
if (!errors.isEmpty()) {
writer.addImport("smithy_core.shapes", "ShapeID");
}
for (var error : errors) {
var errSymbol = symbolProvider.toSymbol(model.expectShape(error));
writer.write("ShapeID($S): $T,", error, errSymbol);
}
}

private void writeAuthSchemes(PythonWriter writer) {
var authSchemes = ServiceIndex.of(model)
.getEffectiveAuthSchemes(context.settings().service(),
shape.getId(),
ServiceIndex.AuthSchemeMode.NO_AUTH_AWARE);

if (!authSchemes.isEmpty()) {
writer.addImport("smithy_core.shapes", "ShapeID");
}

for (var authSchemeId : authSchemes.keySet()) {
writer.write("ShapeID($S)", authSchemeId);
}

}
}
Loading