Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 6aed848

Browse files
authored
ray_ddp: support logged_metrics as part of remote worker return value (#156)
1 parent 771d5a4 commit 6aed848

File tree

4 files changed

+107
-3
lines changed

4 files changed

+107
-3
lines changed

ray_lightning/ray_ddp.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,14 +370,20 @@ def post_dispatch(self, trainer: "pl.Trainer"):
370370

371371
results = ray.get(self._futures)
372372
# Get the results, checkpoint path, and model weights from worker 0.
373-
results, best_path, state_stream, callback_metrics = results[0]
373+
results, best_path, state_stream, callback_metrics, logged_metrics \
374+
= results[0]
374375
self._results = results
375376

376377
# From DDPSpawn.get_queue
377378
self.lightning_module.trainer.callback_metrics.update(
378379
apply_to_collection(callback_metrics,
379380
np.ndarray, lambda x: torch.tensor(x)))
380381

382+
# Same for logged_metrics
383+
self.lightning_module.trainer.logged_metrics.update(
384+
apply_to_collection(logged_metrics,
385+
np.ndarray, lambda x: torch.tensor(x)))
386+
381387
# DDPSpawnPlugin.__recover_child_process_weights begin
382388
# Difference here is that instead of writing the model weights to a
383389
# file and loading it, we use the state dict of the model directly.
@@ -500,8 +506,14 @@ def execute_remote(self,
500506
torch.Tensor, lambda x: x.cpu().numpy(
501507
)) # send as numpy to avoid issues with memory sharing
502508

509+
# Same for logged_metrics
510+
logged_metrics: dict = apply_to_collection(
511+
self.lightning_module.trainer.logged_metrics,
512+
torch.Tensor, lambda x: x.cpu().numpy(
513+
)) # send as numpy to avoid issues with memory sharing
514+
503515
return_val = results, best_model_path, model_state_stream, \
504-
callback_metrics
516+
callback_metrics, logged_metrics
505517
else:
506518
return_val = None
507519
# __transfer_distrib_spawn_state_on_fit_end end

ray_lightning/tests/test_ddp.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
from ray.util.client.ray_client_helpers import ray_start_client_server
3+
import torch
34
from torch.utils.data import DistributedSampler
45

56
from pl_bolts.datamodules import MNISTDataModule
@@ -12,7 +13,8 @@
1213

1314
from ray_lightning import RayPlugin
1415
from ray_lightning.tests.utils import get_trainer, train_test, \
15-
load_test, predict_test, BoringModel, LightningMNISTClassifier
16+
load_test, predict_test, BoringModel, LightningMNISTClassifier, \
17+
XORModel, XORDataModule
1618

1719

1820
@pytest.fixture
@@ -319,3 +321,30 @@ def on_train_start(self, trainer, pl_module):
319321
trainer = get_trainer(
320322
tmpdir, plugins=[plugin], callbacks=[UnusedParameterCallback()])
321323
trainer.fit(model)
324+
325+
326+
def test_metrics(tmpdir, ray_start_2_cpus):
327+
"""Tests if metrics are returned correctly"""
328+
model = XORModel()
329+
plugin = RayPlugin(num_workers=2, find_unused_parameters=False)
330+
trainer = get_trainer(
331+
tmpdir,
332+
plugins=[plugin],
333+
max_epochs=1,
334+
num_sanity_val_steps=0,
335+
reload_dataloaders_every_n_epochs=1)
336+
dataset = XORDataModule()
337+
trainer.fit(model, dataset)
338+
callback_metrics = trainer.callback_metrics
339+
logged_metrics = trainer.logged_metrics
340+
assert callback_metrics["avg_val_loss"] == logged_metrics["avg_val_loss"]
341+
assert logged_metrics["val_foo"] == torch.tensor(1.234)
342+
assert callback_metrics["val_foo"] == torch.tensor(1.234)
343+
# forked name is used for on_step logged metrics
344+
forked_name_loss = "val_loss" + "_step"
345+
forked_name_bar = "val_bar" + "_step"
346+
assert forked_name_loss in logged_metrics.keys()
347+
assert logged_metrics[forked_name_bar] == torch.tensor(5.678)
348+
# callback_metrics doesn't record on_step metrics
349+
assert forked_name_loss not in callback_metrics.keys()
350+
assert forked_name_bar not in callback_metrics.keys()

ray_lightning/tests/utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,68 @@ def validation_epoch_end(self, outputs):
148148
self.log("ptl/val_accuracy", avg_acc)
149149

150150

151+
class XORModel(LightningModule):
152+
def __init__(self, input_dim=2, output_dim=1):
153+
super(XORModel, self).__init__()
154+
self.save_hyperparameters()
155+
self.lin1 = torch.nn.Linear(input_dim, 8)
156+
self.lin2 = torch.nn.Linear(8, output_dim)
157+
158+
def forward(self, features):
159+
x = features.float()
160+
x = self.lin1(x)
161+
x = torch.tanh(x)
162+
x = self.lin2(x)
163+
x = torch.sigmoid(x)
164+
return x
165+
166+
def configure_optimizers(self):
167+
return torch.optim.Adam(self.parameters(), lr=0.02)
168+
169+
def training_step(self, batch, batch_nb):
170+
x, y = batch["x"], batch["y"].unsqueeze(1)
171+
y_hat = self(x)
172+
loss = F.binary_cross_entropy(y_hat, y.float())
173+
return loss
174+
175+
def validation_step(self, batch, batch_nb):
176+
x, y = batch["x"], batch["y"].unsqueeze(1)
177+
y_hat = self(x)
178+
loss = F.binary_cross_entropy(y_hat, y.float())
179+
self.log("val_loss", loss, on_step=True)
180+
# Log a constant for test purpose
181+
self.log("val_bar", torch.tensor(5.678), on_step=True)
182+
return loss
183+
184+
def validation_epoch_end(self, outputs):
185+
avg_loss = torch.stack(outputs).mean()
186+
self.log("avg_val_loss", avg_loss)
187+
# Log a constant for test purpose
188+
self.log("val_foo", torch.tensor(1.234))
189+
190+
191+
class XORDataModule(LightningDataModule):
192+
def train_dataloader(self):
193+
input_train = [{
194+
"x": torch.tensor([[0.0, 0.0]]),
195+
"y": torch.tensor([0])
196+
}, {
197+
"x": torch.tensor([[1.0, 1.0]]),
198+
"y": torch.tensor([0])
199+
}]
200+
return iter(input_train)
201+
202+
def val_dataloader(self):
203+
input_val = [{
204+
"x": torch.tensor([[0.0, 1.0]]),
205+
"y": torch.tensor([1])
206+
}, {
207+
"x": torch.tensor([[1.0, 0.0]]),
208+
"y": torch.tensor([1])
209+
}]
210+
return iter(input_val)
211+
212+
151213
def get_trainer(dir,
152214
plugins: List[PLUGIN_INPUT],
153215
max_epochs: int = 1,

requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ ray[tune]
1010
torch==1.8.1
1111
torchmetrics
1212
torchvision
13+
protobuf<=3.20.1

0 commit comments

Comments
 (0)