Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,10 @@ await sleep(retry_token.retry_delay)

writer.pushState(new SignRequestSection());
if (context.applicationProtocol().isHttpProtocol() && supportsAuth) {
writer.addStdlibImport("re");
writer.addStdlibImport("typing", "Any");
writer.addImport("smithy_core.interfaces.identity", "Identity");
writer.addImport("smithy_core.types", "PropertyKey");
writer.write("""
# Step 7i: sign the request
if auth_option and signer:
Expand All @@ -587,6 +591,23 @@ await sleep(retry_token.retry_delay)
)
)
logger.debug("Signed HTTP request: %s", context.transport_request)

# TODO - Move this to separate resolution/population function
fields = context.transport_request.fields
auth_value = fields["Authorization"].as_string() # type: ignore
signature = re.split("Signature=", auth_value)[-1] # type: ignore
context.properties["signature"] = signature.encode('utf-8')

identity_key: PropertyKey[Identity | None] = PropertyKey(
key="identity",
value_type=Identity | None # type: ignore
)
sp_key: PropertyKey[dict[str, Any]] = PropertyKey(
key="signer_properties",
value_type=dict[str, Any] # type: ignore
)
context.properties[identity_key] = identity
context.properties[sp_key] = auth_option.signer_properties
""");
}
writer.popState();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.util.List;
import java.util.Set;
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait;
import software.amazon.smithy.model.knowledge.EventStreamIndex;
import software.amazon.smithy.model.knowledge.HttpBinding;
import software.amazon.smithy.model.node.ArrayNode;
import software.amazon.smithy.model.node.ObjectNode;
Expand Down Expand Up @@ -156,6 +157,21 @@ protected void serializeDocumentBody(
writer.popState();
}

@Override
protected void writeDefaultHeaders(GenerationContext context, PythonWriter writer, OperationShape operation) {
var eventStreamIndex = EventStreamIndex.of(context.model());
if (eventStreamIndex.getInputInfo(operation).isPresent()) {
writer.addImport("smithy_http", "Field");
writer.write(
"Field(name=\"Content-Type\", values=[$S]),",
"application/vnd.amazon.eventstream");
writer.write(
"Field(name=\"X-Amz-Content-SHA256\", values=[$S]),",
"STREAMING-AWS4-HMAC-SHA256-EVENTS");
}
}


