Skip to content

Commit 0ffcb39

Browse files
authored
Support specifying a torch range
1 parent d462da2 commit 0ffcb39

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ def get_dist(pkgname):
9696
return None
9797

9898
pytorch_dep = os.getenv("TORCH_PACKAGE_NAME", "torch")
99-
if os.getenv("PYTORCH_VERSION"):
100-
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
99+
if version_pin := os.getenv("PYTORCH_VERSION"):
100+
pytorch_dep += "==" + version_pin
101+
elif (version_pin_ge := os.getenv("PYTORCH_VERSION_GE")) and (version_pin_lt := os.getenv("PYTORCH_VERSION_LT")):
102+
pytorch_dep += f">={version_pin_ge},<{version_pin_lt}"
101103

102104
requirements = [
103105
"numpy",

0 commit comments

Comments
 (0)