|  | 
| 12 | 12 | """ | 
| 13 | 13 | 
 | 
| 14 | 14 | # %% | 
|  | 15 | +from typing import Any, Dict | 
|  | 16 | + | 
| 15 | 17 | import torch | 
| 16 | 18 | from torchvision import tv_tensors | 
| 17 | 19 | from torchvision.transforms import v2 | 
| @@ -89,33 +91,55 @@ def forward(self, img, bboxes, label):  # we assume inputs are always structured | 
| 89 | 91 | # A key feature of the builtin Torchvision V2 transforms is that they can accept | 
| 90 | 92 | # arbitrary input structure and return the same structure as output (with | 
| 91 | 93 | # transformed entries). For example, transforms can accept a single image, or a | 
| 92 |  | -# tuple of ``(img, label)``, or an arbitrary nested dictionary as input: | 
|  | 94 | +# tuple of ``(img, label)``, or an arbitrary nested dictionary as input. Here's | 
|  | 95 | +# an example on the built-in transform :class:`~torchvision.transforms.v2.RandomHorizontalFlip`: | 
| 93 | 96 | 
 | 
| 94 | 97 | structured_input = { | 
| 95 | 98 |     "img": img, | 
| 96 | 99 |     "annotations": (bboxes, label), | 
| 97 |  | -    "something_that_will_be_ignored": (1, "hello") | 
|  | 100 | +    "something that will be ignored": (1, "hello"), | 
|  | 101 | +    "another tensor that is ignored": torch.arange(10), | 
| 98 | 102 | } | 
| 99 | 103 | structured_output = v2.RandomHorizontalFlip(p=1)(structured_input) | 
| 100 | 104 | 
 | 
| 101 | 105 | assert isinstance(structured_output, dict) | 
| 102 |  | -assert structured_output["something_that_will_be_ignored"] == (1, "hello") | 
|  | 106 | +assert structured_output["something that will be ignored"] == (1, "hello") | 
|  | 107 | +assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all() | 
|  | 108 | +print(f"The input bboxes are:\n{structured_input['annotations'][0]}") | 
|  | 109 | +print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") | 
|  | 110 | + | 
|  | 111 | +# %% | 
|  | 112 | +# In order to support arbitrary inputs in your custom transform, you will need | 
|  | 113 | +# to inherit from :class:`~torchvision.transforms.v2.Transform` and override the | 
|  | 114 | +# `.transform()` method (not the `forward()` method!). | 
|  | 115 | + | 
|  | 116 | + | 
|  | 117 | +class MyCustomTransform(v2.Transform): | 
|  | 118 | +    def transform(self, inpt: Any, params: Dict[str, Any]): | 
|  | 119 | +        if type(inpt) == torch.Tensor: | 
|  | 120 | +            print(f"I'm transforming an image of shape {inpt.shape}") | 
|  | 121 | +            return inpt + 1  # dummy transformation | 
|  | 122 | +        elif isinstance(inpt, tv_tensors.BoundingBoxes): | 
|  | 123 | +            print(f"I'm transforming bounding boxes! {inpt.canvas_size = }") | 
|  | 124 | +            return tv_tensors.wrap(inpt + 100, like=inpt)  # dummy transformation | 
|  | 125 | + | 
|  | 126 | + | 
|  | 127 | +my_custom_transform = MyCustomTransform() | 
|  | 128 | +structured_output = my_custom_transform(structured_input) | 
|  | 129 | + | 
|  | 130 | +assert isinstance(structured_output, dict) | 
|  | 131 | +assert structured_output["something that will be ignored"] == (1, "hello") | 
|  | 132 | +assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all() | 
|  | 133 | +print(f"The input bboxes are:\n{structured_input['annotations'][0]}") | 
| 103 | 134 | print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") | 
| 104 | 135 | 
 | 
| 105 | 136 | # %% | 
| 106 |  | -# If you want to reproduce this behavior in your own transform, we invite you to | 
| 107 |  | -# look at our `code | 
| 108 |  | -# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_ | 
| 109 |  | -# and adapt it to your needs. | 
|  | 137 | +# An important thing to note is that when we call `my_custom_transform` on | 
|  | 138 | +# `structured_input`, the input is flattened and then each individual part is | 
|  | 139 | +# passed to `transform()`. That is, `transform()` received the input image, then | 
|  | 140 | +# the bounding boxes, etc. It is then within `transform()` that you can decide | 
|  | 141 | +# how to transform each input, based on their type. | 
| 110 | 142 | # | 
| 111 |  | -# In brief, the core logic is to unpack the input into a flat list using `pytree | 
| 112 |  | -# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and | 
| 113 |  | -# then transform only the entries that can be transformed (the decision is made | 
| 114 |  | -# based on the **class** of the entries, as all TVTensors are | 
| 115 |  | -# tensor-subclasses) plus some custom logic that is out of score here - check the | 
| 116 |  | -# code for details. The (potentially transformed) entries are then repacked and | 
| 117 |  | -# returned, in the same structure as the input. | 
|  | 143 | +# If you're curious why the other tensor (`torch.arange()`) didn't get passed to `transform()`, see :ref:`_passthrough_heuristic`. | 
| 118 | 144 | # | 
| 119 |  | -# We do not provide public dev-facing tools to achieve that at this time, but if | 
| 120 |  | -# this is something that would be valuable to you, please let us know by opening | 
| 121 |  | -# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_. | 
|  | 145 | +# TODO explain make_params() | 
0 commit comments