Skip to content

Commit dc2632e

Browse files
Update onnx dynamo export method
1 parent ea91d48 commit dc2632e

File tree

5 files changed

+86
-15
lines changed

5 files changed

+86
-15
lines changed

model_navigator/commands/export/exporters/torch2dynamo_onnx.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import onnx_graphsurgeon as gs
2222
import torch # pytype: disable=import-error
2323

24+
from model_navigator.configuration import TensorRTProfile
2425
from model_navigator.core.dataloader import load_samples
2526
from model_navigator.core.tensor import TensorMetadata
2627

@@ -37,6 +38,7 @@ def get_model() -> torch.nn.Module:
3738
def export(
3839
exported_model_path: str,
3940
input_metadata: Dict[str, Any],
41+
dataloader_trt_profile: Dict[str, Any],
4042
input_names: List[str],
4143
output_names: List[str],
4244
batch_dim: Optional[int],
@@ -45,12 +47,15 @@ def export(
4547
verbose: bool,
4648
custom_args: Dict[str, Any],
4749
navigator_workspace: Optional[str] = None,
50+
dataloader_max_batch_size: Optional[int] = None,
51+
device_max_batch_size: Optional[int] = None,
4852
):
4953
"""Export Torch model using dynamo.
5054
5155
Args:
5256
exported_model_path (str): Output ONNX model path.
5357
input_metadata (Dict[str, Any]): List of input metadata.
58+
dataloader_trt_profile: Profiles generated based on shapes.
5459
input_names (List[str]): List of model input names.
5560
output_names (List[str]): List of model output names.
5661
batch_dim (Optional[int]): Batch dimension.
@@ -61,6 +66,8 @@ def export(
6166
When None use current workdir. Defaults to None.
6267
custom_args (Optional[Dict[str, Any]], optional): Passthrough parameters for torch.jit.trace
6368
For available arguments check PyTorch documentation: https://pytorch.org/docs/stable/jit.html#torch.jit.trace
69+
dataloader_max_batch_size: Maximum batch size in the dataloader. Defaults to None.
70+
device_max_batch_size: Maximum batch size that fits on the device. Defaults to None.
6471
"""
6572
model = get_model()
6673

@@ -71,6 +78,14 @@ def export(
7178
profiling_sample = load_samples("profiling_sample", navigator_workspace, batch_dim)[0]
7279
input_metadata = TensorMetadata.from_json(input_metadata)
7380

81+
def expand_batch_dim(tensor, batch_dim, max_batch_size):
82+
if batch_dim is not None and tensor.shape[batch_dim] < max_batch_size:
83+
expand_shape = list(tensor.shape)
84+
expand_shape[batch_dim] = max_batch_size
85+
expanded_tensor = tensor.expand(*expand_shape)
86+
return expanded_tensor
87+
return tensor
88+
7489
dummy_input = {n: torch.from_numpy(val).to(target_device) for n, val in profiling_sample.items()}
7590
dummy_input = input_metadata.unflatten_sample(dummy_input, wrap_input=False)
7691

@@ -80,23 +95,53 @@ def export(
8095
dummy_input = (*dummy_input, {})
8196
*args, kwargs = dummy_input
8297

98+
# Expand batch_dim of tensors to max_batch_size
99+
max_batch_size = device_max_batch_size or dataloader_max_batch_size
100+
if max_batch_size is not None:
101+
args = tuple(
102+
expand_batch_dim(arg, batch_dim, max_batch_size) if isinstance(arg, torch.Tensor) else arg for arg in args
103+
)
104+
kwargs = {
105+
k: expand_batch_dim(v, batch_dim, max_batch_size) if isinstance(v, torch.Tensor) else v
106+
for k, v in kwargs.items()
107+
}
108+
83109
loglevel = logging.WARNING if verbose else logging.ERROR
84-
export_options_kwargs = {}
85-
export_options_kwargs["diagnostic_options"] = torch.onnx.DiagnosticOptions(verbosity_level=loglevel)
86-
if dynamic_shapes:
87-
export_options_kwargs["dynamic_shapes"] = True
88-
export_options = torch.onnx.ExportOptions(**export_options_kwargs)
89110

90111
root_logger = logging.getLogger()
91112
original_loglevel = root_logger.getEffectiveLevel()
92113
root_logger.setLevel(loglevel)
114+
115+
# Dynamic shapes support
116+
117+
# Collect trt profile for min and max shape data
118+
# FIXME: Use a common structure for the min/max shapes
119+
dataloader_trt_profile = TensorRTProfile.from_dict(dataloader_trt_profile)
120+
dynamic_shapes = []
121+
for name, spec_ in dataloader_trt_profile.items():
122+
tensor_metadata = input_metadata.get(name)
123+
if not tensor_metadata:
124+
continue
125+
126+
dynamic_shapes_ = {}
127+
if max_batch_size is not None and max_batch_size > 1 and len(tensor_metadata.shape) > 0:
128+
dynamic_shapes_[0] = torch.export.Dim("batch", min=1, max=max_batch_size)
129+
130+
for idx in range(1, len(spec_.min)):
131+
if spec_.min[idx] != spec_.max[idx]:
132+
dynamic_shapes_[idx] = torch.export.Dim(f"{name}__{idx}", min=spec_.min[idx], max=spec_.max[idx])
133+
134+
dynamic_shapes.append(dynamic_shapes_)
135+
93136
try:
94-
exported_model = torch.onnx.dynamo_export(
137+
exported_model = torch.onnx.export(
95138
model,
96-
*args,
139+
args=tuple(args),
140+
kwargs=kwargs,
97141
**custom_args,
98-
**kwargs,
99-
export_options=export_options,
142+
dynamo=True,
143+
dynamic_shapes=dynamic_shapes,
144+
fallback=False,
100145
)
101146

102147
exported_model_path = pathlib.Path(exported_model_path)

model_navigator/commands/export/torch.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,13 +322,16 @@ def _run(
322322
path: pathlib.Path,
323323
input_metadata: TensorMetadata,
324324
output_metadata: TensorMetadata,
325+
dataloader_trt_profile: TensorRTProfile,
325326
target_device: DeviceKind,
326327
dynamic_axes: Dict[str, Union[Dict[int, str], List[int]]],
327328
dynamo_dynamic_shapes: Optional[bool],
328329
verbose: bool,
329330
custom_args: Dict[str, Any],
330331
model: Any = None,
331332
batch_dim: Optional[int] = None,
333+
dataloader_max_batch_size: Optional[int] = None,
334+
device_max_batch_size: Optional[int] = None,
332335
) -> CommandOutput:
333336
"""Execute command.
334337
@@ -338,6 +341,7 @@ def _run(
338341
opset: ONNX opset
339342
input_metadata: Model inputs metadata
340343
output_metadata: Model outputs metadata
344+
dataloader_trt_profile: Profile from dataloader
341345
target_device: Target device for export - determine the exported model
342346
dynamic_axes: Definition of model inputs dynamic axes
343347
dynamo_dynamic_shapes: Enable dynamo dynamic shapes
@@ -347,6 +351,8 @@ def _run(
347351
custom_args (Optional[Dict[str, Any]], optional): Passthrough parameters for torch.onnx.dynamo_export
348352
Can be used to pass ExportOptions object.
349353
For available arguments check PyTorch documentation: https://pytorch.org/docs/stable/onnx.html#torch.onnx.export
354+
dataloader_max_batch_size: The maximal batch size obtained from datalaoder
355+
device_max_batch_size: The maximal batch size obtained for device
350356
351357
Returns:
352358
CommandOutput object with status
@@ -394,8 +400,10 @@ def on_exit():
394400
"navigator_workspace": workspace.path.as_posix(),
395401
"custom_args": custom_args,
396402
"verbose": verbose,
403+
"dataloader_max_batch_size": dataloader_max_batch_size,
404+
"device_max_batch_size": device_max_batch_size,
405+
"dataloader_trt_profile": dataloader_trt_profile.to_dict(),
397406
}
398-
399407
args = parse_kwargs_to_cmd(kwargs)
400408

401409
context.execute_python_script(

model_navigator/pipelines/builders/torch.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from model_navigator.configuration.model.model_config import ModelConfig, ONNXModelConfig
3131
from model_navigator.pipelines.constants import (
3232
PIPELINE_TORCH_CONVERSION,
33+
PIPELINE_TORCH_DYNAMO_ONNX,
3334
PIPELINE_TORCH_EXPORT,
3435
PIPELINE_TORCH_EXPORTEDPROGRAM,
3536
)
@@ -58,10 +59,7 @@ def torch_export_builder(config: CommonConfig, models_config: Dict[Format, List[
5859
ExecutionUnit(command=CopyModelFromPath, model_config=models_config[Format.ONNX][0])
5960
)
6061
else:
61-
if model_cfg.dynamo_export: # pytype: disable=attribute-error
62-
execution_units.append(ExecutionUnit(command=ExportTorch2DynamoONNX, model_config=model_cfg))
63-
else:
64-
execution_units.append(ExecutionUnit(command=ExportTorch2ONNX, model_config=model_cfg))
62+
execution_units.append(ExecutionUnit(command=ExportTorch2ONNX, model_config=model_cfg))
6563

6664
assert isinstance(model_cfg, ONNXModelConfig)
6765
if model_cfg.graph_surgeon_optimization:
@@ -70,6 +68,24 @@ def torch_export_builder(config: CommonConfig, models_config: Dict[Format, List[
7068
return Pipeline(name=PIPELINE_TORCH_EXPORT, execution_units=execution_units)
7169

7270

71+
def torch_dynamo_onnx_builder(config: CommonConfig, models_config: Dict[Format, List[ModelConfig]]) -> Pipeline:
72+
"""Prepare export steps for pipeline.
73+
74+
Args:
75+
config: A configuration for pipelines
76+
models_config: List of model configs per format
77+
78+
Returns:
79+
Pipeline with steps for export
80+
"""
81+
execution_units: List[ExecutionUnit] = []
82+
for model_cfg in models_config.get(Format.ONNX, []):
83+
if model_cfg.dynamo_export: # pytype: disable=attribute-error
84+
execution_units.append(ExecutionUnit(command=ExportTorch2DynamoONNX, model_config=model_cfg))
85+
86+
return Pipeline(name=PIPELINE_TORCH_DYNAMO_ONNX, execution_units=execution_units)
87+
88+
7389
def torch_exportedprogram_builder(config: CommonConfig, models_config: Dict[Format, List[ModelConfig]]) -> Pipeline:
7490
"""Prepare export steps for pipeline.
7591

model_navigator/pipelines/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@
2727
PIPELINE_TORCH_CONVERSION = "PyTorch Conversion"
2828
PIPELINE_TORCH_EXPORT = "PyTorch Export"
2929
PIPELINE_TORCH_EXPORTEDPROGRAM = "PyTorch ExportedProgram Export"
30+
PIPELINE_TORCH_DYNAMO_ONNX = "PyTorch Dynamo ONNX Export"
3031
PIPELINE_VERIFY_MODELS = "Verify Models"

model_navigator/torch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
verify_builder,
4646
)
4747
from model_navigator.pipelines.builders.find_device_max_batch_size import find_device_max_batch_size_builder
48-
from model_navigator.pipelines.builders.torch import torch_exportedprogram_builder
48+
from model_navigator.pipelines.builders.torch import torch_dynamo_onnx_builder, torch_exportedprogram_builder
4949
from model_navigator.pipelines.wrappers.optimize import optimize_pipeline
5050
from model_navigator.runners.base import NavigatorRunner
5151
from model_navigator.runners.utils import default_runners, filter_runners
@@ -140,6 +140,7 @@ def optimize(
140140
torch_export_builder,
141141
find_device_max_batch_size_builder,
142142
torch_exportedprogram_builder,
143+
torch_dynamo_onnx_builder,
143144
torch_conversion_builder,
144145
torch_tensorrt_conversion_builder,
145146
tensorrt_conversion_builder,

0 commit comments

Comments
 (0)