diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index f6639a94d..2cdd09717 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -187,199 +187,200 @@ def _classify_error( """); writer.dedent(); - writer.write(""" - async def _execute_operation( - self, - input: Input, - plugins: list[$1T], - serialize: Callable[[Input, $5T], Awaitable[$2T]], - deserialize: Callable[[$3T, $5T], Awaitable[Output]], - config: $5T, - operation_name: str, - ${?hasEventStream} - has_input_stream: bool = False, - event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, - event_response_deserializer: DeserializeableShape | None = None, - ${/hasEventStream} - ) -> Output: - try: - return await self._handle_execution( - input, plugins, serialize, deserialize, config, operation_name, + writer.write( + """ + async def _execute_operation( + self, + input: Input, + plugins: list[$1T], + serialize: Callable[[Input, $5T], Awaitable[$2T]], + deserialize: Callable[[$3T, $5T], Awaitable[Output]], + config: $5T, + operation_name: str, ${?hasEventStream} - has_input_stream, event_deserializer, event_response_deserializer, + has_input_stream: bool = False, + event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, + event_response_deserializer: DeserializeableShape | None = None, ${/hasEventStream} - ) - except Exception as e: - # Make sure every exception that we throw is an instance of $4T so - # customers can reliably catch everything we throw. - if not isinstance(e, $4T): - raise $4T(e) from e - raise e - - async def _handle_execution( - self, - input: Input, - plugins: list[$1T], - serialize: Callable[[Input, $5T], Awaitable[$2T]], - deserialize: Callable[[$3T, $5T], Awaitable[Output]], - config: $5T, - operation_name: str, - ${?hasEventStream} - has_input_stream: bool = False, - event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, - event_response_deserializer: DeserializeableShape | None = None, - ${/hasEventStream} - ) -> Output: - logger.debug('Making request for operation "%s" with parameters: %s', operation_name, input) - context: InterceptorContext[Input, None, None, None] = InterceptorContext( - request=input, - response=None, - transport_request=None, - transport_response=None, - ) - _client_interceptors = config.interceptors - client_interceptors = cast( - list[Interceptor[Input, Output, $2T, $3T]], _client_interceptors - ) - interceptors = client_interceptors - - try: - # Step 1a: Invoke read_before_execution on client-level interceptors - for interceptor in client_interceptors: - interceptor.read_before_execution(context) - - # Step 1b: Run operation-level plugins - config = deepcopy(config) - for plugin in plugins: - plugin(config) - - _client_interceptors = config.interceptors - interceptors = cast( - list[Interceptor[Input, Output, $2T, $3T]], - _client_interceptors, - ) + ) -> Output: + try: + return await self._handle_execution( + input, plugins, serialize, deserialize, config, operation_name, + ${?hasEventStream} + has_input_stream, event_deserializer, event_response_deserializer, + ${/hasEventStream} + ) + except Exception as e: + # Make sure every exception that we throw is an instance of $4T so + # customers can reliably catch everything we throw. + if not isinstance(e, $4T): + raise $4T(e) from e + raise e + + async def _handle_execution( + self, + input: Input, + plugins: list[$1T], + serialize: Callable[[Input, $5T], Awaitable[$2T]], + deserialize: Callable[[$3T, $5T], Awaitable[Output]], + config: $5T, + operation_name: str, + ${?hasEventStream} + has_input_stream: bool = False, + event_deserializer: Callable[[ShapeDeserializer], Any] | None = None, + event_response_deserializer: DeserializeableShape | None = None, + ${/hasEventStream} + ) -> Output: + logger.debug('Making request for operation "%s" with parameters: %s', operation_name, input) + context: InterceptorContext[Input, None, None, None] = InterceptorContext( + request=input, + response=None, + transport_request=None, + transport_response=None, + ) + _client_interceptors = config.interceptors + client_interceptors = cast( + list[Interceptor[Input, Output, $2T, $3T]], _client_interceptors + ) + interceptors = client_interceptors + + try: + # Step 1a: Invoke read_before_execution on client-level interceptors + for interceptor in client_interceptors: + interceptor.read_before_execution(context) + + # Step 1b: Run operation-level plugins + config = deepcopy(config) + for plugin in plugins: + plugin(config) + + _client_interceptors = config.interceptors + interceptors = cast( + list[Interceptor[Input, Output, $2T, $3T]], + _client_interceptors, + ) - # Step 1c: Invoke the read_before_execution hooks on newly added - # interceptors. - for interceptor in interceptors: - if interceptor not in client_interceptors: - interceptor.read_before_execution(context) + # Step 1c: Invoke the read_before_execution hooks on newly added + # interceptors. + for interceptor in interceptors: + if interceptor not in client_interceptors: + interceptor.read_before_execution(context) - # Step 2: Invoke the modify_before_serialization hooks - for interceptor in interceptors: - context._request = interceptor.modify_before_serialization(context) + # Step 2: Invoke the modify_before_serialization hooks + for interceptor in interceptors: + context._request = interceptor.modify_before_serialization(context) - # Step 3: Invoke the read_before_serialization hooks - for interceptor in interceptors: - interceptor.read_before_serialization(context) + # Step 3: Invoke the read_before_serialization hooks + for interceptor in interceptors: + interceptor.read_before_serialization(context) - # Step 4: Serialize the request - context_with_transport_request = cast( - InterceptorContext[Input, None, $2T, None], context - ) - logger.debug("Serializing request for: %s", context_with_transport_request.request) - context_with_transport_request._transport_request = await serialize( - context_with_transport_request.request, config - ) - logger.debug("Serialization complete. Transport request: %s", context_with_transport_request._transport_request) + # Step 4: Serialize the request + context_with_transport_request = cast( + InterceptorContext[Input, None, $2T, None], context + ) + logger.debug("Serializing request for: %s", context_with_transport_request.request) + context_with_transport_request._transport_request = await serialize( + context_with_transport_request.request, config + ) + logger.debug("Serialization complete. Transport request: %s", context_with_transport_request._transport_request) - # Step 5: Invoke read_after_serialization - for interceptor in interceptors: - interceptor.read_after_serialization(context_with_transport_request) + # Step 5: Invoke read_after_serialization + for interceptor in interceptors: + interceptor.read_after_serialization(context_with_transport_request) - # Step 6: Invoke modify_before_retry_loop - for interceptor in interceptors: - context_with_transport_request._transport_request = ( - interceptor.modify_before_retry_loop(context_with_transport_request) - ) + # Step 6: Invoke modify_before_retry_loop + for interceptor in interceptors: + context_with_transport_request._transport_request = ( + interceptor.modify_before_retry_loop(context_with_transport_request) + ) - # Step 7: Acquire the retry token. - retry_strategy = config.retry_strategy - retry_token = retry_strategy.acquire_initial_retry_token() - - while True: - # Make an attempt, creating a copy of the context so we don't pass - # around old data. - context_with_response = await self._handle_attempt( - deserialize, - interceptors, - context_with_transport_request.copy(), - config, - operation_name, - ) + # Step 7: Acquire the retry token. + retry_strategy = config.retry_strategy + retry_token = retry_strategy.acquire_initial_retry_token() + + while True: + # Make an attempt, creating a copy of the context so we don't pass + # around old data. + context_with_response = await self._handle_attempt( + deserialize, + interceptors, + context_with_transport_request.copy(), + config, + operation_name, + ) - # We perform this type-ignored re-assignment because `context` needs - # to point at the latest context so it can be generically handled - # later on. This is only an issue here because we've created a copy, - # so we're no longer simply pointing at the same object in memory - # with different names and type hints. It is possible to address this - # without having to fall back to the type ignore, but it would impose - # unnecessary runtime costs. - context = context_with_response # type: ignore - - if isinstance(context_with_response.response, Exception): - # Step 7u: Reacquire retry token if the attempt failed - try: - retry_token = retry_strategy.refresh_retry_token_for_retry( - token_to_renew=retry_token, - error_info=self._classify_error( - error=context_with_response.response, - context=context_with_response, + # We perform this type-ignored re-assignment because `context` needs + # to point at the latest context so it can be generically handled + # later on. This is only an issue here because we've created a copy, + # so we're no longer simply pointing at the same object in memory + # with different names and type hints. It is possible to address this + # without having to fall back to the type ignore, but it would impose + # unnecessary runtime costs. + context = context_with_response # type: ignore + + if isinstance(context_with_response.response, Exception): + # Step 7u: Reacquire retry token if the attempt failed + try: + retry_token = retry_strategy.refresh_retry_token_for_retry( + token_to_renew=retry_token, + error_info=self._classify_error( + error=context_with_response.response, + context=context_with_response, + ) + ) + except SmithyRetryException: + raise context_with_response.response + logger.debug( + "Retry needed. Attempting request #%s in %.4f seconds.", + retry_token.retry_count + 1, + retry_token.retry_delay ) - ) - except SmithyRetryException: - raise context_with_response.response - logger.debug( - "Retry needed. Attempting request #%s in %.4f seconds.", - retry_token.retry_count + 1, - retry_token.retry_delay - ) - await sleep(retry_token.retry_delay) - current_body = context_with_transport_request.transport_request.body - if (seek := getattr(current_body, "seek", None)) is not None: - await seek(0) + await sleep(retry_token.retry_delay) + current_body = context_with_transport_request.transport_request.body + if (seek := getattr(current_body, "seek", None)) is not None: + await seek(0) + else: + # Step 8: Invoke record_success + retry_strategy.record_success(token=retry_token) + break + except Exception as e: + if context.response is not None: + logger.exception("Exception occurred while handling: %s", context.response) + pass + context._response = e + + # At this point, the context's request will have been definitively set, and + # The response will be set either with the modeled output or an exception. The + # transport_request and transport_response may be set or None. + execution_context = cast( + InterceptorContext[Input, Output, $2T | None, $3T | None], context + ) + ${^hasEventStream} + return await self._finalize_execution(interceptors, execution_context) + ${/hasEventStream} + ${?hasEventStream} + operation_output = await self._finalize_execution(interceptors, execution_context) + if has_input_stream or event_deserializer is not None: + ${6C|} else: - # Step 8: Invoke record_success - retry_strategy.record_success(token=retry_token) - break - except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e + return operation_output + ${/hasEventStream} - # At this point, the context's request will have been definitively set, and - # The response will be set either with the modeled output or an exception. The - # transport_request and transport_response may be set or None. - execution_context = cast( - InterceptorContext[Input, Output, $2T | None, $3T | None], context - ) - ${^hasEventStream} - return await self._finalize_execution(interceptors, execution_context) - ${/hasEventStream} - ${?hasEventStream} - operation_output = await self._finalize_execution(interceptors, execution_context) - if has_input_stream or event_deserializer is not None: - ${6C|} - else: - return operation_output - ${/hasEventStream} - - async def _handle_attempt( - self, - deserialize: Callable[[$3T, $5T], Awaitable[Output]], - interceptors: list[Interceptor[Input, Output, $2T, $3T]], - context: InterceptorContext[Input, None, $2T, None], - config: $5T, - operation_name: str, - ) -> InterceptorContext[Input, Output, $2T, $3T | None]: - try: - # assert config.interceptors is not None - # Step 7a: Invoke read_before_attempt - for interceptor in interceptors: - interceptor.read_before_attempt(context) + async def _handle_attempt( + self, + deserialize: Callable[[$3T, $5T], Awaitable[Output]], + interceptors: list[Interceptor[Input, Output, $2T, $3T]], + context: InterceptorContext[Input, None, $2T, None], + config: $5T, + operation_name: str, + ) -> InterceptorContext[Input, Output, $2T, $3T | None]: + try: + # assert config.interceptors is not None + # Step 7a: Invoke read_before_attempt + for interceptor in interceptors: + interceptor.read_before_attempt(context) - """, + """, pluginSymbol, transportRequest, transportResponse, diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java index 6d7f107ae..88b825e18 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java @@ -138,8 +138,7 @@ private static List getHttpProperties(GenerationContext context) if (usesHttp2(context)) { clientBuilder .initialize(writer -> { - writer.addDependency(SmithyPythonDependency.SMITHY_HTTP); - writer.addDependency(SmithyPythonDependency.AWS_CRT); + writer.addDependency(SmithyPythonDependency.SMITHY_HTTP.withOptionalDependencies("awscrt")); writer.addImport("smithy_http.aio.crt", "AWSCRTHTTPClient"); writer.write("self.http_client = http_client or AWSCRTHTTPClient()"); }); @@ -147,8 +146,7 @@ private static List getHttpProperties(GenerationContext context) } else { clientBuilder .initialize(writer -> { - writer.addDependency(SmithyPythonDependency.SMITHY_HTTP); - writer.addDependency(SmithyPythonDependency.AIO_HTTP); + writer.addDependency(SmithyPythonDependency.SMITHY_HTTP.withOptionalDependencies("aiohttp")); writer.addImport("smithy_http.aio.aiohttp", "AIOHTTPClient"); writer.write("self.http_client = http_client or AIOHTTPClient()"); }); diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/SetupGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/SetupGenerator.java index 0beffbe91..307f5af03 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/SetupGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/SetupGenerator.java @@ -9,6 +9,12 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.TreeMap; +import java.util.function.BinaryOperator; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import software.amazon.smithy.codegen.core.CodegenException; import software.amazon.smithy.codegen.core.SymbolDependency; import software.amazon.smithy.codegen.core.WriterDelegator; import software.amazon.smithy.model.traits.DocumentationTrait; @@ -30,17 +36,70 @@ @SmithyInternalApi public final class SetupGenerator { - private SetupGenerator() {} + private SetupGenerator() { + } public static void generateSetup( PythonSettings settings, GenerationContext context ) { - var dependencies = SymbolDependency.gatherDependencies(context.writerDelegator().getDependencies().stream()); + var dependencies = gatherDependencies(context.writerDelegator().getDependencies().stream()); writePyproject(settings, context.writerDelegator(), dependencies); writeReadme(settings, context); } + /** + * Merge all the symbol dependencies. Also merges optional dependencies. + * Modification of : SymbolDependency.gatherDependencies that also considers the OPTIONAL_DEPENDENCIES + * property. + */ + @SuppressWarnings("unchecked") + private static Map> gatherDependencies( + Stream symbolStream + ) { + BinaryOperator guardedMergeWithProperties = (a, b) -> { + if (!a.getVersion().equals(b.getVersion())) { + throw new CodegenException(String.format( + "Found a conflicting `%s` dependency for `%s`: `%s` conflicts with `%s`", + a.getDependencyType(), + a.getPackageName(), + a.getVersion(), + b.getVersion())); + } + // For our purposes, we need only consider OPTIONAL_DEPENDENCIES property. + // The only other property currently used is IS_LINK, and it is consistent across all usages of + // a given SymbolDependency. + if (!b.getTypedProperties().isEmpty()) { + var optional_a = a.getProperty(SymbolProperties.OPTIONAL_DEPENDENCIES).orElse(List.of()); + var optional_b = b.getProperty(SymbolProperties.OPTIONAL_DEPENDENCIES).orElse(List.of()); + + if (optional_b.isEmpty()) { + return a; + } + + if (optional_a.isEmpty()) { + return b; + } + + var merged = Stream.concat(optional_a.stream(), optional_b.stream()) + .distinct() + .toList(); + + return a.toBuilder() + .putProperty(SymbolProperties.OPTIONAL_DEPENDENCIES, merged) + .build(); + } else { + return a; + } + }; + return symbolStream.sorted() + .collect(Collectors.groupingBy(SymbolDependency::getDependencyType, + Collectors.toMap(SymbolDependency::getPackageName, + Function.identity(), + guardedMergeWithProperties, + TreeMap::new))); + } + /** * Write a pyproject.toml file. * @@ -64,7 +123,7 @@ private static void writePyproject( [build-system] requires = ["setuptools", "setuptools-scm", "wheel"] build-backend = "setuptools.build_meta" - + [project] name = $1S version = $2S @@ -100,7 +159,7 @@ private static void writePyproject( writer.write(""" [tool.setuptools.packages.find] exclude=["tests*"] - + [tool.pyright] typeCheckingMode = "strict" reportPrivateUsage = false @@ -108,10 +167,10 @@ private static void writePyproject( reportUnusedVariable = false reportUnnecessaryComparison = false reportUnusedClass = false - + [tool.black] target-version = ["py311"] - + [tool.pytest.ini_options] python_classes = ["!Test"] asyncio_mode = "auto" @@ -122,17 +181,17 @@ private static void writePyproject( } private static void writeDependencyList(PythonWriter writer, Collection dependencies) { - for (var iter = dependencies.iterator(); iter.hasNext();) { + for (var iter = dependencies.iterator(); iter.hasNext(); ) { writer.pushState(); var dependency = iter.next(); writer.putContext("deps", getOptionalDependencies(dependency)); writer.putContext("isLink", dependency.getProperty(SymbolProperties.IS_LINK).orElse(false)); writer.putContext("last", !iter.hasNext()); writer.write(""" - "$L\ - ${?deps}[${#deps}${value:L}${^key.last}, ${/key.last}${/deps}]${/deps}\ - ${?isLink} @ ${/isLink}$L"\ - ${^last},${/last}""", + "$L\ + ${?deps}[${#deps}${value:L}${^key.last}, ${/key.last}${/deps}]${/deps}\ + ${?isLink} @ ${/isLink}$L"\ + ${^last},${/last}""", dependency.getPackageName(), dependency.getVersion()); writer.popState(); @@ -152,7 +211,7 @@ private static List getOptionalDependencies(SymbolDependency dependency) }) .orElse(Collections.emptyList()); try { - return (List) optionals; + return optionals; } catch (Exception e) { return Collections.emptyList(); } @@ -177,7 +236,7 @@ private static void writeReadme( writer.pushState(new ReadmeSection()); writer.write(""" ## $L Client - + $L """, title, description); @@ -190,7 +249,7 @@ private static void writeReadme( // since the python code docs are RST format. writer.write(""" ### Documentation - + $L """, documentation); });