Skip to content

Commit 5ca2dcf

Browse files
authored
Merge pull request #18 from tweak-wtf/enhancement/dynamic-torch-pypi-index
Dynamic torch pypi index urls
2 parents cc4d461 + 16388c9 commit 5ca2dcf

File tree

3 files changed

+73
-2
lines changed

3 files changed

+73
-2
lines changed

client/ayon_comfyui/hooks/pre_launch.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,49 @@ def pre_process(self):
8282
cache_tmpl = self.addon_settings["caching"]["cache_dir_template"]
8383
self.cache_dir = StringTemplate(cache_tmpl).format_strict(self.tmpl_data)
8484

85+
# get installed CUDA version and build correct pypi index url
86+
try:
87+
smi_version_details = subprocess.check_output(
88+
["nvidia-smi", "--version"], text=True
89+
).strip()
90+
except subprocess.CalledProcessError as e:
91+
log.error(f"Failed to execute `nvidia-smi`: {e} Please ensure NVIDIA drivers are installed.")
92+
93+
cuda_version = None
94+
for line in smi_version_details.splitlines():
95+
if "CUDA Version" in line:
96+
parts = line.split(":")
97+
cuda_version = parts[1].strip()
98+
break
99+
if not cuda_version:
100+
log.error("Could not determine CUDA version from `nvidia-smi` output.")
101+
raise RuntimeError("CUDA version could not be determined.")
102+
103+
pypi_url_map = {
104+
"11.8": {
105+
"stable": "https://download.pytorch.org/whl/cu118",
106+
"nightly": None,
107+
},
108+
"12.6": {
109+
"stable": "https://download.pytorch.org/whl/cu126",
110+
"nightly": "https://download.pytorch.org/whl/nightly/cu126",
111+
},
112+
"12.8": {
113+
"stable": "https://download.pytorch.org/whl/cu128",
114+
"nightly": "https://download.pytorch.org/whl/nightly/cu128",
115+
},
116+
"12.9": {
117+
"stable": None,
118+
"nightly": "https://download.pytorch.org/whl/nightly/cu129",
119+
},
120+
}
121+
if bool(self.addon_settings["venv"]["use_torch_nightly"]):
122+
self.pypi_url = pypi_url_map[cuda_version]["nightly"]
123+
else:
124+
self.pypi_url = pypi_url_map[cuda_version]["stable"]
125+
126+
self.py_version = self.addon_settings["venv"]["python_version"]
127+
85128
def clone_repositories(self):
86129
def git_clone(url: str, dest: Path, tag: str = "") -> git.Repo:
87130
if not dest.exists():
@@ -182,6 +225,12 @@ def run_server(self):
182225
if self.extra_flags:
183226
launch_args.append("-extraFlags")
184227
launch_args.append(",".join(self.extra_flags))
228+
if self.pypi_url:
229+
launch_args.append("-pypiUrl")
230+
launch_args.append(self.pypi_url)
231+
if self.py_version:
232+
launch_args.append("-pythonVersion")
233+
launch_args.append(self.py_version)
185234

186235
_cmd.extend(launch_args)
187236
cmd = " ".join([str(arg) for arg in _cmd])

client/ayon_comfyui/tools/install_and_run_server_venv.ps1

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# assumes to be in comfyui directory
22
param(
33
[string]$cacheDir = "",
4+
[string]$pypiUrl = "",
5+
[string]$pythonVersion = "",
46
[string[]]$plugins = @(),
57
[string[]]$extraFlags = @(),
68
[string[]]$extraDependencies = @()
@@ -18,16 +20,18 @@ if ($cacheDir) {
1820
}
1921

2022
# create local venv
21-
uv venv --allow-existing --python 3.12
23+
uv venv --allow-existing --python $pythonVersion
2224
if (-not $?){
2325
Write-Output "Failed to create venv"
2426
exit 1
2527
}
2628
.venv\Scripts\activate
2729

2830
# Install requirements
31+
Write-Output "Installing PyTorch with CUDA support"
32+
uv pip install --pre torch torchvision torchaudio --index-url $pypiUrl
2933
Write-Output "Installing ComfyUI requirements"
30-
uv pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu126
34+
uv pip install -r requirements.txt
3135

3236
# install plugins dependencies
3337
foreach ($plugin in $plugins) {

server/settings.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,19 @@ def __init__(self, **data):
1818
self.name = Path(self.url).stem
1919

2020

21+
class VirtualEnvSettings(BaseSettingsModel):
22+
python_version: str = SettingsField(
23+
default="3.12",
24+
title="Python Version",
25+
description="Python version to use for the virtual environment.",
26+
)
27+
use_torch_nightly: bool = SettingsField(
28+
default=True,
29+
title="Use PyTorch Nightly",
30+
description="Use the nightly version of PyTorch.",
31+
)
32+
33+
2134
class CustomNodeSettings(RepositorySettings):
2235
extra_dependencies: list[str] = SettingsField(
2336
default_factory=list,
@@ -76,6 +89,11 @@ class AddonSettings(BaseSettingsModel):
7689
title="Extra Flags",
7790
description="Extra argument flags to pass when launching the ComfyUI server.",
7891
)
92+
venv: VirtualEnvSettings = SettingsField(
93+
default_factory=VirtualEnvSettings,
94+
title="Virtual Environment Settings",
95+
description="Virtual Environment Settings for the ComfyUI server.",
96+
)
7997
repositories: ComfyUIRepositorySettings = SettingsField(
8098
default_factory=ComfyUIRepositorySettings,
8199
title="Repository Settings",

0 commit comments

Comments
 (0)