-
Notifications
You must be signed in to change notification settings - Fork 132
Description
I believe the removal of the NoneType check of the ignore_index parameter:
terratorch/terratorch/tasks/segmentation_tasks.py
Lines 222 to 224 in a65d86a
| if loss == "ce": | |
| ignore_value = -100 if ignore_index is None else ignore_index | |
| self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights) |
in the following commit: 5afdc22 (cc. @blumenstiel) resulted in the
ignore_index parameter being 'broken' (I mention this in the following comment on said commit).
I believe the removal of the None-type check on the
ignore_index(which replaced None with -100) can be a cause for issues for the CrossEntropyLoss, which seems to not handle NoneTypeignore_index, and instead require an integer one if provided.I'm encountering these issues:
[...] File "VENV/lib/python3.12/site-packages/terratorch/tasks/segmentation_tasks.py", line 389, in validation_step loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "VENV/lib/python3.12/site-packages/terratorch/tasks/loss_handler.py", line 45, in compute_loss loss = self._compute_loss(model_output.output, ground_truth, criterion) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "VENV/lib/python3.12/site-packages/terratorch/tasks/loss_handler.py", line 78, in _compute_loss loss: Tensor = criterion(y_hat, ground_truth) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "VENV/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl> return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "VENV/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "VENV/lib/python3.12/site-packages/torch/nn/modules/loss.py", line 1179, in forward return F.cross_entropy(input, target, weight=self.weight, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "VENV/lib/python3.12/site-packages/torch/nn/functional.py", line 3059, in cross_entropy return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: cross_entropy_loss(): argument 'ignore_index' (position 5) must be int, not NoneType
Originally posted by @Foxigod in 5afdc22
Was the intention to force the user to provide an integer ignore_index, in which case it would probably be best to report it to the user as a missing parameter? Or was the removal accidental, in which case we should simply re-implement the ignore_index check for None and either replace it with e.g. -100 as previously done, or instantiate the loss without providing the ignore_index parameter (which would then use the default value in the nn.CrossEntropyLoss which is -100: https://github.com/pytorch/pytorch/blob/1aaedbcfdd5c0615a882eebba3f51b2409162142/torch/nn/modules/loss.py#L1352 )?
I suppose these changes could be made in the init_loss() function:
terratorch/terratorch/tasks/segmentation_tasks.py
Lines 30 to 32 in 85467a9
| def init_loss(loss: str, ignore_index: int = None, class_weights: list = None) -> nn.Module: | |
| if loss == "ce": | |
| return nn.CrossEntropyLoss(ignore_index=ignore_index, weight=class_weights) |