Skip to content

Commit e2519c8

Browse files
committed
More docs
1 parent 956f319 commit e2519c8

File tree

3 files changed

+72
-12
lines changed

3 files changed

+72
-12
lines changed

docs/source/transforms.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,9 @@ Developer tools
519519
:template: function.rst
520520

521521
v2.functional.register_kernel
522+
v2.query_size
523+
v2.query_chw
524+
v2.get_bounding_boxes
522525

523526

524527
V1 API Reference

gallery/transforms/plot_custom_transforms.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"""
1313

1414
# %%
15-
from typing import Any, Dict
15+
from typing import Any, Dict, List
1616

1717
import torch
1818
from torchvision import tv_tensors
@@ -109,9 +109,12 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
109109
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
110110

111111
# %%
112+
# Basics: override the `transform()` method
113+
# -----------------------------------------
114+
#
112115
# In order to support arbitrary inputs in your custom transform, you will need
113116
# to inherit from :class:`~torchvision.transforms.v2.Transform` and override the
114-
# `.transform()` method (not the `forward()` method!).
117+
# `.transform()` method (not the `forward()` method!). Below is a basic example:
115118

116119

117120
class MyCustomTransform(v2.Transform):
@@ -134,12 +137,63 @@ def transform(self, inpt: Any, params: Dict[str, Any]):
134137
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
135138

136139
# %%
137-
# An important thing to note is that when we call `my_custom_transform` on
138-
# `structured_input`, the input is flattened and then each individual part is
139-
# passed to `transform()`. That is, `transform()` received the input image, then
140-
# the bounding boxes, etc. It is then within `transform()` that you can decide
141-
# how to transform each input, based on their type.
140+
# An important thing to note is that when we call ``my_custom_transform`` on
141+
# ``structured_input``, the input is flattened and then each individual part is
142+
# passed to ``transform()``. That is, ``transform()``` receives the input image,
143+
# then the bounding boxes, etc. Within ``transform()``, you can decide how to
144+
# transform each input, based on their type.
145+
#
146+
# If you're curious why the other tensor (``torch.arange()``) didn't get passed
147+
# to ``transform()``, see :ref:`passthrough_heuristic`.
148+
#
149+
# Advanced: The ``make_params()`` method
150+
# --------------------------------------
151+
#
152+
# The ``make_params()`` method is called internally before calling
153+
# ``transform()`` on each input. This is typically useful to generate random
154+
# parameter values. In the example below, we use it to randomly apply the
155+
# transformation with a probability of 0.5
156+
157+
158+
class MyRandomTransform(MyCustomTransform):
159+
def __init__(self, p=0.5):
160+
self.p = p
161+
super().__init__()
162+
163+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
164+
apply_transform = (torch.rand(size=(1,)) < self.p).item()
165+
params = dict(apply_transform=apply_transform)
166+
return params
167+
168+
def transform(self, inpt: Any, params: Dict[str, Any]):
169+
if not params["apply_transform"]:
170+
print("Not transforming anything!")
171+
return inpt
172+
else:
173+
return super().transform(inpt, params)
174+
175+
176+
my_random_transform = MyRandomTransform()
177+
178+
torch.manual_seed(0)
179+
_ = my_random_transform(structured_input) # transforms
180+
_ = my_random_transform(structured_input) # doesn't transform
181+
182+
# %%
183+
#
184+
# .. note::
185+
#
186+
# It's important for such random parameter generation to happen within
187+
# ``make_params()`` and not within ``transform()``, so that for a given
188+
# transform call, the same RNG applies to all the inputs in the same way. If
189+
# we were to perform the RNG within ``transform()``, we would risk e.g.
190+
# transforming the image while *not* transforming the bounding boxes.
142191
#
143-
# If you're curious why the other tensor (`torch.arange()`) didn't get passed to `transform()`, see :ref:`_passthrough_heuristic`.
192+
# The ``make_params()`` method takes the list of all the inputs as parameter
193+
# (each of the elements in this list will later be pased to ``transform()``).
194+
# You can use ``flat_inputs`` to e.g. figure out the dimensions on the input,
195+
# using :func:`~torchvision.transforms.v2.query_chw` or
196+
# :func:`~torchvision.transforms.v2.query_size`.
144197
#
145-
# TODO explain make_params()
198+
# ``make_params()`` should return a dict (or actually, anything you want) that
199+
# will then be passed to ``transform()``.

torchvision/transforms/v2/_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ def _parse_labels_getter(labels_getter: Union[str, Callable[[Any], Any], None])
151151

152152

153153
def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes:
154+
"""Return the Bounding Boxes in the input.
155+
156+
Assumes only one ``BoundingBoxes`` object is present.
157+
"""
154158
# This assumes there is only one bbox per sample as per the general convention
155159
try:
156160
return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.BoundingBoxes))
@@ -159,9 +163,7 @@ def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes:
159163

160164

161165
def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
162-
print("AEFAEFAE")
163-
print(len(flat_inputs))
164-
print([type(inpt) for inpt in flat_inputs])
166+
"""Return Channel, Height, and Width."""
165167
chws = {
166168
tuple(get_dimensions(inpt))
167169
for inpt in flat_inputs
@@ -176,6 +178,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
176178

177179

178180
def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
181+
"""Return Height and Width."""
179182
sizes = {
180183
tuple(get_size(inpt))
181184
for inpt in flat_inputs

0 commit comments

Comments
 (0)