@@ -4,9 +4,9 @@ The previous TorchScript-based ONNX exporter would execute the model once to tra
44memory on your GPU if the model's memory requirements exceeded the available GPU memory. This issue has been addressed with the new
55TorchDynamo-based ONNX exporter.
66
7- The TorchDynamo-based ONNX exporter leverages ` FakeTensorMode < https://pytorch.org/docs/stable/ torch.compiler_fake_tensor.html >`_ to
8- avoid performing actual tensor computations during the export process. This approach results in significantly lower memory usage
9- compared to the TorchScript-based ONNX exporter.
7+ The TorchDynamo-based ONNX exporter utilizes torch.export.export() function to leverage
8+ ` FakeTensorMode < https://pytorch.org/docs/stable/torch.compiler_fake_tensor.html >`_ to avoid performing actual tensor computations
9+ during the export process. This approach results in significantly lower memory usage compared to the TorchScript-based ONNX exporter.
1010
1111Below is an example demonstrating the memory usage difference between TorchScript-based and TorchDynamo-based ONNX exporters.
1212In this example, we use the HighResNet model from MONAI. Before proceeding, please install it from PyPI:
@@ -29,7 +29,6 @@ The code below could be run to generate a snapshot file which records the state
2929
3030 import torch
3131
32- from torch.onnx.utils import export
3332 from monai.networks.nets import (
3433 HighResNet,
3534 )
@@ -44,17 +43,19 @@ The code below could be run to generate a snapshot file which records the state
4443 data = torch.randn(30 , 1 , 48 , 48 , 48 , dtype = torch.float32).to(" cuda" )
4544
4645 with torch.no_grad():
47- export(
46+ onnx_program = torch.onnx. export(
4847 model,
4948 data,
5049 " torchscript_exporter_highresnet.onnx" ,
50+ dynamo = False ,
5151 )
5252
53- snapshot_name = f " torchscript_exporter_example.pickle "
53+ snapshot_name = " torchscript_exporter_example.pickle"
5454 print (f " generate { snapshot_name} " )
5555
5656 torch.cuda.memory._dump_snapshot(snapshot_name)
57- print (f " Export is done. " )
57+ print (" Export is done." )
58+
5859
5960 Open `pytorch.org/memory_viz <https://pytorch.org/memory_viz >`_ and drag/drop the generated pickled snapshot file into the visualizer.
6061The memory usage is described as below:
0 commit comments