1212"""
1313
1414# %%
15- from typing import Any , Dict
15+ from typing import Any , Dict , List
1616
1717import torch
1818from torchvision import tv_tensors
@@ -109,9 +109,12 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
109109print (f"The transformed bboxes are:\n { structured_output ['annotations' ][0 ]} " )
110110
111111# %%
112+ # Basics: override the `transform()` method
113+ # -----------------------------------------
114+ #
112115# In order to support arbitrary inputs in your custom transform, you will need
113116# to inherit from :class:`~torchvision.transforms.v2.Transform` and override the
114- # `.transform()` method (not the `forward()` method!).
117+ # `.transform()` method (not the `forward()` method!). Below is a basic example:
115118
116119
117120class MyCustomTransform (v2 .Transform ):
@@ -134,12 +137,63 @@ def transform(self, inpt: Any, params: Dict[str, Any]):
134137print (f"The transformed bboxes are:\n { structured_output ['annotations' ][0 ]} " )
135138
136139# %%
137- # An important thing to note is that when we call `my_custom_transform` on
138- # `structured_input`, the input is flattened and then each individual part is
139- # passed to `transform()`. That is, `transform()` received the input image, then
140- # the bounding boxes, etc. It is then within `transform()` that you can decide
141- # how to transform each input, based on their type.
140+ # An important thing to note is that when we call ``my_custom_transform`` on
141+ # ``structured_input``, the input is flattened and then each individual part is
142+ # passed to ``transform()``. That is, ``transform()``` receives the input image,
143+ # then the bounding boxes, etc. Within ``transform()``, you can decide how to
144+ # transform each input, based on their type.
145+ #
146+ # If you're curious why the other tensor (``torch.arange()``) didn't get passed
147+ # to ``transform()``, see :ref:`passthrough_heuristic`.
148+ #
149+ # Advanced: The ``make_params()`` method
150+ # --------------------------------------
151+ #
152+ # The ``make_params()`` method is called internally before calling
153+ # ``transform()`` on each input. This is typically useful to generate random
154+ # parameter values. In the example below, we use it to randomly apply the
155+ # transformation with a probability of 0.5
156+
157+
158+ class MyRandomTransform (MyCustomTransform ):
159+ def __init__ (self , p = 0.5 ):
160+ self .p = p
161+ super ().__init__ ()
162+
163+ def make_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
164+ apply_transform = (torch .rand (size = (1 ,)) < self .p ).item ()
165+ params = dict (apply_transform = apply_transform )
166+ return params
167+
168+ def transform (self , inpt : Any , params : Dict [str , Any ]):
169+ if not params ["apply_transform" ]:
170+ print ("Not transforming anything!" )
171+ return inpt
172+ else :
173+ return super ().transform (inpt , params )
174+
175+
176+ my_random_transform = MyRandomTransform ()
177+
178+ torch .manual_seed (0 )
179+ _ = my_random_transform (structured_input ) # transforms
180+ _ = my_random_transform (structured_input ) # doesn't transform
181+
182+ # %%
183+ #
184+ # .. note::
185+ #
186+ # It's important for such random parameter generation to happen within
187+ # ``make_params()`` and not within ``transform()``, so that for a given
188+ # transform call, the same RNG applies to all the inputs in the same way. If
189+ # we were to perform the RNG within ``transform()``, we would risk e.g.
190+ # transforming the image while *not* transforming the bounding boxes.
142191#
143- # If you're curious why the other tensor (`torch.arange()`) didn't get passed to `transform()`, see :ref:`_passthrough_heuristic`.
192+ # The ``make_params()`` method takes the list of all the inputs as parameter
193+ # (each of the elements in this list will later be pased to ``transform()``).
194+ # You can use ``flat_inputs`` to e.g. figure out the dimensions on the input,
195+ # using :func:`~torchvision.transforms.v2.query_chw` or
196+ # :func:`~torchvision.transforms.v2.query_size`.
144197#
145- # TODO explain make_params()
198+ # ``make_params()`` should return a dict (or actually, anything you want) that
199+ # will then be passed to ``transform()``.
0 commit comments