Skip to content

Commit 2e414f7

Browse files
committed
Added a recipe for showcasing torch.export flow for 4 models
1 parent 1263f06 commit 2e414f7

File tree

2 files changed

+326
-0
lines changed

2 files changed

+326
-0
lines changed

recipes_source/recipes_index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
157157
:link: ../recipes/torch_export_aoti_python.html
158158
:tags: Basics
159159

160+
.. customcarditem::
161+
:header: Demonstration of torch.export flow, common challenges and the solutions to address them
162+
:card_description: Learn how to export models for popular usecases
163+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
164+
:link: ../recipes/torch_export_challenges_solutions.html
165+
:tags: Basics
166+
160167
.. Interpretability
161168
162169
.. customcarditem::
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
Demonstration of torch.export flow, common challenges and the solutions to address them
2+
=======================================================================================
3+
**Authors:** `Ankith Gunapal`, `Jordi Ramon`, `Marcos Carranza`
4+
5+
In a previous `tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`__ , we learnt how to use `torch.export <https://pytorch.org/docs/stable/export.html>`__.
6+
This tutorial builds on the previous tutorial and explores the process of exporting popular models with code & addresses common challenges one might face with `torch.export`.
7+
8+
You will learn how to export models for these usecases
9+
10+
* Video classifier (MViT)
11+
* Pose Estimation (Yolov11 Pose)
12+
* Image Captioning (BLIP)
13+
* Promptable Image Segmentation (SAM2)
14+
15+
Each of the four models were chosen to demonstrate unique features of `torch.export`, some practical considerations
16+
& issues faced in the implementation.
17+
18+
Prerequisites
19+
-------------
20+
21+
* PyTorch 2.4 or later
22+
* Basic understanding of ``torch.export`` and PyTorch Eager inference.
23+
24+
25+
Key requirement for `torch.export`: No graph break
26+
------------------------------------------------
27+
28+
`torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__ speeds up PyTorch code by JIT compiling PyTorch code into optimized kernels. It optimizes the given model
29+
using TorchDynamo and creates an optimized graph , which is then lowered into the hardware using the backend specified in the API.
30+
When TorchDynamo encounters unsupported Python features, it breaks the computation graph, lets the default Python interpreter
31+
handle the unsupported code, then resumes capturing the graph. This break in the computation graph is called a `graph break <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#torchdynamo-and-fx-graphs>`__.
32+
33+
One of the key differences between `torch.export` and `torch.compile` is that `torch.export` doesn’t support graph breaks
34+
i.e the entire model or part of the model that you are exporting needs to be a single graph. This is because handling graph breaks
35+
involves interpreting the unsupported operation with default Python evaluation, which is incompatible with what torch.export is
36+
designed for.
37+
38+
You can identify graph breaks in your program by using the following
39+
40+
.. code:: console
41+
42+
TORCH_LOGS="graph_breaks" python <file_name>.py
43+
44+
You will need to modify your program to get rid of graph breaks. Once resolved, you are ready to export the model.
45+
PyTorch runs `nightly benchmarks <https://hud.pytorch.org/benchmark/compilers>`__ for `torch.compile` on popular HuggingFace and TIMM models.
46+
Most of these models have no graph break.
47+
48+
The models in this recipe have no graph break, but fail with `torch.export`
49+
50+
Video Classification
51+
--------------------
52+
53+
MViT is a class of models based on `MultiScale Vision Transformers <https://arxiv.org/abs/2104.11227>`__. This has been trained for video classification using the `Kinetics-400 Dataset <https://arxiv.org/abs/1705.06950>`__.
54+
This model with a relevant dataset can be used for action recognition in the context of gaming.
55+
56+
57+
The code below exports MViT by tracing with `batch_size=2` and then checks if the ExportedProgram can run with `batch_size=4`
58+
59+
.. code:: python
60+
61+
import numpy as np
62+
import torch
63+
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
64+
import traceback as tb
65+
66+
model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)
67+
68+
# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
69+
input_frames = torch.randn(2,16, 224, 224, 3)
70+
# Transpose to get [1, 3, num_clips, height, width].
71+
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
72+
73+
# Export the model.
74+
exported_program = torch.export.export(
75+
model,
76+
(input_frames,),
77+
)
78+
79+
# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
80+
input_frames = torch.randn(4,16, 224, 224, 3)
81+
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
82+
try:
83+
exported_program.module()(input_frames)
84+
except Exception:
85+
tb.print_exc()
86+
87+
88+
Error: Static batch size
89+
~~~~~~~~~~~~~~~~~~~~~~~~
90+
91+
.. code:: console
92+
93+
raise RuntimeError(
94+
RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 4
95+
96+
97+
By default, the exporting flow will trace the program assuming that all input shapes are static, so if you run the program with
98+
inputs shapes that are different than the ones you used while tracing, you will run into an error.
99+
100+
Solution
101+
~~~~~~~~
102+
103+
To address the error, we specify the first dimension of the input (`batch_size`) to be dynamic , specifying the expected range of `batch_size`.
104+
In the corrected example shown below, we specify that the expected `batch_size` can range from 1 to 16.
105+
One detail to notice that `min=2` is not a bug and is explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__. A detailed description of dynamic shapes
106+
for torch.export can be found in the export tutorial. The code shown below demonstrates how to export mViT with dynamic batch sizes.
107+
108+
.. code:: python
109+
110+
import numpy as np
111+
import torch
112+
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
113+
import traceback as tb
114+
115+
116+
model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)
117+
118+
# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
119+
input_frames = torch.randn(2,16, 224, 224, 3)
120+
121+
# Transpose to get [1, 3, num_clips, height, width].
122+
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
123+
124+
# Export the model.
125+
batch_dim = torch.export.Dim("batch", min=2, max=16)
126+
exported_program = torch.export.export(
127+
model,
128+
(input_frames,),
129+
# Specify the first dimension of the input x as dynamic
130+
dynamic_shapes={"x": {0: batch_dim}},
131+
)
132+
133+
# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
134+
input_frames = torch.randn(4,16, 224, 224, 3)
135+
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
136+
try:
137+
exported_program.module()(input_frames)
138+
except Exception:
139+
tb.print_exc()
140+
141+
142+
143+
144+
145+
Pose Estimation
146+
---------------
147+
148+
Pose Estimation is a popular Computer Vision concept that can be used to identify the location of joints of a human in a 2D image.
149+
`Ultralytics <https://docs.ultralytics.com/tasks/pose/>`__ has published a Pose Estimation model based on `YOLO11 <https://docs.ultralytics.com/models/yolo11/>`__. This has been trained on the `COCO Dataset <https://cocodataset.org/#keypoints-2017>`__. This model can be used
150+
for analyzing human pose for determining action or intent. The code below tries to export the YOLO11 Pose model with `batch_size=1`
151+
152+
153+
.. code:: python
154+
155+
from ultralytics import YOLO
156+
import torch
157+
from torch.export import export
158+
159+
pose_model = YOLO("yolo11n-pose.pt") # Load model
160+
pose_model.model.eval()
161+
162+
inputs = torch.rand((1,3,640,640))
163+
exported_program: torch.export.ExportedProgram= export(pose_model.model, args=(inputs,))
164+
165+
166+
Error: strict tracing with TorchDynamo
167+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
168+
169+
.. code:: console
170+
171+
torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {zuf0} not in returned outputs FakeTensor(..., size=(6400, 1)) ((1, 1), 0).
172+
173+
174+
By default `torch.export` traces your code using `TorchDynamo <https://pytorch.org/docs/stable/torch.compiler_dynamo_overview.html>`__, a byte-code analysis engine, which symbolically analyzes your code and builds a graph.
175+
This analysis provides a stronger guarantee about safety but not all python code is supported. When we export the `yolo11n-pose` model using the
176+
default strict mode, it errors.
177+
178+
Solution
179+
~~~~~~~~
180+
181+
To address the above error `torch.export` supports non_strict mode where the program is traced using the python interpreter, which works similar to
182+
PyTorch eager execution, the only difference is that all Tensor objects will be replaced by ProxyTensors, which will record all their operations into
183+
a graph. By using `strict=False`, we are able to export the program.
184+
185+
.. code:: python
186+
187+
from ultralytics import YOLO
188+
import torch
189+
from torch.export import export
190+
191+
pose_model = YOLO("yolo11n-pose.pt") # Load model
192+
pose_model.model.eval()
193+
194+
inputs = torch.rand((1,3,640,640))
195+
exported_program: torch.export.ExportedProgram= export(pose_model.model, args=(inputs,), strict=False)
196+
197+
198+
199+
Image Captioning
200+
----------------
201+
202+
Image Captioning is the task of defining the contents of an image in words. In the context of gaming, Image Captioning can be used to enhance the
203+
gameplay experience by dynamically generating text description of the various game objects in the scene, thereby providing the gamer with additional
204+
details. `BLIP <https://arxiv.org/pdf/2201.12086>`__ is a popular model for Image Captioning `released by SalesForce Research <https://github.com/salesforce/BLIP>`__. The code below tries to export BLIP with `batch_size=1`
205+
206+
207+
.. code:: python
208+
209+
import torch
210+
from models.blip import blip_decoder
211+
212+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
213+
image_size = 384
214+
image = torch.randn(1, 3,384,384).to(device)
215+
caption_input = ""
216+
217+
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
218+
model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
219+
model.eval()
220+
model = model.to(device)
221+
222+
exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False)
223+
224+
225+
226+
Error: Unsupported python operations
227+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
228+
229+
While exporting a model, it might fail because the model implementation might contain certain python operations which are not yet supported by `torch.export`.
230+
Some of these failures may have a workaround. BLIP is an example where the original model errors and making a small change in the code resolves the issue.
231+
`torch.export` lists the common cases of supported and unsupported operations in `ExportDB <https://pytorch.org/docs/main/generated/exportdb/index.html>`__ and shows how you can modify your code to make it export compatible.
232+
233+
.. code:: console
234+
235+
File "/BLIP/models/blip.py", line 112, in forward
236+
text.input_ids[:,0] = self.tokenizer.bos_token_id
237+
File "/anaconda3/envs/export/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 545, in __torch_dispatch__
238+
outs_unwrapped = func._op_dk(
239+
RuntimeError: cannot mutate tensors with frozen storage
240+
241+
242+
243+
Solution
244+
~~~~~~~~
245+
246+
Clone the `tensor <https://github.com/salesforce/BLIP/blob/main/models/blip.py#L112>`__ where export fails.
247+
248+
.. code:: python
249+
250+
text.input_ids = text.input_ids.clone() # clone the tensor
251+
text.input_ids[:,0] = self.tokenizer.bos_token_id
252+
253+
254+
255+
Promptable Image Segmentation
256+
-----------------------------
257+
258+
Image segmentation is a computer vision technique that divides a digital image into distinct groups of pixels, or segments, based on their characteristics.
259+
Segment Anything Model(`SAM <https://ai.meta.com/blog/segment-anything-foundation-model-image-segmentation/>`__) introduced promptable image segmentation, which predicts object masks given prompts that indicate the desired object. `SAM 2 <https://ai.meta.com/sam2/>`__ is
260+
the first unified model for segmenting objects across images and videos. The `SAM2ImagePredictor <https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py#L20>`__ class provides an easy interface to the model for prompting
261+
the model. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction. Since SAM2 provides strong
262+
zero-shot performance for object tracking, it can be used for tracking game objects in a scene. The code below tries to export SAM2ImagePredictor with batch_size=1
263+
264+
265+
The tensor operations in the predict method of `SAM2ImagePredictor <https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py#L20>`__ are happening in the `_predict <https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py#L291>`__ method. So, we try to export this.
266+
267+
.. code:: python
268+
269+
ep = torch.export.export(
270+
self._predict,
271+
args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
272+
kwargs={"return_logits": return_logits},
273+
strict=False,
274+
)
275+
276+
277+
Error: Model is not of type `torch.nn.Module`
278+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
279+
280+
`torch.export` expects the module to be of type `torch.nn.Module`. However, the module we are trying to export is a class method. Hence it errors.
281+
282+
.. code:: console
283+
284+
Traceback (most recent call last):
285+
File "/sam2/image_predict.py", line 20, in <module>
286+
masks, scores, _ = predictor.predict(
287+
File "/sam2/sam2/sam2_image_predictor.py", line 312, in predict
288+
ep = torch.export.export(
289+
File "python3.10/site-packages/torch/export/__init__.py", line 359, in export
290+
raise ValueError(
291+
ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'method'>.
292+
293+
294+
Solution
295+
~~~~~~~~
296+
297+
We write a helper class, which inherits from `torch.nn.Module` and call the `_predict method` in the `forward` method of the class. The complete code can be found `here <https://github.com/anijain2305/sam2/blob/ued/sam2/sam2_image_predictor.py#L293-L311>`__.
298+
299+
.. code:: python
300+
301+
class ExportHelper(torch.nn.Module):
302+
def __init__(self):
303+
super().__init__()
304+
305+
def forward(_, *args, **kwargs):
306+
return self._predict(*args, **kwargs)
307+
308+
model_to_export = ExportHelper()
309+
ep = torch.export.export(
310+
model_to_export,
311+
args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
312+
kwargs={"return_logits": return_logits},
313+
strict=False,
314+
)
315+
316+
Conclusion
317+
----------
318+
319+
In this tutorial, we have learned how to use `torch.export` to export models for popular use cases by addressing challenges through correct configuration & simple code modifications.

0 commit comments

Comments
 (0)