Skip to content

Commit 0fd8a21

Browse files
committed
pyright fixes
1 parent 8b0ec44 commit 0fd8a21

File tree

6 files changed

+70
-59
lines changed

6 files changed

+70
-59
lines changed

packages/aws-sdk-signers/src/aws_sdk_signers/_http.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,8 @@ def __init__(
347347
self.fields = fields
348348

349349
def __deepcopy__(
350-
self, memo: dict[int, interfaces_http.Request] | None = None
351-
) -> interfaces_http.Request:
350+
self, memo: dict[int, AWSRequest] | None = None
351+
) -> AWSRequest:
352352
if memo is None:
353353
memo = {}
354354

@@ -358,7 +358,7 @@ def __deepcopy__(
358358
# the destination doesn't need to be copied because it's immutable
359359
# the body can't be copied because it's an iterator
360360
new_instance = self.__class__(
361-
destination=self.destination,
361+
destination=self.destination, # pyright: ignore [reportArgumentType]
362362
body=self.body,
363363
method=self.method,
364364
fields=deepcopy(self.fields, memo),

packages/aws-sdk-signers/src/aws_sdk_signers/_io.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, data: StreamingBlob):
3939
else:
4040
self._data = data
4141

42-
async def read(self, size: int = -1) -> bytes:
42+
async def read(self, size: int | None = -1) -> bytes:
4343
"""Read a number of bytes from the stream.
4444
4545
:param size: The maximum number of bytes to read. If less than 0, all bytes will
@@ -48,14 +48,17 @@ async def read(self, size: int = -1) -> bytes:
4848
if self._closed or not self._data:
4949
raise ValueError("I/O operation on closed file.")
5050

51-
if isinstance(self._data, ByteStream) and not iscoroutinefunction(
51+
if size is None:
52+
size = -1
53+
54+
if isinstance(self._data, ByteStream) and not iscoroutinefunction( # type: ignore - TODO(pyright)
5255
self._data.read
5356
):
5457
# Python's runtime_checkable can't actually tell the difference between
5558
# sync and async, so we have to check ourselves.
5659
return self._data.read(size)
5760

58-
if isinstance(self._data, AsyncByteStream):
61+
if isinstance(self._data, AsyncByteStream): # type: ignore - TODO(pyright)
5962
return await self._data.read(size)
6063

6164
return await self._read_from_iterable(
@@ -135,7 +138,7 @@ def __init__(self, data: StreamingBlob):
135138
if isinstance(data, bytes | bytearray):
136139
self._buffer = BytesIO(data)
137140
self._data_source = None
138-
elif isinstance(data, AsyncByteStream) and iscoroutinefunction(data.read):
141+
elif isinstance(data, AsyncByteStream) and iscoroutinefunction(data.read): # type: ignore - TODO(pyright)
139142
# Note that we need that iscoroutine check because python won't actually check
140143
# whether or not the read function is async.
141144
self._buffer = BytesIO()
@@ -144,12 +147,15 @@ def __init__(self, data: StreamingBlob):
144147
self._buffer = BytesIO()
145148
self._data_source = AsyncBytesReader(data)
146149

147-
async def read(self, size: int = -1) -> bytes:
150+
async def read(self, size: int | None = -1) -> bytes:
148151
"""Read a number of bytes from the stream.
149152
150153
:param size: The maximum number of bytes to read. If less than 0, all bytes will
151154
be read.
152155
"""
156+
if size is None:
157+
size = -1
158+
153159
if self._data_source is None or size == 0:
154160
return self._buffer.read(size)
155161

packages/aws-sdk-signers/src/aws_sdk_signers/interfaces/http.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import OrderedDict
77
from collections.abc import AsyncIterable, Iterable, Iterator
88
from enum import Enum
9-
from typing import Protocol
9+
from typing import Protocol, runtime_checkable
1010

1111

1212
class FieldPosition(Enum):
@@ -123,6 +123,7 @@ class Request(Protocol):
123123
body: AsyncIterable[bytes] | Iterable[bytes] | None
124124

125125

126+
@runtime_checkable
126127
class URI(Protocol):
127128
"""Universal Resource Identifier, target location for a :py:class:`Request`."""
128129

packages/aws-sdk-signers/src/aws_sdk_signers/interfaces/io.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,20 @@
88
class ByteStream(Protocol):
99
"""A file-like object with a read method that returns bytes."""
1010

11-
def read(self, size: int = -1) -> bytes: ...
11+
def read(self, size: int | None = -1, /) -> bytes: ...
1212

1313

1414
@runtime_checkable
1515
class AsyncByteStream(Protocol):
1616
"""A file-like object with an async read method."""
1717

18-
async def read(self, size: int = -1) -> bytes: ...
18+
async def read(self, size: int | None = -1, /) -> bytes: ...
19+
20+
21+
@runtime_checkable
22+
class Seekable(Protocol):
23+
"""A file-like object with seek and tell implemented."""
24+
25+
def seek(self, offset: int, whence: int = 0, /) -> int: ...
26+
27+
def tell(self) -> int: ...

packages/aws-sdk-signers/src/aws_sdk_signers/signers.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import hmac
66
import io
77
import warnings
8-
from collections.abc import AsyncIterable
8+
from collections.abc import AsyncIterable, Iterable
99
from copy import deepcopy
1010
from hashlib import sha256
1111
from typing import Required, TypedDict
1212
from urllib.parse import parse_qsl, quote
1313

14+
from .interfaces.io import Seekable
1415
from ._http import URI, AWSRequest, Field
1516
from ._identity import AWSCredentialIdentity
1617
from ._io import AsyncBytesReader
@@ -64,6 +65,8 @@ def sign(
6465
new_signing_properties = self._normalize_signing_properties(
6566
signing_properties=signing_properties
6667
)
68+
assert "date" in new_signing_properties
69+
6770
new_request = self._generate_new_request(request=request)
6871
self._apply_required_fields(
6972
request=new_request,
@@ -133,14 +136,14 @@ def _signature(
133136
service. The date, region, service and resulting signing key are individually
134137
hashed, then the composite hash is used to sign the string to sign.
135138
"""
136-
assert signing_properties["date"] is not None
137139

138140
# Components of Signing Key Calculation
139141
#
140142
# DateKey = HMAC-SHA256("AWS4"+"<SecretAccessKey>", "<YYYYMMDD>")
141143
# DateRegionKey = HMAC-SHA256(<DateKey>, "<aws-region>")
142144
# DateRegionServiceKey = HMAC-SHA256(<DateRegionKey>, "<aws-service>")
143145
# SigningKey = HMAC-SHA256(<DateRegionServiceKey>, "aws4_request")
146+
assert "date" in signing_properties
144147
k_date = self._hash(
145148
key=f"AWS4{secret_key}".encode(), value=signing_properties["date"][0:8]
146149
)
@@ -155,7 +158,7 @@ def _hash(self, key: bytes, value: str) -> bytes:
155158

156159
def _validate_identity(self, *, identity: AWSCredentialIdentity) -> None:
157160
"""Perform runtime and expiration checks before attempting signing."""
158-
if not isinstance(identity, AWSCredentialIdentity):
161+
if not isinstance(identity, AWSCredentialIdentity): #pyright: ignore
159162
raise ValueError(
160163
"Received unexpected value for identity parameter. Expected "
161164
f"AWSCredentialIdentity but received {type(identity)}."
@@ -171,20 +174,14 @@ def _normalize_signing_properties(
171174
) -> SigV4SigningProperties:
172175
# Create copy of signing properties to avoid mutating the original
173176
new_signing_properties = SigV4SigningProperties(**signing_properties)
174-
new_signing_properties["date"] = self._resolve_signing_date(
175-
date=new_signing_properties.get("date")
176-
)
177+
if "date" not in new_signing_properties:
178+
date_obj = datetime.datetime.now(datetime.UTC)
179+
new_signing_properties["date"] = date_obj.strftime(SIGV4_TIMESTAMP_FORMAT)
177180
return new_signing_properties
178181

179182
def _generate_new_request(self, *, request: AWSRequest) -> AWSRequest:
180183
return deepcopy(request)
181184

182-
def _resolve_signing_date(self, *, date: str | None) -> str:
183-
if date is None:
184-
date_obj = datetime.datetime.now(datetime.UTC)
185-
date = date_obj.strftime(SIGV4_TIMESTAMP_FORMAT)
186-
return date
187-
188185
def _apply_required_fields(
189186
self,
190187
*,
@@ -194,6 +191,7 @@ def _apply_required_fields(
194191
) -> None:
195192
# Apply required X-Amz-Date if neither X-Amz-Date nor Date are present.
196193
if "Date" not in request.fields and "X-Amz-Date" not in request.fields:
194+
assert "date" in signing_properties
197195
request.fields.set_field(
198196
Field(name="X-Amz-Date", values=[signing_properties["date"]])
199197
)
@@ -283,6 +281,7 @@ def string_to_sign(
283281
)
284282

285283
def _scope(self, signing_properties: SigV4SigningProperties) -> str:
284+
assert "date" in signing_properties
286285
formatted_date = signing_properties["date"][0:8]
287286
region = signing_properties["region"]
288287
service = signing_properties["service"]
@@ -315,13 +314,13 @@ def _normalize_signing_fields(self, *, request: AWSRequest) -> dict[str, str]:
315314
}
316315
if "host" not in normalized_fields:
317316
normalized_fields["host"] = self._normalize_host_field(
318-
uri=request.destination
317+
uri=request.destination # type: ignore - TODO(pyright)
319318
)
320319

321320
return dict(sorted(normalized_fields.items()))
322321

323-
def _is_signable_header(self, field):
324-
if field in HEADERS_EXCLUDED_FROM_SIGNING:
322+
def _is_signable_header(self, field_name: str):
323+
if field_name in HEADERS_EXCLUDED_FROM_SIGNING:
325324
return False
326325
return True
327326

@@ -355,12 +354,6 @@ def _format_canonical_payload(
355354
request: AWSRequest,
356355
signing_properties: SigV4SigningProperties,
357356
) -> str:
358-
if isinstance(request.body, AsyncIterable):
359-
raise TypeError(
360-
"An async body was attached to a synchronous signer. Please use "
361-
"AsyncSigV4Signer for async AWSRequests or ensure your body is "
362-
"of type Iterable[bytes]."
363-
)
364357
payload_hash = self._compute_payload_hash(
365358
request=request, signing_properties=signing_properties
366359
)
@@ -383,21 +376,28 @@ def _compute_payload_hash(
383376
if body is None:
384377
return EMPTY_SHA256_HASH
385378

379+
if not isinstance(body, Iterable):
380+
raise TypeError(
381+
"An async body was attached to a synchronous signer. Please use "
382+
"AsyncSigV4Signer for async AWSRequests or ensure your body is "
383+
"of type Iterable[bytes]."
384+
)
385+
386386
warnings.warn(
387387
"Payload signing is enabled. This may result in "
388388
"decreased performance for large request bodies.",
389389
AWSSDKWarning,
390390
)
391391

392392
checksum = sha256()
393-
if hasattr(body, "seek") and hasattr(body, "tell"):
393+
if isinstance(body, Seekable):
394394
position = body.tell()
395-
for chunk in body: # type: ignore[union-attr]
395+
for chunk in body:
396396
checksum.update(chunk)
397397
body.seek(position)
398398
else:
399399
buffer = io.BytesIO()
400-
for chunk in body: # type: ignore[union-attr]
400+
for chunk in body:
401401
buffer.write(chunk)
402402
checksum.update(chunk)
403403
buffer.seek(0)
@@ -498,14 +498,14 @@ async def _signature(
498498
service. The date, region, service and resulting signing key are individually
499499
hashed, then the composite hash is used to sign the string to sign.
500500
"""
501-
assert signing_properties.get("date") is not None
502501

503502
# Components of Signing Key Calculation
504503
#
505504
# DateKey = HMAC-SHA256("AWS4"+"<SecretAccessKey>", "<YYYYMMDD>")
506505
# DateRegionKey = HMAC-SHA256(<DateKey>, "<aws-region>")
507506
# DateRegionServiceKey = HMAC-SHA256(<DateRegionKey>, "<aws-service>")
508507
# SigningKey = HMAC-SHA256(<DateRegionServiceKey>, "aws4_request")
508+
assert "date" in signing_properties
509509
k_date = await self._hash(
510510
key=f"AWS4{secret_key}".encode(), value=signing_properties["date"][0:8]
511511
)
@@ -521,7 +521,7 @@ async def _hash(self, key: bytes, value: str) -> bytes:
521521

522522
async def _validate_identity(self, *, identity: AWSCredentialIdentity) -> None:
523523
"""Perform runtime and expiration checks before attempting signing."""
524-
if not isinstance(identity, AWSCredentialIdentity):
524+
if not isinstance(identity, AWSCredentialIdentity): #pyright: ignore
525525
raise ValueError(
526526
"Received unexpected value for identity parameter. Expected "
527527
f"AWSCredentialIdentity but received {type(identity)}."
@@ -537,20 +537,14 @@ async def _normalize_signing_properties(
537537
) -> SigV4SigningProperties:
538538
# Create copy of signing properties to avoid mutating the original
539539
new_signing_properties = SigV4SigningProperties(**signing_properties)
540-
new_signing_properties["date"] = await self._resolve_signing_date(
541-
date=new_signing_properties.get("date")
542-
)
540+
if "date" not in new_signing_properties:
541+
date_obj = datetime.datetime.now(datetime.UTC)
542+
new_signing_properties["date"] = date_obj.strftime(SIGV4_TIMESTAMP_FORMAT)
543543
return new_signing_properties
544544

545545
async def _generate_new_request(self, *, request: AWSRequest) -> AWSRequest:
546546
return deepcopy(request)
547547

548-
async def _resolve_signing_date(self, *, date: str | None) -> str:
549-
if date is None:
550-
date_obj = datetime.datetime.now(datetime.UTC)
551-
date = date_obj.strftime(SIGV4_TIMESTAMP_FORMAT)
552-
return date
553-
554548
async def _apply_required_fields(
555549
self,
556550
*,
@@ -560,6 +554,7 @@ async def _apply_required_fields(
560554
) -> None:
561555
# Apply required X-Amz-Date if neither X-Amz-Date nor Date are present.
562556
if "Date" not in request.fields and "X-Amz-Date" not in request.fields:
557+
assert "date" in signing_properties
563558
request.fields.set_field(
564559
Field(name="X-Amz-Date", values=[signing_properties["date"]])
565560
)
@@ -654,6 +649,7 @@ async def string_to_sign(
654649
)
655650

656651
async def _scope(self, signing_properties: SigV4SigningProperties) -> str:
652+
assert "date" in signing_properties
657653
formatted_date = signing_properties["date"][0:8]
658654
region = signing_properties["region"]
659655
service = signing_properties["service"]
@@ -686,13 +682,13 @@ async def _normalize_signing_fields(self, *, request: AWSRequest) -> dict[str, s
686682
}
687683
if "host" not in normalized_fields:
688684
normalized_fields["host"] = await self._normalize_host_field(
689-
uri=request.destination
685+
uri=request.destination # type: ignore - TODO(pyright)
690686
)
691687

692688
return dict(sorted(normalized_fields.items()))
693689

694-
def _is_signable_header(self, field):
695-
if field in HEADERS_EXCLUDED_FROM_SIGNING:
690+
def _is_signable_header(self, field_name: str):
691+
if field_name in HEADERS_EXCLUDED_FROM_SIGNING:
696692
return False
697693
return True
698694

@@ -748,7 +744,7 @@ async def _compute_payload_hash(
748744
if body is None:
749745
return EMPTY_SHA256_HASH
750746

751-
if not isinstance(request.body, AsyncIterable):
747+
if not isinstance(body, AsyncIterable):
752748
raise TypeError(
753749
"A sync body was attached to an asynchronous signer. Please use "
754750
"SigV4Signer for sync AWSRequests or ensure your body is "
@@ -761,14 +757,14 @@ async def _compute_payload_hash(
761757
)
762758

763759
checksum = sha256()
764-
if hasattr(body, "seek") and hasattr(body, "tell"):
760+
if isinstance(body, Seekable):
765761
position = body.tell()
766-
async for chunk in body: # type: ignore[union-attr]
762+
async for chunk in body:
767763
checksum.update(chunk)
768764
body.seek(position)
769765
else:
770766
buffer = io.BytesIO()
771-
async for chunk in body: # type: ignore[union-attr]
767+
async for chunk in body:
772768
buffer.write(chunk)
773769
checksum.update(chunk)
774770
buffer.seek(0)
@@ -784,7 +780,7 @@ def _remove_dot_segments(path: str, remove_consecutive_slashes: bool = True) ->
784780
:param remove_consecutive_slashes: Whether to remove consecutive slashes.
785781
:returns: The path with dot segments removed.
786782
"""
787-
output = []
783+
output: list[str] = []
788784
for segment in path.split("/"):
789785
if segment == ".":
790786
continue

0 commit comments

Comments
 (0)