diff --git a/packages/smithy-core/src/smithy_core/interceptors.py b/packages/smithy-core/src/smithy_core/interceptors.py index 850d1fa8d..48f4acef2 100644 --- a/packages/smithy-core/src/smithy_core/interceptors.py +++ b/packages/smithy-core/src/smithy_core/interceptors.py @@ -1,116 +1,62 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from copy import copy, deepcopy -from typing import TypeVar +from dataclasses import dataclass, field +from typing import Any from .types import TypedProperties -Request = TypeVar("Request") -Response = TypeVar("Response") -TransportRequest = TypeVar("TransportRequest") -TransportResponse = TypeVar("TransportResponse") - +@dataclass(kw_only=True, slots=True) class InterceptorContext[Request, Response, TransportRequest, TransportResponse]: - def __init__( - self, - *, - request: Request, - response: Response | Exception, - transport_request: TransportRequest, - transport_response: TransportResponse, - ): - """A container for the current data available to an interceptor. - - :param request: The modeled request for the operation being invoked. - :param response: The modeled response for the operation being invoked. This will - only be available once the transport_response has been deserialized or the - attempt/execution has failed. - :param transport_request: The transmittable request for the operation being - invoked. This will only be available once request serialization has - completed. - :param transport_response: The transmitted response for the operation being - invoked. This will only be available once transmission has completed. - """ - self._request = request - self._response = response - self._transport_request = transport_request - self._transport_response = transport_response - self._properties = TypedProperties() - - @property - def request(self) -> Request: - """Retrieve the modeled request for the operation being invoked.""" - return self._request - - @property - def response(self) -> Response | Exception: - """Retrieve the modeled response for the operation being invoked. - - This will only be available once the transport_response has been deserialized or - the attempt/execution has failed. - """ - return self._response - - # Note that TransportRequest (and TransportResponse below) aren't resolved types, - # but rather TypeVars. This is very important, because in the actual Interceptor - # interface class these are sometimes typed as None rather than, say, HTTPRequest. - # That lets us use the type system to tell people when something will be set and - # when it will not be set without leaking nullability into the cases where the - # property will ALWAYS be set. - @property - def transport_request(self) -> TransportRequest: - """Retrieve the transmittable request for the operation being invoked. - - This will only be available once request serialization has completed. - """ - return self._transport_request + """A container for the current data available to an interceptor.""" - @property - def transport_response(self) -> TransportResponse: - """Retrieve the transmitted response for the operation being invoked. + request: Request + """The modeled request for the operation being invoked.""" - This will only be available once transmission has completed. - """ - return self._transport_response + response: Response | Exception | None + """The modeled response for the operation being invoked. - @property - def properties(self) -> TypedProperties: - """Retrieve the generic property bag. + This will only be available once the transport_response has been deserialized or the + attempt/execution has failed. + """ - These untyped properties will be made available to all other interceptors or - hooks that are called for this execution. - """ - return self._properties + transport_request: TransportRequest | None + """The transmittable request for the operation being invoked. - # The static properties of this class are made 'read-only' like this to discourage - # people from trying to modify the context outside of the specific hooks where that - # is allowed. - def copy( - self, - *, - request: Request | None = None, - response: Response | Exception | None = None, - transport_request: TransportRequest | None = None, - transport_response: TransportResponse | None = None, - ) -> "InterceptorContext[Request, Response, TransportRequest, TransportResponse]": - """Copy the context object, optionally overriding certain properties.""" - if transport_request is None: - transport_request = copy(self._transport_request) - - if transport_response is None: - transport_response = copy(self._transport_response) + This will only be available once request serialization has completed. + """ + + transport_response: TransportResponse | None + """The transmitted response for the operation being invoked. + + This will only be available once transmission has completed. + """ + + properties: TypedProperties = field(default_factory=TypedProperties) + """The generic property bag. + + These properties will be made available to all other interceptors or hooks that are + called for this execution. + """ + + def expect_response(self) -> Response | Exception: + """Assert the modeled response is available and return it.""" + assert self.response is not None + return self.response + + def expect_transport_request(self) -> TransportRequest: + """Assert the transport request is available and return it.""" + assert self.transport_request is not None + return self.transport_request + + def expect_transport_response(self) -> TransportResponse: + """Assert the transport response is available and return it.""" + assert self.transport_response is not None + return self.transport_response - context: InterceptorContext[ - Request, Response, TransportRequest, TransportResponse - ] = InterceptorContext( - request=request if request is not None else self._request, - response=response if response is not None else self._response, - transport_request=transport_request, - transport_response=transport_response, - ) - context._properties = deepcopy(self._properties) - return context + +type AnyInterceptorContext = InterceptorContext[Any, Any, Any, Any] +"""An InterceptorContext alias that accepts any parametric types.""" class Interceptor[Request, Response, TransportRequest, TransportResponse]: @@ -129,7 +75,10 @@ class Interceptor[Request, Response, TransportRequest, TransportResponse]: """ def read_before_execution( - self, context: InterceptorContext[Request, None, None, None] + self, + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> None: """A hook called at the start of an execution, before the SDK does anything else. @@ -152,7 +101,10 @@ def read_before_execution( """ def modify_before_serialization( - self, context: InterceptorContext[Request, None, None, None] + self, + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> Request: """A hook called before the request is serialized into a transport request. @@ -175,7 +127,10 @@ def modify_before_serialization( return context.request def read_before_serialization( - self, context: InterceptorContext[Request, None, None, None] + self, + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> None: """A hook called before the input message is serialized into a transport request. @@ -196,7 +151,10 @@ def read_before_serialization( """ def read_after_serialization( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> None: """A hook called after the input message is serialized into a transport request. @@ -216,7 +174,10 @@ def read_after_serialization( """ def modify_before_retry_loop( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> TransportRequest: """A hook called before the retry loop is entered. @@ -232,10 +193,13 @@ def modify_before_retry_loop( The transport request returned by this hook MUST be the same type of request passed into this hook. If not, an exception will immediately occur. """ - return context.transport_request + return context.expect_transport_request() def read_before_attempt( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> None: """A hook called before each attempt at sending the transport request to the service. @@ -260,7 +224,10 @@ def read_before_attempt( """ def modify_before_signing( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> TransportRequest: """A hook called before the transport request is signed. @@ -283,10 +250,13 @@ def modify_before_signing( The transport request returned by this hook MUST be the same type of request passed into this hook. If not, an exception will immediately occur. """ - return context.transport_request + return context.expect_transport_request() def read_before_signing( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> None: """A hook called before the transport request is signed. @@ -309,7 +279,10 @@ def read_before_signing( """ def read_after_signing( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> None: """A hook called after the transport request is signed. @@ -332,7 +305,10 @@ def read_after_signing( """ def modify_before_transmit( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> TransportRequest: """A hook called before the transport request is sent to the service. @@ -355,10 +331,13 @@ def modify_before_transmit( The transport request returned by this hook MUST be the same type of request passed into this hook. If not, an exception will immediately occur. """ - return context.transport_request + return context.expect_transport_request() def read_before_transmit( - self, context: InterceptorContext[Request, None, TransportRequest, None] + self, + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> None: """A hook called before the transport request is sent to the service. @@ -383,7 +362,9 @@ def read_before_transmit( def read_after_transmit( self, - context: InterceptorContext[Request, None, TransportRequest, TransportResponse], + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> None: """A hook called after the transport request is sent to the service and a transport response is received. @@ -409,7 +390,9 @@ def read_after_transmit( def modify_before_deserialization( self, - context: InterceptorContext[Request, None, TransportRequest, TransportResponse], + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> TransportResponse: """A hook called before the transport response is deserialized. @@ -435,11 +418,13 @@ def modify_before_deserialization( The transport response returned by this hook MUST be the same type of response passed into this hook. If not, an exception will immediately occur. """ - return context.transport_response + return context.expect_transport_response() def read_before_deserialization( self, - context: InterceptorContext[Request, None, TransportRequest, TransportResponse], + context: InterceptorContext[ + Request, Response, TransportRequest, TransportResponse + ], ) -> None: """A hook called before the transport response is deserialized. @@ -492,7 +477,7 @@ def read_after_deserialization( def modify_before_attempt_completion( self, context: InterceptorContext[ - Request, Response, TransportRequest, TransportResponse | None + Request, Response, TransportRequest, TransportResponse ], ) -> Response | Exception: """A hook called when an attempt is completed. @@ -517,12 +502,12 @@ def modify_before_attempt_completion( exception type can be returned, replacing the `response` currently in the context. """ - return context.response + return context.expect_response() def read_after_attempt( self, context: InterceptorContext[ - Request, Response, TransportRequest, TransportResponse | None + Request, Response, TransportRequest, TransportResponse ], ) -> None: """A hook called when an attempt is completed. @@ -550,7 +535,7 @@ def read_after_attempt( def modify_before_completion( self, context: InterceptorContext[ - Request, Response, TransportRequest | None, TransportResponse | None + Request, Response, TransportRequest, TransportResponse ], ) -> Response | Exception: """A hook called when an execution is completed. @@ -570,12 +555,12 @@ def modify_before_completion( Any output returned by this hook MUST match the operation being invoked. Any exception type can be returned, replacing the `response` currently in the context. """ - return context.response + return context.expect_response() def read_after_execution( self, context: InterceptorContext[ - Request, Response, TransportRequest | None, TransportResponse | None + Request, Response, TransportRequest, TransportResponse ], ) -> None: """A hook called when an execution is completed. @@ -596,3 +581,8 @@ def read_after_execution( final response. If multiple `read_after_execution` methods throw exceptions, the latest will be used and earlier ones will be logged and dropped. """ + + +# Deliberately not a type alias because those can't be subclassed +AnyInterceptor = Interceptor[Any, Any, Any, Any] +"""An Interceptor alias that accepts any parametric types.""" diff --git a/packages/smithy-http/src/smithy_http/aio/interfaces/__init__.py b/packages/smithy-http/src/smithy_http/aio/interfaces/__init__.py index d3a5772e2..24cebcf69 100644 --- a/packages/smithy-http/src/smithy_http/aio/interfaces/__init__.py +++ b/packages/smithy-http/src/smithy_http/aio/interfaces/__init__.py @@ -1,7 +1,8 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Protocol, Self +from typing import Protocol, Self, Any +from smithy_core.interceptors import Interceptor, InterceptorContext from smithy_core.aio.interfaces import Request, Response, ClientTransport from smithy_core.aio.utils import read_streaming_blob, read_streaming_blob_async @@ -97,3 +98,16 @@ async def send( :param request_config: Configuration specific to this request. """ ... + + +# Deliberately not a type alias because those can't be subclassed +AnyHTTPInterceptor = Interceptor[Any, Any, HTTPRequest, HTTPResponse] +"""Interceptor alias for interceptors that work with HTTP transport types.""" + +type HTTPInterceptorContext[Request, Response] = InterceptorContext[ + Request, Response, HTTPRequest, HTTPResponse +] +"""Type alias for interceptor contexts that can contain HTTP transport types.""" + +type AnyHTTPInterceptorContext = InterceptorContext[Any, Any, HTTPRequest, HTTPResponse] +"""Type alias for interceptor contexts that can contain HTTP transport types.""" diff --git a/packages/smithy-http/src/smithy_http/interceptors/user_agent.py b/packages/smithy-http/src/smithy_http/interceptors/user_agent.py index ad58ddcaa..8cd3d0980 100644 --- a/packages/smithy-http/src/smithy_http/interceptors/user_agent.py +++ b/packages/smithy-http/src/smithy_http/interceptors/user_agent.py @@ -1,31 +1,30 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import platform -from typing import Self, Any +from typing import Self import smithy_core -from smithy_core.interceptors import Interceptor, InterceptorContext from smithy_http import Field -from smithy_http.aio.interfaces import HTTPRequest +from smithy_http.aio.interfaces import ( + HTTPRequest, + AnyHTTPInterceptor, + AnyHTTPInterceptorContext, +) from smithy_http.user_agent import UserAgent, UserAgentComponent -class UserAgentInterceptor(Interceptor[Any, None, HTTPRequest, None]): +class UserAgentInterceptor(AnyHTTPInterceptor): """Adds interceptors that initialize UserAgent in the context and add the user-agent header.""" - def read_before_execution( - self, context: InterceptorContext[Any, None, None, None] - ) -> None: + def read_before_execution(self, context: AnyHTTPInterceptorContext) -> None: context.properties["user_agent"] = _UserAgentBuilder.from_environment().build() - def modify_before_signing( - self, context: InterceptorContext[Any, None, HTTPRequest, None] - ) -> HTTPRequest: + def modify_before_signing(self, context: AnyHTTPInterceptorContext) -> HTTPRequest: user_agent = context.properties["user_agent"] - request = context.transport_request + request = context.expect_transport_request() request.fields.set_field(Field(name="User-Agent", values=[str(user_agent)])) - return context.transport_request + return request _USERAGENT_ALLOWED_OS_NAMES = (