|
60 | 60 | if not _version.hip: |
61 | 61 | import pynvml # type: ignore[import] |
62 | 62 | else: |
63 | | - import amdsmi # type: ignore[import] |
| 63 | + import ctypes |
| 64 | + from pathlib import Path |
| 65 | + |
| 66 | + # In ROCm (at least up through 6.3.2) there're 2 copies of libamd_smi.so: |
| 67 | + # - One at lib/libamd_smi.so |
| 68 | + # - One at share/amd_smi/amdsmi/libamd_smi.so |
| 69 | + # |
| 70 | + # The amdsmi python module hardcodes loading the second one in share- |
| 71 | + # https://github.com/ROCm/amdsmi/blob/1d305dc9708e87080f64f668402887794cd46584/py-interface/amdsmi_wrapper.py#L174 |
| 72 | + # |
| 73 | + # See also https://github.com/ROCm/amdsmi/issues/72. |
| 74 | + # |
| 75 | + # This creates an ODR violation if the copy of libamd_smi.so from lib |
| 76 | + # is also loaded (via `ld` linking, `LD_LIBRARY_PATH` or `rpath`). |
| 77 | + # |
| 78 | + # In order to avoid the violation we hook CDLL and try using the |
| 79 | + # already loaded version of amdsmi, or any version in the processes |
| 80 | + # rpath/LD_LIBRARY_PATH first, so that we only load a single copy |
| 81 | + # of the .so. |
| 82 | + class amdsmi_cdll_hook: |
| 83 | + def __init__(self) -> None: |
| 84 | + self.original_CDLL = ctypes.CDLL # type: ignore[misc,assignment] |
| 85 | + |
| 86 | + def hooked_CDLL( |
| 87 | + self, name: Union[str, Path, None], *args: Any, **kwargs: Any |
| 88 | + ) -> ctypes.CDLL: |
| 89 | + if name and Path(name).name == "libamd_smi.so": |
| 90 | + try: |
| 91 | + return self.original_CDLL("libamd_smi.so", *args, **kwargs) |
| 92 | + except OSError: |
| 93 | + pass |
| 94 | + return self.original_CDLL(name, *args, **kwargs) # type: ignore[arg-type] |
| 95 | + |
| 96 | + def __enter__(self) -> None: |
| 97 | + ctypes.CDLL = self.hooked_CDLL # type: ignore[misc,assignment] |
| 98 | + |
| 99 | + def __exit__(self, type: Any, value: Any, traceback: Any) -> None: |
| 100 | + ctypes.CDLL = self.original_CDLL # type: ignore[misc] |
| 101 | + |
| 102 | + with amdsmi_cdll_hook(): |
| 103 | + import amdsmi # type: ignore[import] |
64 | 104 |
|
65 | 105 | _HAS_PYNVML = True |
66 | 106 | except ModuleNotFoundError: |
|
0 commit comments