Skip to content

Commit c77b057

Browse files
committed
Torch-TRT save through torch.export.save
1 parent 4130fb2 commit c77b057

File tree

14 files changed

+105
-33
lines changed

14 files changed

+105
-33
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License.
2727
- new: GPU and Host memory usage logging
2828
- change: Install the TensorRT package for architectures other than x86_64
2929
- change: Disable conversion fallback for TensorRT paths and expose control option in custom config
30+
- change: Use torch.export.save for Torch-TRT model serialization
3031
- fix: Correctness command relative tolerance formula
3132
- fix: Memory management during export and conversion process for Torch
3233

model_navigator/commands/convert/converters/ep2torchtrt.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@
1414
"""Convert ExportedProgram model to Torch-TensorRT model."""
1515

1616
import pathlib
17-
from typing import Any, Dict, Optional
17+
from typing import Any, Dict, List, Optional
1818

1919
import fire
2020
import numpy as np
2121
import torch # pytype: disable=import-error
2222
from loguru import logger
23+
from packaging.version import Version
2324

2425
from model_navigator.configuration import TensorRTPrecision, TensorRTPrecisionMode
2526
from model_navigator.configuration.device import map_device_string
26-
from model_navigator.core.dataloader import load_samples
27+
from model_navigator.core.dataloader import expand_sample, load_samples
28+
from model_navigator.core.logger import LOGGER
2729
from model_navigator.core.tensor import TensorMetadata
2830
from model_navigator.frameworks.tensorrt import utils as tensorrt_utils
29-
from model_navigator.frameworks.tensorrt.timing_tactics import TimingCacheManager, trt_cache_inplace_cache_dir
31+
from model_navigator.frameworks.tensorrt.timing_tactics import TimingCacheManager
3032
from model_navigator.utils.common import numpy_to_torch_dtype
3133

3234

@@ -59,9 +61,10 @@ def convert(
5961
exported_model_path: str,
6062
converted_model_path: str,
6163
input_metadata: Dict[str, Any],
62-
shapes: Dict[str, Dict[str, int]],
64+
shapes: Dict[str, Dict[str, List[int]]],
6365
batch_dim: Optional[int],
6466
max_workspace_size: int,
67+
pickle_protocol: int,
6568
precision: str,
6669
precision_mode: str,
6770
target_device: str,
@@ -82,12 +85,13 @@ def convert(
8285
and respective values.
8386
batch_dim: Batch dimension.
8487
max_workspace_size: Maximum workspace size in bytes.
88+
pickle_protocol: Pickle protocol used during model serialization
8589
precision: TensorRT precision. Could be "fp16" or "fp32".
8690
precision_mode: TensorRT precision mode.
8791
target_device: Device on which perform the conversion
8892
debug: If True print debug logs.
8993
custom_args: Dictionary with passthrough parameters. For available arguments check PyTorch
90-
documentation: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html
94+
documentation: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html
9195
timing_cache_dir: Directory to save timing cache. Defaults to None which means it will be saved in workspace root.
9296
model_name: Model name for the timing cache. Defaults to None which means it will be named after the model file.
9397
navigator_workspace: Model Navigator workspace path. When None use current workdir. Defaults to None.
@@ -106,9 +110,34 @@ def convert(
106110

107111
conversion_sample = load_samples("conversion_samples", navigator_workspace, batch_dim)[0]
108112

113+
if batch_dim is None:
114+
max_batch_size = None
115+
expanded_sample = expand_sample(conversion_sample, input_metadata, batch_dim=batch_dim, batch_size=None)
116+
else:
117+
# WAR to make data dynamic
118+
max_batch_size = list(shapes.values())[0]["max"][0]
119+
batch_size = 2 if max_batch_size > 1 else 1 # select the minimum value to expand samples
120+
expanded_sample = expand_sample(conversion_sample, input_metadata, batch_dim=batch_dim, batch_size=batch_size)
121+
122+
dummy_input = {n: torch.from_numpy(val).to(target_device) for n, val in expanded_sample.items()}
123+
dummy_input = input_metadata.unflatten_sample(dummy_input, wrap_input=False)
124+
125+
if not isinstance(dummy_input, tuple):
126+
dummy_input = (dummy_input,)
127+
if not isinstance(dummy_input[-1], dict):
128+
dummy_input = (*dummy_input, {})
129+
*args, kwargs = dummy_input
130+
109131
input_dtypes = [numpy_to_torch_dtype(np.dtype(input_dtype)) for input_dtype in input_dtypes]
110132
model_input_shapes = []
111-
for input_shapes, input_dtype in zip(shapes.values(), input_dtypes):
133+
dynamic_shapes = []
134+
for input_name, input_dtype in zip(shapes.keys(), input_dtypes):
135+
input_shapes = shapes.get(input_name)
136+
tensor_metadata = input_metadata.get(input_name)
137+
if not tensor_metadata or not input_shapes:
138+
LOGGER.warning(f"Input metadata or input shapes for input {input_name} is not found")
139+
continue
140+
112141
model_input_shapes.append(
113142
torch_tensorrt.Input(
114143
min_shape=input_shapes["min"],
@@ -118,6 +147,18 @@ def convert(
118147
)
119148
)
120149

150+
dynamic_shape_map = {}
151+
if max_batch_size is not None and max_batch_size > 1 and len(tensor_metadata.shape) > 0:
152+
dynamic_shape_map[0] = torch.export.Dim(f"{input_name}_batch", min=1, max=max_batch_size)
153+
154+
for idx in range(1, len(input_shapes["min"])):
155+
min_value = input_shapes["min"][idx]
156+
max_value = input_shapes["max"][idx]
157+
if min_value != max_value:
158+
dynamic_shape_map[idx] = torch.export.Dim(f"{input_name}__{idx}", min=min_value, max=max_value)
159+
160+
dynamic_shapes.append(dynamic_shape_map)
161+
121162
exported_model_path = pathlib.Path(exported_model_path)
122163
if not exported_model_path.is_absolute():
123164
exported_model_path = navigator_workspace / exported_model_path
@@ -134,19 +175,14 @@ def convert(
134175

135176
target_device = map_device_string(target_device)
136177

137-
# saving timing cache in model_navigator workspace or ...
138-
timing_cache = trt_cache_inplace_cache_dir()
139-
if timing_cache_dir is not None:
140-
timing_cache = pathlib.Path(timing_cache_dir)
141-
142178
with TimingCacheManager(model_name=model_name, cache_path=timing_cache_dir) as timing_cache:
143179
timing_cache_path = timing_cache.as_posix() if timing_cache else None
144180

145181
# reusing custom_args as dynamo.compile has a default cache path argument
146182
if timing_cache_path is not None:
147183
custom_args["timing_cache_path"] = timing_cache_path
148184

149-
tr_model_compiled = torch_tensorrt.dynamo.compile(
185+
trt_model_compiled = torch_tensorrt.dynamo.compile(
150186
exported_program=model,
151187
inputs=model_input_shapes,
152188
workspace_size=max_workspace_size,
@@ -155,15 +191,24 @@ def convert(
155191
**custom_args,
156192
)
157193

194+
exported_model = torch.export.export(
195+
trt_model_compiled,
196+
args=tuple(args),
197+
kwargs=kwargs,
198+
dynamic_shapes=dynamic_shapes,
199+
strict=False,
200+
)
201+
158202
converted_model_path = pathlib.Path(converted_model_path)
159203
if not converted_model_path.is_absolute():
160204
converted_model_path = navigator_workspace / converted_model_path
161205

162-
inputs = []
163-
for _, val in conversion_sample.items():
164-
inputs.append(torch.from_numpy(val).to(target_device))
206+
save_kwargs = {}
207+
if Version(torch.__version__) > Version("2.6"):
208+
LOGGER.info("Using pickle protocol {}.", pickle_protocol)
209+
save_kwargs["pickle_protocol"] = pickle_protocol
165210

166-
torch_tensorrt.save(tr_model_compiled, converted_model_path.as_posix(), inputs=inputs)
211+
torch.export.save(exported_model, converted_model_path.as_posix(), **save_kwargs)
167212

168213

169214
if __name__ == "__main__":

model_navigator/commands/convert/torch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def _run(
114114
precision: TensorRTPrecision,
115115
precision_mode: TensorRTPrecisionMode,
116116
max_workspace_size: int,
117+
pickle_protocol: int,
117118
verbose: bool,
118119
debug: bool,
119120
dataloader_trt_profile: TensorRTProfile,
@@ -138,6 +139,7 @@ def _run(
138139
precision: TensorRTPrecision.
139140
precision_mode: TensorRT precision mode.
140141
max_workspace_size: TensorRT maximum workspace size.
142+
pickle_protocol: Pickle protocol for model serialization.
141143
verbose: If True verbose logging.
142144
debug: If True print debug logs.
143145
dataloader_trt_profile: Dataloader TensorRT profile.
@@ -183,6 +185,7 @@ def get_args(max_batch_size=None):
183185
"max_workspace_size": max_workspace_size,
184186
"precision": precision.value,
185187
"precision_mode": precision_mode.value,
188+
"pickle_protocol": pickle_protocol,
186189
"navigator_workspace": workspace.path.as_posix(),
187190
"target_device": target_device.value,
188191
"custom_args": custom_args,

model_navigator/commands/export/exporters/torch2dynamo_onnx.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,15 @@ def expand_batch_dim(tensor, batch_dim, max_batch_size):
124124
if not tensor_metadata:
125125
continue
126126

127-
dynamic_shapes_ = {}
127+
dynamic_shape_map = {}
128128
if max_batch_size is not None and max_batch_size > 1 and len(tensor_metadata.shape) > 0:
129-
dynamic_shapes_[0] = torch.export.Dim("batch", min=1, max=max_batch_size)
129+
dynamic_shape_map[0] = torch.export.Dim("batch", min=1, max=max_batch_size)
130130

131131
for idx in range(1, len(spec_.min)):
132132
if spec_.min[idx] != spec_.max[idx]:
133-
dynamic_shapes_[idx] = torch.export.Dim(f"{name}__{idx}", min=spec_.min[idx], max=spec_.max[idx])
133+
dynamic_shape_map[idx] = torch.export.Dim(f"{name}__{idx}", min=spec_.min[idx], max=spec_.max[idx])
134134

135-
dynamic_shapes.append(dynamic_shapes_)
135+
dynamic_shapes.append(dynamic_shape_map)
136136

137137
try:
138138
exported_model = torch.onnx.export(

model_navigator/commands/export/exporters/torch2exportedprogram.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,8 @@ def export(
7777
max_batch_size = None
7878
expanded_sample = expand_sample(conversion_sample, input_metadata, batch_dim=batch_dim, batch_size=None)
7979
else:
80-
# WAR for to big batch size value
81-
max_batch_size = max_batch_size if max_batch_size < 2048 else max_batch_size - 1
8280
# WAR to make data dynamic
83-
batch_size = min(2, max_batch_size) # select the minimum value to expand samples
81+
batch_size = 2 if max_batch_size > 1 else 1 # select the minimum value to expand samples
8482
expanded_sample = expand_sample(conversion_sample, input_metadata, batch_dim=batch_dim, batch_size=batch_size)
8583

8684
dummy_input = {n: torch.from_numpy(val).to(target_device) for n, val in expanded_sample.items()}
@@ -103,15 +101,15 @@ def export(
103101
if not tensor_metadata:
104102
continue
105103

106-
dynamic_shapes_ = {}
104+
dynamic_shape_map = {}
107105
if max_batch_size is not None and max_batch_size > 1 and len(tensor_metadata.shape) > 0:
108-
dynamic_shapes_[0] = torch.export.Dim(f"{name}_batch", min=1, max=max_batch_size)
106+
dynamic_shape_map[0] = torch.export.Dim(f"{name}_batch", min=1, max=max_batch_size)
109107

110108
for idx in range(1, len(spec_.min)):
111109
if spec_.min[idx] != spec_.max[idx]:
112-
dynamic_shapes_[idx] = torch.export.Dim(f"{name}__{idx}", min=spec_.min[idx], max=spec_.max[idx])
110+
dynamic_shape_map[idx] = torch.export.Dim(f"{name}__{idx}", min=spec_.min[idx], max=spec_.max[idx])
113111

114-
dynamic_shapes.append(dynamic_shapes_)
112+
dynamic_shapes.append(dynamic_shape_map)
115113

116114
try:
117115
exported_model = torch.export.export(

model_navigator/commands/export/exporters/torch2torchscript.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from model_navigator.core.dataloader import load_samples
2525
from model_navigator.core.tensor import TensorMetadata
2626
from model_navigator.exceptions import ModelNavigatorUserInputError
27+
from model_navigator.frameworks.torch.utils import offload_torch_model_to_cpu
2728
from model_navigator.utils.common import numpy_to_torch_dtype
2829

2930

@@ -122,6 +123,8 @@ def export(
122123

123124
torch.jit.save(script_module, exported_model_path.as_posix())
124125

126+
offload_torch_model_to_cpu(script_module)
127+
125128

126129
if __name__ == "__main__":
127130
fire.Fire(export)

model_navigator/configuration/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
DEFAULT_MIN_SEGMENT_SIZE,
5050
DEFAULT_MIN_TRIALS,
5151
DEFAULT_ONNX_OPSET,
52+
DEFAULT_PICKLE_PROTOCOL_TORCHTRT,
5253
DEFAULT_STABILITY_PERCENTAGE,
5354
DEFAULT_STABILIZATION_WINDOWS,
5455
DEFAULT_THROUGHPUT_BACKOFF_LIMIT,
@@ -844,6 +845,7 @@ class TorchTensorRTConfig(CustomConfigForTensorRT):
844845
"""Torch custom config used for TensorRT TorchScript conversion."""
845846

846847
max_workspace_size: Optional[int] = DEFAULT_MAX_WORKSPACE_SIZE_TORCHTRT
848+
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL_TORCHTRT
847849

848850
@property
849851
def format(self) -> Format:

model_navigator/configuration/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
DEFAULT_MAX_WORKSPACE_SIZE_TORCHTRT = (
3838
0 # Default to use full device memory https://pytorch.org/TensorRT/py_api/dynamo.html
3939
)
40+
DEFAULT_PICKLE_PROTOCOL_TORCHTRT = 5
4041
DEFAULT_MIN_SEGMENT_SIZE = 3
4142
DEFAULT_TENSORRT_MAX_DIMENSION_SIZE = 2**31 - 1
4243
OPT_MAX_SHAPE_RATIO = 4 / 5

model_navigator/configuration/model/model_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ def __init__(
609609
precision: TensorRTPrecision,
610610
precision_mode: TensorRTPrecisionMode,
611611
max_workspace_size: int,
612+
pickle_protocol: int,
612613
trt_profiles: Optional[List[TensorRTProfile]] = None,
613614
parent: Optional[ModelConfig] = None,
614615
custom_args: Optional[Dict[str, Any]] = None,
@@ -622,6 +623,7 @@ def __init__(
622623
precision: TensorRT model precision
623624
precision_mode: Mode how the precision flags are combined
624625
max_workspace_size: The maximum GPU memory the model can use temporarily during execution
626+
pickle_protocol: Pickle protocol used during model serialization
625627
trt_profiles: TensorRT profiles
626628
custom_args: Custom arguments passed to Torch TensorRT conversion
627629
device: runtime device e.g. "cuda:0"
@@ -631,6 +633,7 @@ def __init__(
631633
self.precision = precision
632634
self.precision_mode = precision_mode
633635
self.max_workspace_size = max_workspace_size
636+
self.pickle_protocol = pickle_protocol
634637
self.trt_profiles = trt_profiles
635638
self.custom_args = custom_args
636639
self.runner_config = DeviceRunnerConfig(device=device)
@@ -648,6 +651,7 @@ def _from_dict(cls, data_dict: Dict):
648651
precision=cls._parse_string(TensorRTPrecision, data_dict.get("precision")),
649652
precision_mode=cls._parse_string(TensorRTPrecisionMode, data_dict.get("precision_mode")),
650653
max_workspace_size=cls._parse_string(int, data_dict.get("max_workspace_size")),
654+
pickle_protocol=cls._parse_string(int, data_dict.get("pickle_protocol")),
651655
trt_profiles=trt_profiles,
652656
device=data_dict.get("device"),
653657
conversion_fallback=data_dict.get("conversion_fallback", False),

model_navigator/configuration/model/model_config_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def get_torch_trt_config(
308308
precision=precision,
309309
precision_mode=torch_trt_config.precision_mode,
310310
max_workspace_size=torch_trt_config.max_workspace_size,
311+
pickle_protocol=torch_trt_config.pickle_protocol,
311312
trt_profiles=torch_trt_config.trt_profiles,
312313
custom_args=torch_trt_config.custom_args,
313314
device=torch_trt_config.device,

0 commit comments

Comments
 (0)