Skip to content

Commit fe19bf1

Browse files
committed
fix: crashes when certain nvml features not available, closes #55
1 parent 49e5ac3 commit fe19bf1

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

dmlcloud/core/callbacks.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -576,27 +576,34 @@ class CudaCallback(Callback):
576576
Logs various properties pertaining to CUDA devices.
577577
"""
578578

579+
@staticmethod
580+
def _call_pynvml(method, *args, **kwargs):
581+
try:
582+
return method(*args, **kwargs)
583+
except pynvml.NVMLError:
584+
return None
585+
579586
def pre_run(self, pipe):
580587
handle = torch.cuda._get_pynvml_handler(pipe.device)
581588

582589
info = {
583-
'name': pynvml.nvmlDeviceGetName(handle),
584-
'uuid': pynvml.nvmlDeviceGetUUID(handle),
585-
'serial': pynvml.nvmlDeviceGetSerial(handle),
590+
'name': self._call_pynvml(pynvml.nvmlDeviceGetName, handle),
591+
'uuid': self._call_pynvml(pynvml.nvmlDeviceGetUUID, handle),
592+
'serial': self._call_pynvml(pynvml.nvmlDeviceGetSerial, handle),
586593
'torch_device': str(pipe.device),
587-
'minor_number': pynvml.nvmlDeviceGetMinorNumber(handle),
588-
'architecture': pynvml.nvmlDeviceGetArchitecture(handle),
589-
'brand': pynvml.nvmlDeviceGetBrand(handle),
590-
'vbios_version': pynvml.nvmlDeviceGetVbiosVersion(handle),
591-
'driver_version': pynvml.nvmlSystemGetDriverVersion(),
592-
'cuda_driver_version': pynvml.nvmlSystemGetCudaDriverVersion_v2(),
593-
'nvml_version': pynvml.nvmlSystemGetNVMLVersion(),
594-
'total_memory': pynvml.nvmlDeviceGetMemoryInfo(handle, pynvml.nvmlMemory_v2).total,
595-
'reserved_memory': pynvml.nvmlDeviceGetMemoryInfo(handle, pynvml.nvmlMemory_v2).reserved,
596-
'num_gpu_cores': pynvml.nvmlDeviceGetNumGpuCores(handle),
597-
'power_managment_limit': pynvml.nvmlDeviceGetPowerManagementLimit(handle),
598-
'power_managment_default_limit': pynvml.nvmlDeviceGetPowerManagementDefaultLimit(handle),
599-
'cuda_compute_capability': pynvml.nvmlDeviceGetCudaComputeCapability(handle),
594+
'minor_number': self._call_pynvml(pynvml.nvmlDeviceGetMinorNumber, handle),
595+
'architecture': self._call_pynvml(pynvml.nvmlDeviceGetArchitecture, handle),
596+
'brand': self._call_pynvml(pynvml.nvmlDeviceGetBrand, handle),
597+
'vbios_version': self._call_pynvml(pynvml.nvmlDeviceGetVbiosVersion, handle),
598+
'driver_version': self._call_pynvml(pynvml.nvmlSystemGetDriverVersion),
599+
'cuda_driver_version': self._call_pynvml(pynvml.nvmlSystemGetCudaDriverVersion_v2),
600+
'nvml_version': self._call_pynvml(pynvml.nvmlSystemGetNVMLVersion),
601+
'total_memory': self._call_pynvml(pynvml.nvmlDeviceGetMemoryInfo, handle, pynvml.nvmlMemory_v2).total,
602+
'reserved_memory': self._call_pynvml(pynvml.nvmlDeviceGetMemoryInfo, handle, pynvml.nvmlMemory_v2).reserved,
603+
'num_gpu_cores': self._call_pynvml(pynvml.nvmlDeviceGetNumGpuCores, handle),
604+
'power_managment_limit': self._call_pynvml(pynvml.nvmlDeviceGetPowerManagementLimit, handle),
605+
'power_managment_default_limit': self._call_pynvml(pynvml.nvmlDeviceGetPowerManagementDefaultLimit, handle),
606+
'cuda_compute_capability': self._call_pynvml(pynvml.nvmlDeviceGetCudaComputeCapability, handle),
600607
}
601608
all_devices = all_gather_object(info)
602609

0 commit comments

Comments
 (0)