@@ -26,10 +26,20 @@ def get_total_and_free_memory_in_Mb(cuda_device):
2626 "nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader"
2727 )
2828 devices_info = devices_info_str .read ().strip ().split ("\n " )
29- if "CUDA_VISIBLE_DEVICES" in os .environ :
30- visible_devices = os .environ ["CUDA_VISIBLE_DEVICES" ].split (',' )
31- cuda_device = int (visible_devices [cuda_device ])
32- total , used = devices_info [int (cuda_device )].split ("," )
29+ if len (devices_info ) > 1 :
30+ if "CUDA_VISIBLE_DEVICES" in os .environ :
31+ visible_devices = os .environ ["CUDA_VISIBLE_DEVICES" ].split (',' )
32+ cuda_device = int (visible_devices [cuda_device ])
33+ total , used = devices_info [int (cuda_device )].split ("," )
34+ else :
35+ devices_info_str = os .popen (
36+ "rocm-smi --showmeminfo vram --csv"
37+ )
38+ devices_info = devices_info_str .read ().strip ().split ("\n " )
39+ _ , total , used = devices_info [1 ].split (',' )
40+ total = int (total ) // (1024 * 1024 )
41+ used = int (used ) // (1024 * 1024 )
42+
3343 return int (total ), int (used )
3444
3545
0 commit comments