- 
                Notifications
    You must be signed in to change notification settings 
- Fork 7.2k
Description
🚀 The feature
Supporting arbitrary input structures in custom transforms is very important in the case of transform compositions:
tr = Compose([RandomCrop((128,128), CustomTransform])This can be done by inheriting from torchvision.transforms.v2.Transform and implementing the private ._transform method, which avoids having to unravel the data structure on your own (since this is done anyway in the .forward method).
class CustomTransform(Transform):
  def __init__(self, *kwargs):
    pass
  def _transform(self, inpt, params):
    if isinstance(inpt, Image):
      pass
    elif isinstance(inpt, BoundingBoxes):
      pass
    else:
      pass
    return transformed_inptThe method has also been described in this blog post How to Create Custom Torchvision V2 Transforms, but the official torchvision docs do not yet describe it and instead suggest hard-coding the input structure.
Having to implement a private method for this (even though the class Transform is public) feels very wrong this means that things could break on our side any time. I would appreciate if the ._transform method was made public -> .transform and the Transform class would receive proper documentation on how this method should be implemented for custom transforms.
Motivation, pitch
The torchvision.transforms.v2 API has now been around for quite some time already and it would be nice to give developers the chance to develop transforms of the same quality and flexibility as the originally implemented ones!
Alternatives
No response
Additional context
No response