2121import onnx_graphsurgeon as gs
2222import torch # pytype: disable=import-error
2323
24+ from model_navigator .configuration import TensorRTProfile
2425from model_navigator .core .dataloader import load_samples
2526from model_navigator .core .tensor import TensorMetadata
2627
@@ -37,6 +38,7 @@ def get_model() -> torch.nn.Module:
3738def export (
3839 exported_model_path : str ,
3940 input_metadata : Dict [str , Any ],
41+ dataloader_trt_profile : Dict [str , Any ],
4042 input_names : List [str ],
4143 output_names : List [str ],
4244 batch_dim : Optional [int ],
@@ -45,12 +47,15 @@ def export(
4547 verbose : bool ,
4648 custom_args : Dict [str , Any ],
4749 navigator_workspace : Optional [str ] = None ,
50+ dataloader_max_batch_size : Optional [int ] = None ,
51+ device_max_batch_size : Optional [int ] = None ,
4852):
4953 """Export Torch model using dynamo.
5054
5155 Args:
5256 exported_model_path (str): Output ONNX model path.
5357 input_metadata (Dict[str, Any]): List of input metadata.
58+ dataloader_trt_profile: Profiles generated based on shapes.
5459 input_names (List[str]): List of model input names.
5560 output_names (List[str]): List of model output names.
5661 batch_dim (Optional[int]): Batch dimension.
@@ -61,6 +66,8 @@ def export(
6166 When None use current workdir. Defaults to None.
6267 custom_args (Optional[Dict[str, Any]], optional): Passthrough parameters for torch.jit.trace
6368 For available arguments check PyTorch documentation: https://pytorch.org/docs/stable/jit.html#torch.jit.trace
69+ dataloader_max_batch_size: Maximum batch size in the dataloader. Defaults to None.
70+ device_max_batch_size: Maximum batch size that fits on the device. Defaults to None.
6471 """
6572 model = get_model ()
6673
@@ -71,6 +78,14 @@ def export(
7178 profiling_sample = load_samples ("profiling_sample" , navigator_workspace , batch_dim )[0 ]
7279 input_metadata = TensorMetadata .from_json (input_metadata )
7380
81+ def expand_batch_dim (tensor , batch_dim , max_batch_size ):
82+ if batch_dim is not None and tensor .shape [batch_dim ] < max_batch_size :
83+ expand_shape = list (tensor .shape )
84+ expand_shape [batch_dim ] = max_batch_size
85+ expanded_tensor = tensor .expand (* expand_shape )
86+ return expanded_tensor
87+ return tensor
88+
7489 dummy_input = {n : torch .from_numpy (val ).to (target_device ) for n , val in profiling_sample .items ()}
7590 dummy_input = input_metadata .unflatten_sample (dummy_input , wrap_input = False )
7691
@@ -80,23 +95,53 @@ def export(
8095 dummy_input = (* dummy_input , {})
8196 * args , kwargs = dummy_input
8297
98+ # Expand batch_dim of tensors to max_batch_size
99+ max_batch_size = device_max_batch_size or dataloader_max_batch_size
100+ if max_batch_size is not None :
101+ args = tuple (
102+ expand_batch_dim (arg , batch_dim , max_batch_size ) if isinstance (arg , torch .Tensor ) else arg for arg in args
103+ )
104+ kwargs = {
105+ k : expand_batch_dim (v , batch_dim , max_batch_size ) if isinstance (v , torch .Tensor ) else v
106+ for k , v in kwargs .items ()
107+ }
108+
83109 loglevel = logging .WARNING if verbose else logging .ERROR
84- export_options_kwargs = {}
85- export_options_kwargs ["diagnostic_options" ] = torch .onnx .DiagnosticOptions (verbosity_level = loglevel )
86- if dynamic_shapes :
87- export_options_kwargs ["dynamic_shapes" ] = True
88- export_options = torch .onnx .ExportOptions (** export_options_kwargs )
89110
90111 root_logger = logging .getLogger ()
91112 original_loglevel = root_logger .getEffectiveLevel ()
92113 root_logger .setLevel (loglevel )
114+
115+ # Dynamic shapes support
116+
117+ # Collect trt profile for min and max shape data
118+ # FIXME: Use a common structure for the min/max shapes
119+ dataloader_trt_profile = TensorRTProfile .from_dict (dataloader_trt_profile )
120+ dynamic_shapes = []
121+ for name , spec_ in dataloader_trt_profile .items ():
122+ tensor_metadata = input_metadata .get (name )
123+ if not tensor_metadata :
124+ continue
125+
126+ dynamic_shapes_ = {}
127+ if max_batch_size is not None and max_batch_size > 1 and len (tensor_metadata .shape ) > 0 :
128+ dynamic_shapes_ [0 ] = torch .export .Dim ("batch" , min = 1 , max = max_batch_size )
129+
130+ for idx in range (1 , len (spec_ .min )):
131+ if spec_ .min [idx ] != spec_ .max [idx ]:
132+ dynamic_shapes_ [idx ] = torch .export .Dim (f"{ name } __{ idx } " , min = spec_ .min [idx ], max = spec_ .max [idx ])
133+
134+ dynamic_shapes .append (dynamic_shapes_ )
135+
93136 try :
94- exported_model = torch .onnx .dynamo_export (
137+ exported_model = torch .onnx .export (
95138 model ,
96- * args ,
139+ args = tuple (args ),
140+ kwargs = kwargs ,
97141 ** custom_args ,
98- ** kwargs ,
99- export_options = export_options ,
142+ dynamo = True ,
143+ dynamic_shapes = dynamic_shapes ,
144+ fallback = False ,
100145 )
101146
102147 exported_model_path = pathlib .Path (exported_model_path )
0 commit comments