-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Open
Description
🚀 The feature
Make transforms.v2 ignore raw tensors / let them pass through instead of assuming that they are images.
- This could be activated manually by a flag in order to not disrupt the existing API.
- Or if one element of the sample is already an
Image(detected upon calling tree_flatten), don't assume the pure tensors are images too. - Alternatively, implement a new
TVTensortype which is registered for all transformation kernels with an identity / passthrough op. It would need to be thoroughly implemented so that the TVTensor type does not get unwrapped by some transformations.
I understand that the current behaviour reduces the migration work for v1 code, but it also makes it a bit odd to work with for fresh code IMHO.
Motivation, pitch
I'm trying to use transforms.v2 in combination with tv_tensors. My samples contain a combination of an image, boxes, keypoints, labels, and some extra attributes stored as tensors (they are multi-dimensional, not scalar). Unfortunately, transforms.v2 modules get confused by the tensors.
Alternatives
- Implement a
PassthroughTVTensor myself. This is verbose and tedious since I don't know the complete list of transformation kernels. - Don't use tv_tensors and implement all transformations as nn.Module. I pass only a tuple of the sample fields that need transformations and recompose the sample back together.
class RandomPhotometricDistort(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.t = T.RandomPhotometricDistort(*args, **kwargs)
def forward(self, sample):
return sample | {"image": self.t(sample["image"])}Additional context
| _transformed_types = (is_pure_tensor, PIL.Image.Image, np.ndarray) |
| _transformed_types = (is_pure_tensor, tv_tensors.Image, np.ndarray) |
Metadata
Metadata
Assignees
Labels
No labels