-
Notifications
You must be signed in to change notification settings - Fork 133
Description
Describe the issue
I installed the most recent terratorch version from the main branch (commit 7cbf469f1c179860b81117512f873c6ff2b1ff5a), to use the latest features (specifically, I was interested in combining multiple losses for semantic segmentation, like "ce"+"dice"). I used:
pip install git+https://github.com/IBM/terratorch.git
However, with this version installed, trainer.fit() hangs forever after the first train+val epoch. This only happens when distributing training over more than 1 GPUs and when using a CombinedLoss. It works when using a single loss, like "dice", instead.
To Reproduce
I attach a simple script for semantic segmentation trained with random data, to reproduce the issue:
import lightning as L
from terratorch.tasks import SemanticSegmentationTask
from torch.utils.data import Dataset, DataLoader
import torch
LOSS = "dice" # works
# LOSS = {"ce": 1.0, "dice": 5.0} # does not work
class DummyDataModule(L.LightningDataModule):
def setup(self, stage) -> None:
self.train_dataset = RandomDataset()
self.val_dataset = RandomDataset()
def train_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.train_dataset,
batch_size=1,
shuffle=True,
num_workers=2,
pin_memory=True,
drop_last=True,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.val_dataset,
batch_size=1,
shuffle=False,
num_workers=2,
pin_memory=False,
drop_last=False,
)
class RandomDataset(Dataset):
def __init__(self):
self.imgs = [torch.rand((4, 4, 64, 64)) for _ in range(100)]
self.masks = [torch.ones((64, 64)).long() for _ in range(100)]
super().__init__()
def __getitem__(self, idx):
return {
"image": self.imgs[idx],
"mask": self.masks[idx],
}
def __len__(self):
return len(self.imgs)
module = SemanticSegmentationTask(
model_factory="EncoderDecoderFactory",
model_args={
"backbone": "prithvi_eo_v2_300",
"backbone_pretrained": True,
"backbone_img_size": 64,
"backbone_num_frames": 4,
"backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW"],
"backbone_coords_encoding": [],
"necks": [
{"name": "SelectIndices", "indices": [5, 11, 17, 23]},
{"name": "ReshapeTokensToImage", "effective_time_dim": 4},
{"name": "LearnedInterpolateToPyramidal"},
],
"decoder": "UNetDecoder",
"decoder_channels": [512, 256, 128, 64],
"head_dropout": 0.1,
"num_classes": 2,
},
loss=LOSS,
lr=1e-4,
optimizer="AdamW",
ignore_index=-1,
freeze_backbone=False,
freeze_decoder=False,
plot_on_val=False,
)
datamodule = DummyDataModule()
trainer = L.Trainer(
max_epochs=5,
default_root_dir="./logs",
num_sanity_val_steps=2,
precision="bf16-mixed",
log_every_n_steps=10,
limit_val_batches=5,
limit_train_batches=5,
)
trainer.fit(module, datamodule=datamodule)
Screenshots or log output
Lightning fit progress bar simply gets stuck at the end of the first epoch:
![]()
Expected behavior
Training should work with CombinedLoss, too.
Deployment information
Describe what you've deployed and how:
- TerraTorch version: latest version from main branch (commit 7cbf469)
- Installation source: pip
- torch version: 2.9.0 (cu12)
- devices: 2x NVIDIA TITAN RTX GPUs
I add @blumenstiel, since I noticed that the CombinedLoss feature has been recently added by a PR from you. Sorry if you are not the right person to ask.