@Override
protected void serializePayloadBody(
GenerationContext context,
Expand Down Expand Up @@ -397,12 +413,24 @@ public void wrapInputStream(GenerationContext context, PythonWriter writer) {
writer.addImport("smithy_core.aio.types", "AsyncBytesReader");
writer.addImport("smithy_core.types", "TimestampFormat");
writer.addImport("aws_event_stream.aio", "AWSEventPublisher");
writer.addImport("aws_sdk_signers", "AsyncEventSigner");
writer.write(
"""
# TODO - Move this out of the RestJSON generator
ctx = request_context
signer_properties = ctx.properties.get("signer_properties") # type: ignore
identity = ctx.properties.get("identity") # type: ignore
signature = ctx.properties.get("signature") # type: ignore
signer = AsyncEventSigner(
signing_properties=signer_properties, # type: ignore
identity=identity, # type: ignore
initial_signature=signature, # type: ignore
)
codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS)
publisher = AWSEventPublisher[Any](
payload_codec=codec,
async_writer=request_context.transport_request.body, # type: ignore
signer=signer, # type: ignore
)
""");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from smithy_core.shapes import ShapeType

from ..events import EventMessage, HEADER_VALUE, Short, Byte, Long
from ..events import EventHeaderEncoder, EventMessage, HEADER_VALUE, Short, Byte, Long
from ..exceptions import InvalidHeaderValue
from . import (
INITIAL_REQUEST_EVENT_TYPE,
Expand All @@ -43,6 +43,7 @@ def __init__(
self._initial_message_event_type = INITIAL_REQUEST_EVENT_TYPE
else:
self._initial_message_event_type = INITIAL_RESPONSE_EVENT_TYPE
self.event_header_encoder_cls = EventHeaderEncoder

def get_result(self) -> EventMessage | None:
return self._result
Expand Down
24 changes: 19 additions & 5 deletions packages/aws-event-stream/src/aws_event_stream/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,31 @@

from .._private.serializers import EventSerializer as _EventSerializer
from .._private.deserializers import EventDeserializer as _EventDeserializer
from ..events import Event, EventMessage
from ..events import Event, EventHeaderEncoder, EventMessage
from ..exceptions import EventError

from typing import Protocol

logger = logging.getLogger(__name__)


type Signer = Callable[[EventMessage], EventMessage]
"""A function that takes an event message and signs it, and returns it signed."""
class EventSigner(Protocol):
"""A signer to manage credentials and EventMessages for an Event Stream lifecyle."""

def sign_event(
self,
*,
event_message: EventMessage,
event_encoder_cls: type[EventHeaderEncoder],
) -> EventMessage: ...


class AWSEventPublisher[E: SerializeableShape](EventPublisher[E]):
def __init__(
self,
payload_codec: Codec,
async_writer: AsyncWriter,
signer: Signer | None = None,
signer: EventSigner | None = None,
is_client_mode: bool = True,
):
self._writer = async_writer
Expand All @@ -50,8 +59,13 @@ async def send(self, event: E) -> None:
"Expected an event message to be serialized, but was None."
)
if self._signer is not None:
result = self._signer(result)
encoder = self._serializer.event_header_encoder_cls
result = await self._signer.sign_event( # type: ignore
event_message=result,
event_encoder_cls=encoder,
)

assert isinstance(result, EventMessage)
encoded_result = result.encode()
try:
logger.debug("Publishing serialized event: %s", result)
Expand Down
2 changes: 1 addition & 1 deletion packages/aws-event-stream/src/aws_event_stream/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def get_result(self) -> bytes:
raise InvalidHeadersLength(len(result))
return result

def encode_headers(self, headers: HEADERS_DICT):
def encode_headers(self, headers: HEADERS_DICT) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand why this method builds an internal buffer - but the interface does seem weird. It seems like we could either:

  1. have this as a class method which would construct an instance, call _encode_headers and then return the result. or
  2. Have a private class that handles the build/state, _EventHeaderEncoder which encode_headers would instantiate and return the result from.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'd generally agree with this, the encode workflow here is a bit odd. I think we can look at a refactor as follow up, this PR is just fixing the typing issue from the original.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this was intended for streaming out without an intermediary. Then we switched to have an intermediary

"""Encode a map of headers.

:param headers: A mapping of headers to encode.
Expand Down
8 changes: 7 additions & 1 deletion packages/aws-sdk-signers/src/aws_sdk_signers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@
from ._http import URI, AWSRequest, Field, Fields
from ._identity import AWSCredentialIdentity
from ._io import AsyncBytesReader
from .signers import AsyncSigV4Signer, SigV4Signer, SigV4SigningProperties
from .signers import (
AsyncSigV4Signer,
AsyncEventSigner,
SigV4Signer,
SigV4SigningProperties,
)

__license__ = "Apache-2.0"
__version__ = importlib.metadata.version("aws-sdk-signers")

__all__ = (
"AsyncBytesReader",
"AsyncSigV4Signer",
"AsyncEventSigner",
"AWSCredentialIdentity",
"AWSRequest",
"Field",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import datetime
import uuid
from collections.abc import Mapping
from typing import Protocol


type HEADER_VALUE = bool | int | bytes | str | datetime.datetime | uuid.UUID
"""A union of valid value types for event headers."""


type HEADERS_DICT = Mapping[str, HEADER_VALUE]
"""A dictionary of event headers."""


class EventMessage(Protocol):
"""A signable message that may be sent over an event stream."""

headers: HEADERS_DICT
"""The headers present in the event message."""

payload: bytes
"""The serialized bytes of the message payload."""

def encode(self) -> bytes:
"""Encode heads and payload into bytes for transit."""
...


class EventHeaderEncoder(Protocol):
"""A utility class that encodes event headers into bytes."""

def clear(self) -> None:
"""Clear all previously encoded headers."""
...

def get_result(self) -> bytes:
"""Get all the encoded header bytes."""
...

def encode_headers(self, headers: HEADERS_DICT) -> None:
"""Encode a map of headers."""
...
120 changes: 119 additions & 1 deletion packages/aws-sdk-signers/src/aws_sdk_signers/signers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import asyncio
import datetime
import hmac
import io
import warnings
from asyncio import iscoroutinefunction
from binascii import hexlify
from collections.abc import AsyncIterable, Iterable
from copy import deepcopy
from hashlib import sha256
from typing import Required, TypedDict
from typing import Required, TypedDict, TYPE_CHECKING
from urllib.parse import parse_qsl, quote

from .interfaces.io import AsyncSeekable, Seekable
Expand All @@ -19,6 +21,9 @@
from ._io import AsyncBytesReader
from .exceptions import AWSSDKWarning, MissingExpectedParameterException

if TYPE_CHECKING:
from .interfaces.events import EventMessage, EventHeaderEncoder

HEADERS_EXCLUDED_FROM_SIGNING: tuple[str, ...] = (
"accept",
"accept-encoding",
Expand Down Expand Up @@ -739,6 +744,12 @@ async def _format_canonical_payload(
request: AWSRequest,
signing_properties: SigV4SigningProperties,
) -> str:
if (
"X-Amz-Content-SHA256" in request.fields
and len(request.fields["X-Amz-Content-SHA256"].values) == 1
):
return request.fields["X-Amz-Content-SHA256"].values[0]

payload_hash = await self._compute_payload_hash(
request=request, signing_properties=signing_properties
)
Expand Down Expand Up @@ -789,6 +800,113 @@ async def _compute_payload_hash(
return checksum.hexdigest()


class AsyncEventSigner:
def __init__(
self,
*,
signing_properties: SigV4SigningProperties,
identity: AWSCredentialIdentity,
initial_signature: bytes,
):
self._signing_properties = signing_properties
self._identity = identity
self._prior_signature = initial_signature
self._signing_lock = asyncio.Lock()

async def sign_event(
self,
*,
event_message: "EventMessage",
event_encoder_cls: type["EventHeaderEncoder"],
) -> "EventMessage":
async with self._signing_lock:
# Copy and prepopulate any missing values in the
# signing properties.
new_signing_properties = SigV4SigningProperties( # type: ignore
**self._signing_properties
)
# TODO: If date is in properties, parse a datetime from it.
date_obj = datetime.datetime.now(datetime.UTC)
if "date" not in new_signing_properties:
new_signing_properties["date"] = date_obj.strftime(
SIGV4_TIMESTAMP_FORMAT
)

timestamp = new_signing_properties["date"]
headers: dict[str, str | bytes | datetime.datetime] = {":date": date_obj}
encoder = event_encoder_cls()
encoder.encode_headers(headers)
encoded_headers = encoder.get_result()

payload = event_message.encode()

string_to_sign = await self._event_string_to_sign(
timestamp=timestamp,
scope=self._scope(new_signing_properties),
encoded_headers=encoded_headers,
payload=payload,
prior_signature=self._prior_signature,
)
event_signature = await self._sign_event(
timestamp=timestamp,
string_to_sign=string_to_sign,
signing_properties=new_signing_properties,
)
headers[":chunk-signature"] = event_signature

event_message.headers = headers
event_message.payload = payload

# set new prior signature before releasing the lock
self._prior_signature = hexlify(event_signature)

return event_message

async def _event_string_to_sign(
self,
*,
timestamp: str,
scope: str,
encoded_headers: bytes,
payload: bytes,
prior_signature: bytes,
) -> str:
return (
"AWS4-HMAC-SHA256-PAYLOAD\n"
f"{timestamp}\n"
f"{scope}\n"
f"{prior_signature.decode('utf-8')}\n"
f"{sha256(encoded_headers).hexdigest()}\n"
f"{sha256(payload).hexdigest()}"
)

async def _sign_event(
self,
*,
timestamp: str,
string_to_sign: str,
signing_properties: SigV4SigningProperties,
) -> bytes:
key = self._identity.secret_access_key.encode("utf-8")
today = timestamp[:8].encode("utf-8")
k_date = self._hash(b"AWS4" + key, today)
k_region = self._hash(k_date, signing_properties["region"].encode("utf-8"))
k_service = self._hash(k_region, signing_properties["service"].encode("utf-8"))
k_signing = self._hash(k_service, b"aws4_request")
return self._hash(k_signing, string_to_sign.encode("utf-8"))

def _hash(self, key: bytes, msg: bytes) -> bytes:
return hmac.new(key, msg, sha256).digest()

def _scope(self, signing_properties: SigV4SigningProperties) -> str:
assert "date" in signing_properties
formatted_date = signing_properties["date"][0:8]
region = signing_properties["region"]
service = signing_properties["service"]
# Scope format: <YYYYMMDD>/<AWS Region>/<AWS Service>/aws4_request
return f"{formatted_date}/{region}/{service}/aws4_request"


def _remove_dot_segments(path: str, remove_consecutive_slashes: bool = True) -> str:
"""Removes dot segments from a path per :rfc:`3986#section-5.2.4`.

Expand Down
Loading