Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit c8bcae7

Browse files
authored
Fix for fractional GPU (#125)
Closes #124. Fixes device calculation to take into account fractional GPUs. But also raises a warning advising against this in the multi-worker case as sharing GPUs across workers will often lead to failures with NCCL training. Test was run manually and passes.
1 parent fac8b8e commit c8bcae7

File tree

2 files changed

+41
-6
lines changed

2 files changed

+41
-6
lines changed

ray_lightning/ray_ddp.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from contextlib import closing
55
import os
66
import socket
7+
import warnings
78

89
import numpy as np
910
import torch
@@ -137,6 +138,18 @@ def __init__(self,
137138
self.num_gpus_per_worker = int(use_gpu)
138139

139140
self.use_gpu = self.num_gpus_per_worker > 0
141+
142+
if self.use_gpu and self.num_gpus_per_worker < 1 and num_workers > 1:
143+
warnings.warn("Identified less than 1 GPU being set per worker. "
144+
"If using NCCL backend (which is the default for "
145+
"GPU training), GPU devices cannot be shared "
146+
"across processes/workers and training is likely "
147+
"to fail. It is recommended to use 1 GPU per "
148+
"worker for training, or if you must use "
149+
"fractional GPUs, then use the gloo backend by "
150+
"setting PL_TORCH_DISTRIBUTED_BACKEND=gloo "
151+
"environment variable.")
152+
140153
self.additional_resources_per_worker = resources_per_worker
141154
self.workers = []
142155
self.init_hook = init_hook
@@ -514,8 +527,9 @@ def node_rank(self) -> int:
514527
def root_device(self):
515528
if self.use_gpu and torch.cuda.is_available():
516529
if self._is_remote:
517-
# Adjust for if there are multiple GPUs per worker.
518-
device_id = self.local_rank * self.num_gpus_per_worker
530+
# Adjust to support multiple GPUs per worker or fractional
531+
# GPUs per worker.
532+
device_id = ray.get_gpu_ids()[0]
519533
return torch.device("cuda", device_id)
520534
else:
521535
# If the root device is requested on the driver, just return

ray_lightning/tests/test_ddp_gpu.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,36 @@ def on_epoch_end(self, trainer, pl_module):
8080

8181

8282
@pytest.mark.skipif(
83-
torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
84-
@pytest.mark.parametrize("num_gpus_per_worker", [1, 2])
85-
def test_correct_devices(tmpdir, ray_start_4_gpus, num_gpus_per_worker):
83+
torch.cuda.device_count() < 4, reason="test requires multi-GPU machine")
84+
@pytest.mark.parametrize("num_gpus_per_worker", [0.4, 0.5, 1, 2])
85+
def test_correct_devices(tmpdir, ray_start_4_gpus, num_gpus_per_worker,
86+
monkeypatch):
8687
"""Tests if GPU devices are correctly set."""
8788
model = BoringModel()
8889

90+
if num_gpus_per_worker < 1:
91+
monkeypatch.setenv("PL_TORCH_DISTRIBUTED_BACKEND", "gloo")
92+
93+
def get_gpu_placement(current_worker_index, num_gpus_per_worker):
94+
"""Simulates GPU resource bin packing."""
95+
next_gpu_index = 0
96+
starting_resource_count = num_gpus_per_worker
97+
for _ in range(current_worker_index + 1):
98+
current_gpu_index = next_gpu_index
99+
next_resources = starting_resource_count + \
100+
num_gpus_per_worker - 0.0001
101+
# If the next worker cannot fit on the current GPU, then we move
102+
# onto the next GPU.
103+
if int(next_resources) != current_gpu_index:
104+
increment = max(1, int(num_gpus_per_worker))
105+
next_gpu_index = current_gpu_index + increment
106+
107+
return current_gpu_index
108+
89109
class CheckDevicesCallback(Callback):
90110
def on_epoch_end(self, trainer, pl_module):
91-
assert trainer.root_gpu == trainer.local_rank * num_gpus_per_worker
111+
assert trainer.root_gpu == get_gpu_placement(
112+
trainer.local_rank, num_gpus_per_worker)
92113
assert trainer.root_gpu == pl_module.device.index
93114
assert torch.cuda.current_device() == trainer.root_gpu
94115

0 commit comments

Comments
 (0)