Skip to content

Commit 956f319

Browse files
committed
Add some doc
1 parent 0fdd655 commit 956f319

File tree

4 files changed

+55
-17
lines changed

4 files changed

+55
-17
lines changed

docs/source/transforms.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,12 @@ are combining pairs of images together. These can be used after the dataloader
508508
Developer tools
509509
^^^^^^^^^^^^^^^
510510

511+
.. autosummary::
512+
:toctree: generated/
513+
:template: class.rst
514+
515+
v2.Transform
516+
511517
.. autosummary::
512518
:toctree: generated/
513519
:template: function.rst

gallery/transforms/plot_custom_transforms.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
"""
1313

1414
# %%
15+
from typing import Any, Dict
16+
1517
import torch
1618
from torchvision import tv_tensors
1719
from torchvision.transforms import v2
@@ -89,33 +91,55 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
8991
# A key feature of the builtin Torchvision V2 transforms is that they can accept
9092
# arbitrary input structure and return the same structure as output (with
9193
# transformed entries). For example, transforms can accept a single image, or a
92-
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input:
94+
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input. Here's
95+
# an example on the built-in transform :class:`~torchvision.transforms.v2.RandomHorizontalFlip`:
9396

9497
structured_input = {
9598
"img": img,
9699
"annotations": (bboxes, label),
97-
"something_that_will_be_ignored": (1, "hello")
100+
"something that will be ignored": (1, "hello"),
101+
"another tensor that is ignored": torch.arange(10),
98102
}
99103
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)
100104

101105
assert isinstance(structured_output, dict)
102-
assert structured_output["something_that_will_be_ignored"] == (1, "hello")
106+
assert structured_output["something that will be ignored"] == (1, "hello")
107+
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
108+
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
109+
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
110+
111+
# %%
112+
# In order to support arbitrary inputs in your custom transform, you will need
113+
# to inherit from :class:`~torchvision.transforms.v2.Transform` and override the
114+
# `.transform()` method (not the `forward()` method!).
115+
116+
117+
class MyCustomTransform(v2.Transform):
118+
def transform(self, inpt: Any, params: Dict[str, Any]):
119+
if type(inpt) == torch.Tensor:
120+
print(f"I'm transforming an image of shape {inpt.shape}")
121+
return inpt + 1 # dummy transformation
122+
elif isinstance(inpt, tv_tensors.BoundingBoxes):
123+
print(f"I'm transforming bounding boxes! {inpt.canvas_size = }")
124+
return tv_tensors.wrap(inpt + 100, like=inpt) # dummy transformation
125+
126+
127+
my_custom_transform = MyCustomTransform()
128+
structured_output = my_custom_transform(structured_input)
129+
130+
assert isinstance(structured_output, dict)
131+
assert structured_output["something that will be ignored"] == (1, "hello")
132+
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
133+
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
103134
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
104135

105136
# %%
106-
# If you want to reproduce this behavior in your own transform, we invite you to
107-
# look at our `code
108-
# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_
109-
# and adapt it to your needs.
137+
# An important thing to note is that when we call `my_custom_transform` on
138+
# `structured_input`, the input is flattened and then each individual part is
139+
# passed to `transform()`. That is, `transform()` received the input image, then
140+
# the bounding boxes, etc. It is then within `transform()` that you can decide
141+
# how to transform each input, based on their type.
110142
#
111-
# In brief, the core logic is to unpack the input into a flat list using `pytree
112-
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
113-
# then transform only the entries that can be transformed (the decision is made
114-
# based on the **class** of the entries, as all TVTensors are
115-
# tensor-subclasses) plus some custom logic that is out of score here - check the
116-
# code for details. The (potentially transformed) entries are then repacked and
117-
# returned, in the same structure as the input.
143+
# If you're curious why the other tensor (`torch.arange()`) didn't get passed to `transform()`, see :ref:`_passthrough_heuristic`.
118144
#
119-
# We do not provide public dev-facing tools to achieve that at this time, but if
120-
# this is something that would be valuable to you, please let us know by opening
121-
# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_.
145+
# TODO explain make_params()

torchvision/transforms/v2/_transform.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515

1616

1717
class Transform(nn.Module):
18+
"""Base class to implement your own v2 transforms.
19+
20+
See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py` for
21+
more details.
22+
"""
1823

1924
# Class attribute defining transformed types. Other types are passed-through without any transformation
2025
# We support both Types and callables that are able to do further checks on the type of the input.

torchvision/transforms/v2/_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes:
159159

160160

161161
def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
162+
print("AEFAEFAE")
163+
print(len(flat_inputs))
164+
print([type(inpt) for inpt in flat_inputs])
162165
chws = {
163166
tuple(get_dimensions(inpt))
164167
for inpt in flat_inputs

0 commit comments

Comments
 (0)