1414"""Convert ExportedProgram model to Torch-TensorRT model."""
1515
1616import pathlib
17- from typing import Any , Dict , Optional
17+ from typing import Any , Dict , List , Optional
1818
1919import fire
2020import numpy as np
2121import torch # pytype: disable=import-error
2222from loguru import logger
23+ from packaging .version import Version
2324
2425from model_navigator .configuration import TensorRTPrecision , TensorRTPrecisionMode
2526from model_navigator .configuration .device import map_device_string
26- from model_navigator .core .dataloader import load_samples
27+ from model_navigator .core .dataloader import expand_sample , load_samples
28+ from model_navigator .core .logger import LOGGER
2729from model_navigator .core .tensor import TensorMetadata
2830from model_navigator .frameworks .tensorrt import utils as tensorrt_utils
29- from model_navigator .frameworks .tensorrt .timing_tactics import TimingCacheManager , trt_cache_inplace_cache_dir
31+ from model_navigator .frameworks .tensorrt .timing_tactics import TimingCacheManager
3032from model_navigator .utils .common import numpy_to_torch_dtype
3133
3234
@@ -59,9 +61,10 @@ def convert(
5961 exported_model_path : str ,
6062 converted_model_path : str ,
6163 input_metadata : Dict [str , Any ],
62- shapes : Dict [str , Dict [str , int ]],
64+ shapes : Dict [str , Dict [str , List [ int ] ]],
6365 batch_dim : Optional [int ],
6466 max_workspace_size : int ,
67+ pickle_protocol : int ,
6568 precision : str ,
6669 precision_mode : str ,
6770 target_device : str ,
@@ -82,12 +85,13 @@ def convert(
8285 and respective values.
8386 batch_dim: Batch dimension.
8487 max_workspace_size: Maximum workspace size in bytes.
88+ pickle_protocol: Pickle protocol used during model serialization
8589 precision: TensorRT precision. Could be "fp16" or "fp32".
8690 precision_mode: TensorRT precision mode.
8791 target_device: Device on which perform the conversion
8892 debug: If True print debug logs.
8993 custom_args: Dictionary with passthrough parameters. For available arguments check PyTorch
90- documentation: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html
94+ documentation: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html
9195 timing_cache_dir: Directory to save timing cache. Defaults to None which means it will be saved in workspace root.
9296 model_name: Model name for the timing cache. Defaults to None which means it will be named after the model file.
9397 navigator_workspace: Model Navigator workspace path. When None use current workdir. Defaults to None.
@@ -106,9 +110,34 @@ def convert(
106110
107111 conversion_sample = load_samples ("conversion_samples" , navigator_workspace , batch_dim )[0 ]
108112
113+ if batch_dim is None :
114+ max_batch_size = None
115+ expanded_sample = expand_sample (conversion_sample , input_metadata , batch_dim = batch_dim , batch_size = None )
116+ else :
117+ # WAR to make data dynamic
118+ max_batch_size = list (shapes .values ())[0 ]["max" ][0 ]
119+ batch_size = 2 if max_batch_size > 1 else 1 # select the minimum value to expand samples
120+ expanded_sample = expand_sample (conversion_sample , input_metadata , batch_dim = batch_dim , batch_size = batch_size )
121+
122+ dummy_input = {n : torch .from_numpy (val ).to (target_device ) for n , val in expanded_sample .items ()}
123+ dummy_input = input_metadata .unflatten_sample (dummy_input , wrap_input = False )
124+
125+ if not isinstance (dummy_input , tuple ):
126+ dummy_input = (dummy_input ,)
127+ if not isinstance (dummy_input [- 1 ], dict ):
128+ dummy_input = (* dummy_input , {})
129+ * args , kwargs = dummy_input
130+
109131 input_dtypes = [numpy_to_torch_dtype (np .dtype (input_dtype )) for input_dtype in input_dtypes ]
110132 model_input_shapes = []
111- for input_shapes , input_dtype in zip (shapes .values (), input_dtypes ):
133+ dynamic_shapes = []
134+ for input_name , input_dtype in zip (shapes .keys (), input_dtypes ):
135+ input_shapes = shapes .get (input_name )
136+ tensor_metadata = input_metadata .get (input_name )
137+ if not tensor_metadata or not input_shapes :
138+ LOGGER .warning (f"Input metadata or input shapes for input { input_name } is not found" )
139+ continue
140+
112141 model_input_shapes .append (
113142 torch_tensorrt .Input (
114143 min_shape = input_shapes ["min" ],
@@ -118,6 +147,18 @@ def convert(
118147 )
119148 )
120149
150+ dynamic_shape_map = {}
151+ if max_batch_size is not None and max_batch_size > 1 and len (tensor_metadata .shape ) > 0 :
152+ dynamic_shape_map [0 ] = torch .export .Dim (f"{ input_name } _batch" , min = 1 , max = max_batch_size )
153+
154+ for idx in range (1 , len (input_shapes ["min" ])):
155+ min_value = input_shapes ["min" ][idx ]
156+ max_value = input_shapes ["max" ][idx ]
157+ if min_value != max_value :
158+ dynamic_shape_map [idx ] = torch .export .Dim (f"{ input_name } __{ idx } " , min = min_value , max = max_value )
159+
160+ dynamic_shapes .append (dynamic_shape_map )
161+
121162 exported_model_path = pathlib .Path (exported_model_path )
122163 if not exported_model_path .is_absolute ():
123164 exported_model_path = navigator_workspace / exported_model_path
@@ -134,19 +175,14 @@ def convert(
134175
135176 target_device = map_device_string (target_device )
136177
137- # saving timing cache in model_navigator workspace or ...
138- timing_cache = trt_cache_inplace_cache_dir ()
139- if timing_cache_dir is not None :
140- timing_cache = pathlib .Path (timing_cache_dir )
141-
142178 with TimingCacheManager (model_name = model_name , cache_path = timing_cache_dir ) as timing_cache :
143179 timing_cache_path = timing_cache .as_posix () if timing_cache else None
144180
145181 # reusing custom_args as dynamo.compile has a default cache path argument
146182 if timing_cache_path is not None :
147183 custom_args ["timing_cache_path" ] = timing_cache_path
148184
149- tr_model_compiled = torch_tensorrt .dynamo .compile (
185+ trt_model_compiled = torch_tensorrt .dynamo .compile (
150186 exported_program = model ,
151187 inputs = model_input_shapes ,
152188 workspace_size = max_workspace_size ,
@@ -155,15 +191,24 @@ def convert(
155191 ** custom_args ,
156192 )
157193
194+ exported_model = torch .export .export (
195+ trt_model_compiled ,
196+ args = tuple (args ),
197+ kwargs = kwargs ,
198+ dynamic_shapes = dynamic_shapes ,
199+ strict = False ,
200+ )
201+
158202 converted_model_path = pathlib .Path (converted_model_path )
159203 if not converted_model_path .is_absolute ():
160204 converted_model_path = navigator_workspace / converted_model_path
161205
162- inputs = []
163- for _ , val in conversion_sample .items ():
164- inputs .append (torch .from_numpy (val ).to (target_device ))
206+ save_kwargs = {}
207+ if Version (torch .__version__ ) > Version ("2.6" ):
208+ LOGGER .info ("Using pickle protocol {}." , pickle_protocol )
209+ save_kwargs ["pickle_protocol" ] = pickle_protocol
165210
166- torch_tensorrt . save (tr_model_compiled , converted_model_path .as_posix (), inputs = inputs )
211+ torch . export . save (exported_model , converted_model_path .as_posix (), ** save_kwargs )
167212
168213
169214if __name__ == "__main__" :
0 commit comments