File tree Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -356,12 +356,23 @@ def error_on_warning():
356356 yield
357357
358358
359+ def get_physical_device_indices (devices ):
360+ visible_devices = os .environ .get ("CUDA_VISIBLE_DEVICES" )
361+ if visible_devices is None :
362+ return devices
363+
364+ visible_indices = [int (x ) for x in visible_devices .split ("," )]
365+ index_mapping = {i : physical for i , physical in enumerate (visible_indices )}
366+ return [index_mapping [i ] for i in devices if i in index_mapping ]
367+
368+
359369@_nvml ()
360370def wait_for_gpu_memory_to_clear (devices : List [int ],
361371 threshold_bytes : int ,
362372 timeout_s : float = 120 ) -> None :
363373 # Use nvml instead of pytorch to reduce measurement error from torch cuda
364374 # context.
375+ devices = get_physical_device_indices (devices )
365376 start_time = time .time ()
366377 while True :
367378 output : Dict [int , str ] = {}
You can’t perform that action at this time.
0 commit comments