@@ -82,6 +82,49 @@ def pre_process(self):
8282 cache_tmpl = self .addon_settings ["caching" ]["cache_dir_template" ]
8383 self .cache_dir = StringTemplate (cache_tmpl ).format_strict (self .tmpl_data )
8484
85+ # get installed CUDA version and build correct pypi index url
86+ try :
87+ smi_version_details = subprocess .check_output (
88+ ["nvidia-smi" , "--version" ], text = True
89+ ).strip ()
90+ except subprocess .CalledProcessError as e :
91+ log .error (f"Failed to execute `nvidia-smi`: { e } Please ensure NVIDIA drivers are installed." )
92+
93+ cuda_version = None
94+ for line in smi_version_details .splitlines ():
95+ if "CUDA Version" in line :
96+ parts = line .split (":" )
97+ cuda_version = parts [1 ].strip ()
98+ break
99+ if not cuda_version :
100+ log .error ("Could not determine CUDA version from `nvidia-smi` output." )
101+ raise RuntimeError ("CUDA version could not be determined." )
102+
103+ pypi_url_map = {
104+ "11.8" : {
105+ "stable" : "https://download.pytorch.org/whl/cu118" ,
106+ "nightly" : None ,
107+ },
108+ "12.6" : {
109+ "stable" : "https://download.pytorch.org/whl/cu126" ,
110+ "nightly" : "https://download.pytorch.org/whl/nightly/cu126" ,
111+ },
112+ "12.8" : {
113+ "stable" : "https://download.pytorch.org/whl/cu128" ,
114+ "nightly" : "https://download.pytorch.org/whl/nightly/cu128" ,
115+ },
116+ "12.9" : {
117+ "stable" : None ,
118+ "nightly" : "https://download.pytorch.org/whl/nightly/cu129" ,
119+ },
120+ }
121+ if bool (self .addon_settings ["venv" ]["use_torch_nightly" ]):
122+ self .pypi_url = pypi_url_map [cuda_version ]["nightly" ]
123+ else :
124+ self .pypi_url = pypi_url_map [cuda_version ]["stable" ]
125+
126+ self .py_version = self .addon_settings ["venv" ]["python_version" ]
127+
85128 def clone_repositories (self ):
86129 def git_clone (url : str , dest : Path , tag : str = "" ) -> git .Repo :
87130 if not dest .exists ():
@@ -182,6 +225,12 @@ def run_server(self):
182225 if self .extra_flags :
183226 launch_args .append ("-extraFlags" )
184227 launch_args .append ("," .join (self .extra_flags ))
228+ if self .pypi_url :
229+ launch_args .append ("-pypiUrl" )
230+ launch_args .append (self .pypi_url )
231+ if self .py_version :
232+ launch_args .append ("-pythonVersion" )
233+ launch_args .append (self .py_version )
185234
186235 _cmd .extend (launch_args )
187236 cmd = " " .join ([str (arg ) for arg in _cmd ])
0 commit comments