|
| 1 | +import os |
1 | 2 | from contextlib import contextmanager
|
2 |
| -from typing import Optional |
| 3 | +from typing import List, Optional |
3 | 4 |
|
4 | 5 | import torch
|
5 | 6 | import torch.distributed as dist
|
@@ -53,14 +54,20 @@ def init_custom_ar() -> None:
|
53 | 54 | return False
|
54 | 55 | # test nvlink first, this will filter out most of the cases
|
55 | 56 | # where custom allreduce is not supported
|
56 |
| - full_nvlink = _is_full_nvlink(rank, world_size) |
| 57 | + if "CUDA_VISIBLE_DEVICES" in os.environ: |
| 58 | + device_ids = list( |
| 59 | + map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(","))) |
| 60 | + else: |
| 61 | + device_ids = list(range(num_dev)) |
| 62 | + # this checks hardware and driver support for NVLink |
| 63 | + full_nvlink = _is_full_nvlink(device_ids) |
57 | 64 | if world_size > 2 and not full_nvlink:
|
58 | 65 | logger.warn(
|
59 | 66 | "Custom allreduce is disabled because it's not supported on more"
|
60 | 67 | " than two PCIe-only GPUs. To silence this warning, specify"
|
61 | 68 | " disable_custom_all_reduce=True explicitly.")
|
62 | 69 | return
|
63 |
| - # test P2P capability |
| 70 | + # test P2P capability, this checks software/cudaruntime support |
64 | 71 | # this is expensive to compute at the first time
|
65 | 72 | # then we cache the result
|
66 | 73 | if not _can_p2p(rank, world_size):
|
@@ -138,23 +145,28 @@ def _nvml():
|
138 | 145 | pynvml.nvmlShutdown()
|
139 | 146 |
|
140 | 147 |
|
141 |
| -# query if the set of gpus are fully connected by nvlink (1 hop) |
142 | 148 | @_nvml()
|
143 |
| -def _is_full_nvlink(rank, world_size): |
144 |
| - handle = pynvml.nvmlDeviceGetHandleByIndex(rank) |
145 |
| - for i in range(world_size): |
146 |
| - if i != rank: |
147 |
| - try: |
148 |
| - peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i) |
149 |
| - p2p_status = pynvml.nvmlDeviceGetP2PStatus( |
150 |
| - handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) |
151 |
| - if p2p_status != pynvml.NVML_P2P_STATUS_OK: |
| 149 | +def _is_full_nvlink(device_ids: List[int]) -> bool: |
| 150 | + """ |
| 151 | + query if the set of gpus are fully connected by nvlink (1 hop) |
| 152 | + Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`, |
| 153 | + so it works on real physical device ids. |
| 154 | + """ |
| 155 | + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids] |
| 156 | + for i, handle in enumerate(handles): |
| 157 | + for j, peer_handle in enumerate(handles): |
| 158 | + if i < j: |
| 159 | + try: |
| 160 | + p2p_status = pynvml.nvmlDeviceGetP2PStatus( |
| 161 | + handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) |
| 162 | + if p2p_status != pynvml.NVML_P2P_STATUS_OK: |
| 163 | + return False |
| 164 | + except pynvml.NVMLError as error: |
| 165 | + logger.error( |
| 166 | + "NVLink detection failed. This is normal if your" |
| 167 | + " machine has no NVLink equipped.", |
| 168 | + exc_info=error) |
152 | 169 | return False
|
153 |
| - except pynvml.NVMLError as error: |
154 |
| - logger.info( |
155 |
| - f"NVLink detection failed with message \"{str(error)}\". " |
156 |
| - "This is normal if your machine has no NVLink equipped") |
157 |
| - return False |
158 | 170 | return True
|
159 | 171 |
|
160 | 172 |
|
|
0 commit comments