Skip to content

Commit bef3a1f

Browse files
authored
s3_management: Fix pyright / mypy errors add mypy linter (#6904)
I was observing these errors in my editor and they were annoying me so I fixed them. I don't expect any functional changes to come out of this. --------- Signed-off-by: Eli Uriegas <[email protected]>
1 parent 3be1e23 commit bef3a1f

File tree

4 files changed

+22
-23
lines changed

4 files changed

+22
-23
lines changed

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ include_patterns = [
4646
'tools/**/*.pyi',
4747
'torchci/**/*.py',
4848
'torchci/**/*.pyi',
49+
's3_management/*.py',
4950
]
5051
exclude_patterns = [
5152
'aws/lambda/servicelab-ingestor/**',

s3_management/backup_conda.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import argparse
88
import hashlib
99
import os
10-
import urllib
10+
import urllib.request
1111
from typing import List, Optional
1212

13-
import boto3
14-
import conda.api
13+
import boto3 # type: ignore[import-untyped]
14+
import conda.api # type: ignore[import-not-found]
1515

1616

1717
S3 = boto3.resource("s3")

s3_management/manage.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from datetime import datetime
1212
from os import makedirs, path
1313
from re import match, search, sub
14-
from typing import Dict, Iterable, List, Optional, Set, Type, TypeVar
14+
from typing import Dict, Iterable, List, Optional, Set, Type, TypeVar, Union
1515

16-
import boto3
17-
import botocore
16+
import boto3 # type: ignore[import]
17+
import botocore # type: ignore[import]
1818
from packaging.version import InvalidVersion, parse as _parse_version, Version
1919

2020

@@ -240,13 +240,13 @@ def __lt__(self, other):
240240

241241
def safe_parse_version(ver_str: str) -> Version:
242242
try:
243-
return _parse_version(ver_str)
243+
return _parse_version(ver_str) # type: ignore[return-value]
244244
except InvalidVersion:
245245
return Version("0.0.0")
246246

247247

248248
class S3Index:
249-
def __init__(self: S3IndexType, objects: List[S3Object], prefix: str) -> None:
249+
def __init__(self, objects: List[S3Object], prefix: str) -> None:
250250
self.objects = objects
251251
self.prefix = prefix.rstrip("/")
252252
self.html_name = "index.html"
@@ -256,7 +256,7 @@ def __init__(self: S3IndexType, objects: List[S3Object], prefix: str) -> None:
256256
path.dirname(obj.key) for obj in objects if path.dirname != prefix
257257
}
258258

259-
def nightly_packages_to_show(self: S3IndexType) -> List[S3Object]:
259+
def nightly_packages_to_show(self) -> List[S3Object]:
260260
"""Finding packages to show based on a threshold we specify
261261
262262
Basically takes our S3 packages, normalizes the version for easier
@@ -326,7 +326,7 @@ def get_package_names(self, subdir: Optional[str] = None) -> List[str]:
326326
{self.obj_to_package_name(obj) for obj in self.gen_file_list(subdir)}
327327
)
328328

329-
def normalize_package_version(self: S3IndexType, obj: S3Object) -> str:
329+
def normalize_package_version(self, obj: S3Object) -> str:
330330
# removes the GPU specifier from the package name as well as
331331
# unnecessary things like the file extension, architecture name, etc.
332332
return sub(r"%2B.*", "", "-".join(path.basename(obj.key).split("-")[:2]))
@@ -498,7 +498,7 @@ def compute_sha256(self) -> None:
498498
)
499499

500500
@classmethod
501-
def has_public_read(cls: Type[S3IndexType], key: str) -> bool:
501+
def has_public_read(cls, key: str) -> bool:
502502
def is_all_users_group(o) -> bool:
503503
return (
504504
o.get("Grantee", {}).get("URI")
@@ -512,11 +512,11 @@ def can_read(o) -> bool:
512512
return any(is_all_users_group(x) and can_read(x) for x in acl_grants)
513513

514514
@classmethod
515-
def grant_public_read(cls: Type[S3IndexType], key: str) -> None:
515+
def grant_public_read(cls, key: str) -> None:
516516
CLIENT.put_object_acl(Bucket=BUCKET.name, Key=key, ACL="public-read")
517517

518518
@classmethod
519-
def fetch_object_names(cls: Type[S3IndexType], prefix: str) -> List[str]:
519+
def fetch_object_names(cls, prefix: str) -> List[str]:
520520
obj_names = []
521521
for obj in BUCKET.objects.filter(Prefix=prefix):
522522
is_acceptable = any(
@@ -531,7 +531,7 @@ def fetch_object_names(cls: Type[S3IndexType], prefix: str) -> List[str]:
531531
obj_names.append(obj.key)
532532
return obj_names
533533

534-
def fetch_metadata(self: S3IndexType) -> None:
534+
def fetch_metadata(self) -> None:
535535
# Add PEP 503-compatible hashes to URLs to allow clients to avoid spurious downloads, if possible.
536536
regex_multipart_upload = r"^[A-Za-z0-9+/=]+=-[0-9]+$"
537537
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
@@ -565,17 +565,17 @@ def fetch_metadata(self: S3IndexType) -> None:
565565
if size := response.get("ContentLength"):
566566
self.objects[idx].size = int(size)
567567

568-
def fetch_pep658(self: S3IndexType) -> None:
568+
def fetch_pep658(self) -> None:
569569
def _fetch_metadata(key: str) -> str:
570570
try:
571571
response = CLIENT.head_object(
572572
Bucket=BUCKET.name, Key=f"{key}.metadata", ChecksumMode="Enabled"
573573
)
574574
sha256 = base64.b64decode(response.get("ChecksumSHA256")).hex()
575575
return sha256
576-
except botocore.exceptions.ClientError as e:
576+
except botocore.exceptions.ClientError as e: # type: ignore[attr-defined]
577577
if e.response["Error"]["Code"] == "404":
578-
return None
578+
return ""
579579
raise
580580

581581
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
@@ -592,9 +592,7 @@ def _fetch_metadata(key: str) -> str:
592592
self.objects[idx].pep658 = response
593593

594594
@classmethod
595-
def from_S3(
596-
cls: Type[S3IndexType], prefix: str, with_metadata: bool = True
597-
) -> S3IndexType:
595+
def from_S3(cls, prefix: str, with_metadata: bool = True) -> "S3Index":
598596
prefix = prefix.rstrip("/")
599597
obj_names = cls.fetch_object_names(prefix)
600598

@@ -622,7 +620,7 @@ def sanitize_key(key: str) -> str:
622620
return rc
623621

624622
@classmethod
625-
def undelete_prefix(cls: Type[S3IndexType], prefix: str) -> None:
623+
def undelete_prefix(cls, prefix: str) -> None:
626624
paginator = CLIENT.get_paginator("list_object_versions")
627625
for page in paginator.paginate(Bucket=BUCKET.name, Prefix=prefix):
628626
for obj in page.get("DeleteMarkers", []):

s3_management/update_dependencies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
from typing import Dict, List
33

4-
import boto3
4+
import boto3 # type: ignore[import-untyped]
55

66

77
S3 = boto3.resource("s3")
@@ -225,7 +225,7 @@ def upload_missing_whls(
225225

226226
pypi_latest_packages = get_wheels_of_version(pypi_idx, selected_version)
227227

228-
download_latest_packages = []
228+
download_latest_packages: Dict[str, str] = {}
229229
if not only_pypi:
230230
download_idx = parse_simple_idx(
231231
f"https://download.pytorch.org/{prefix}/{pkg_name}"

0 commit comments

Comments
 (0)