Skip to content

Commit 0106d5e

Browse files
committed
Initial signer pass
1 parent 94e9f57 commit 0106d5e

File tree

6 files changed

+188
-12
lines changed

6 files changed

+188
-12
lines changed

codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,14 @@ async def _handle_attempt(
501501
identity=identity,
502502
signing_properties=auth_option.signer_properties,
503503
)
504+
505+
# TODO - Move this to separate resolution/population function
506+
fields = context._transport_request.fields
507+
auth_value = fields["Authorization"].as_string() # type: ignore
508+
signature = re.split("Signature=", auth_value)[-1] # type: ignore
509+
context._properties["signature"] = hexlify(signature.encode('utf-8')) # type: ignore
510+
context._properties["identity"] = identity
511+
context._properties["signer_properties"] = auth_option.signer_properties
504512
logger.debug("Signed HTTP request: %s", context._transport_request)
505513
""");
506514
}

packages/aws-event-stream/src/aws_event_stream/_private/serializers.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# SPDX-License-Identifier: Apache-2.0
33
import asyncio
44
import datetime
5-
from collections.abc import Callable, Iterator
5+
from collections.abc import Iterator
66
from contextlib import contextmanager
77
from io import BytesIO
8-
from typing import Never
8+
from typing import Never, Protocol
99

1010
from smithy_core.aio.interfaces import AsyncWriter
1111
from smithy_core.codecs import Codec
@@ -20,7 +20,7 @@
2020
from smithy_core.shapes import ShapeType
2121
from smithy_event_stream.aio.interfaces import AsyncEventPublisher
2222

23-
from ..events import EventMessage, HEADER_VALUE, Short, Byte, Long
23+
from ..events import EventHeaderEncoder, EventMessage, HEADER_VALUE, Short, Byte, Long
2424
from ..exceptions import InvalidHeaderValue
2525
from . import (
2626
INITIAL_REQUEST_EVENT_TYPE,
@@ -33,16 +33,23 @@
3333
_DEFAULT_BLOB_CONTENT_TYPE = "application/octet-stream"
3434

3535

36-
type Signer = Callable[[EventMessage], EventMessage]
37-
"""A function that takes an event message and signs it, and returns it signed."""
36+
class EventSigner(Protocol):
37+
"""A signer to manage credentials and EventMessages for an Event Stream lifecyle."""
38+
39+
def sign_event(
40+
self,
41+
*,
42+
event_message: EventMessage,
43+
event_encoder: type[EventHeaderEncoder],
44+
) -> EventMessage: ...
3845

3946

4047
class AWSAsyncEventPublisher[E: SerializeableShape](AsyncEventPublisher[E]):
4148
def __init__(
4249
self,
4350
payload_codec: Codec,
4451
async_writer: AsyncWriter,
45-
signer: Signer | None = None,
52+
signer: EventSigner | None = None,
4653
is_client_mode: bool = True,
4754
):
4855
self._writer = async_writer
@@ -59,7 +66,11 @@ async def send(self, event: E) -> None:
5966
"Expected an event message to be serialized, but was None."
6067
)
6168
if self._signer is not None:
62-
result = self._signer(result)
69+
encoder = self._serializer.event_header_encoder_cls
70+
result = self._signer.sign_event(
71+
event_message=result,
72+
event_encoder=encoder,
73+
)
6374
await self._writer.write(result.encode())
6475

6576
async def close(self) -> None:
@@ -80,6 +91,7 @@ def __init__(
8091
self._initial_message_event_type = INITIAL_REQUEST_EVENT_TYPE
8192
else:
8293
self._initial_message_event_type = INITIAL_RESPONSE_EVENT_TYPE
94+
self.event_header_encoder_cls = EventHeaderEncoder
8395

8496
def get_result(self) -> EventMessage | None:
8597
return self._result

packages/aws-event-stream/src/aws_event_stream/aio/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from .._private.deserializers import AWSAsyncEventReceiver as _AWSEventReceiver
1919
from .._private.serializers import AWSAsyncEventPublisher as _AWSEventPublisher
20-
from .._private.serializers import Signer
20+
from .._private.serializers import EventSigner
2121
from ..exceptions import MissingInitialResponse
2222

2323

@@ -36,7 +36,7 @@ def __init__(
3636
async_reader: AsyncByteStream | None = None,
3737
initial_response: R | None = None,
3838
deserializeable_response: type[R] | None = None,
39-
signer: Signer | None = None,
39+
signer: EventSigner | None = None,
4040
is_client_mode: bool = True,
4141
) -> None:
4242
"""Construct an AWSDuplexEventStream.
@@ -134,7 +134,7 @@ def __init__(
134134
payload_codec: Codec,
135135
async_writer: AsyncWriter,
136136
initial_response: R | None = None,
137-
signer: Signer | None = None,
137+
signer: EventSigner | None = None,
138138
is_client_mode: bool = True,
139139
) -> None:
140140
"""Construct an AWSInputEventStream.

packages/aws-sdk-signers/src/aws_sdk_signers/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
from ._http import URI, AWSRequest, Field, Fields
1010
from ._identity import AWSCredentialIdentity
1111
from ._io import AsyncBytesReader
12-
from .signers import AsyncSigV4Signer, SigV4Signer, SigV4SigningProperties
12+
from .signers import (
13+
AsyncSigV4Signer,
14+
AsyncEventSigner,
15+
SigV4Signer,
16+
SigV4SigningProperties,
17+
)
1318

1419
__license__ = "Apache-2.0"
1520
__version__ = importlib.metadata.version("aws-sdk-signers")
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import datetime
7+
import uuid
8+
from collections.abc import Mapping
9+
from typing import Protocol
10+
11+
12+
type HEADER_VALUE = bool | int | bytes | str | datetime.datetime | uuid.UUID
13+
"""A union of valid value types for event headers."""
14+
15+
16+
type HEADERS_DICT = Mapping[str, HEADER_VALUE]
17+
"""A dictionary of event headers."""
18+
19+
20+
class EventMessage(Protocol):
21+
"""A signable message that may be sent over an event stream."""
22+
23+
headers: HEADERS_DICT
24+
"""The headers present in the event message."""
25+
26+
payload: bytes
27+
"""The serialized bytes of the message payload."""
28+
29+
def encode(self) -> bytes:
30+
"""Encode heads and payload into bytes for transit."""
31+
...
32+
33+
34+
class EventHeaderEncoder(Protocol):
35+
"""A utility class that encodes event headers into bytes."""
36+
37+
def clear(self) -> None:
38+
"""Clear all previously encoded headers."""
39+
...
40+
41+
def get_result(self) -> bytes:
42+
"""Get all the encoded header bytes."""
43+
...
44+
45+
def encode_headers(self, headers: HEADERS_DICT) -> None:
46+
"""Encode a map of headers."""
47+
...

packages/aws-sdk-signers/src/aws_sdk_signers/signers.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import asyncio
45
import datetime
56
import hmac
67
import io
78
import warnings
89
from asyncio import iscoroutinefunction
10+
from binascii import hexlify
911
from collections.abc import AsyncIterable, Iterable
1012
from copy import deepcopy
1113
from hashlib import sha256
12-
from typing import Required, TypedDict
14+
from typing import Required, TypedDict, TYPE_CHECKING
1315
from urllib.parse import parse_qsl, quote
1416

1517
from .interfaces.io import AsyncSeekable, Seekable
@@ -19,6 +21,9 @@
1921
from ._io import AsyncBytesReader
2022
from .exceptions import AWSSDKWarning, MissingExpectedParameterException
2123

24+
if TYPE_CHECKING:
25+
from .interfaces.events import EventMessage, EventHeaderEncoder
26+
2227
HEADERS_EXCLUDED_FROM_SIGNING: tuple[str, ...] = (
2328
"accept",
2429
"accept-encoding",
@@ -774,6 +779,105 @@ async def _compute_payload_hash(
774779
return checksum.hexdigest()
775780

776781

782+
class AsyncEventSigner:
783+
def __init__(
784+
self,
785+
*,
786+
signing_properties: SigV4SigningProperties,
787+
identity: AWSCredentialIdentity,
788+
):
789+
self._signing_properties = signing_properties
790+
self._identity = identity
791+
self._prior_signature = initial_signature
792+
self._signing_lock: asyncio.Lock = asyncio.Lock()
793+
794+
async def sign_event(
795+
self,
796+
*,
797+
event_message: "EventMessage",
798+
event_encoder_cls: type["EventHeaderEncoder"],
799+
):
800+
async with self._signing_lock:
801+
# Copy and prepopulate any missing values in the
802+
# signing properties.
803+
new_signing_properties = SigV4SigningProperties( # type: ignore
804+
**self._signing_properties
805+
)
806+
if "date" not in new_signing_properties:
807+
date_obj = datetime.datetime.now(datetime.UTC)
808+
new_signing_properties["date"] = date_obj.strftime(
809+
SIGV4_TIMESTAMP_FORMAT
810+
)
811+
812+
timestamp = new_signing_properties["date"]
813+
headers: dict[str, str | bytes] = {":date": timestamp}
814+
encoder = event_encoder_cls()
815+
encoder.encode_headers(event_message.headers)
816+
encoded_headers = encoder.get_result()
817+
818+
string_to_sign = await self._event_string_to_sign(
819+
timestamp=timestamp,
820+
scope=self._scope(new_signing_properties),
821+
encoded_headers=encoded_headers,
822+
payload=event_message.payload,
823+
prior_signature=self._prior_signature,
824+
)
825+
event_signature = await self._sign_event(
826+
timestamp=timestamp,
827+
string_to_sign=string_to_sign,
828+
signing_properties=new_signing_properties,
829+
)
830+
headers[":chunk-signature"] = event_signature
831+
event_message.headers.update(headers) # type: ignore
832+
833+
# set new prior signature before releasing the lock
834+
self._prior_signature = event_signature
835+
836+
async def _event_string_to_sign(
837+
self,
838+
*,
839+
timestamp: str,
840+
scope: str,
841+
encoded_headers: bytes,
842+
payload: bytes,
843+
prior_signature: bytes,
844+
) -> str:
845+
return (
846+
"AWS-HMAC-SHA256-PAYLOAD\n"
847+
f"{timestamp}\n"
848+
f"{scope}\n"
849+
f"{hexlify(prior_signature).decode('utf-8')}\n"
850+
f"{sha256(encoded_headers).hexdigest()}\n"
851+
f"{sha256(payload).hexdigest()}\n"
852+
)
853+
854+
async def _sign_event(
855+
self,
856+
*,
857+
timestamp: str,
858+
string_to_sign: str,
859+
signing_properties: SigV4SigningProperties,
860+
) -> bytes:
861+
key = self._identity.secret_access_key.encode("utf-8")
862+
today = timestamp[:8].encode("utf-8")
863+
k_date = self._hash(b"AWS4" + key, today)
864+
k_region = self._hash(k_date, signing_properties["region"].encode("utf-8"))
865+
k_service = self._hash(k_region, signing_properties["service"].encode("utf-8"))
866+
k_signing = self._hash(k_service, b"aws4_request")
867+
return self._hash(k_signing, string_to_sign.encode("utf-8"))
868+
869+
def _hash(self, key: bytes, msg: bytes) -> bytes:
870+
return hmac.new(key, msg, sha256).digest()
871+
872+
def _scope(self, signing_properties: SigV4SigningProperties) -> str:
873+
assert "date" in signing_properties
874+
formatted_date = signing_properties["date"][0:8]
875+
region = signing_properties["region"]
876+
service = signing_properties["service"]
877+
# Scope format: <YYYYMMDD>/<AWS Region>/<AWS Service>/aws4_request
878+
return f"{formatted_date}/{region}/{service}/aws4_request"
879+
880+
777881
def _remove_dot_segments(path: str, remove_consecutive_slashes: bool = True) -> str:
778882
"""Removes dot segments from a path per :rfc:`3986#section-5.2.4`.
779883

0 commit comments

Comments
 (0)