55import hmac
66import io
77import warnings
8- from collections .abc import AsyncIterable
8+ from collections .abc import AsyncIterable , Iterable
99from copy import deepcopy
1010from hashlib import sha256
1111from typing import Required , TypedDict
1212from urllib .parse import parse_qsl , quote
1313
14+ from .interfaces .io import Seekable
1415from ._http import URI , AWSRequest , Field
1516from ._identity import AWSCredentialIdentity
1617from ._io import AsyncBytesReader
@@ -64,6 +65,8 @@ def sign(
6465 new_signing_properties = self ._normalize_signing_properties (
6566 signing_properties = signing_properties
6667 )
68+ assert "date" in new_signing_properties
69+
6770 new_request = self ._generate_new_request (request = request )
6871 self ._apply_required_fields (
6972 request = new_request ,
@@ -133,14 +136,14 @@ def _signature(
133136 service. The date, region, service and resulting signing key are individually
134137 hashed, then the composite hash is used to sign the string to sign.
135138 """
136- assert signing_properties ["date" ] is not None
137139
138140 # Components of Signing Key Calculation
139141 #
140142 # DateKey = HMAC-SHA256("AWS4"+"<SecretAccessKey>", "<YYYYMMDD>")
141143 # DateRegionKey = HMAC-SHA256(<DateKey>, "<aws-region>")
142144 # DateRegionServiceKey = HMAC-SHA256(<DateRegionKey>, "<aws-service>")
143145 # SigningKey = HMAC-SHA256(<DateRegionServiceKey>, "aws4_request")
146+ assert "date" in signing_properties
144147 k_date = self ._hash (
145148 key = f"AWS4{ secret_key } " .encode (), value = signing_properties ["date" ][0 :8 ]
146149 )
@@ -155,7 +158,7 @@ def _hash(self, key: bytes, value: str) -> bytes:
155158
156159 def _validate_identity (self , * , identity : AWSCredentialIdentity ) -> None :
157160 """Perform runtime and expiration checks before attempting signing."""
158- if not isinstance (identity , AWSCredentialIdentity ):
161+ if not isinstance (identity , AWSCredentialIdentity ): # pyright: ignore
159162 raise ValueError (
160163 "Received unexpected value for identity parameter. Expected "
161164 f"AWSCredentialIdentity but received { type (identity )} ."
@@ -171,20 +174,14 @@ def _normalize_signing_properties(
171174 ) -> SigV4SigningProperties :
172175 # Create copy of signing properties to avoid mutating the original
173176 new_signing_properties = SigV4SigningProperties (** signing_properties )
174- new_signing_properties [ "date" ] = self . _resolve_signing_date (
175- date = new_signing_properties . get ( "date" )
176- )
177+ if "date" not in new_signing_properties :
178+ date_obj = datetime . datetime . now ( datetime . UTC )
179+ new_signing_properties [ "date" ] = date_obj . strftime ( SIGV4_TIMESTAMP_FORMAT )
177180 return new_signing_properties
178181
179182 def _generate_new_request (self , * , request : AWSRequest ) -> AWSRequest :
180183 return deepcopy (request )
181184
182- def _resolve_signing_date (self , * , date : str | None ) -> str :
183- if date is None :
184- date_obj = datetime .datetime .now (datetime .UTC )
185- date = date_obj .strftime (SIGV4_TIMESTAMP_FORMAT )
186- return date
187-
188185 def _apply_required_fields (
189186 self ,
190187 * ,
@@ -194,6 +191,7 @@ def _apply_required_fields(
194191 ) -> None :
195192 # Apply required X-Amz-Date if neither X-Amz-Date nor Date are present.
196193 if "Date" not in request .fields and "X-Amz-Date" not in request .fields :
194+ assert "date" in signing_properties
197195 request .fields .set_field (
198196 Field (name = "X-Amz-Date" , values = [signing_properties ["date" ]])
199197 )
@@ -283,6 +281,7 @@ def string_to_sign(
283281 )
284282
285283 def _scope (self , signing_properties : SigV4SigningProperties ) -> str :
284+ assert "date" in signing_properties
286285 formatted_date = signing_properties ["date" ][0 :8 ]
287286 region = signing_properties ["region" ]
288287 service = signing_properties ["service" ]
@@ -315,13 +314,13 @@ def _normalize_signing_fields(self, *, request: AWSRequest) -> dict[str, str]:
315314 }
316315 if "host" not in normalized_fields :
317316 normalized_fields ["host" ] = self ._normalize_host_field (
318- uri = request .destination
317+ uri = request .destination # type: ignore - TODO(pyright)
319318 )
320319
321320 return dict (sorted (normalized_fields .items ()))
322321
323- def _is_signable_header (self , field ):
324- if field in HEADERS_EXCLUDED_FROM_SIGNING :
322+ def _is_signable_header (self , field_name : str ):
323+ if field_name in HEADERS_EXCLUDED_FROM_SIGNING :
325324 return False
326325 return True
327326
@@ -355,12 +354,6 @@ def _format_canonical_payload(
355354 request : AWSRequest ,
356355 signing_properties : SigV4SigningProperties ,
357356 ) -> str :
358- if isinstance (request .body , AsyncIterable ):
359- raise TypeError (
360- "An async body was attached to a synchronous signer. Please use "
361- "AsyncSigV4Signer for async AWSRequests or ensure your body is "
362- "of type Iterable[bytes]."
363- )
364357 payload_hash = self ._compute_payload_hash (
365358 request = request , signing_properties = signing_properties
366359 )
@@ -383,21 +376,28 @@ def _compute_payload_hash(
383376 if body is None :
384377 return EMPTY_SHA256_HASH
385378
379+ if not isinstance (body , Iterable ):
380+ raise TypeError (
381+ "An async body was attached to a synchronous signer. Please use "
382+ "AsyncSigV4Signer for async AWSRequests or ensure your body is "
383+ "of type Iterable[bytes]."
384+ )
385+
386386 warnings .warn (
387387 "Payload signing is enabled. This may result in "
388388 "decreased performance for large request bodies." ,
389389 AWSSDKWarning ,
390390 )
391391
392392 checksum = sha256 ()
393- if hasattr (body , "seek" ) and hasattr ( body , "tell" ):
393+ if isinstance (body , Seekable ):
394394 position = body .tell ()
395- for chunk in body : # type: ignore[union-attr]
395+ for chunk in body :
396396 checksum .update (chunk )
397397 body .seek (position )
398398 else :
399399 buffer = io .BytesIO ()
400- for chunk in body : # type: ignore[union-attr]
400+ for chunk in body :
401401 buffer .write (chunk )
402402 checksum .update (chunk )
403403 buffer .seek (0 )
@@ -498,14 +498,14 @@ async def _signature(
498498 service. The date, region, service and resulting signing key are individually
499499 hashed, then the composite hash is used to sign the string to sign.
500500 """
501- assert signing_properties .get ("date" ) is not None
502501
503502 # Components of Signing Key Calculation
504503 #
505504 # DateKey = HMAC-SHA256("AWS4"+"<SecretAccessKey>", "<YYYYMMDD>")
506505 # DateRegionKey = HMAC-SHA256(<DateKey>, "<aws-region>")
507506 # DateRegionServiceKey = HMAC-SHA256(<DateRegionKey>, "<aws-service>")
508507 # SigningKey = HMAC-SHA256(<DateRegionServiceKey>, "aws4_request")
508+ assert "date" in signing_properties
509509 k_date = await self ._hash (
510510 key = f"AWS4{ secret_key } " .encode (), value = signing_properties ["date" ][0 :8 ]
511511 )
@@ -521,7 +521,7 @@ async def _hash(self, key: bytes, value: str) -> bytes:
521521
522522 async def _validate_identity (self , * , identity : AWSCredentialIdentity ) -> None :
523523 """Perform runtime and expiration checks before attempting signing."""
524- if not isinstance (identity , AWSCredentialIdentity ):
524+ if not isinstance (identity , AWSCredentialIdentity ): # pyright: ignore
525525 raise ValueError (
526526 "Received unexpected value for identity parameter. Expected "
527527 f"AWSCredentialIdentity but received { type (identity )} ."
@@ -537,20 +537,14 @@ async def _normalize_signing_properties(
537537 ) -> SigV4SigningProperties :
538538 # Create copy of signing properties to avoid mutating the original
539539 new_signing_properties = SigV4SigningProperties (** signing_properties )
540- new_signing_properties [ "date" ] = await self . _resolve_signing_date (
541- date = new_signing_properties . get ( "date" )
542- )
540+ if "date" not in new_signing_properties :
541+ date_obj = datetime . datetime . now ( datetime . UTC )
542+ new_signing_properties [ "date" ] = date_obj . strftime ( SIGV4_TIMESTAMP_FORMAT )
543543 return new_signing_properties
544544
545545 async def _generate_new_request (self , * , request : AWSRequest ) -> AWSRequest :
546546 return deepcopy (request )
547547
548- async def _resolve_signing_date (self , * , date : str | None ) -> str :
549- if date is None :
550- date_obj = datetime .datetime .now (datetime .UTC )
551- date = date_obj .strftime (SIGV4_TIMESTAMP_FORMAT )
552- return date
553-
554548 async def _apply_required_fields (
555549 self ,
556550 * ,
@@ -560,6 +554,7 @@ async def _apply_required_fields(
560554 ) -> None :
561555 # Apply required X-Amz-Date if neither X-Amz-Date nor Date are present.
562556 if "Date" not in request .fields and "X-Amz-Date" not in request .fields :
557+ assert "date" in signing_properties
563558 request .fields .set_field (
564559 Field (name = "X-Amz-Date" , values = [signing_properties ["date" ]])
565560 )
@@ -654,6 +649,7 @@ async def string_to_sign(
654649 )
655650
656651 async def _scope (self , signing_properties : SigV4SigningProperties ) -> str :
652+ assert "date" in signing_properties
657653 formatted_date = signing_properties ["date" ][0 :8 ]
658654 region = signing_properties ["region" ]
659655 service = signing_properties ["service" ]
@@ -686,13 +682,13 @@ async def _normalize_signing_fields(self, *, request: AWSRequest) -> dict[str, s
686682 }
687683 if "host" not in normalized_fields :
688684 normalized_fields ["host" ] = await self ._normalize_host_field (
689- uri = request .destination
685+ uri = request .destination # type: ignore - TODO(pyright)
690686 )
691687
692688 return dict (sorted (normalized_fields .items ()))
693689
694- def _is_signable_header (self , field ):
695- if field in HEADERS_EXCLUDED_FROM_SIGNING :
690+ def _is_signable_header (self , field_name : str ):
691+ if field_name in HEADERS_EXCLUDED_FROM_SIGNING :
696692 return False
697693 return True
698694
@@ -748,7 +744,7 @@ async def _compute_payload_hash(
748744 if body is None :
749745 return EMPTY_SHA256_HASH
750746
751- if not isinstance (request . body , AsyncIterable ):
747+ if not isinstance (body , AsyncIterable ):
752748 raise TypeError (
753749 "A sync body was attached to an asynchronous signer. Please use "
754750 "SigV4Signer for sync AWSRequests or ensure your body is "
@@ -761,14 +757,14 @@ async def _compute_payload_hash(
761757 )
762758
763759 checksum = sha256 ()
764- if hasattr (body , "seek" ) and hasattr ( body , "tell" ):
760+ if isinstance (body , Seekable ):
765761 position = body .tell ()
766- async for chunk in body : # type: ignore[union-attr]
762+ async for chunk in body :
767763 checksum .update (chunk )
768764 body .seek (position )
769765 else :
770766 buffer = io .BytesIO ()
771- async for chunk in body : # type: ignore[union-attr]
767+ async for chunk in body :
772768 buffer .write (chunk )
773769 checksum .update (chunk )
774770 buffer .seek (0 )
@@ -784,7 +780,7 @@ def _remove_dot_segments(path: str, remove_consecutive_slashes: bool = True) ->
784780 :param remove_consecutive_slashes: Whether to remove consecutive slashes.
785781 :returns: The path with dot segments removed.
786782 """
787- output = []
783+ output : list [ str ] = []
788784 for segment in path .split ("/" ):
789785 if segment == "." :
790786 continue
0 commit comments