diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 399b6c479..9d506c7e7 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -4,10 +4,14 @@ */ package software.amazon.smithy.python.codegen; +import static software.amazon.smithy.python.codegen.SymbolProperties.OPERATION_METHOD; + import java.util.Collection; import java.util.LinkedHashSet; import java.util.Set; +import software.amazon.smithy.codegen.core.SymbolProvider; import software.amazon.smithy.codegen.core.SymbolReference; +import software.amazon.smithy.model.Model; import software.amazon.smithy.model.knowledge.EventStreamIndex; import software.amazon.smithy.model.knowledge.EventStreamInfo; import software.amazon.smithy.model.knowledge.ServiceIndex; @@ -34,10 +38,14 @@ final class ClientGenerator implements Runnable { private final GenerationContext context; + private final Model model; private final ServiceShape service; + private final SymbolProvider symbolProvider; ClientGenerator(GenerationContext context, ServiceShape service) { this.context = context; + this.symbolProvider = context.symbolProvider(); + this.model = context.model(); this.service = service; } @@ -47,7 +55,7 @@ public void run() { } private void generateService(PythonWriter writer) { - var serviceSymbol = context.symbolProvider().toSymbol(service); + var serviceSymbol = symbolProvider.toSymbol(service); var configSymbol = CodegenUtils.getConfigSymbol(context.settings()); var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings()); writer.addLogger(); @@ -77,7 +85,7 @@ private void generateService(PythonWriter writer) { for (PythonIntegration integration : context.integrations()) { for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins(context)) { - if (runtimeClientPlugin.matchesService(context.model(), service)) { + if (runtimeClientPlugin.matchesService(model, service)) { runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add); } } @@ -97,8 +105,8 @@ def __init__(self, config: $1T | None = None, plugins: list[$2T] | None = None): plugin(self._config) """, configSymbol, pluginSymbol, writer.consumer(w -> writeDefaultPlugins(w, defaultPlugins))); - var topDownIndex = TopDownIndex.of(context.model()); - var eventStreamIndex = EventStreamIndex.of(context.model()); + var topDownIndex = TopDownIndex.of(model); + var eventStreamIndex = EventStreamIndex.of(model); for (OperationShape operation : topDownIndex.getContainedOperations(service)) { if (eventStreamIndex.getInputInfo(operation).isPresent() || eventStreamIndex.getOutputInfo(operation).isPresent()) { @@ -253,7 +261,7 @@ async def _duplex_stream( writer.consumer(w -> context.protocolGenerator().wrapOutputStream(context, w)), writer.consumer(w -> context.protocolGenerator().wrapDuplexStream(context, w))); } - + writer.addStdlibImport("typing", "Any"); writer.write( """ async def _execute_operation( @@ -448,7 +456,7 @@ async def _handle_attempt( errorSymbol, configSymbol); - boolean supportsAuth = !ServiceIndex.of(context.model()).getAuthSchemes(service).isEmpty(); + boolean supportsAuth = !ServiceIndex.of(model).getAuthSchemes(service).isEmpty(); writer.pushState(new ResolveIdentitySection()); if (context.applicationProtocol().isHttpProtocol() && supportsAuth) { writer.pushState(new InitializeHttpAuthParametersSection()); @@ -731,8 +739,8 @@ async def _finalize_execution( } private boolean hasEventStream() { - var streamIndex = EventStreamIndex.of(context.model()); - var topDownIndex = TopDownIndex.of(context.model()); + var streamIndex = EventStreamIndex.of(model); + var topDownIndex = TopDownIndex.of(model); for (OperationShape operation : topDownIndex.getContainedOperations(context.settings().service())) { if (streamIndex.getInputInfo(operation).isPresent() || streamIndex.getOutputInfo(operation).isPresent()) { return true; @@ -745,7 +753,7 @@ private void initializeHttpAuthParameters(PythonWriter writer) { var derived = new LinkedHashSet(); for (PythonIntegration integration : context.integrations()) { for (RuntimeClientPlugin plugin : integration.getClientPlugins(context)) { - if (plugin.matchesService(context.model(), service) + if (plugin.matchesService(model, service) && plugin.getAuthScheme().isPresent() && plugin.getAuthScheme().get().getApplicationProtocol().isHttpProtocol()) { derived.addAll(plugin.getAuthScheme().get().getAuthProperties()); @@ -773,18 +781,18 @@ private void writeDefaultPlugins(PythonWriter writer, Collection $T:", "", - operationSymbol.getName(), + operationMethodSymbol.getName(), inputSymbol, pluginSymbol, outputSymbol, @@ -834,7 +842,7 @@ private void writeSharedOperationInit(PythonWriter writer, OperationShape operat var defaultPlugins = new LinkedHashSet(); for (PythonIntegration integration : context.integrations()) { for (RuntimeClientPlugin runtimeClientPlugin : integration.getClientPlugins(context)) { - if (runtimeClientPlugin.matchesOperation(context.model(), service, operation)) { + if (runtimeClientPlugin.matchesOperation(model, service, operation)) { runtimeClientPlugin.getPythonPlugin().ifPresent(defaultPlugins::add); } } @@ -852,29 +860,29 @@ private void writeSharedOperationInit(PythonWriter writer, OperationShape operat private void generateEventStreamOperation(PythonWriter writer, OperationShape operation) { writer.pushState(); writer.addDependency(SmithyPythonDependency.SMITHY_EVENT_STREAM); - var operationSymbol = context.symbolProvider().toSymbol(operation); - writer.putContext("operationName", operationSymbol.getName()); + var operationMethodSymbol = symbolProvider.toSymbol(operation).expectProperty(OPERATION_METHOD); + writer.putContext("operationName", operationMethodSymbol.getName()); var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings()); writer.putContext("plugin", pluginSymbol); - var input = context.model().expectShape(operation.getInputShape()); - var inputSymbol = context.symbolProvider().toSymbol(input); + var input = model.expectShape(operation.getInputShape()); + var inputSymbol = symbolProvider.toSymbol(input); writer.putContext("input", inputSymbol); - var eventStreamIndex = EventStreamIndex.of(context.model()); + var eventStreamIndex = EventStreamIndex.of(model); var inputStreamSymbol = eventStreamIndex.getInputInfo(operation) .map(EventStreamInfo::getEventStreamTarget) - .map(target -> context.symbolProvider().toSymbol(target)) + .map(symbolProvider::toSymbol) .orElse(null); writer.putContext("inputStream", inputStreamSymbol); - var output = context.model().expectShape(operation.getOutputShape()); - var outputSymbol = context.symbolProvider().toSymbol(output); + var output = model.expectShape(operation.getOutputShape()); + var outputSymbol = symbolProvider.toSymbol(output); writer.putContext("output", outputSymbol); var outputStreamSymbol = eventStreamIndex.getOutputInfo(operation) .map(EventStreamInfo::getEventStreamTarget) - .map(target -> context.symbolProvider().toSymbol(target)) + .map(symbolProvider::toSymbol) .orElse(null); writer.putContext("outputStream", outputStreamSymbol); diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonClientCodegen.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonClientCodegen.java index 2d28fe013..54a2531f0 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonClientCodegen.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonClientCodegen.java @@ -21,6 +21,7 @@ import software.amazon.smithy.codegen.core.directed.GenerateIntEnumDirective; import software.amazon.smithy.codegen.core.directed.GenerateListDirective; import software.amazon.smithy.codegen.core.directed.GenerateMapDirective; +import software.amazon.smithy.codegen.core.directed.GenerateOperationDirective; import software.amazon.smithy.codegen.core.directed.GenerateServiceDirective; import software.amazon.smithy.codegen.core.directed.GenerateStructureDirective; import software.amazon.smithy.codegen.core.directed.GenerateUnionDirective; @@ -35,6 +36,7 @@ import software.amazon.smithy.python.codegen.generators.IntEnumGenerator; import software.amazon.smithy.python.codegen.generators.ListGenerator; import software.amazon.smithy.python.codegen.generators.MapGenerator; +import software.amazon.smithy.python.codegen.generators.OperationGenerator; import software.amazon.smithy.python.codegen.generators.ProtocolGenerator; import software.amazon.smithy.python.codegen.generators.SchemaGenerator; import software.amazon.smithy.python.codegen.generators.ServiceErrorGenerator; @@ -134,6 +136,19 @@ public void generateService(GenerateServiceDirective directive) { + DirectedCodegen.super.generateOperation(directive); + + directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> { + OperationGenerator generator = new OperationGenerator( + directive.context(), + writer, + directive.shape()); + generator.run(); + }); + } + @Override public void generateStructure(GenerateStructureDirective directive) { directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> { diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java index c3ff1d406..b31257b87 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java @@ -255,7 +255,9 @@ with raises(KeyError): except Exception as err: fail(f"Expected '$2L' exception to be thrown, but received {type(err).__name__}: {err}") """, - context.symbolProvider().toSymbol(operation), + context.symbolProvider() + .toSymbol(operation) + .expectProperty(SymbolProperties.OPERATION_METHOD), TEST_HTTP_SERVICE_ERR_SYMBOL, testCase.getMethod(), testCase.getUri(), @@ -459,7 +461,9 @@ private void generateResponseTest(OperationShape operation, HttpResponseTestCase ${C|} """, - context.symbolProvider().toSymbol(operation), + context.symbolProvider() + .toSymbol(operation) + .expectProperty(SymbolProperties.OPERATION_METHOD), (Runnable) () -> testCase.getParams().accept(new ValueNodeVisitor(outputShape)), (Runnable) () -> assertResponseEqual(testCase, operation)); }); @@ -507,7 +511,9 @@ private void generateErrorResponseTest( if type(err).__name__ != $2S: fail(f"Expected '$2L' exception to be thrown, but received {type(err).__name__}: {err}") """, - context.symbolProvider().toSymbol(operation), + context.symbolProvider() + .toSymbol(operation) + .expectProperty(SymbolProperties.OPERATION_METHOD), error.getId().getName()); // TODO: Correctly assert the status code and other values }); diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/PythonSymbolProvider.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/PythonSymbolProvider.java index 10394bc7d..19a4e2f96 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/PythonSymbolProvider.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/PythonSymbolProvider.java @@ -276,8 +276,17 @@ public Symbol bigDecimalShape(BigDecimalShape shape) { public Symbol operationShape(OperationShape shape) { // Operation names are escaped like members because ultimately they're // properties on an object too. - var name = escaper.escapeMemberName(CaseUtils.toSnakeCase(shape.getId().getName(service))); - return createGeneratedSymbolBuilder(shape, name, "client").build(); + var methodName = escaper.escapeMemberName(CaseUtils.toSnakeCase(shape.getId().getName(service))); + var methodSymbol = createGeneratedSymbolBuilder(shape, methodName, "client", false) + .putProperty(SymbolProperties.IMPORTABLE, false) + .build(); + + // We add a symbol for the method in the client as a property, whereas the actual + // operation symbol points to the generated type for it + var name = CaseUtils.toSnakeCase(getDefaultShapeName(shape)).toUpperCase(Locale.ENGLISH); + return createGeneratedSymbolBuilder(shape, name, SHAPES_FILE) + .putProperty(SymbolProperties.OPERATION_METHOD, methodSymbol) + .build(); } @Override diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/SymbolProperties.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/SymbolProperties.java index 46862aa72..aa35a1da0 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/SymbolProperties.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/SymbolProperties.java @@ -73,5 +73,16 @@ public final class SymbolProperties { */ public static final Property DESERIALIZER = Property.named("deserializer"); + /** + * Contains a symbol pointing to an operation shape's method in the client. This is + * only used for operations. + */ + public static final Property OPERATION_METHOD = Property.named("operationMethod"); + + /** + * Whether a symbol is importable (i.e. an instance method is not "importable") + */ + public static final Property IMPORTABLE = Property.named("nonImportable"); + private SymbolProperties() {} } diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java new file mode 100644 index 000000000..d04ebf82c --- /dev/null +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java @@ -0,0 +1,101 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.python.codegen.generators; + +import java.util.List; +import java.util.logging.Logger; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.knowledge.ServiceIndex; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.python.codegen.GenerationContext; +import software.amazon.smithy.python.codegen.SmithyPythonDependency; +import software.amazon.smithy.python.codegen.SymbolProperties; +import software.amazon.smithy.python.codegen.writer.PythonWriter; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public final class OperationGenerator implements Runnable { + private static final Logger LOGGER = Logger.getLogger(OperationGenerator.class.getName()); + + private final GenerationContext context; + private final PythonWriter writer; + private final OperationShape shape; + private final SymbolProvider symbolProvider; + private final Model model; + + public OperationGenerator(GenerationContext context, PythonWriter writer, OperationShape shape) { + this.context = context; + this.writer = writer; + this.shape = shape; + this.symbolProvider = context.symbolProvider(); + this.model = context.model(); + } + + @Override + public void run() { + var opSymbol = symbolProvider.toSymbol(shape); + var inSymbol = symbolProvider.toSymbol(model.expectShape(shape.getInputShape())); + var outSymbol = symbolProvider.toSymbol(model.expectShape(shape.getOutputShape())); + + writer.addStdlibImport("dataclasses", "dataclass"); + writer.addDependency(SmithyPythonDependency.SMITHY_CORE); + writer.addImport("smithy_core.schemas", "APIOperation"); + writer.addImport("smithy_core.documents", "TypeRegistry"); + + writer.write(""" + $1L = APIOperation( + input = $2T, + output = $3T, + schema = $4T, + input_schema = $5T, + output_schema = $6T, + error_registry = TypeRegistry({ + $7C + }), + effective_auth_schemes = [ + $8C + ] + ) + """, + opSymbol.getName(), + inSymbol, + outSymbol, + opSymbol.expectProperty(SymbolProperties.SCHEMA), + inSymbol.expectProperty(SymbolProperties.SCHEMA), + outSymbol.expectProperty(SymbolProperties.SCHEMA), + writer.consumer(this::writeErrorTypeRegistry), + writer.consumer(this::writeAuthSchemes) + ); + } + + private void writeErrorTypeRegistry(PythonWriter writer) { + List errors = shape.getErrors(); + if (!errors.isEmpty()) { + writer.addImport("smithy_core.shapes", "ShapeID"); + } + for (var error : errors) { + var errSymbol = symbolProvider.toSymbol(model.expectShape(error)); + writer.write("ShapeID($S): $T,", error, errSymbol); + } + } + + private void writeAuthSchemes(PythonWriter writer) { + var authSchemes = ServiceIndex.of(model) + .getEffectiveAuthSchemes(context.settings().service(), + shape.getId(), + ServiceIndex.AuthSchemeMode.NO_AUTH_AWARE); + + if (!authSchemes.isEmpty()) { + writer.addImport("smithy_core.shapes", "ShapeID"); + } + + for (var authSchemeId : authSchemes.keySet()) { + writer.write("ShapeID($S)", authSchemeId); + } + + } +} diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/writer/PythonWriter.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/writer/PythonWriter.java index 4e9e7a335..20c00f885 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/writer/PythonWriter.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/writer/PythonWriter.java @@ -4,6 +4,8 @@ */ package software.amazon.smithy.python.codegen.writer; +import static software.amazon.smithy.python.codegen.SymbolProperties.IMPORTABLE; + import java.util.Map; import java.util.Optional; import java.util.Set; @@ -21,10 +23,8 @@ import software.amazon.smithy.model.node.NumberNode; import software.amazon.smithy.model.node.ObjectNode; import software.amazon.smithy.model.node.StringNode; -import software.amazon.smithy.model.shapes.Shape; import software.amazon.smithy.python.codegen.CodegenUtils; import software.amazon.smithy.python.codegen.PythonSettings; -import software.amazon.smithy.python.codegen.SymbolProperties; import software.amazon.smithy.utils.SmithyUnstableApi; import software.amazon.smithy.utils.StringUtils; @@ -301,16 +301,14 @@ public String toString() { private final class PythonSymbolFormatter implements BiFunction { @Override public String apply(Object type, String indent) { - if (type instanceof Symbol) { - Symbol typeSymbol = (Symbol) type; - // Check if the symbol is an operation - we shouldn't add imports for operations, since - // they are methods of the service object and *can't* be imported - if (!isOperationSymbol(typeSymbol)) { + if (type instanceof Symbol typeSymbol) { + // If a symbol has the IMPORTABLE property set to false, don't import it and + // treat the lack of the property being set as true + if (typeSymbol.getProperty(IMPORTABLE).orElse(true)) { addUseImports(typeSymbol); } return typeSymbol.getName(); - } else if (type instanceof SymbolReference) { - SymbolReference typeSymbol = (SymbolReference) type; + } else if (type instanceof SymbolReference typeSymbol) { addImport(typeSymbol.getSymbol(), typeSymbol.getAlias(), SymbolReference.ContextOption.USE); return typeSymbol.getAlias(); } else { @@ -320,10 +318,6 @@ public String apply(Object type, String indent) { } } - private Boolean isOperationSymbol(Symbol typeSymbol) { - return typeSymbol.getProperty(SymbolProperties.SHAPE).map(Shape::isOperationShape).orElse(false); - } - private final class PythonNodeFormatter implements BiFunction { private final PythonWriter writer; diff --git a/packages/smithy-core/src/smithy_core/documents.py b/packages/smithy-core/src/smithy_core/documents.py index 828f7a53e..a07304f4a 100644 --- a/packages/smithy-core/src/smithy_core/documents.py +++ b/packages/smithy-core/src/smithy_core/documents.py @@ -682,7 +682,7 @@ def __getitem__(self, shape: ShapeID): def __contains__(self, item: object, /): """Check if the registry contains the given shape. - :param shape: The shape ID to check for. + :param item: The shape ID to check for. """ return item in self._types or ( self._sub_registry is not None and item in self._sub_registry