|  | 
| 12 | 12 | """ | 
| 13 | 13 | 
 | 
| 14 | 14 | # %% | 
|  | 15 | +from typing import Any, Dict, List | 
|  | 16 | + | 
| 15 | 17 | import torch | 
| 16 | 18 | from torchvision import tv_tensors | 
| 17 | 19 | from torchvision.transforms import v2 | 
| @@ -89,33 +91,110 @@ def forward(self, img, bboxes, label):  # we assume inputs are always structured | 
| 89 | 91 | # A key feature of the builtin Torchvision V2 transforms is that they can accept | 
| 90 | 92 | # arbitrary input structure and return the same structure as output (with | 
| 91 | 93 | # transformed entries). For example, transforms can accept a single image, or a | 
| 92 |  | -# tuple of ``(img, label)``, or an arbitrary nested dictionary as input: | 
|  | 94 | +# tuple of ``(img, label)``, or an arbitrary nested dictionary as input. Here's | 
|  | 95 | +# an example on the built-in transform :class:`~torchvision.transforms.v2.RandomHorizontalFlip`: | 
| 93 | 96 | 
 | 
| 94 | 97 | structured_input = { | 
| 95 | 98 |     "img": img, | 
| 96 | 99 |     "annotations": (bboxes, label), | 
| 97 |  | -    "something_that_will_be_ignored": (1, "hello") | 
|  | 100 | +    "something that will be ignored": (1, "hello"), | 
|  | 101 | +    "another tensor that is ignored": torch.arange(10), | 
| 98 | 102 | } | 
| 99 | 103 | structured_output = v2.RandomHorizontalFlip(p=1)(structured_input) | 
| 100 | 104 | 
 | 
| 101 | 105 | assert isinstance(structured_output, dict) | 
| 102 |  | -assert structured_output["something_that_will_be_ignored"] == (1, "hello") | 
|  | 106 | +assert structured_output["something that will be ignored"] == (1, "hello") | 
|  | 107 | +assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all() | 
|  | 108 | +print(f"The input bboxes are:\n{structured_input['annotations'][0]}") | 
|  | 109 | +print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") | 
|  | 110 | + | 
|  | 111 | +# %% | 
|  | 112 | +# Basics: override the `transform()` method | 
|  | 113 | +# ----------------------------------------- | 
|  | 114 | +# | 
|  | 115 | +# In order to support arbitrary inputs in your custom transform, you will need | 
|  | 116 | +# to inherit from :class:`~torchvision.transforms.v2.Transform` and override the | 
|  | 117 | +# `.transform()` method (not the `forward()` method!). Below is a basic example: | 
|  | 118 | + | 
|  | 119 | + | 
|  | 120 | +class MyCustomTransform(v2.Transform): | 
|  | 121 | +    def transform(self, inpt: Any, params: Dict[str, Any]): | 
|  | 122 | +        if type(inpt) == torch.Tensor: | 
|  | 123 | +            print(f"I'm transforming an image of shape {inpt.shape}") | 
|  | 124 | +            return inpt + 1  # dummy transformation | 
|  | 125 | +        elif isinstance(inpt, tv_tensors.BoundingBoxes): | 
|  | 126 | +            print(f"I'm transforming bounding boxes! {inpt.canvas_size = }") | 
|  | 127 | +            return tv_tensors.wrap(inpt + 100, like=inpt)  # dummy transformation | 
|  | 128 | + | 
|  | 129 | + | 
|  | 130 | +my_custom_transform = MyCustomTransform() | 
|  | 131 | +structured_output = my_custom_transform(structured_input) | 
|  | 132 | + | 
|  | 133 | +assert isinstance(structured_output, dict) | 
|  | 134 | +assert structured_output["something that will be ignored"] == (1, "hello") | 
|  | 135 | +assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all() | 
|  | 136 | +print(f"The input bboxes are:\n{structured_input['annotations'][0]}") | 
| 103 | 137 | print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") | 
| 104 | 138 | 
 | 
| 105 | 139 | # %% | 
| 106 |  | -# If you want to reproduce this behavior in your own transform, we invite you to | 
| 107 |  | -# look at our `code | 
| 108 |  | -# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_ | 
| 109 |  | -# and adapt it to your needs. | 
| 110 |  | -# | 
| 111 |  | -# In brief, the core logic is to unpack the input into a flat list using `pytree | 
| 112 |  | -# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and | 
| 113 |  | -# then transform only the entries that can be transformed (the decision is made | 
| 114 |  | -# based on the **class** of the entries, as all TVTensors are | 
| 115 |  | -# tensor-subclasses) plus some custom logic that is out of score here - check the | 
| 116 |  | -# code for details. The (potentially transformed) entries are then repacked and | 
| 117 |  | -# returned, in the same structure as the input. | 
| 118 |  | -# | 
| 119 |  | -# We do not provide public dev-facing tools to achieve that at this time, but if | 
| 120 |  | -# this is something that would be valuable to you, please let us know by opening | 
| 121 |  | -# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_. | 
|  | 140 | +# An important thing to note is that when we call ``my_custom_transform`` on | 
|  | 141 | +# ``structured_input``, the input is flattened and then each individual part is | 
|  | 142 | +# passed to ``transform()``. That is, ``transform()``` receives the input image, | 
|  | 143 | +# then the bounding boxes, etc. Within ``transform()``, you can decide how to | 
|  | 144 | +# transform each input, based on their type. | 
|  | 145 | +# | 
|  | 146 | +# If you're curious why the other tensor (``torch.arange()``) didn't get passed | 
|  | 147 | +# to ``transform()``, see :ref:`this note <passthrough_heuristic>` for more | 
|  | 148 | +# details. | 
|  | 149 | +# | 
|  | 150 | +# Advanced: The ``make_params()`` method | 
|  | 151 | +# -------------------------------------- | 
|  | 152 | +# | 
|  | 153 | +# The ``make_params()`` method is called internally before calling | 
|  | 154 | +# ``transform()`` on each input. This is typically useful to generate random | 
|  | 155 | +# parameter values. In the example below, we use it to randomly apply the | 
|  | 156 | +# transformation with a probability of 0.5 | 
|  | 157 | + | 
|  | 158 | + | 
|  | 159 | +class MyRandomTransform(MyCustomTransform): | 
|  | 160 | +    def __init__(self, p=0.5): | 
|  | 161 | +        self.p = p | 
|  | 162 | +        super().__init__() | 
|  | 163 | + | 
|  | 164 | +    def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: | 
|  | 165 | +        apply_transform = (torch.rand(size=(1,)) < self.p).item() | 
|  | 166 | +        params = dict(apply_transform=apply_transform) | 
|  | 167 | +        return params | 
|  | 168 | + | 
|  | 169 | +    def transform(self, inpt: Any, params: Dict[str, Any]): | 
|  | 170 | +        if not params["apply_transform"]: | 
|  | 171 | +            print("Not transforming anything!") | 
|  | 172 | +            return inpt | 
|  | 173 | +        else: | 
|  | 174 | +            return super().transform(inpt, params) | 
|  | 175 | + | 
|  | 176 | + | 
|  | 177 | +my_random_transform = MyRandomTransform() | 
|  | 178 | + | 
|  | 179 | +torch.manual_seed(0) | 
|  | 180 | +_ = my_random_transform(structured_input)  # transforms | 
|  | 181 | +_ = my_random_transform(structured_input)  # doesn't transform | 
|  | 182 | + | 
|  | 183 | +# %% | 
|  | 184 | +# | 
|  | 185 | +# .. note:: | 
|  | 186 | +# | 
|  | 187 | +#     It's important for such random parameter generation to happen within | 
|  | 188 | +#     ``make_params()`` and not within ``transform()``, so that for a given | 
|  | 189 | +#     transform call, the same RNG applies to all the inputs in the same way. If | 
|  | 190 | +#     we were to perform the RNG within ``transform()``, we would risk e.g. | 
|  | 191 | +#     transforming the image while *not* transforming the bounding boxes. | 
|  | 192 | +# | 
|  | 193 | +# The ``make_params()`` method takes the list of all the inputs as parameter | 
|  | 194 | +# (each of the elements in this list will later be pased to ``transform()``). | 
|  | 195 | +# You can use ``flat_inputs`` to e.g. figure out the dimensions on the input, | 
|  | 196 | +# using :func:`~torchvision.transforms.v2.query_chw` or | 
|  | 197 | +# :func:`~torchvision.transforms.v2.query_size`. | 
|  | 198 | +# | 
|  | 199 | +# ``make_params()`` should return a dict (or actually, anything you want) that | 
|  | 200 | +# will then be passed to ``transform()``. | 
0 commit comments