Skip to content

CrossEntropyLoss incompatible with default ignore_index parameter. #1019

@Foxigod

Description

@Foxigod

I believe the removal of the NoneType check of the ignore_index parameter:

if loss == "ce":
ignore_value = -100 if ignore_index is None else ignore_index
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)

in the following commit: 5afdc22 (cc. @blumenstiel) resulted in the ignore_index parameter being 'broken' (I mention this in the following comment on said commit).

I believe the removal of the None-type check on the ignore_index (which replaced None with -100) can be a cause for issues for the CrossEntropyLoss, which seems to not handle NoneType ignore_index, and instead require an integer one if provided.

I'm encountering these issues:

  [...]
  File "VENV/lib/python3.12/site-packages/terratorch/tasks/segmentation_tasks.py", line 389, in validation_step
    loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "VENV/lib/python3.12/site-packages/terratorch/tasks/loss_handler.py", line 45, in compute_loss
    loss = self._compute_loss(model_output.output, ground_truth, criterion)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "VENV/lib/python3.12/site-packages/terratorch/tasks/loss_handler.py", line 78, in _compute_loss
    loss: Tensor = criterion(y_hat, ground_truth)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "VENV/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl>
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "VENV/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "VENV/lib/python3.12/site-packages/torch/nn/modules/loss.py", line 1179, in forward
    return F.cross_entropy(input, target, weight=self.weight,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "VENV/lib/python3.12/site-packages/torch/nn/functional.py", line 3059, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: cross_entropy_loss(): argument 'ignore_index' (position 5) must be int, not NoneType

Originally posted by @Foxigod in 5afdc22

Was the intention to force the user to provide an integer ignore_index, in which case it would probably be best to report it to the user as a missing parameter? Or was the removal accidental, in which case we should simply re-implement the ignore_index check for None and either replace it with e.g. -100 as previously done, or instantiate the loss without providing the ignore_index parameter (which would then use the default value in the nn.CrossEntropyLoss which is -100: https://github.com/pytorch/pytorch/blob/1aaedbcfdd5c0615a882eebba3f51b2409162142/torch/nn/modules/loss.py#L1352 )?
I suppose these changes could be made in the init_loss() function:

def init_loss(loss: str, ignore_index: int = None, class_weights: list = None) -> nn.Module:
if loss == "ce":
return nn.CrossEntropyLoss(ignore_index=ignore_index, weight=class_weights)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions