Skip to content

Commit 2af6c96

Browse files
authored
[BE] Refactor dependency update code (#6735)
This should enable Nvidia, Intel, AMD to submit PRs for updating the dependent libraries
1 parent 5c88794 commit 2af6c96

File tree

1 file changed

+138
-101
lines changed

1 file changed

+138
-101
lines changed

s3_management/update_dependencies.py

Lines changed: 138 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -8,100 +8,103 @@
88
CLIENT = boto3.client("s3")
99
BUCKET = S3.Bucket("pytorch")
1010

11-
PACKAGES_PER_PROJECT = {
12-
"torch": [
13-
"sympy",
14-
"mpmath",
15-
"pillow",
16-
"networkx",
17-
"numpy",
18-
"jinja2",
19-
"filelock",
20-
"fsspec",
21-
"nvidia-cudnn-cu11",
22-
"nvidia-cudnn-cu12",
23-
"typing-extensions",
24-
],
25-
"triton": [
26-
"arpeggio",
27-
"caliper-reader",
28-
"contourpy",
29-
"cycler",
30-
"dill",
31-
"fonttools",
32-
"kiwisolver",
33-
"llnl-hatchet",
34-
"matplotlib",
35-
"pandas",
36-
"pydot",
37-
"pyparsing",
38-
"pytz",
39-
"textX",
40-
"tzdata",
41-
"importlib-metadata",
42-
"importlib-resources",
43-
"zipp",
44-
],
45-
"torchtune": [
46-
"aiohttp",
47-
"aiosignal",
48-
"antlr4-python3-runtime",
49-
"attrs",
50-
"blobfile",
51-
"certifi",
52-
"charset-normalizer",
53-
"datasets",
54-
"dill",
55-
"frozenlist",
56-
"huggingface-hub",
57-
"idna",
58-
"lxml",
59-
"markupsafe",
60-
"multidict",
61-
"multiprocess",
62-
"omegaconf",
63-
"pandas",
64-
"pyarrow",
65-
"pyarrow-hotfix",
66-
"pycryptodomex",
67-
"python-dateutil",
68-
"pytz",
69-
"pyyaml",
70-
"regex",
71-
"requests",
72-
"safetensors",
73-
"sentencepiece",
74-
"six",
75-
"tiktoken",
76-
"tqdm",
77-
"tzdata",
78-
"urllib3",
79-
"xxhash",
80-
"yarl",
81-
],
82-
"torch_xpu": [
83-
"dpcpp-cpp-rt",
84-
"intel-cmplr-lib-rt",
85-
"intel-cmplr-lib-ur",
86-
"intel-cmplr-lic-rt",
87-
"intel-opencl-rt",
88-
"intel-sycl-rt",
89-
"intel-openmp",
90-
"tcmlib",
91-
"umf",
92-
"intel-pti",
93-
"tbb",
94-
"oneccl-devel",
95-
"oneccl",
96-
"impi-rt",
97-
"onemkl-sycl-blas",
98-
"onemkl-sycl-dft",
99-
"onemkl-sycl-lapack",
100-
"onemkl-sycl-sparse",
101-
"onemkl-sycl-rng",
102-
"mkl",
103-
]
104-
}
11+
PACKAGES_PER_PROJECT = [
12+
{"package": "sympy", "version": "latest", "project": "torch"},
13+
{"package": "mpmath", "version": "latest", "project": "torch"},
14+
{"package": "pillow", "version": "latest", "project": "torch"},
15+
{"package": "networkx", "version": "latest", "project": "torch"},
16+
{"package": "numpy", "version": "latest", "project": "torch"},
17+
{"package": "jinja2", "version": "latest", "project": "torch"},
18+
{"package": "filelock", "version": "latest", "project": "torch"},
19+
{"package": "fsspec", "version": "latest", "project": "torch"},
20+
{"package": "nvidia-cudnn-cu11", "version": "latest", "project": "torch"},
21+
{"package": "nvidia-cudnn-cu12", "version": "latest", "project": "torch"},
22+
{"package": "typing-extensions", "version": "latest", "project": "torch"},
23+
{"package": "nvidia-cuda-nvrtc-cu12", "version": "12.9.86", "project": "torch", "target": "cu129"},
24+
{"package": "nvidia-cuda-runtime-cu12", "version": "12.9.79", "project": "torch", "target": "cu129"},
25+
{"package": "nvidia-cuda-cupti-cu12", "version": "12.9.79", "project": "torch", "target": "cu129"},
26+
{"package": "nvidia-cublas-cu12", "version": "12.9.1.4", "project": "torch", "target": "cu129"},
27+
{"package": "nvidia-cufft-cu12", "version": "11.4.1.4", "project": "torch", "target": "cu129"},
28+
{"package": "nvidia-curand-cu12", "version": "10.3.10.19", "project": "torch", "target": "cu129"},
29+
{"package": "nvidia-cusolver-cu12", "version": "11.7.5.82", "project": "torch", "target": "cu129"},
30+
{"package": "nvidia-cusparse-cu12", "version": "12.5.10.65", "project": "torch", "target": "cu129"},
31+
{"package": "nvidia-nvtx-cu12", "version": "12.9.79", "project": "torch", "target": "cu129"},
32+
{"package": "nvidia-nvjitlink-cu12", "version": "12.9.86", "project": "torch", "target": "cu129"},
33+
{"package": "nvidia-cufile-cu12", "version": "1.14.1.1", "project": "torch", "target": "cu129"},
34+
{"package": "arpeggio", "version": "latest", "project": "triton"},
35+
{"package": "caliper-reader", "version": "latest", "project": "triton"},
36+
{"package": "contourpy", "version": "latest", "project": "triton"},
37+
{"package": "cycler", "version": "latest", "project": "triton"},
38+
{"package": "dill", "version": "latest", "project": "triton"},
39+
{"package": "fonttools", "version": "latest", "project": "triton"},
40+
{"package": "kiwisolver", "version": "latest", "project": "triton"},
41+
{"package": "llnl-hatchet", "version": "latest", "project": "triton"},
42+
{"package": "matplotlib", "version": "latest", "project": "triton"},
43+
{"package": "pandas", "version": "latest", "project": "triton"},
44+
{"package": "pydot", "version": "latest", "project": "triton"},
45+
{"package": "pyparsing", "version": "latest", "project": "triton"},
46+
{"package": "pytz", "version": "latest", "project": "triton"},
47+
{"package": "textX", "version": "latest", "project": "triton"},
48+
{"package": "tzdata", "version": "latest", "project": "triton"},
49+
{"package": "importlib-metadata", "version": "latest", "project": "triton"},
50+
{"package": "importlib-resources", "version": "latest", "project": "triton"},
51+
{"package": "zipp", "version": "latest", "project": "triton"},
52+
{"package": "aiohttp", "version": "latest", "project": "torchtune"},
53+
{"package": "aiosignal", "version": "latest", "project": "torchtune"},
54+
{"package": "antlr4-python3-runtime", "version": "latest", "project": "torchtune"},
55+
{"package": "attrs", "version": "latest", "project": "torchtune"},
56+
{"package": "blobfile", "version": "latest", "project": "torchtune"},
57+
{"package": "certifi", "version": "latest", "project": "torchtune"},
58+
{"package": "charset-normalizer", "version": "latest", "project": "torchtune"},
59+
{"package": "datasets", "version": "latest", "project": "torchtune"},
60+
{"package": "dill", "version": "latest", "project": "torchtune"},
61+
{"package": "frozenlist", "version": "latest", "project": "torchtune"},
62+
{"package": "huggingface-hub", "version": "latest", "project": "torchtune"},
63+
{"package": "idna", "version": "latest", "project": "torchtune"},
64+
{"package": "lxml", "version": "latest", "project": "torchtune"},
65+
{"package": "markupsafe", "version": "latest", "project": "torchtune"},
66+
{"package": "multidict", "version": "latest", "project": "torchtune"},
67+
{"package": "multiprocess", "version": "latest", "project": "torchtune"},
68+
{"package": "omegaconf", "version": "latest", "project": "torchtune"},
69+
{"package": "pandas", "version": "latest", "project": "torchtune"},
70+
{"package": "pyarrow", "version": "latest", "project": "torchtune"},
71+
{"package": "pyarrow-hotfix", "version": "latest", "project": "torchtune"},
72+
{"package": "pycryptodomex", "version": "latest", "project": "torchtune"},
73+
{"package": "python-dateutil", "version": "latest", "project": "torchtune"},
74+
{"package": "pytz", "version": "latest", "project": "torchtune"},
75+
{"package": "pyyaml", "version": "latest", "project": "torchtune"},
76+
{"package": "regex", "version": "latest", "project": "torchtune"},
77+
{"package": "requests", "version": "latest", "project": "torchtune"},
78+
{"package": "safetensors", "version": "latest", "project": "torchtune"},
79+
{"package": "sentencepiece", "version": "latest", "project": "torchtune"},
80+
{"package": "six", "version": "latest", "project": "torchtune"},
81+
{"package": "tiktoken", "version": "latest", "project": "torchtune"},
82+
{"package": "tqdm", "version": "latest", "project": "torchtune"},
83+
{"package": "tzdata", "version": "latest", "project": "torchtune"},
84+
{"package": "urllib3", "version": "latest", "project": "torchtune"},
85+
{"package": "xxhash", "version": "latest", "project": "torchtune"},
86+
{"package": "yarl", "version": "latest", "project": "torchtune"},
87+
{"package": "dpcpp-cpp-rt", "version": "latest", "project": "torch_xpu"},
88+
{"package": "intel-cmplr-lib-rt", "version": "latest", "project": "torch_xpu"},
89+
{"package": "intel-cmplr-lib-ur", "version": "latest", "project": "torch_xpu"},
90+
{"package": "intel-cmplr-lic-rt", "version": "latest", "project": "torch_xpu"},
91+
{"package": "intel-opencl-rt", "version": "latest", "project": "torch_xpu"},
92+
{"package": "intel-sycl-rt", "version": "latest", "project": "torch_xpu"},
93+
{"package": "intel-openmp", "version": "latest", "project": "torch_xpu"},
94+
{"package": "tcmlib", "version": "latest", "project": "torch_xpu"},
95+
{"package": "umf", "version": "latest", "project": "torch_xpu"},
96+
{"package": "intel-pti", "version": "latest", "project": "torch_xpu"},
97+
{"package": "tbb", "version": "latest", "project": "torch_xpu"},
98+
{"package": "oneccl-devel", "version": "latest", "project": "torch_xpu"},
99+
{"package": "oneccl", "version": "latest", "project": "torch_xpu"},
100+
{"package": "impi-rt", "version": "latest", "project": "torch_xpu"},
101+
{"package": "onemkl-sycl-blas", "version": "latest", "project": "torch_xpu"},
102+
{"package": "onemkl-sycl-dft", "version": "latest", "project": "torch_xpu"},
103+
{"package": "onemkl-sycl-lapack", "version": "latest", "project": "torch_xpu"},
104+
{"package": "onemkl-sycl-sparse", "version": "latest", "project": "torch_xpu"},
105+
{"package": "onemkl-sycl-rng", "version": "latest", "project": "torch_xpu"},
106+
{"package": "mkl", "version": "latest", "project": "torch_xpu"},
107+
]
105108

106109

107110
def download(url: str) -> bytes:
@@ -136,16 +139,34 @@ def get_wheels_of_version(idx: Dict[str, str], version: str) -> Dict[str, str]:
136139

137140

138141
def upload_missing_whls(
139-
pkg_name: str = "numpy", prefix: str = "whl/test", *, dry_run: bool = False, only_pypi: bool = False
142+
pkg_name: str = "numpy",
143+
prefix: str = "whl/test", *,
144+
dry_run: bool = False,
145+
only_pypi: bool = False,
146+
target_version: str = "latest"
140147
) -> None:
141148
pypi_idx = parse_simple_idx(f"https://pypi.org/simple/{pkg_name}")
142149
pypi_versions = get_whl_versions(pypi_idx)
143-
pypi_latest_packages = get_wheels_of_version(pypi_idx, pypi_versions[-1])
150+
151+
# Determine which version to use
152+
if target_version == "latest" or not target_version:
153+
selected_version = pypi_versions[-1] if pypi_versions else None
154+
elif target_version in pypi_versions:
155+
selected_version = target_version
156+
else:
157+
print(f"Warning: Version {target_version} not found for {pkg_name}, using latest")
158+
selected_version = pypi_versions[-1] if pypi_versions else None
159+
160+
if not selected_version:
161+
print(f"No stable versions found for {pkg_name}")
162+
return
163+
164+
pypi_latest_packages = get_wheels_of_version(pypi_idx, selected_version)
144165

145166
download_latest_packages = []
146167
if not only_pypi:
147168
download_idx = parse_simple_idx(f"https://download.pytorch.org/{prefix}/{pkg_name}")
148-
download_latest_packages = get_wheels_of_version(download_idx, pypi_versions[-1])
169+
download_latest_packages = get_wheels_of_version(download_idx, selected_version)
149170

150171
has_updates = False
151172
for pkg in pypi_latest_packages:
@@ -163,6 +184,7 @@ def upload_missing_whls(
163184
print(f"Downloading {pkg}")
164185
if dry_run:
165186
has_updates = True
187+
print(f"Dry Run - not Uploading {pkg} to s3://pytorch/{prefix}/")
166188
continue
167189
data = download(pypi_idx[pkg])
168190
print(f"Uploading {pkg} to s3://pytorch/{prefix}/")
@@ -172,15 +194,17 @@ def upload_missing_whls(
172194
has_updates = True
173195
if not has_updates:
174196
print(
175-
f"{pkg_name} is already at latest version {pypi_versions[-1]} for {prefix}"
197+
f"{pkg_name} is already at version {selected_version} for {prefix}"
176198
)
177199

178200

179201
def main() -> None:
180202
from argparse import ArgumentParser
181203

182204
parser = ArgumentParser("Upload dependent packages to s3://pytorch")
183-
parser.add_argument("--package", choices=PACKAGES_PER_PROJECT.keys(), default="torch")
205+
# Get unique paths from the packages list
206+
project_paths = list(set(pkg["project"] for pkg in PACKAGES_PER_PROJECT))
207+
parser.add_argument("--package", choices=project_paths, default="torch")
184208
parser.add_argument("--dry-run", action="store_true")
185209
parser.add_argument("--only-pypi", action="store_true")
186210
parser.add_argument("--include-stable", action="store_true")
@@ -191,8 +215,21 @@ def main() -> None:
191215
SUBFOLDERS.append("whl")
192216

193217
for prefix in SUBFOLDERS:
194-
for package in PACKAGES_PER_PROJECT[args.package]:
195-
upload_missing_whls(package, prefix, dry_run=args.dry_run, only_pypi=args.only_pypi)
218+
# Filter packages by the selected project path
219+
selected_packages = [pkg for pkg in PACKAGES_PER_PROJECT if pkg["project"] == args.package]
220+
for pkg_info in selected_packages:
221+
if( hasattr(pkg_info, "target") and pkg_info["target"] != ""):
222+
full_path=f'{prefix}/{pkg_info["target"]}'
223+
else:
224+
full_path=f'{prefix}'
225+
226+
upload_missing_whls(
227+
pkg_info["package"],
228+
full_path,
229+
dry_run=args.dry_run,
230+
only_pypi=args.only_pypi,
231+
target_version=pkg_info["version"]
232+
)
196233

197234

198235
if __name__ == "__main__":

0 commit comments

Comments
 (0)