Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit c13bbb1

Browse files
authored
Torch Distributed Configurable backend + env var init (#26)
* support configurable torch distributed backends * formatting * use env var initialization * re-enable examples tests * fixes * update classifier model * update horovod test * formatting * update * increase test timeout * change back to tmpdir * increase timeout more
1 parent 1181052 commit c13bbb1

File tree

9 files changed

+106
-43
lines changed

9 files changed

+106
-43
lines changed

.github/workflows/test.yaml

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
./format.sh --all
2323
test_linux_ray_master:
2424
runs-on: ubuntu-latest
25-
timeout-minutes: 12
25+
timeout-minutes: 25
2626
steps:
2727
- uses: actions/checkout@v2
2828
- name: Set up Python 3.7
@@ -49,7 +49,7 @@ jobs:
4949
5050
test_linux_ray_master_examples:
5151
runs-on: ubuntu-latest
52-
timeout-minutes: 12
52+
timeout-minutes: 15
5353
steps:
5454
- uses: actions/checkout@v2
5555
- name: Set up Python 3.7
@@ -70,15 +70,15 @@ jobs:
7070
- name: Run Examples
7171
run: |
7272
pushd examples/
73-
# echo "running ray_ddp_example.py" && python ray_ddp_example.py --smoke-test
74-
# echo "running ray_ddp_example.py with Tune" && python ray_ddp_example.py --smoke-test --tune
75-
# echo "running ray_ddp_tune.py" && python ray_ddp_tune.py --smoke-test
76-
# echo "running ray_horovod_example.py" && python ray_horovod_example.py --smoke-test
77-
# echo "running ray_horovod_example.py with Tune" && python ray_horovod_example.py --smoke-test --tune
73+
echo "running ray_ddp_example.py" && python ray_ddp_example.py --smoke-test
74+
echo "running ray_ddp_example.py with Tune" && python ray_ddp_example.py --smoke-test --tune
75+
echo "running ray_ddp_tune.py" && python ray_ddp_tune.py --smoke-test
76+
echo "running ray_horovod_example.py" && python ray_horovod_example.py --smoke-test
77+
echo "running ray_horovod_example.py with Tune" && python ray_horovod_example.py --smoke-test --tune
7878
7979
test_linux_ray_release:
8080
runs-on: ubuntu-latest
81-
timeout-minutes: 12
81+
timeout-minutes: 25
8282
steps:
8383
- uses: actions/checkout@v2
8484
- name: Set up Python 3.7
@@ -106,7 +106,7 @@ jobs:
106106
107107
test_linux_ray_release_examples:
108108
runs-on: ubuntu-latest
109-
timeout-minutes: 12
109+
timeout-minutes: 15
110110
steps:
111111
- uses: actions/checkout@v2
112112
- name: Set up Python 3.7
@@ -127,8 +127,8 @@ jobs:
127127
- name: Run Examples
128128
run: |
129129
pushd examples/
130-
# echo "running ray_ddp_example.py" && python ray_ddp_example.py --smoke-test
131-
# echo "running ray_ddp_example.py with Tune" && python ray_ddp_example.py --smoke-test --tune
132-
# echo "running ray_ddp_tune.py" && python ray_ddp_tune.py --smoke-test
133-
# echo "running ray_horovod_example.py" && python ray_horovod_example.py --smoke-test
134-
# echo "running ray_horovod_example.py with Tune" && python ray_horovod_example.py --smoke-test --tune
130+
echo "running ray_ddp_example.py" && python ray_ddp_example.py --smoke-test
131+
echo "running ray_ddp_example.py with Tune" && python ray_ddp_example.py --smoke-test --tune
132+
echo "running ray_ddp_tune.py" && python ray_ddp_tune.py --smoke-test
133+
echo "running ray_horovod_example.py" && python ray_horovod_example.py --smoke-test
134+
echo "running ray_horovod_example.py with Tune" && python ray_horovod_example.py --smoke-test --tune

examples/ray_ddp_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
import ray
1212
from ray import tune
13-
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier
1413
from ray_lightning.tune import TuneReportCallback
1514
from ray_lightning import RayPlugin
15+
from ray_lightning.tests.utils import LightningMNISTClassifier
1616

1717

1818
class MNISTClassifier(LightningMNISTClassifier):

examples/ray_ddp_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import pytorch_lightning as pl
88
import ray
99
from ray import tune
10-
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier
1110
from ray_lightning.tune import TuneReportCallback
1211
from ray_lightning import RayPlugin
12+
from ray_lightning.tests.utils import LightningMNISTClassifier
1313

1414

1515
def train_mnist(config,

examples/ray_horovod_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
import ray
1212
from ray import tune
13-
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier
1413
from ray_lightning.tune import TuneReportCallback
1514
from ray_lightning import HorovodRayPlugin
15+
from ray_lightning.tests.utils import LightningMNISTClassifier
1616

1717

1818
class MNISTClassifier(LightningMNISTClassifier):

ray_lightning/ray_ddp.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Dict
1+
from typing import Callable, Dict, List
22

33
import os
44
from collections import defaultdict
@@ -7,7 +7,7 @@
77
import torch
88
from pytorch_lightning.plugins import DDPSpawnPlugin
99
from pytorch_lightning import _logger as log, LightningModule
10-
from ray.util.sgd.torch.utils import setup_address
10+
from ray.util.sgd.utils import find_free_port
1111

1212
from ray_lightning.session import init_session
1313
from ray_lightning.util import process_results, Queue
@@ -20,7 +20,15 @@ class RayExecutor:
2020

2121
def set_env_var(self, key: str, value: str):
2222
"""Set an environment variable with the provided values."""
23-
os.environ[key] = value
23+
if value is not None:
24+
value = str(value)
25+
os.environ[key] = value
26+
27+
def set_env_vars(self, keys: List[str], values: List[str]):
28+
"""Sets multiple env vars with the provided values"""
29+
assert len(keys) == len(values)
30+
for key, value in zip(keys, values):
31+
self.set_env_var(key, value)
2432

2533
def get_node_ip(self):
2634
"""Returns the IP address of the node that this Ray actor is on."""
@@ -137,16 +145,19 @@ def start_training(self, trainer):
137145
revieve intermediate results, and process those results. Finally
138146
retrieve the training results from the rank 0 worker and return."""
139147

140-
if "PL_GLOBAL_SEED" in os.environ:
141-
seed = os.environ["PL_GLOBAL_SEED"]
142-
ray.get([
143-
w.set_env_var.remote("PL_GLOBAL_SEED", seed)
144-
for w in self.workers
145-
])
148+
# Get rank 0 worker address and port for DDP connection.
149+
os.environ["MASTER_ADDR"] = ray.get(
150+
self.workers[0].get_node_ip.remote())
151+
os.environ["MASTER_PORT"] = str(
152+
ray.get(self.workers[0].execute.remote(find_free_port)))
146153

147-
# Get the rank 0 address for DDP connection.
148-
self.ddp_address = ray.get(
149-
self.workers[0].execute.remote(setup_address))
154+
# Set environment variables for remote workers.
155+
keys = [
156+
"PL_GLOBAL_SEED", "PL_TORCH_DISTRIBUTED_BACKEND", "MASTER_ADDR",
157+
"MASTER_PORT"
158+
]
159+
values = [os.getenv(k) for k in keys]
160+
ray.get([w.set_env_vars.remote(keys, values) for w in self.workers])
150161

151162
self.global_to_local = self.get_local_ranks()
152163

@@ -235,14 +246,15 @@ def init_ddp_connection(self,
235246
world_size: int,
236247
is_slurm_managing_tasks: bool = False) -> None:
237248
"""Process group creation to be executed on each remote worker."""
238-
torch_backend = "nccl" if self.use_gpu else "gloo"
249+
torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND")
250+
if torch_backend is None:
251+
torch_backend = "nccl" if self.use_gpu else "gloo"
239252

240253
if not torch.distributed.is_initialized():
241254
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER:"
242255
f" {global_rank + 1}/{world_size}")
243256
torch.distributed.init_process_group(
244257
backend=torch_backend,
245-
init_method=self.ddp_address,
246258
rank=global_rank,
247259
world_size=world_size,
248260
)

ray_lightning/tests/test_ddp.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
from pytorch_lightning.callbacks import EarlyStopping
88

99
import ray
10-
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier
1110

1211
from ray_lightning import RayPlugin
1312
from ray_lightning.tests.utils import get_trainer, train_test, \
14-
load_test, predict_test, BoringModel
13+
load_test, predict_test, BoringModel, LightningMNISTClassifier
1514

1615

1716
@pytest.fixture
@@ -95,7 +94,6 @@ def test_load(tmpdir, ray_start_2_cpus, num_workers):
9594
load_test(trainer, model)
9695

9796

98-
@pytest.mark.skip("Skip until next torchvision release.")
9997
@pytest.mark.parametrize("num_workers", [1, 2])
10098
def test_predict(tmpdir, ray_start_2_cpus, seed, num_workers):
10199
"""Tests if trained model has high accuracy on test set."""
@@ -105,12 +103,13 @@ def test_predict(tmpdir, ray_start_2_cpus, seed, num_workers):
105103
"lr": 1e-2,
106104
"batch_size": 32,
107105
}
106+
108107
model = LightningMNISTClassifier(config, tmpdir)
109108
dm = MNISTDataModule(
110109
data_dir=tmpdir, num_workers=1, batch_size=config["batch_size"])
111110
plugin = RayPlugin(num_workers=num_workers, use_gpu=False)
112111
trainer = get_trainer(
113-
tmpdir, limit_train_batches=10, max_epochs=1, plugins=[plugin])
112+
tmpdir, limit_train_batches=20, max_epochs=1, plugins=[plugin])
114113
predict_test(trainer, model, dm)
115114

116115

ray_lightning/tests/test_ddp_gpu.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
from pytorch_lightning import Callback
88

99
import ray
10-
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier
1110

1211
from ray_lightning import RayPlugin
1312
from ray_lightning.tests.utils import get_trainer, train_test, BoringModel, \
14-
predict_test
13+
predict_test, LightningMNISTClassifier
1514

1615

1716
@pytest.fixture
@@ -54,7 +53,7 @@ def test_predict(tmpdir, ray_start_2_gpus, seed, num_workers):
5453
plugin = RayPlugin(num_workers=num_workers, use_gpu=True)
5554
trainer = get_trainer(
5655
tmpdir,
57-
limit_train_batches=10,
56+
limit_train_batches=20,
5857
max_epochs=1,
5958
plugins=[plugin],
6059
use_gpu=True)

ray_lightning/tests/test_horovod.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
HOROVOD_AVAILABLE = True
1414

1515
import ray
16-
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier
1716

1817
from ray_lightning import HorovodRayPlugin
1918
from ray_lightning.tests.utils import get_trainer, BoringModel, \
20-
train_test, load_test, predict_test
19+
train_test, load_test, predict_test, LightningMNISTClassifier
2120

2221

2322
def _nccl_available():
@@ -66,7 +65,6 @@ def test_load(tmpdir, ray_start_2_cpus, seed, num_slots):
6665
load_test(trainer, model)
6766

6867

69-
@pytest.mark.skip("Skip until next torchvision release.")
7068
@pytest.mark.parametrize("num_slots", [1, 2])
7169
def test_predict(tmpdir, ray_start_2_cpus, seed, num_slots):
7270
"""Tests if trained model has high accuracy on test set."""
@@ -81,7 +79,7 @@ def test_predict(tmpdir, ray_start_2_cpus, seed, num_slots):
8179
data_dir=tmpdir, num_workers=1, batch_size=config["batch_size"])
8280
plugin = HorovodRayPlugin(num_slots=num_slots, use_gpu=False)
8381
trainer = get_trainer(
84-
tmpdir, limit_train_batches=10, max_epochs=1, plugins=[plugin])
82+
tmpdir, limit_train_batches=20, max_epochs=1, plugins=[plugin])
8583
predict_test(trainer, model, dm)
8684

8785

@@ -130,7 +128,7 @@ def test_predict_gpu(tmpdir, ray_start_2_gpus, seed, num_slots):
130128
plugin = HorovodRayPlugin(num_slots=num_slots, use_gpu=True)
131129
trainer = get_trainer(
132130
tmpdir,
133-
limit_train_batches=10,
131+
limit_train_batches=20,
134132
max_epochs=1,
135133
plugins=[plugin],
136134
use_gpu=True)

ray_lightning/tests/utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import os
12
from typing import Optional, List
23

34
import torch
5+
import torch.nn.functional as F
46
from pytorch_lightning.plugins import Plugin
57
from torch.utils.data import Dataset
68

@@ -91,6 +93,58 @@ def on_load_checkpoint(self, checkpoint) -> None:
9193
self.val_epoch = checkpoint["val_epoch"]
9294

9395

96+
class LightningMNISTClassifier(pl.LightningModule):
97+
def __init__(self, config, data_dir=None):
98+
super(LightningMNISTClassifier, self).__init__()
99+
100+
self.data_dir = data_dir or os.getcwd()
101+
self.lr = config["lr"]
102+
layer_1, layer_2 = config["layer_1"], config["layer_2"]
103+
self.batch_size = config["batch_size"]
104+
105+
# mnist images are (1, 28, 28) (channels, width, height)
106+
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
107+
self.layer_2 = torch.nn.Linear(layer_1, layer_2)
108+
self.layer_3 = torch.nn.Linear(layer_2, 10)
109+
self.accuracy = pl.metrics.Accuracy()
110+
111+
def forward(self, x):
112+
batch_size, channels, width, height = x.size()
113+
x = x.view(batch_size, -1)
114+
x = self.layer_1(x)
115+
x = torch.relu(x)
116+
x = self.layer_2(x)
117+
x = torch.relu(x)
118+
x = self.layer_3(x)
119+
x = F.softmax(x, dim=1)
120+
return x
121+
122+
def configure_optimizers(self):
123+
return torch.optim.Adam(self.parameters(), lr=self.lr)
124+
125+
def training_step(self, train_batch, batch_idx):
126+
x, y = train_batch
127+
logits = self.forward(x)
128+
loss = F.nll_loss(logits, y)
129+
acc = self.accuracy(logits, y)
130+
self.log("ptl/train_loss", loss)
131+
self.log("ptl/train_accuracy", acc)
132+
return loss
133+
134+
def validation_step(self, val_batch, batch_idx):
135+
x, y = val_batch
136+
logits = self.forward(x)
137+
loss = F.nll_loss(logits, y)
138+
acc = self.accuracy(logits, y)
139+
return {"val_loss": loss, "val_accuracy": acc}
140+
141+
def validation_epoch_end(self, outputs):
142+
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
143+
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
144+
self.log("ptl/val_loss", avg_loss)
145+
self.log("ptl/val_accuracy", avg_acc)
146+
147+
94148
def get_trainer(dir,
95149
plugins: List[Plugin],
96150
use_gpu: bool = False,
@@ -139,6 +193,7 @@ def predict_test(trainer: Trainer, model: LightningModule,
139193
dm: LightningDataModule):
140194
"""Checks if the trained model has high accuracy on the test set."""
141195
trainer.fit(model, datamodule=dm)
196+
model = trainer.lightning_module
142197
dm.setup(stage="test")
143198
test_loader = dm.test_dataloader()
144199
acc = pl.metrics.Accuracy()

0 commit comments

Comments
 (0)