diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index 4cca19070..ffb0233ba 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -570,6 +570,10 @@ await sleep(retry_token.retry_delay) writer.pushState(new SignRequestSection()); if (context.applicationProtocol().isHttpProtocol() && supportsAuth) { + writer.addStdlibImport("re"); + writer.addStdlibImport("typing", "Any"); + writer.addImport("smithy_core.interfaces.identity", "Identity"); + writer.addImport("smithy_core.types", "PropertyKey"); writer.write(""" # Step 7i: sign the request if auth_option and signer: @@ -587,6 +591,23 @@ await sleep(retry_token.retry_delay) ) ) logger.debug("Signed HTTP request: %s", context.transport_request) + + # TODO - Move this to separate resolution/population function + fields = context.transport_request.fields + auth_value = fields["Authorization"].as_string() # type: ignore + signature = re.split("Signature=", auth_value)[-1] # type: ignore + context.properties["signature"] = signature.encode('utf-8') + + identity_key: PropertyKey[Identity | None] = PropertyKey( + key="identity", + value_type=Identity | None # type: ignore + ) + sp_key: PropertyKey[dict[str, Any]] = PropertyKey( + key="signer_properties", + value_type=dict[str, Any] # type: ignore + ) + context.properties[identity_key] = identity + context.properties[sp_key] = auth_option.signer_properties """); } writer.popState(); diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java index 0fa2e312a..6614cc390 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Set; import software.amazon.smithy.aws.traits.protocols.RestJson1Trait; +import software.amazon.smithy.model.knowledge.EventStreamIndex; import software.amazon.smithy.model.knowledge.HttpBinding; import software.amazon.smithy.model.node.ArrayNode; import software.amazon.smithy.model.node.ObjectNode; @@ -156,6 +157,21 @@ protected void serializeDocumentBody( writer.popState(); } + @Override + protected void writeDefaultHeaders(GenerationContext context, PythonWriter writer, OperationShape operation) { + var eventStreamIndex = EventStreamIndex.of(context.model()); + if (eventStreamIndex.getInputInfo(operation).isPresent()) { + writer.addImport("smithy_http", "Field"); + writer.write( + "Field(name=\"Content-Type\", values=[$S]),", + "application/vnd.amazon.eventstream"); + writer.write( + "Field(name=\"X-Amz-Content-SHA256\", values=[$S]),", + "STREAMING-AWS4-HMAC-SHA256-EVENTS"); + } + } + + @Override protected void serializePayloadBody( GenerationContext context, @@ -397,12 +413,24 @@ public void wrapInputStream(GenerationContext context, PythonWriter writer) { writer.addImport("smithy_core.aio.types", "AsyncBytesReader"); writer.addImport("smithy_core.types", "TimestampFormat"); writer.addImport("aws_event_stream.aio", "AWSEventPublisher"); + writer.addImport("aws_sdk_signers", "AsyncEventSigner"); writer.write( """ + # TODO - Move this out of the RestJSON generator + ctx = request_context + signer_properties = ctx.properties.get("signer_properties") # type: ignore + identity = ctx.properties.get("identity") # type: ignore + signature = ctx.properties.get("signature") # type: ignore + signer = AsyncEventSigner( + signing_properties=signer_properties, # type: ignore + identity=identity, # type: ignore + initial_signature=signature, # type: ignore + ) codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) publisher = AWSEventPublisher[Any]( payload_codec=codec, async_writer=request_context.transport_request.body, # type: ignore + signer=signer, # type: ignore ) """); } diff --git a/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py b/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py index 31594ea3c..f6c66aa9b 100644 --- a/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py +++ b/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py @@ -16,7 +16,7 @@ ) from smithy_core.shapes import ShapeType -from ..events import EventMessage, HEADER_VALUE, Short, Byte, Long +from ..events import EventHeaderEncoder, EventMessage, HEADER_VALUE, Short, Byte, Long from ..exceptions import InvalidHeaderValue from . import ( INITIAL_REQUEST_EVENT_TYPE, @@ -43,6 +43,7 @@ def __init__( self._initial_message_event_type = INITIAL_REQUEST_EVENT_TYPE else: self._initial_message_event_type = INITIAL_RESPONSE_EVENT_TYPE + self.event_header_encoder_cls = EventHeaderEncoder def get_result(self) -> EventMessage | None: return self._result diff --git a/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py b/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py index 4403a9af7..8a80da3e5 100644 --- a/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py +++ b/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py @@ -14,14 +14,23 @@ from .._private.serializers import EventSerializer as _EventSerializer from .._private.deserializers import EventDeserializer as _EventDeserializer -from ..events import Event, EventMessage +from ..events import Event, EventHeaderEncoder, EventMessage from ..exceptions import EventError +from typing import Protocol + logger = logging.getLogger(__name__) -type Signer = Callable[[EventMessage], EventMessage] -"""A function that takes an event message and signs it, and returns it signed.""" +class EventSigner(Protocol): + """A signer to manage credentials and EventMessages for an Event Stream lifecyle.""" + + def sign_event( + self, + *, + event_message: EventMessage, + event_encoder_cls: type[EventHeaderEncoder], + ) -> EventMessage: ... class AWSEventPublisher[E: SerializeableShape](EventPublisher[E]): @@ -29,7 +38,7 @@ def __init__( self, payload_codec: Codec, async_writer: AsyncWriter, - signer: Signer | None = None, + signer: EventSigner | None = None, is_client_mode: bool = True, ): self._writer = async_writer @@ -50,8 +59,13 @@ async def send(self, event: E) -> None: "Expected an event message to be serialized, but was None." ) if self._signer is not None: - result = self._signer(result) + encoder = self._serializer.event_header_encoder_cls + result = await self._signer.sign_event( # type: ignore + event_message=result, + event_encoder_cls=encoder, + ) + assert isinstance(result, EventMessage) encoded_result = result.encode() try: logger.debug("Publishing serialized event: %s", result) diff --git a/packages/aws-event-stream/src/aws_event_stream/events.py b/packages/aws-event-stream/src/aws_event_stream/events.py index 50e8b1598..46dfc2c7d 100644 --- a/packages/aws-event-stream/src/aws_event_stream/events.py +++ b/packages/aws-event-stream/src/aws_event_stream/events.py @@ -387,7 +387,7 @@ def get_result(self) -> bytes: raise InvalidHeadersLength(len(result)) return result - def encode_headers(self, headers: HEADERS_DICT): + def encode_headers(self, headers: HEADERS_DICT) -> None: """Encode a map of headers. :param headers: A mapping of headers to encode. diff --git a/packages/aws-sdk-signers/src/aws_sdk_signers/__init__.py b/packages/aws-sdk-signers/src/aws_sdk_signers/__init__.py index bdb729c85..4c1c3041c 100644 --- a/packages/aws-sdk-signers/src/aws_sdk_signers/__init__.py +++ b/packages/aws-sdk-signers/src/aws_sdk_signers/__init__.py @@ -9,7 +9,12 @@ from ._http import URI, AWSRequest, Field, Fields from ._identity import AWSCredentialIdentity from ._io import AsyncBytesReader -from .signers import AsyncSigV4Signer, SigV4Signer, SigV4SigningProperties +from .signers import ( + AsyncSigV4Signer, + AsyncEventSigner, + SigV4Signer, + SigV4SigningProperties, +) __license__ = "Apache-2.0" __version__ = importlib.metadata.version("aws-sdk-signers") @@ -17,6 +22,7 @@ __all__ = ( "AsyncBytesReader", "AsyncSigV4Signer", + "AsyncEventSigner", "AWSCredentialIdentity", "AWSRequest", "Field", diff --git a/packages/aws-sdk-signers/src/aws_sdk_signers/interfaces/events.py b/packages/aws-sdk-signers/src/aws_sdk_signers/interfaces/events.py new file mode 100644 index 000000000..f4cec80d6 --- /dev/null +++ b/packages/aws-sdk-signers/src/aws_sdk_signers/interfaces/events.py @@ -0,0 +1,47 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import datetime +import uuid +from collections.abc import Mapping +from typing import Protocol + + +type HEADER_VALUE = bool | int | bytes | str | datetime.datetime | uuid.UUID +"""A union of valid value types for event headers.""" + + +type HEADERS_DICT = Mapping[str, HEADER_VALUE] +"""A dictionary of event headers.""" + + +class EventMessage(Protocol): + """A signable message that may be sent over an event stream.""" + + headers: HEADERS_DICT + """The headers present in the event message.""" + + payload: bytes + """The serialized bytes of the message payload.""" + + def encode(self) -> bytes: + """Encode heads and payload into bytes for transit.""" + ... + + +class EventHeaderEncoder(Protocol): + """A utility class that encodes event headers into bytes.""" + + def clear(self) -> None: + """Clear all previously encoded headers.""" + ... + + def get_result(self) -> bytes: + """Get all the encoded header bytes.""" + ... + + def encode_headers(self, headers: HEADERS_DICT) -> None: + """Encode a map of headers.""" + ... diff --git a/packages/aws-sdk-signers/src/aws_sdk_signers/signers.py b/packages/aws-sdk-signers/src/aws_sdk_signers/signers.py index 15d71d24b..111b1995c 100644 --- a/packages/aws-sdk-signers/src/aws_sdk_signers/signers.py +++ b/packages/aws-sdk-signers/src/aws_sdk_signers/signers.py @@ -1,15 +1,17 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio import datetime import hmac import io import warnings from asyncio import iscoroutinefunction +from binascii import hexlify from collections.abc import AsyncIterable, Iterable from copy import deepcopy from hashlib import sha256 -from typing import Required, TypedDict +from typing import Required, TypedDict, TYPE_CHECKING from urllib.parse import parse_qsl, quote from .interfaces.io import AsyncSeekable, Seekable @@ -19,6 +21,9 @@ from ._io import AsyncBytesReader from .exceptions import AWSSDKWarning, MissingExpectedParameterException +if TYPE_CHECKING: + from .interfaces.events import EventMessage, EventHeaderEncoder + HEADERS_EXCLUDED_FROM_SIGNING: tuple[str, ...] = ( "accept", "accept-encoding", @@ -739,6 +744,12 @@ async def _format_canonical_payload( request: AWSRequest, signing_properties: SigV4SigningProperties, ) -> str: + if ( + "X-Amz-Content-SHA256" in request.fields + and len(request.fields["X-Amz-Content-SHA256"].values) == 1 + ): + return request.fields["X-Amz-Content-SHA256"].values[0] + payload_hash = await self._compute_payload_hash( request=request, signing_properties=signing_properties ) @@ -789,6 +800,113 @@ async def _compute_payload_hash( return checksum.hexdigest() +class AsyncEventSigner: + def __init__( + self, + *, + signing_properties: SigV4SigningProperties, + identity: AWSCredentialIdentity, + initial_signature: bytes, + ): + self._signing_properties = signing_properties + self._identity = identity + self._prior_signature = initial_signature + self._signing_lock = asyncio.Lock() + + async def sign_event( + self, + *, + event_message: "EventMessage", + event_encoder_cls: type["EventHeaderEncoder"], + ) -> "EventMessage": + async with self._signing_lock: + # Copy and prepopulate any missing values in the + # signing properties. + new_signing_properties = SigV4SigningProperties( # type: ignore + **self._signing_properties + ) + # TODO: If date is in properties, parse a datetime from it. + date_obj = datetime.datetime.now(datetime.UTC) + if "date" not in new_signing_properties: + new_signing_properties["date"] = date_obj.strftime( + SIGV4_TIMESTAMP_FORMAT + ) + + timestamp = new_signing_properties["date"] + headers: dict[str, str | bytes | datetime.datetime] = {":date": date_obj} + encoder = event_encoder_cls() + encoder.encode_headers(headers) + encoded_headers = encoder.get_result() + + payload = event_message.encode() + + string_to_sign = await self._event_string_to_sign( + timestamp=timestamp, + scope=self._scope(new_signing_properties), + encoded_headers=encoded_headers, + payload=payload, + prior_signature=self._prior_signature, + ) + event_signature = await self._sign_event( + timestamp=timestamp, + string_to_sign=string_to_sign, + signing_properties=new_signing_properties, + ) + headers[":chunk-signature"] = event_signature + + event_message.headers = headers + event_message.payload = payload + + # set new prior signature before releasing the lock + self._prior_signature = hexlify(event_signature) + + return event_message + + async def _event_string_to_sign( + self, + *, + timestamp: str, + scope: str, + encoded_headers: bytes, + payload: bytes, + prior_signature: bytes, + ) -> str: + return ( + "AWS4-HMAC-SHA256-PAYLOAD\n" + f"{timestamp}\n" + f"{scope}\n" + f"{prior_signature.decode('utf-8')}\n" + f"{sha256(encoded_headers).hexdigest()}\n" + f"{sha256(payload).hexdigest()}" + ) + + async def _sign_event( + self, + *, + timestamp: str, + string_to_sign: str, + signing_properties: SigV4SigningProperties, + ) -> bytes: + key = self._identity.secret_access_key.encode("utf-8") + today = timestamp[:8].encode("utf-8") + k_date = self._hash(b"AWS4" + key, today) + k_region = self._hash(k_date, signing_properties["region"].encode("utf-8")) + k_service = self._hash(k_region, signing_properties["service"].encode("utf-8")) + k_signing = self._hash(k_service, b"aws4_request") + return self._hash(k_signing, string_to_sign.encode("utf-8")) + + def _hash(self, key: bytes, msg: bytes) -> bytes: + return hmac.new(key, msg, sha256).digest() + + def _scope(self, signing_properties: SigV4SigningProperties) -> str: + assert "date" in signing_properties + formatted_date = signing_properties["date"][0:8] + region = signing_properties["region"] + service = signing_properties["service"] + # Scope format: ///aws4_request + return f"{formatted_date}/{region}/{service}/aws4_request" + + def _remove_dot_segments(path: str, remove_consecutive_slashes: bool = True) -> str: """Removes dot segments from a path per :rfc:`3986#section-5.2.4`. diff --git a/packages/smithy-core/src/smithy_core/interfaces/identity.py b/packages/smithy-core/src/smithy_core/interfaces/identity.py index 765e32c22..e34a8f976 100644 --- a/packages/smithy-core/src/smithy_core/interfaces/identity.py +++ b/packages/smithy-core/src/smithy_core/interfaces/identity.py @@ -1,9 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from datetime import datetime -from typing import Protocol, TypedDict, TypeVar +from typing import Protocol, TypedDict, TypeVar, runtime_checkable +@runtime_checkable class Identity(Protocol): """An entity available to the client representing who the user is."""