-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Open
Labels
Description
🐛 Describe the bug
When resizing tensor wrapped in tv_tensors.Mask interpolation mode NEAREST_EXACT is changed to NEAREST
import torch
import torch.nn.functional as F
import torchvision.transforms.v2 as T
import torchvision.transforms.v2.functional as TF
from torchvision import tv_tensors
def compare(mask, size):
mask_nearest = TF.resize(mask, size=size, interpolation=T.InterpolationMode.NEAREST)
mask_nearest_exact = TF.resize(mask, size=size, interpolation=T.InterpolationMode.NEAREST_EXACT)
mask_nearest_interpolate = F.interpolate(mask.unsqueeze(0), size=size, mode='nearest').squeeze(0)
mask_nearest_exact_interpolate = F.interpolate(mask.unsqueeze(0), size=size, mode='nearest-exact').squeeze(0)
print(torch.equal(mask_nearest, mask_nearest_exact))
print(torch.equal(mask_nearest_interpolate, mask_nearest_exact_interpolate))
print(torch.equal(mask_nearest, mask_nearest_interpolate))
print(torch.equal(mask_nearest_exact, mask_nearest_exact_interpolate))
height = 64
width = 64
new_height = 16
new_width = 16
mask = torch.randint(low=0, high=2, size=(1, height, width)).float()
compare(mask, (new_height, new_width))
print()
mask = tv_tensors.Mask(mask)
compare(mask, (new_height, new_width))
False
False
True
True
True
False
True
False
Versions
torch = "2.8.0"
torchvision = "0.23.0"