Skip to content

Commit 42872bc

Browse files
committed
Initial signer pass
1 parent bd17d3e commit 42872bc

File tree

7 files changed

+195
-12
lines changed

7 files changed

+195
-12
lines changed

codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,8 @@ async def _handle_attempt(
548548

549549
writer.pushState(new SignRequestSection());
550550
if (context.applicationProtocol().isHttpProtocol() && supportsAuth) {
551+
writer.addStdlibImport("binascii", "hexlify");
552+
writer.addStdlibImport("re");
551553
writer.write("""
552554
# Step 7i: sign the request
553555
if auth_option and signer:
@@ -561,6 +563,14 @@ async def _handle_attempt(
561563
identity=identity,
562564
signing_properties=auth_option.signer_properties,
563565
)
566+
567+
# TODO - Move this to separate resolution/population function
568+
fields = context._transport_request.fields
569+
auth_value = fields["Authorization"].as_string() # type: ignore
570+
signature = re.split("Signature=", auth_value)[-1] # type: ignore
571+
context._properties["signature"] = hexlify(signature.encode('utf-8')) # type: ignore
572+
context._properties["identity"] = identity
573+
context._properties["signer_properties"] = auth_option.signer_properties
564574
logger.debug("Signed HTTP request: %s", context._transport_request)
565575
""");
566576
}

codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ public void wrapOutputStream(GenerationContext context, PythonWriter writer) {
426426
transport_response.body # type: ignore
427427
),
428428
deserializer=event_deserializer, # type: ignore
429+
signer=signer, # type: ignore
429430
)
430431
""");
431432
}

packages/aws-event-stream/src/aws_event_stream/_private/serializers.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# SPDX-License-Identifier: Apache-2.0
33
import asyncio
44
import datetime
5-
from collections.abc import Callable, Iterator
5+
from collections.abc import Iterator
66
from contextlib import contextmanager
77
from io import BytesIO
8-
from typing import Never
8+
from typing import Never, Protocol
99

1010
from smithy_core.aio.interfaces import AsyncWriter
1111
from smithy_core.codecs import Codec
@@ -20,7 +20,7 @@
2020
from smithy_core.shapes import ShapeType
2121
from smithy_event_stream.aio.interfaces import AsyncEventPublisher
2222

23-
from ..events import EventMessage, HEADER_VALUE, Short, Byte, Long
23+
from ..events import EventHeaderEncoder, EventMessage, HEADER_VALUE, Short, Byte, Long
2424
from ..exceptions import InvalidHeaderValue
2525
from . import (
2626
INITIAL_REQUEST_EVENT_TYPE,
@@ -33,16 +33,23 @@
3333
_DEFAULT_BLOB_CONTENT_TYPE = "application/octet-stream"
3434

3535

36-
type Signer = Callable[[EventMessage], EventMessage]
37-
"""A function that takes an event message and signs it, and returns it signed."""
36+
class EventSigner(Protocol):
37+
"""A signer to manage credentials and EventMessages for an Event Stream lifecyle."""
38+
39+
def sign_event(
40+
self,
41+
*,
42+
event_message: EventMessage,
43+
event_encoder_cls: type[EventHeaderEncoder],
44+
) -> EventMessage: ...
3845

3946

4047
class AWSAsyncEventPublisher[E: SerializeableShape](AsyncEventPublisher[E]):
4148
def __init__(
4249
self,
4350
payload_codec: Codec,
4451
async_writer: AsyncWriter,
45-
signer: Signer | None = None,
52+
signer: EventSigner | None = None,
4653
is_client_mode: bool = True,
4754
):
4855
self._writer = async_writer
@@ -59,7 +66,11 @@ async def send(self, event: E) -> None:
5966
"Expected an event message to be serialized, but was None."
6067
)
6168
if self._signer is not None:
62-
result = self._signer(result)
69+
encoder = self._serializer.event_header_encoder_cls
70+
result = await self._signer.sign_event(
71+
event_message=result,
72+
event_encoder_cls=encoder,
73+
)
6374
await self._writer.write(result.encode())
6475

6576
async def close(self) -> None:
@@ -80,6 +91,7 @@ def __init__(
8091
self._initial_message_event_type = INITIAL_REQUEST_EVENT_TYPE
8192
else:
8293
self._initial_message_event_type = INITIAL_RESPONSE_EVENT_TYPE
94+
self.event_header_encoder_cls = EventHeaderEncoder
8395

8496
def get_result(self) -> EventMessage | None:
8597
return self._result

packages/aws-event-stream/src/aws_event_stream/aio/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from .._private.deserializers import AWSAsyncEventReceiver as _AWSEventReceiver
1919
from .._private.serializers import AWSAsyncEventPublisher as _AWSEventPublisher
20-
from .._private.serializers import Signer
20+
from .._private.serializers import EventSigner
2121
from ..exceptions import MissingInitialResponse
2222

2323

@@ -36,7 +36,7 @@ def __init__(
3636
awaitable_response: Awaitable[Response],
3737
awaitable_output: Awaitable[R],
3838
deserializeable_response: type[R] | None = None,
39-
signer: Signer | None = None,
39+
signer: EventSigner | None = None,
4040
is_client_mode: bool = True,
4141
) -> None:
4242
"""Construct an AWSDuplexEventStream.
@@ -112,7 +112,7 @@ def __init__(
112112
payload_codec: Codec,
113113
async_writer: AsyncWriter,
114114
awaitable_output: Awaitable[R],
115-
signer: Signer | None = None,
115+
signer: EventSigner | None = None,
116116
is_client_mode: bool = True,
117117
) -> None:
118118
"""Construct an AWSInputEventStream.

packages/aws-sdk-signers/src/aws_sdk_signers/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,20 @@
99
from ._http import URI, AWSRequest, Field, Fields
1010
from ._identity import AWSCredentialIdentity
1111
from ._io import AsyncBytesReader
12-
from .signers import AsyncSigV4Signer, SigV4Signer, SigV4SigningProperties
12+
from .signers import (
13+
AsyncSigV4Signer,
14+
AsyncEventSigner,
15+
SigV4Signer,
16+
SigV4SigningProperties,
17+
)
1318

1419
__license__ = "Apache-2.0"
1520
__version__ = importlib.metadata.version("aws-sdk-signers")
1621

1722
__all__ = (
1823
"AsyncBytesReader",
1924
"AsyncSigV4Signer",
25+
"AsyncEventSigner",
2026
"AWSCredentialIdentity",
2127
"AWSRequest",
2228
"Field",
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import datetime
7+
import uuid
8+
from collections.abc import Mapping
9+
from typing import Protocol
10+
11+
12+
type HEADER_VALUE = bool | int | bytes | str | datetime.datetime | uuid.UUID
13+
"""A union of valid value types for event headers."""
14+
15+
16+
type HEADERS_DICT = Mapping[str, HEADER_VALUE]
17+
"""A dictionary of event headers."""
18+
19+
20+
class EventMessage(Protocol):
21+
"""A signable message that may be sent over an event stream."""
22+
23+
headers: HEADERS_DICT
24+
"""The headers present in the event message."""
25+
26+
payload: bytes
27+
"""The serialized bytes of the message payload."""
28+
29+
def encode(self) -> bytes:
30+
"""Encode heads and payload into bytes for transit."""
31+
...
32+
33+
34+
class EventHeaderEncoder(Protocol):
35+
"""A utility class that encodes event headers into bytes."""
36+
37+
def clear(self) -> None:
38+
"""Clear all previously encoded headers."""
39+
...
40+
41+
def get_result(self) -> bytes:
42+
"""Get all the encoded header bytes."""
43+
...
44+
45+
def encode_headers(self, headers: HEADERS_DICT) -> None:
46+
"""Encode a map of headers."""
47+
...

packages/aws-sdk-signers/src/aws_sdk_signers/signers.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import asyncio
45
import datetime
56
import hmac
67
import io
78
import warnings
89
from asyncio import iscoroutinefunction
10+
from binascii import hexlify
911
from collections.abc import AsyncIterable, Iterable
1012
from copy import deepcopy
1113
from hashlib import sha256
12-
from typing import Required, TypedDict
14+
from typing import Required, TypedDict, TYPE_CHECKING
1315
from urllib.parse import parse_qsl, quote
1416

1517
from .interfaces.io import AsyncSeekable, Seekable
@@ -19,6 +21,9 @@
1921
from ._io import AsyncBytesReader
2022
from .exceptions import AWSSDKWarning, MissingExpectedParameterException
2123

24+
if TYPE_CHECKING:
25+
from .interfaces.events import EventMessage, EventHeaderEncoder
26+
2227
HEADERS_EXCLUDED_FROM_SIGNING: tuple[str, ...] = (
2328
"accept",
2429
"accept-encoding",
@@ -774,6 +779,108 @@ async def _compute_payload_hash(
774779
return checksum.hexdigest()
775780

776781

782+
class AsyncEventSigner:
783+
def __init__(
784+
self,
785+
*,
786+
signing_properties: SigV4SigningProperties,
787+
identity: AWSCredentialIdentity,
788+
initial_signature: bytes,
789+
):
790+
self._signing_properties = signing_properties
791+
self._identity = identity
792+
self._prior_signature = initial_signature
793+
self._signing_lock = asyncio.Lock()
794+
795+
async def sign_event(
796+
self,
797+
*,
798+
event_message: "EventMessage",
799+
event_encoder_cls: type["EventHeaderEncoder"],
800+
) -> "EventMessage":
801+
async with self._signing_lock:
802+
# Copy and prepopulate any missing values in the
803+
# signing properties.
804+
new_signing_properties = SigV4SigningProperties( # type: ignore
805+
**self._signing_properties
806+
)
807+
if "date" not in new_signing_properties:
808+
date_obj = datetime.datetime.now(datetime.UTC)
809+
new_signing_properties["date"] = date_obj.strftime(
810+
SIGV4_TIMESTAMP_FORMAT
811+
)
812+
813+
timestamp = new_signing_properties["date"]
814+
headers: dict[str, str | bytes] = {":date": timestamp}
815+
encoder = event_encoder_cls()
816+
encoder.encode_headers(event_message.headers)
817+
encoded_headers = encoder.get_result()
818+
819+
string_to_sign = await self._event_string_to_sign(
820+
timestamp=timestamp,
821+
scope=self._scope(new_signing_properties),
822+
encoded_headers=encoded_headers,
823+
payload=event_message.payload,
824+
prior_signature=self._prior_signature,
825+
)
826+
event_signature = await self._sign_event(
827+
timestamp=timestamp,
828+
string_to_sign=string_to_sign,
829+
signing_properties=new_signing_properties,
830+
)
831+
headers[":chunk-signature"] = event_signature
832+
event_message.headers.update(headers) # type: ignore
833+
834+
# set new prior signature before releasing the lock
835+
self._prior_signature = event_signature
836+
837+
return event_message
838+
839+
async def _event_string_to_sign(
840+
self,
841+
*,
842+
timestamp: str,
843+
scope: str,
844+
encoded_headers: bytes,
845+
payload: bytes,
846+
prior_signature: bytes,
847+
) -> str:
848+
return (
849+
"AWS-HMAC-SHA256-PAYLOAD\n"
850+
f"{timestamp}\n"
851+
f"{scope}\n"
852+
f"{hexlify(prior_signature).decode('utf-8')}\n"
853+
f"{sha256(encoded_headers).hexdigest()}\n"
854+
f"{sha256(payload).hexdigest()}\n"
855+
)
856+
857+
async def _sign_event(
858+
self,
859+
*,
860+
timestamp: str,
861+
string_to_sign: str,
862+
signing_properties: SigV4SigningProperties,
863+
) -> bytes:
864+
key = self._identity.secret_access_key.encode("utf-8")
865+
today = timestamp[:8].encode("utf-8")
866+
k_date = self._hash(b"AWS4" + key, today)
867+
k_region = self._hash(k_date, signing_properties["region"].encode("utf-8"))
868+
k_service = self._hash(k_region, signing_properties["service"].encode("utf-8"))
869+
k_signing = self._hash(k_service, b"aws4_request")
870+
return self._hash(k_signing, string_to_sign.encode("utf-8"))
871+
872+
def _hash(self, key: bytes, msg: bytes) -> bytes:
873+
return hmac.new(key, msg, sha256).digest()
874+
875+
def _scope(self, signing_properties: SigV4SigningProperties) -> str:
876+
assert "date" in signing_properties
877+
formatted_date = signing_properties["date"][0:8]
878+
region = signing_properties["region"]
879+
service = signing_properties["service"]
880+
# Scope format: <YYYYMMDD>/<AWS Region>/<AWS Service>/aws4_request
881+
return f"{formatted_date}/{region}/{service}/aws4_request"
882+
883+
777884
def _remove_dot_segments(path: str, remove_consecutive_slashes: bool = True) -> str:
778885
"""Removes dot segments from a path per :rfc:`3986#section-5.2.4`.
779886

0 commit comments

Comments
 (0)