Skip to content

combined losses for segmentation task break distributed training #982

@matteot11

Description

@matteot11

Describe the issue
I installed the most recent terratorch version from the main branch (commit 7cbf469f1c179860b81117512f873c6ff2b1ff5a), to use the latest features (specifically, I was interested in combining multiple losses for semantic segmentation, like "ce"+"dice"). I used:
pip install git+https://github.com/IBM/terratorch.git

However, with this version installed, trainer.fit() hangs forever after the first train+val epoch. This only happens when distributing training over more than 1 GPUs and when using a CombinedLoss. It works when using a single loss, like "dice", instead.

To Reproduce
I attach a simple script for semantic segmentation trained with random data, to reproduce the issue:

import lightning as L
from terratorch.tasks import SemanticSegmentationTask
from torch.utils.data import Dataset, DataLoader
import torch

LOSS = "dice"  # works
# LOSS = {"ce": 1.0, "dice": 5.0}  # does not work


class DummyDataModule(L.LightningDataModule):
    def setup(self, stage) -> None:
        self.train_dataset = RandomDataset()
        self.val_dataset = RandomDataset()

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.train_dataset,
            batch_size=1,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
            drop_last=True,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.val_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=2,
            pin_memory=False,
            drop_last=False,
        )


class RandomDataset(Dataset):
    def __init__(self):
        self.imgs = [torch.rand((4, 4, 64, 64)) for _ in range(100)]
        self.masks = [torch.ones((64, 64)).long() for _ in range(100)]
        super().__init__()

    def __getitem__(self, idx):

        return {
            "image": self.imgs[idx],
            "mask": self.masks[idx],
        }

    def __len__(self):
        return len(self.imgs)


module = SemanticSegmentationTask(
    model_factory="EncoderDecoderFactory",
    model_args={
        "backbone": "prithvi_eo_v2_300",
        "backbone_pretrained": True,
        "backbone_img_size": 64,
        "backbone_num_frames": 4,
        "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW"],
        "backbone_coords_encoding": [],
        "necks": [
            {"name": "SelectIndices", "indices": [5, 11, 17, 23]},
            {"name": "ReshapeTokensToImage", "effective_time_dim": 4},
            {"name": "LearnedInterpolateToPyramidal"},
        ],
        "decoder": "UNetDecoder",
        "decoder_channels": [512, 256, 128, 64],
        "head_dropout": 0.1,
        "num_classes": 2,
    },
    loss=LOSS,
    lr=1e-4,
    optimizer="AdamW",
    ignore_index=-1,
    freeze_backbone=False,
    freeze_decoder=False,
    plot_on_val=False,
)
datamodule = DummyDataModule()
trainer = L.Trainer(
    max_epochs=5,
    default_root_dir="./logs",
    num_sanity_val_steps=2,
    precision="bf16-mixed",
    log_every_n_steps=10,
    limit_val_batches=5,
    limit_train_batches=5,
)
trainer.fit(module, datamodule=datamodule)

Screenshots or log output
Lightning fit progress bar simply gets stuck at the end of the first epoch:
Image

Expected behavior
Training should work with CombinedLoss, too.

Deployment information
Describe what you've deployed and how:

  • TerraTorch version: latest version from main branch (commit 7cbf469)
  • Installation source: pip
  • torch version: 2.9.0 (cu12)
  • devices: 2x NVIDIA TITAN RTX GPUs

I add @blumenstiel, since I noticed that the CombinedLoss feature has been recently added by a PR from you. Sorry if you are not the right person to ask.

Metadata

Metadata

Assignees

No one assigned

    Labels

    1.1.2bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions