|
| 1 | +Demonstration of torch.export flow, common challenges and the solutions to address them |
| 2 | +======================================================================================= |
| 3 | +**Authors:** `Ankith Gunapal`, `Jordi Ramon`, `Marcos Carranza` |
| 4 | + |
| 5 | +In a previous `tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__ , we learnt how to use `torch.export <https://pytorch.org/docs/stable/export.html>`__. |
| 6 | +This tutorial builds on the previous tutorial and explores the process of exporting popular models with code & addresses common challenges one might face with `torch.export`. |
| 7 | + |
| 8 | +You will learn how to export models for these usecases |
| 9 | + |
| 10 | +* Video classifier (MViT) |
| 11 | +* Pose Estimation (Yolov11 Pose) |
| 12 | +* Image Captioning (BLIP) |
| 13 | +* Promptable Image Segmentation (SAM2) |
| 14 | + |
| 15 | +Each of the four models were chosen to demonstrate unique features of `torch.export`, some practical considerations |
| 16 | +& issues faced in the implementation. |
| 17 | + |
| 18 | +Prerequisites |
| 19 | +------------- |
| 20 | + |
| 21 | +* PyTorch 2.4 or later |
| 22 | +* Basic understanding of ``torch.export`` and PyTorch Eager inference. |
| 23 | + |
| 24 | + |
| 25 | +Key requirement for `torch.export`: No graph break |
| 26 | +------------------------------------------------ |
| 27 | + |
| 28 | +`torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__ speeds up PyTorch code by JIT compiling PyTorch code into optimized kernels. It optimizes the given model |
| 29 | +using TorchDynamo and creates an optimized graph , which is then lowered into the hardware using the backend specified in the API. |
| 30 | +When TorchDynamo encounters unsupported Python features, it breaks the computation graph, lets the default Python interpreter |
| 31 | +handle the unsupported code, then resumes capturing the graph. This break in the computation graph is called a `graph break <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#torchdynamo-and-fx-graphs>`__. |
| 32 | + |
| 33 | +One of the key differences between `torch.export` and `torch.compile` is that `torch.export` doesn’t support graph breaks |
| 34 | +i.e the entire model or part of the model that you are exporting needs to be a single graph. This is because handling graph breaks |
| 35 | +involves interpreting the unsupported operation with default Python evaluation, which is incompatible with what torch.export is |
| 36 | +designed for. |
| 37 | + |
| 38 | +You can identify graph breaks in your program by using the following |
| 39 | + |
| 40 | +.. code:: console |
| 41 | +
|
| 42 | + TORCH_LOGS="graph_breaks" python <file_name>.py |
| 43 | +
|
| 44 | +You will need to modify your program to get rid of graph breaks. Once resolved, you are ready to export the model. |
| 45 | +PyTorch runs `nightly benchmarks <https://hud.pytorch.org/benchmark/compilers>`__ for `torch.compile` on popular HuggingFace and TIMM models. |
| 46 | +Most of these models have no graph break. |
| 47 | + |
| 48 | +The models in this recipe have no graph break, but fail with `torch.export` |
| 49 | + |
| 50 | +Video Classification |
| 51 | +-------------------- |
| 52 | + |
| 53 | +MViT is a class of models based on `MultiScale Vision Transformers <https://arxiv.org/abs/2104.11227>`__. This has been trained for video classification using the `Kinetics-400 Dataset <https://arxiv.org/abs/1705.06950>`__. |
| 54 | +This model with a relevant dataset can be used for action recognition in the context of gaming. |
| 55 | + |
| 56 | + |
| 57 | +The code below exports MViT by tracing with `batch_size=2` and then checks if the ExportedProgram can run with `batch_size=4` |
| 58 | + |
| 59 | +.. code:: python |
| 60 | +
|
| 61 | + import numpy as np |
| 62 | + import torch |
| 63 | + from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b |
| 64 | + import traceback as tb |
| 65 | +
|
| 66 | + model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT) |
| 67 | +
|
| 68 | + # Create a batch of 2 videos, each with 16 frames of shape 224x224x3. |
| 69 | + input_frames = torch.randn(2,16, 224, 224, 3) |
| 70 | + # Transpose to get [1, 3, num_clips, height, width]. |
| 71 | + input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3)) |
| 72 | +
|
| 73 | + # Export the model. |
| 74 | + exported_program = torch.export.export( |
| 75 | + model, |
| 76 | + (input_frames,), |
| 77 | + ) |
| 78 | +
|
| 79 | + # Create a batch of 4 videos, each with 16 frames of shape 224x224x3. |
| 80 | + input_frames = torch.randn(4,16, 224, 224, 3) |
| 81 | + input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3)) |
| 82 | + try: |
| 83 | + exported_program.module()(input_frames) |
| 84 | + except Exception: |
| 85 | + tb.print_exc() |
| 86 | +
|
| 87 | +
|
| 88 | +Error: Static batch size |
| 89 | +~~~~~~~~~~~~~~~~~~~~~~~~ |
| 90 | + |
| 91 | +.. code:: console |
| 92 | +
|
| 93 | + raise RuntimeError( |
| 94 | + RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 4 |
| 95 | +
|
| 96 | +
|
| 97 | +By default, the exporting flow will trace the program assuming that all input shapes are static, so if you run the program with |
| 98 | +inputs shapes that are different than the ones you used while tracing, you will run into an error. |
| 99 | + |
| 100 | +Solution |
| 101 | +~~~~~~~~ |
| 102 | + |
| 103 | +To address the error, we specify the first dimension of the input (`batch_size`) to be dynamic , specifying the expected range of `batch_size`. |
| 104 | +In the corrected example shown below, we specify that the expected `batch_size` can range from 1 to 16. |
| 105 | +One detail to notice that `min=2` is not a bug and is explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__. A detailed description of dynamic shapes |
| 106 | +for torch.export can be found in the export tutorial. The code shown below demonstrates how to export mViT with dynamic batch sizes. |
| 107 | + |
| 108 | +.. code:: python |
| 109 | +
|
| 110 | + import numpy as np |
| 111 | + import torch |
| 112 | + from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b |
| 113 | + import traceback as tb |
| 114 | +
|
| 115 | +
|
| 116 | + model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT) |
| 117 | +
|
| 118 | + # Create a batch of 2 videos, each with 16 frames of shape 224x224x3. |
| 119 | + input_frames = torch.randn(2,16, 224, 224, 3) |
| 120 | +
|
| 121 | + # Transpose to get [1, 3, num_clips, height, width]. |
| 122 | + input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3)) |
| 123 | +
|
| 124 | + # Export the model. |
| 125 | + batch_dim = torch.export.Dim("batch", min=2, max=16) |
| 126 | + exported_program = torch.export.export( |
| 127 | + model, |
| 128 | + (input_frames,), |
| 129 | + # Specify the first dimension of the input x as dynamic |
| 130 | + dynamic_shapes={"x": {0: batch_dim}}, |
| 131 | + ) |
| 132 | +
|
| 133 | + # Create a batch of 4 videos, each with 16 frames of shape 224x224x3. |
| 134 | + input_frames = torch.randn(4,16, 224, 224, 3) |
| 135 | + input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3)) |
| 136 | + try: |
| 137 | + exported_program.module()(input_frames) |
| 138 | + except Exception: |
| 139 | + tb.print_exc() |
| 140 | +
|
| 141 | +
|
| 142 | +
|
| 143 | +
|
| 144 | +
|
| 145 | +Pose Estimation |
| 146 | +--------------- |
| 147 | + |
| 148 | +Pose Estimation is a popular Computer Vision concept that can be used to identify the location of joints of a human in a 2D image. |
| 149 | +`Ultralytics <https://docs.ultralytics.com/tasks/pose/>`__ has published a Pose Estimation model based on `YOLO11 <https://docs.ultralytics.com/models/yolo11/>`__. This has been trained on the `COCO Dataset <https://cocodataset.org/#keypoints-2017>`__. This model can be used |
| 150 | +for analyzing human pose for determining action or intent. The code below tries to export the YOLO11 Pose model with `batch_size=1` |
| 151 | + |
| 152 | + |
| 153 | +.. code:: python |
| 154 | +
|
| 155 | + from ultralytics import YOLO |
| 156 | + import torch |
| 157 | + from torch.export import export |
| 158 | +
|
| 159 | + pose_model = YOLO("yolo11n-pose.pt") # Load model |
| 160 | + pose_model.model.eval() |
| 161 | +
|
| 162 | + inputs = torch.rand((1,3,640,640)) |
| 163 | + exported_program: torch.export.ExportedProgram= export(pose_model.model, args=(inputs,)) |
| 164 | +
|
| 165 | +
|
| 166 | +Error: strict tracing with TorchDynamo |
| 167 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 168 | + |
| 169 | +.. code:: console |
| 170 | +
|
| 171 | + torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {zuf0} not in returned outputs FakeTensor(..., size=(6400, 1)) ((1, 1), 0). |
| 172 | +
|
| 173 | +
|
| 174 | +By default `torch.export` traces your code using `TorchDynamo <https://pytorch.org/docs/stable/torch.compiler_dynamo_overview.html>`__, a byte-code analysis engine, which symbolically analyzes your code and builds a graph. |
| 175 | +This analysis provides a stronger guarantee about safety but not all python code is supported. When we export the `yolo11n-pose` model using the |
| 176 | +default strict mode, it errors. |
| 177 | + |
| 178 | +Solution |
| 179 | +~~~~~~~~ |
| 180 | + |
| 181 | +To address the above error `torch.export` supports non_strict mode where the program is traced using the python interpreter, which works similar to |
| 182 | +PyTorch eager execution, the only difference is that all Tensor objects will be replaced by ProxyTensors, which will record all their operations into |
| 183 | +a graph. By using `strict=False`, we are able to export the program. |
| 184 | + |
| 185 | +.. code:: python |
| 186 | +
|
| 187 | + from ultralytics import YOLO |
| 188 | + import torch |
| 189 | + from torch.export import export |
| 190 | +
|
| 191 | + pose_model = YOLO("yolo11n-pose.pt") # Load model |
| 192 | + pose_model.model.eval() |
| 193 | +
|
| 194 | + inputs = torch.rand((1,3,640,640)) |
| 195 | + exported_program: torch.export.ExportedProgram= export(pose_model.model, args=(inputs,), strict=False) |
| 196 | +
|
| 197 | +
|
| 198 | +
|
| 199 | +Image Captioning |
| 200 | +---------------- |
| 201 | + |
| 202 | +Image Captioning is the task of defining the contents of an image in words. In the context of gaming, Image Captioning can be used to enhance the |
| 203 | +gameplay experience by dynamically generating text description of the various game objects in the scene, thereby providing the gamer with additional |
| 204 | +details. `BLIP <https://arxiv.org/pdf/2201.12086>`__ is a popular model for Image Captioning `released by SalesForce Research <https://github.com/salesforce/BLIP>`__. The code below tries to export BLIP with `batch_size=1` |
| 205 | + |
| 206 | + |
| 207 | +.. code:: python |
| 208 | +
|
| 209 | + import torch |
| 210 | + from models.blip import blip_decoder |
| 211 | +
|
| 212 | + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 213 | + image_size = 384 |
| 214 | + image = torch.randn(1, 3,384,384).to(device) |
| 215 | + caption_input = "" |
| 216 | +
|
| 217 | + model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' |
| 218 | + model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base') |
| 219 | + model.eval() |
| 220 | + model = model.to(device) |
| 221 | +
|
| 222 | + exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False) |
| 223 | +
|
| 224 | +
|
| 225 | +
|
| 226 | +Error: Unsupported python operations |
| 227 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 228 | + |
| 229 | +While exporting a model, it might fail because the model implementation might contain certain python operations which are not yet supported by `torch.export`. |
| 230 | +Some of these failures may have a workaround. BLIP is an example where the original model errors and making a small change in the code resolves the issue. |
| 231 | +`torch.export` lists the common cases of supported and unsupported operations in `ExportDB <https://pytorch.org/docs/main/generated/exportdb/index.html>`__ and shows how you can modify your code to make it export compatible. |
| 232 | + |
| 233 | +.. code:: console |
| 234 | +
|
| 235 | + File "/BLIP/models/blip.py", line 112, in forward |
| 236 | + text.input_ids[:,0] = self.tokenizer.bos_token_id |
| 237 | + File "/anaconda3/envs/export/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 545, in __torch_dispatch__ |
| 238 | + outs_unwrapped = func._op_dk( |
| 239 | + RuntimeError: cannot mutate tensors with frozen storage |
| 240 | +
|
| 241 | +
|
| 242 | +
|
| 243 | +Solution |
| 244 | +~~~~~~~~ |
| 245 | + |
| 246 | +Clone the `tensor <https://github.com/salesforce/BLIP/blob/main/models/blip.py#L112>`__ where export fails. |
| 247 | + |
| 248 | +.. code:: python |
| 249 | +
|
| 250 | + text.input_ids = text.input_ids.clone() # clone the tensor |
| 251 | + text.input_ids[:,0] = self.tokenizer.bos_token_id |
| 252 | +
|
| 253 | +
|
| 254 | +
|
| 255 | +Promptable Image Segmentation |
| 256 | +----------------------------- |
| 257 | + |
| 258 | +Image segmentation is a computer vision technique that divides a digital image into distinct groups of pixels, or segments, based on their characteristics. |
| 259 | +Segment Anything Model(`SAM <https://ai.meta.com/blog/segment-anything-foundation-model-image-segmentation/>`__) introduced promptable image segmentation, which predicts object masks given prompts that indicate the desired object. `SAM 2 <https://ai.meta.com/sam2/>`__ is |
| 260 | +the first unified model for segmenting objects across images and videos. The `SAM2ImagePredictor <https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py#L20>`__ class provides an easy interface to the model for prompting |
| 261 | +the model. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction. Since SAM2 provides strong |
| 262 | +zero-shot performance for object tracking, it can be used for tracking game objects in a scene. The code below tries to export SAM2ImagePredictor with batch_size=1 |
| 263 | + |
| 264 | + |
| 265 | +The tensor operations in the predict method of `SAM2ImagePredictor <https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py#L20>`__ are happening in the `_predict <https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py#L291>`__ method. So, we try to export this. |
| 266 | + |
| 267 | +.. code:: python |
| 268 | +
|
| 269 | + ep = torch.export.export( |
| 270 | + self._predict, |
| 271 | + args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output), |
| 272 | + kwargs={"return_logits": return_logits}, |
| 273 | + strict=False, |
| 274 | + ) |
| 275 | +
|
| 276 | +
|
| 277 | +Error: Model is not of type `torch.nn.Module` |
| 278 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 279 | + |
| 280 | +`torch.export` expects the module to be of type `torch.nn.Module`. However, the module we are trying to export is a class method. Hence it errors. |
| 281 | + |
| 282 | +.. code:: console |
| 283 | +
|
| 284 | + Traceback (most recent call last): |
| 285 | + File "/sam2/image_predict.py", line 20, in <module> |
| 286 | + masks, scores, _ = predictor.predict( |
| 287 | + File "/sam2/sam2/sam2_image_predictor.py", line 312, in predict |
| 288 | + ep = torch.export.export( |
| 289 | + File "python3.10/site-packages/torch/export/__init__.py", line 359, in export |
| 290 | + raise ValueError( |
| 291 | + ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'method'>. |
| 292 | +
|
| 293 | +
|
| 294 | +Solution |
| 295 | +~~~~~~~~ |
| 296 | + |
| 297 | +We write a helper class, which inherits from `torch.nn.Module` and call the `_predict method` in the `forward` method of the class. The complete code can be found `here <https://github.com/anijain2305/sam2/blob/ued/sam2/sam2_image_predictor.py#L293-L311>`__. |
| 298 | + |
| 299 | +.. code:: python |
| 300 | +
|
| 301 | + class ExportHelper(torch.nn.Module): |
| 302 | + def __init__(self): |
| 303 | + super().__init__() |
| 304 | +
|
| 305 | + def forward(_, *args, **kwargs): |
| 306 | + return self._predict(*args, **kwargs) |
| 307 | +
|
| 308 | + model_to_export = ExportHelper() |
| 309 | + ep = torch.export.export( |
| 310 | + model_to_export, |
| 311 | + args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output), |
| 312 | + kwargs={"return_logits": return_logits}, |
| 313 | + strict=False, |
| 314 | + ) |
| 315 | +
|
| 316 | +Conclusion |
| 317 | +---------- |
| 318 | + |
| 319 | +In this tutorial, we have learned how to use `torch.export` to export models for popular use cases by addressing challenges through correct configuration & simple code modifications. |
0 commit comments