diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 9d506c7e7..9f35ea781 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -60,12 +60,6 @@ private void generateService(PythonWriter writer) { var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings()); writer.addLogger(); - writer.addStdlibImport("typing", "TypeVar"); - writer.write(""" - Input = TypeVar("Input") - Output = TypeVar("Output") - """); - writer.openBlock("class $L:", "", serviceSymbol.getName(), () -> { var docs = service.getTrait(DocumentationTrait.class) .map(StringTrait::getValue) @@ -144,12 +138,22 @@ private void generateOperationExecutor(PythonWriter writer) { writer.addStdlibImport("copy", "deepcopy"); writer.addStdlibImport("asyncio"); writer.addStdlibImports("asyncio", Set.of("sleep", "Future")); + writer.addStdlibImport("dataclasses", "replace"); writer.addDependency(SmithyPythonDependency.SMITHY_CORE); writer.addImport("smithy_core.exceptions", "SmithyRetryException"); - writer.addImports("smithy_core.interceptors", Set.of("Interceptor", "InterceptorContext")); + writer.addImports("smithy_core.interceptors", + Set.of("Interceptor", + "InterceptorChain", + "InputContext", + "OutputContext", + "RequestContext", + "ResponseContext")); writer.addImports("smithy_core.interfaces.retries", Set.of("RetryErrorInfo", "RetryErrorType")); writer.addImport("smithy_core.interfaces.exceptions", "HasFault"); + writer.addImport("smithy_core.types", "TypedProperties"); + writer.addImport("smithy_core.serializers", "SerializeableShape"); + writer.addImport("smithy_core.deserializers", "DeserializeableShape"); writer.indent(); writer.write(""" @@ -157,7 +161,7 @@ def _classify_error( self, *, error: Exception, - context: InterceptorContext[Input, Output, $1T, $2T | None] + context: ResponseContext[Any, $1T, $2T | None] ) -> RetryErrorInfo: logger.debug("Classifying error: %s", error) """, transportRequest, transportResponse); @@ -198,7 +202,7 @@ def _classify_error( writer.addStdlibImport("asyncio"); writer.write( """ - async def _input_stream( + async def _input_stream[Input: SerializeableShape, Output: DeserializeableShape]( self, input: Input, plugins: list[$1T], @@ -207,7 +211,7 @@ async def _input_stream( config: $4T, operation_name: str, ) -> Any: - request_future = Future[InterceptorContext[Any, Any, $2T, Any]]() + request_future = Future[RequestContext[Any, $2T]]() awaitable_output = asyncio.create_task(self._execute_operation( input, plugins, serialize, deserialize, config, operation_name, request_future=request_future @@ -215,7 +219,7 @@ async def _input_stream( request_context = await request_future ${5C|} - async def _output_stream( + async def _output_stream[Input: SerializeableShape, Output: DeserializeableShape]( self, input: Input, plugins: list[$1T], @@ -233,7 +237,7 @@ async def _output_stream( transport_response = await response_future ${6C|} - async def _duplex_stream( + async def _duplex_stream[Input: SerializeableShape, Output: DeserializeableShape]( self, input: Input, plugins: list[$1T], @@ -243,7 +247,7 @@ async def _duplex_stream( operation_name: str, event_deserializer: Callable[[ShapeDeserializer], Any], ) -> Any: - request_future = Future[InterceptorContext[Any, Any, $2T, Any]]() + request_future = Future[RequestContext[Any, $2T]]() response_future = Future[$3T]() awaitable_output = asyncio.create_task(self._execute_operation( input, plugins, serialize, deserialize, config, operation_name, @@ -262,9 +266,10 @@ async def _duplex_stream( writer.consumer(w -> context.protocolGenerator().wrapDuplexStream(context, w))); } writer.addStdlibImport("typing", "Any"); + writer.addStdlibImport("asyncio", "iscoroutine"); writer.write( """ - async def _execute_operation( + async def _execute_operation[Input: SerializeableShape, Output: DeserializeableShape]( self, input: Input, plugins: list[$1T], @@ -272,7 +277,7 @@ async def _execute_operation( deserialize: Callable[[$3T, $5T], Awaitable[Output]], config: $5T, operation_name: str, - request_future: Future[InterceptorContext[Any, Any, $2T, Any]] | None = None, + request_future: Future[RequestContext[Any, $2T]] | None = None, response_future: Future[$3T] | None = None, ) -> Output: try: @@ -292,7 +297,7 @@ async def _execute_operation( raise $4T(e) from e raise - async def _handle_execution( + async def _handle_execution[Input: SerializeableShape, Output: DeserializeableShape]( self, input: Input, plugins: list[$1T], @@ -300,154 +305,128 @@ async def _handle_execution( deserialize: Callable[[$3T, $5T], Awaitable[Output]], config: $5T, operation_name: str, - request_future: Future[InterceptorContext[Any, Any, $2T, Any]] | None, + request_future: Future[RequestContext[Any, $2T]] | None, response_future: Future[$3T] | None, ) -> Output: logger.debug('Making request for operation "%s" with parameters: %s', operation_name, input) - context: InterceptorContext[Input, None, None, None] = InterceptorContext( - request=input, - response=None, - transport_request=None, - transport_response=None, - ) - _client_interceptors = config.interceptors + config = deepcopy(config) + for plugin in plugins: + plugin(config) + + input_context = InputContext(request=input, properties=TypedProperties()) + transport_request: $2T | None = None + output_context: OutputContext[Input, Output, $2T | None, $3T | None] | None = None + client_interceptors = cast( - list[Interceptor[Input, Output, $2T, $3T]], _client_interceptors + list[Interceptor[Input, Output, $2T, $3T]], list(config.interceptors) ) - interceptors = client_interceptors + interceptor_chain = InterceptorChain(client_interceptors) try: - # Step 1a: Invoke read_before_execution on client-level interceptors - for interceptor in client_interceptors: - interceptor.read_before_execution(context) - - # Step 1b: Run operation-level plugins - config = deepcopy(config) - for plugin in plugins: - plugin(config) - - _client_interceptors = config.interceptors - interceptors = cast( - list[Interceptor[Input, Output, $2T, $3T]], - _client_interceptors, - ) - - # Step 1c: Invoke the read_before_execution hooks on newly added - # interceptors. - for interceptor in interceptors: - if interceptor not in client_interceptors: - interceptor.read_before_execution(context) + # Step 1: Invoke read_before_execution + interceptor_chain.read_before_execution(input_context) # Step 2: Invoke the modify_before_serialization hooks - for interceptor in interceptors: - context._request = interceptor.modify_before_serialization(context) + input_context = replace( + input_context, + request=interceptor_chain.modify_before_serialization(input_context) + ) # Step 3: Invoke the read_before_serialization hooks - for interceptor in interceptors: - interceptor.read_before_serialization(context) + interceptor_chain.read_before_serialization(input_context) # Step 4: Serialize the request - context_with_transport_request = cast( - InterceptorContext[Input, None, $2T, None], context + logger.debug("Serializing request for: %s", input_context.request) + transport_request = await serialize(input_context.request, config) + request_context = RequestContext( + request=input_context.request, + transport_request=transport_request, + properties=input_context.properties, ) - logger.debug("Serializing request for: %s", context_with_transport_request.request) - context_with_transport_request._transport_request = await serialize( - context_with_transport_request.request, config - ) - logger.debug("Serialization complete. Transport request: %s", context_with_transport_request._transport_request) + logger.debug("Serialization complete. Transport request: %s", request_context.transport_request) # Step 5: Invoke read_after_serialization - for interceptor in interceptors: - interceptor.read_after_serialization(context_with_transport_request) + interceptor_chain.read_after_serialization(request_context) # Step 6: Invoke modify_before_retry_loop - for interceptor in interceptors: - context_with_transport_request._transport_request = ( - interceptor.modify_before_retry_loop(context_with_transport_request) - ) + request_context = replace( + request_context, + transport_request=interceptor_chain.modify_before_retry_loop(request_context) + ) # Step 7: Acquire the retry token. retry_strategy = config.retry_strategy retry_token = retry_strategy.acquire_initial_retry_token() while True: - # Make an attempt, creating a copy of the context so we don't pass - # around old data. - context_with_response = await self._handle_attempt( + # Make an attempt + output_context = await self._handle_attempt( deserialize, - interceptors, - context_with_transport_request.copy(), + interceptor_chain, + request_context, config, operation_name, request_future, ) - # We perform this type-ignored re-assignment because `context` needs - # to point at the latest context so it can be generically handled - # later on. This is only an issue here because we've created a copy, - # so we're no longer simply pointing at the same object in memory - # with different names and type hints. It is possible to address this - # without having to fall back to the type ignore, but it would impose - # unnecessary runtime costs. - context = context_with_response # type: ignore - - if isinstance(context_with_response.response, Exception): + if isinstance(output_context.response, Exception): # Step 7u: Reacquire retry token if the attempt failed try: retry_token = retry_strategy.refresh_retry_token_for_retry( token_to_renew=retry_token, error_info=self._classify_error( - error=context_with_response.response, - context=context_with_response, + error=output_context.response, + context=output_context, ) ) except SmithyRetryException: - raise context_with_response.response + raise output_context.response logger.debug( "Retry needed. Attempting request #%s in %.4f seconds.", retry_token.retry_count + 1, retry_token.retry_delay ) await sleep(retry_token.retry_delay) - current_body = context_with_transport_request.transport_request.body + current_body = output_context.transport_request.body if (seek := getattr(current_body, "seek", None)) is not None: - await seek(0) + if iscoroutine((result := seek(0))): + await result else: # Step 8: Invoke record_success retry_strategy.record_success(token=retry_token) if response_future is not None: response_future.set_result( - context_with_response.transport_response # type: ignore + output_context.transport_response # type: ignore ) break except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e - - # At this point, the context's request will have been definitively set, and - # The response will be set either with the modeled output or an exception. The - # transport_request and transport_response may be set or None. - execution_context = cast( - InterceptorContext[Input, Output, $2T | None, $3T | None], context - ) - return await self._finalize_execution(interceptors, execution_context) + if output_context is not None: + logger.exception("Exception occurred while handling: %s", output_context.response) + output_context = replace(output_context, response=e) + else: + output_context = OutputContext( + request=input_context.request, + response=e, + transport_request=transport_request, + transport_response=None, + properties=input_context.properties + ) + + return await self._finalize_execution(interceptor_chain, output_context) - async def _handle_attempt( + async def _handle_attempt[Input: SerializeableShape, Output: DeserializeableShape]( self, deserialize: Callable[[$3T, $5T], Awaitable[Output]], - interceptors: list[Interceptor[Input, Output, $2T, $3T]], - context: InterceptorContext[Input, None, $2T, None], + interceptor: Interceptor[Input, Output, $2T, $3T], + context: RequestContext[Input, $2T], config: $5T, operation_name: str, - request_future: Future[InterceptorContext[Any, Any, $2T, Any]] | None, - ) -> InterceptorContext[Input, Output, $2T, $3T | None]: + request_future: Future[RequestContext[Input, $2T]] | None, + ) -> OutputContext[Input, Output, $2T, $3T | None]: + transport_response: $3T | None = None try: - # assert config.interceptors is not None # Step 7a: Invoke read_before_attempt - for interceptor in interceptors: - interceptor.read_before_attempt(context) + interceptor.read_before_attempt(context) """, pluginSymbol, @@ -529,14 +508,14 @@ async def _handle_attempt( path = endpoint.uri.path if context.transport_request.destination.path: path += context.transport_request.destination.path - context._transport_request.destination = URI( + context.transport_request.destination = URI( scheme=endpoint.uri.scheme, host=context.transport_request.destination.host + endpoint.uri.host, path=path, port=endpoint.uri.port, query=context.transport_request.destination.query, ) - context._transport_request.fields.extend(endpoint.headers) + context.transport_request.fields.extend(endpoint.headers) """, CodegenUtils.getEndpointParametersSymbol(context.settings())); @@ -545,12 +524,13 @@ async def _handle_attempt( writer.write(""" # Step 7g: Invoke modify_before_signing - for interceptor in interceptors: - context._transport_request = interceptor.modify_before_signing(context) + context = replace( + context, + transport_request=interceptor.modify_before_signing(context) + ) # Step 7h: Invoke read_before_signing - for interceptor in interceptors: - interceptor.read_before_signing(context) + interceptor.read_before_signing(context) """); @@ -564,28 +544,31 @@ async def _handle_attempt( "Signer properties: %s", auth_option.signer_properties ) - context._transport_request = await signer.sign( - http_request=context.transport_request, - identity=identity, - signing_properties=auth_option.signer_properties, + context = replace( + context, + transport_request= await signer.sign( + http_request=context.transport_request, + identity=identity, + signing_properties=auth_option.signer_properties, + ) ) - logger.debug("Signed HTTP request: %s", context._transport_request) + logger.debug("Signed HTTP request: %s", context.transport_request) """); } writer.popState(); writer.write(""" # Step 7j: Invoke read_after_signing - for interceptor in interceptors: - interceptor.read_after_signing(context) + interceptor.read_after_signing(context) # Step 7k: Invoke modify_before_transmit - for interceptor in interceptors: - context._transport_request = interceptor.modify_before_transmit(context) + context = replace( + context, + transport_request=interceptor.modify_before_transmit(context) + ) # Step 7l: Invoke read_before_transmit - for interceptor in interceptors: - interceptor.read_before_transmit(context) + interceptor.read_before_transmit(context) """); @@ -596,112 +579,108 @@ async def _handle_attempt( writer.write(""" # Step 7m: Invoke http_client.send request_config = config.http_request_config or HTTPRequestConfiguration() - context_with_response = cast( - InterceptorContext[Input, None, $1T, $2T], context - ) logger.debug("HTTP request config: %s", request_config) - logger.debug("Sending HTTP request: %s", context_with_response.transport_request) + logger.debug("Sending HTTP request: %s", context.transport_request) if request_future is not None: response_task = asyncio.create_task(config.http_client.send( - request=context_with_response.transport_request, + request=context.transport_request, request_config=request_config, )) - request_future.set_result(context_with_response) - context_with_response._transport_response = await response_task + request_future.set_result(context) + transport_response = await response_task else: - context_with_response._transport_response = await config.http_client.send( - request=context_with_response.transport_request, + transport_response = await config.http_client.send( + request=context.transport_request, request_config=request_config, ) - logger.debug("Received HTTP response: %s", context_with_response.transport_response) - """, transportRequest, transportResponse); + response_context = ResponseContext( + request=context.request, + transport_request=context.transport_request, + transport_response=transport_response, + properties=context.properties + ) + logger.debug("Received HTTP response: %s", response_context.transport_response) + + """); } writer.popState(); writer.write(""" # Step 7n: Invoke read_after_transmit - for interceptor in interceptors: - interceptor.read_after_transmit(context_with_response) + interceptor.read_after_transmit(response_context) # Step 7o: Invoke modify_before_deserialization - for interceptor in interceptors: - context_with_response._transport_response = ( - interceptor.modify_before_deserialization(context_with_response) - ) + response_context = replace( + response_context, + transport_response=interceptor.modify_before_deserialization(response_context) + ) # Step 7p: Invoke read_before_deserialization - for interceptor in interceptors: - interceptor.read_before_deserialization(context_with_response) + interceptor.read_before_deserialization(response_context) # Step 7q: deserialize - context_with_output = cast( - InterceptorContext[Input, Output, $1T, $2T], - context_with_response, + logger.debug("Deserializing transport response: %s", response_context.transport_response) + output = await deserialize( + response_context.transport_response, config ) - logger.debug("Deserializing transport response: %s", context_with_output._transport_response) - context_with_output._response = await deserialize( - context_with_output._transport_response, config + output_context = OutputContext( + request=response_context.request, + response=output, + transport_request=response_context.transport_request, + transport_response=response_context.transport_response, + properties=response_context.properties ) - logger.debug("Deserialization complete. Response: %s", context_with_output._response) + logger.debug("Deserialization complete. Response: %s", output_context.response) # Step 7r: Invoke read_after_deserialization - for interceptor in interceptors: - interceptor.read_after_deserialization(context_with_output) + interceptor.read_after_deserialization(output_context) except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e - - # At this point, the context's request and transport_request have definitively been set, - # the response is either set or an exception, and the transport_resposne is either set or - # None. This will also be true after _finalize_attempt because there is no opportunity - # there to set the transport_response. - attempt_context = cast( - InterceptorContext[Input, Output, $1T, $2T | None], context - ) - return await self._finalize_attempt(interceptors, attempt_context) - - async def _finalize_attempt( + output_context: OutputContext[Input, Output, $1T, $2T] = OutputContext( + request=context.request, + response=e, # type: ignore + transport_request=context.transport_request, + transport_response=transport_response, + properties=context.properties + ) + + return await self._finalize_attempt(interceptor, output_context) + + async def _finalize_attempt[Input: SerializeableShape, Output: DeserializeableShape]( self, - interceptors: list[Interceptor[Input, Output, $1T, $2T]], - context: InterceptorContext[Input, Output, $1T, $2T | None], - ) -> InterceptorContext[Input, Output, $1T, $2T | None]: + interceptor: Interceptor[Input, Output, $1T, $2T], + context: OutputContext[Input, Output, $1T, $2T | None], + ) -> OutputContext[Input, Output, $1T, $2T | None]: # Step 7s: Invoke modify_before_attempt_completion try: - for interceptor in interceptors: - context._response = interceptor.modify_before_attempt_completion( - context - ) + context = replace( + context, + response=interceptor.modify_before_attempt_completion(context) + ) except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e + logger.exception("Exception occurred while handling: %s", context.response) + context = replace(context, response=e) # Step 7t: Invoke read_after_attempt - for interceptor in interceptors: - try: - interceptor.read_after_attempt(context) - except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e + try: + interceptor.read_after_attempt(context) + except Exception as e: + context = replace(context, response=e) return context - async def _finalize_execution( + async def _finalize_execution[Input: SerializeableShape, Output: DeserializeableShape]( self, - interceptors: list[Interceptor[Input, Output, $1T, $2T]], - context: InterceptorContext[Input, Output, $1T | None, $2T | None], + interceptor: Interceptor[Input, Output, $1T, $2T], + context: OutputContext[Input, Output, $1T | None, $2T | None], ) -> Output: try: # Step 9: Invoke modify_before_completion - for interceptor in interceptors: - context._response = interceptor.modify_before_completion(context) + context = replace( + context, + response=interceptor.modify_before_completion(context) + ) # Step 10: Invoke trace_probe.dispatch_events try: @@ -711,20 +690,14 @@ async def _finalize_execution( logger.exception("Exception occurred while dispatching trace events: %s", e) pass except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e + logger.exception("Exception occurred while handling: %s", context.response) + context = replace(context, response=e) # Step 11: Invoke read_after_execution - for interceptor in interceptors: - try: - interceptor.read_after_execution(context) - except Exception as e: - if context.response is not None: - logger.exception("Exception occurred while handling: %s", context.response) - pass - context._response = e + try: + interceptor.read_after_execution(context) + except Exception as e: + context = replace(context, response=e) # Step 12: Return / throw if isinstance(context.response, Exception): diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java index d04ebf82c..c62b425be 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java @@ -68,8 +68,7 @@ public void run() { inSymbol.expectProperty(SymbolProperties.SCHEMA), outSymbol.expectProperty(SymbolProperties.SCHEMA), writer.consumer(this::writeErrorTypeRegistry), - writer.consumer(this::writeAuthSchemes) - ); + writer.consumer(this::writeAuthSchemes)); } private void writeErrorTypeRegistry(PythonWriter writer) { diff --git a/packages/smithy-aws-core/src/smithy_aws_core/interceptors/user_agent.py b/packages/smithy-aws-core/src/smithy_aws_core/interceptors/user_agent.py index 3e5cb388a..c7acc281b 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/interceptors/user_agent.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/interceptors/user_agent.py @@ -5,7 +5,7 @@ import smithy_aws_core import smithy_core -from smithy_core.interceptors import Interceptor, InterceptorContext +from smithy_core.interceptors import Interceptor, RequestContext from smithy_http.interceptors.user_agent import USER_AGENT from smithy_http.user_agent import UserAgentComponent, RawStringUserAgentComponent @@ -37,9 +37,7 @@ def __init__( self._sdk_version = sdk_version self._service_id = service_id - def read_after_serialization( - self, context: InterceptorContext[Any, Any, Any, Any] - ) -> None: + def read_after_serialization(self, context: RequestContext[Any, Any]) -> None: if USER_AGENT in context.properties: user_agent = context.properties[USER_AGENT] user_agent.sdk_metadata = self._build_sdk_metadata() diff --git a/packages/smithy-core/src/smithy_core/interceptors.py b/packages/smithy-core/src/smithy_core/interceptors.py index 850d1fa8d..2c40c5476 100644 --- a/packages/smithy-core/src/smithy_core/interceptors.py +++ b/packages/smithy-core/src/smithy_core/interceptors.py @@ -1,119 +1,55 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from copy import copy, deepcopy -from typing import TypeVar +from dataclasses import dataclass, replace +from typing import Any, Sequence -from .types import TypedProperties +from .interfaces import TypedProperties +from .serializers import SerializeableShape +from .deserializers import DeserializeableShape -Request = TypeVar("Request") -Response = TypeVar("Response") -TransportRequest = TypeVar("TransportRequest") -TransportResponse = TypeVar("TransportResponse") +@dataclass(kw_only=True, frozen=True, slots=True) +class InputContext[Request: SerializeableShape]: + request: Request + """The modeled request for the operation being invoked.""" -class InterceptorContext[Request, Response, TransportRequest, TransportResponse]: - def __init__( - self, - *, - request: Request, - response: Response | Exception, - transport_request: TransportRequest, - transport_response: TransportResponse, - ): - """A container for the current data available to an interceptor. - - :param request: The modeled request for the operation being invoked. - :param response: The modeled response for the operation being invoked. This will - only be available once the transport_response has been deserialized or the - attempt/execution has failed. - :param transport_request: The transmittable request for the operation being - invoked. This will only be available once request serialization has - completed. - :param transport_response: The transmitted response for the operation being - invoked. This will only be available once transmission has completed. - """ - self._request = request - self._response = response - self._transport_request = transport_request - self._transport_response = transport_response - self._properties = TypedProperties() - - @property - def request(self) -> Request: - """Retrieve the modeled request for the operation being invoked.""" - return self._request - - @property - def response(self) -> Response | Exception: - """Retrieve the modeled response for the operation being invoked. - - This will only be available once the transport_response has been deserialized or - the attempt/execution has failed. - """ - return self._response - - # Note that TransportRequest (and TransportResponse below) aren't resolved types, - # but rather TypeVars. This is very important, because in the actual Interceptor - # interface class these are sometimes typed as None rather than, say, HTTPRequest. - # That lets us use the type system to tell people when something will be set and - # when it will not be set without leaking nullability into the cases where the - # property will ALWAYS be set. - @property - def transport_request(self) -> TransportRequest: - """Retrieve the transmittable request for the operation being invoked. - - This will only be available once request serialization has completed. - """ - return self._transport_request + properties: TypedProperties + """A typed context property bag.""" - @property - def transport_response(self) -> TransportResponse: - """Retrieve the transmitted response for the operation being invoked. - This will only be available once transmission has completed. - """ - return self._transport_response +@dataclass(kw_only=True, frozen=True, slots=True) +class RequestContext[Request: SerializeableShape, TransportRequest]( + InputContext[Request] +): + transport_request: TransportRequest + """The transmittable request for the operation being invoked.""" - @property - def properties(self) -> TypedProperties: - """Retrieve the generic property bag. - These untyped properties will be made available to all other interceptors or - hooks that are called for this execution. - """ - return self._properties +@dataclass(kw_only=True, frozen=True, slots=True) +class ResponseContext[Request: SerializeableShape, TransportRequest, TransportResponse]( + RequestContext[Request, TransportRequest] +): + transport_response: TransportResponse + """The transmitted response for the operation being invoked.""" - # The static properties of this class are made 'read-only' like this to discourage - # people from trying to modify the context outside of the specific hooks where that - # is allowed. - def copy( - self, - *, - request: Request | None = None, - response: Response | Exception | None = None, - transport_request: TransportRequest | None = None, - transport_response: TransportResponse | None = None, - ) -> "InterceptorContext[Request, Response, TransportRequest, TransportResponse]": - """Copy the context object, optionally overriding certain properties.""" - if transport_request is None: - transport_request = copy(self._transport_request) - - if transport_response is None: - transport_response = copy(self._transport_response) - - context: InterceptorContext[ - Request, Response, TransportRequest, TransportResponse - ] = InterceptorContext( - request=request if request is not None else self._request, - response=response if response is not None else self._response, - transport_request=transport_request, - transport_response=transport_response, - ) - context._properties = deepcopy(self._properties) - return context - - -class Interceptor[Request, Response, TransportRequest, TransportResponse]: + +@dataclass(kw_only=True, frozen=True, slots=True) +class OutputContext[ + Request: SerializeableShape, + Response: DeserializeableShape, + TransportRequest, + TransportResponse, +](ResponseContext[Request, TransportRequest, TransportResponse]): + response: Response | Exception + """The modeled response for the operation being invoked.""" + + +class Interceptor[ + Request: SerializeableShape, + Response: DeserializeableShape, + TransportRequest, + TransportResponse, +]: """Allows injecting code into the SDK's request execution pipeline. Terminology: @@ -128,9 +64,7 @@ class Interceptor[Request, Response, TransportRequest, TransportResponse]: requests or responses. """ - def read_before_execution( - self, context: InterceptorContext[Request, None, None, None] - ) -> None: + def read_before_execution(self, context: InputContext[Request]) -> None: """A hook called at the start of an execution, before the SDK does anything else. @@ -151,9 +85,7 @@ def read_before_execution( the latest will be used and earlier ones will be logged and dropped. """ - def modify_before_serialization( - self, context: InterceptorContext[Request, None, None, None] - ) -> Request: + def modify_before_serialization(self, context: InputContext[Request]) -> Request: """A hook called before the request is serialized into a transport request. This method has the ability to modify and return a new request of the same @@ -174,9 +106,7 @@ def modify_before_serialization( """ return context.request - def read_before_serialization( - self, context: InterceptorContext[Request, None, None, None] - ) -> None: + def read_before_serialization(self, context: InputContext[Request]) -> None: """A hook called before the input message is serialized into a transport request. @@ -196,7 +126,7 @@ def read_before_serialization( """ def read_after_serialization( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, context: RequestContext[Request, TransportRequest] ) -> None: """A hook called after the input message is serialized into a transport request. @@ -216,7 +146,7 @@ def read_after_serialization( """ def modify_before_retry_loop( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, context: RequestContext[Request, TransportRequest] ) -> TransportRequest: """A hook called before the retry loop is entered. @@ -235,7 +165,7 @@ def modify_before_retry_loop( return context.transport_request def read_before_attempt( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, context: RequestContext[Request, TransportRequest] ) -> None: """A hook called before each attempt at sending the transport request to the service. @@ -260,7 +190,7 @@ def read_before_attempt( """ def modify_before_signing( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, context: RequestContext[Request, TransportRequest] ) -> TransportRequest: """A hook called before the transport request is signed. @@ -286,7 +216,7 @@ def modify_before_signing( return context.transport_request def read_before_signing( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, context: RequestContext[Request, TransportRequest] ) -> None: """A hook called before the transport request is signed. @@ -309,7 +239,7 @@ def read_before_signing( """ def read_after_signing( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, context: RequestContext[Request, TransportRequest] ) -> None: """A hook called after the transport request is signed. @@ -332,7 +262,7 @@ def read_after_signing( """ def modify_before_transmit( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, context: RequestContext[Request, TransportRequest] ) -> TransportRequest: """A hook called before the transport request is sent to the service. @@ -358,7 +288,7 @@ def modify_before_transmit( return context.transport_request def read_before_transmit( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, context: RequestContext[Request, TransportRequest] ) -> None: """A hook called before the transport request is sent to the service. @@ -383,7 +313,7 @@ def read_before_transmit( def read_after_transmit( self, - context: InterceptorContext[Request, None, TransportRequest, TransportResponse], + context: ResponseContext[Request, TransportRequest, TransportResponse], ) -> None: """A hook called after the transport request is sent to the service and a transport response is received. @@ -409,7 +339,7 @@ def read_after_transmit( def modify_before_deserialization( self, - context: InterceptorContext[Request, None, TransportRequest, TransportResponse], + context: ResponseContext[Request, TransportRequest, TransportResponse], ) -> TransportResponse: """A hook called before the transport response is deserialized. @@ -439,7 +369,7 @@ def modify_before_deserialization( def read_before_deserialization( self, - context: InterceptorContext[Request, None, TransportRequest, TransportResponse], + context: ResponseContext[Request, TransportRequest, TransportResponse], ) -> None: """A hook called before the transport response is deserialized. @@ -464,9 +394,7 @@ def read_before_deserialization( def read_after_deserialization( self, - context: InterceptorContext[ - Request, Response, TransportRequest, TransportResponse - ], + context: OutputContext[Request, Response, TransportRequest, TransportResponse], ) -> None: """A hook called after the transport response is deserialized. @@ -491,7 +419,7 @@ def read_after_deserialization( def modify_before_attempt_completion( self, - context: InterceptorContext[ + context: OutputContext[ Request, Response, TransportRequest, TransportResponse | None ], ) -> Response | Exception: @@ -521,7 +449,7 @@ def modify_before_attempt_completion( def read_after_attempt( self, - context: InterceptorContext[ + context: OutputContext[ Request, Response, TransportRequest, TransportResponse | None ], ) -> None: @@ -549,7 +477,7 @@ def read_after_attempt( def modify_before_completion( self, - context: InterceptorContext[ + context: OutputContext[ Request, Response, TransportRequest | None, TransportResponse | None ], ) -> Response | Exception: @@ -574,7 +502,7 @@ def modify_before_completion( def read_after_execution( self, - context: InterceptorContext[ + context: OutputContext[ Request, Response, TransportRequest | None, TransportResponse | None ], ) -> None: @@ -596,3 +524,135 @@ def read_after_execution( final response. If multiple `read_after_execution` methods throw exceptions, the latest will be used and earlier ones will be logged and dropped. """ + + +AnyInterceptor = Interceptor[Any, Any, Any, Any] + + +class InterceptorChain(AnyInterceptor): + """An interceptor that contains an ordered list of delegate interceptors. + + This is primarily intended for use within the client itself. + """ + + def __init__(self, chain: Sequence[AnyInterceptor]) -> None: + """Construct an InterceptorChain. + + :param chain: The ordered interceptors to chain together. + """ + self._chain = tuple(chain) + + def __repr__(self) -> str: + return f"InterceptorChain(chain={self._chain!r})" + + def read_before_execution(self, context: InputContext[Any]) -> None: + for interceptor in self._chain: + interceptor.read_before_execution(context) + + def modify_before_serialization(self, context: InputContext[Any]) -> Any: + request = context.request + for interceptor in self._chain: + request = interceptor.modify_before_serialization(context) + return request + + def read_before_serialization(self, context: InputContext[Any]) -> None: + for interceptor in self._chain: + interceptor.read_before_serialization(context) + + def read_after_serialization(self, context: RequestContext[Any, Any]) -> None: + for interceptor in self._chain: + interceptor.read_after_serialization(context) + + def modify_before_retry_loop(self, context: RequestContext[Any, Any]) -> Any: + transport_request = context.transport_request + for interceptor in self._chain: + transport_request = interceptor.modify_before_retry_loop(context) + return transport_request + + def read_before_attempt(self, context: RequestContext[Any, Any]) -> None: + for interceptor in self._chain: + interceptor.read_before_attempt(context) + + def modify_before_signing(self, context: RequestContext[Any, Any]) -> Any: + transport_request = context.transport_request + for interceptor in self._chain: + transport_request = interceptor.modify_before_retry_loop(context) + return transport_request + + def read_before_signing(self, context: RequestContext[Any, Any]) -> None: + for interceptor in self._chain: + interceptor.read_before_signing(context) + + def read_after_signing(self, context: RequestContext[Any, Any]) -> None: + for interceptor in self._chain: + interceptor.read_after_signing(context) + + def modify_before_transmit(self, context: RequestContext[Any, Any]) -> Any: + transport_request = context.transport_request + for interceptor in self._chain: + transport_request = interceptor.modify_before_retry_loop(context) + return transport_request + + def read_before_transmit(self, context: RequestContext[Any, Any]) -> None: + for interceptor in self._chain: + interceptor.read_before_transmit(context) + + def read_after_transmit(self, context: ResponseContext[Any, Any, Any]) -> None: + for interceptor in self._chain: + interceptor.read_after_transmit(context) + + def modify_before_deserialization( + self, context: ResponseContext[Any, Any, Any] + ) -> Any: + transport_response = context.transport_response + for interceptor in self._chain: + transport_response = interceptor.modify_before_deserialization(context) + return transport_response + + def read_before_deserialization( + self, context: ResponseContext[Any, Any, Any] + ) -> None: + for interceptor in self._chain: + interceptor.read_before_deserialization(context) + + def read_after_deserialization( + self, context: OutputContext[Any, Any, Any, Any] + ) -> None: + for interceptor in self._chain: + interceptor.read_after_deserialization(context) + + def modify_before_attempt_completion( + self, context: OutputContext[Any, Any, Any, Any | None] + ) -> Any | Exception: + response = context.response + for interceptor in self._chain: + response = interceptor.modify_before_attempt_completion(context) + return response + + def read_after_attempt( + self, context: OutputContext[Any, Any, Any, Any | None] + ) -> None: + for interceptor in self._chain: + interceptor.read_after_attempt(context) + + def modify_before_completion( + self, context: OutputContext[Any, Any, Any | None, Any | None] + ) -> Any | Exception: + response = context.response + for interceptor in self._chain: + response = interceptor.modify_before_attempt_completion(context) + return response + + def read_after_execution( + self, context: OutputContext[Any, Any, Any | None, Any | None] + ) -> None: + exception: Exception | None = None + for interceptor in self._chain: + # Every one of these is supposed to be guaranteed to be called. + try: + interceptor.read_after_execution(context) + except Exception as e: + context = replace(context, response=e) + exception = e + if exception is not None: + raise exception diff --git a/packages/smithy-http/src/smithy_http/interceptors/user_agent.py b/packages/smithy-http/src/smithy_http/interceptors/user_agent.py index cd9fd1ea2..5690a7190 100644 --- a/packages/smithy-http/src/smithy_http/interceptors/user_agent.py +++ b/packages/smithy-http/src/smithy_http/interceptors/user_agent.py @@ -4,7 +4,7 @@ from typing import Self, Any import smithy_core -from smithy_core.interceptors import Interceptor, InterceptorContext +from smithy_core.interceptors import Interceptor, InputContext, RequestContext from smithy_core.types import PropertyKey from smithy_http import Field from smithy_http.aio.interfaces import HTTPRequest @@ -13,17 +13,15 @@ USER_AGENT = PropertyKey(key="user_agent", value_type=UserAgent) -class UserAgentInterceptor(Interceptor[Any, None, HTTPRequest, None]): +class UserAgentInterceptor(Interceptor[Any, Any, HTTPRequest, None]): """Adds interceptors that initialize UserAgent in the context and add the user-agent header.""" - def read_before_execution( - self, context: InterceptorContext[Any, None, None, None] - ) -> None: + def read_before_execution(self, context: InputContext[Any]) -> None: context.properties[USER_AGENT] = _UserAgentBuilder.from_environment().build() def modify_before_signing( - self, context: InterceptorContext[Any, None, HTTPRequest, None] + self, context: RequestContext[Any, HTTPRequest] ) -> HTTPRequest: user_agent = context.properties[USER_AGENT] request = context.transport_request