diff --git a/segmentation_models_pytorch/__version__.py b/segmentation_models_pytorch/__version__.py index 3d187266..07597656 100644 --- a/segmentation_models_pytorch/__version__.py +++ b/segmentation_models_pytorch/__version__.py @@ -1 +1 @@ -__version__ = "0.5.0" +__version__ = "0.5.1.dev0" diff --git a/segmentation_models_pytorch/losses/_functional.py b/segmentation_models_pytorch/losses/_functional.py index 791901f0..07efd7f4 100644 --- a/segmentation_models_pytorch/losses/_functional.py +++ b/segmentation_models_pytorch/losses/_functional.py @@ -66,7 +66,7 @@ def focal_loss_with_logits( References: https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py """ - target = target.type(output.type()) + target = target.to(dtype=output.dtype, device=output.device) logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none") pt = torch.exp(-logpt)