Skip to content

Commit a6aaf58

Browse files
Move after the find max batch size
1 parent e1aada0 commit a6aaf58

File tree

6 files changed

+26
-28
lines changed

6 files changed

+26
-28
lines changed

model_navigator/pipelines/builders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from model_navigator.pipelines.pipeline import Pipeline # noqa: F401
2727

2828
if is_torch_available():
29-
from .torch import torch_conversion_builder, torch_export_builder # noqa: F401
29+
from .torch import torch_conversion_builder, torch_export_builder, torch_export_onnx_builder # noqa: F401
3030
from .torch_tensorrt import torch_tensorrt_conversion_builder # noqa: F401
3131

3232
if is_tf_available():

model_navigator/pipelines/builders/find_device_max_batch_size.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -95,24 +95,6 @@ def _find_max_batch_size_config_for_torch(config: CommonConfig, models_config: D
9595
)
9696
configurations.append(mbs_config)
9797

98-
for model_cfg in models_config.get(Format.ONNX, []):
99-
runner_cls = {
100-
DeviceKind.CUDA: OnnxrtCUDARunner,
101-
DeviceKind.CPU: OnnxrtCPURunner,
102-
}[config.target_device]
103-
104-
if model_cfg.format != runner_cls.format():
105-
raise ModelNavigatorRuntimeError(
106-
f"Model config format `{model_cfg.format}` does not match `{runner_cls.format()}`."
107-
)
108-
mbs_config = FindMaxBatchSizeConfig(
109-
format=Format.ONNX,
110-
model_path=model_cfg.path,
111-
runner_cls=runner_cls,
112-
reproduction_scripts_dir=pathlib.Path(model_cfg.key),
113-
)
114-
configurations.append(mbs_config)
115-
11698
return configurations
11799

118100

model_navigator/pipelines/builders/torch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,20 @@ def torch_export_builder(config: CommonConfig, models_config: Dict[Format, List[
5151
for model_cfg in models_config.get(Format.TORCHSCRIPT, []):
5252
execution_units.append(ExecutionUnit(command=ExportTorch2TorchScript, model_config=model_cfg))
5353

54+
return Pipeline(name=PIPELINE_TORCH_EXPORT, execution_units=execution_units)
55+
56+
57+
def torch_export_onnx_builder(config: CommonConfig, models_config: Dict[Format, List[ModelConfig]]) -> Pipeline:
58+
"""Prepare export steps for pipeline.
59+
60+
Args:
61+
config: A configuration for pipelines
62+
models_config: List of model configs per format
63+
64+
Returns:
65+
Pipeline with steps for export
66+
"""
67+
execution_units: List[ExecutionUnit] = []
5468
for model_cfg in models_config.get(Format.ONNX, []):
5569
if model_cfg.parent_path in (None, Format.TORCH):
5670
# If model_path provided in onnx config, copy this onnx instead of exporting.

model_navigator/torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
tensorrt_conversion_builder,
4343
torch_conversion_builder,
4444
torch_export_builder,
45+
torch_export_onnx_builder,
4546
torch_tensorrt_conversion_builder,
4647
verify_builder,
4748
)
@@ -144,6 +145,7 @@ def optimize(
144145
preprocessing_builder,
145146
torch_export_builder,
146147
find_device_max_batch_size_builder,
148+
torch_export_onnx_builder,
147149
torch_exportedprogram_builder,
148150
torch_conversion_builder,
149151
torch_tensorrt_conversion_builder,

tests/unit/base/test_pipelines_builders_find_device_max_batch_size.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,14 @@ def test_find_device_max_batch_size_builder_return_execution_unit_when_torch_fra
8080
assert len(pipeline.execution_units) == 1
8181

8282
execution_unit = pipeline.execution_units[0]
83-
assert len(execution_unit.kwargs["configurations"]) == 3
83+
assert len(execution_unit.kwargs["configurations"]) == 2
8484

8585
configuration = execution_unit.kwargs["configurations"][0]
8686
assert configuration.runner_cls == TorchCUDARunner
8787

8888
configuration = execution_unit.kwargs["configurations"][1]
8989
assert configuration.runner_cls == TorchCompileCUDARunner
9090

91-
configuration = execution_unit.kwargs["configurations"][2]
92-
assert configuration.runner_cls == OnnxrtCUDARunner
93-
9491

9592
def test_find_device_max_batch_size_builder_return_execution_unit_when_tensorflow_framework_is_used():
9693
config = CommonConfig(

tests/unit/torch/test_pipelines_builders_graph_surgeon.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
from model_navigator.configuration.common_config import CommonConfig
1717
from model_navigator.configuration.model.model_config import ONNXModelConfig, TorchScriptModelConfig
1818
from model_navigator.frameworks import Framework
19-
from model_navigator.pipelines.builders.torch import torch_conversion_builder, torch_export_builder
19+
from model_navigator.pipelines.builders.torch import (
20+
torch_conversion_builder,
21+
torch_export_onnx_builder,
22+
)
2023

2124

22-
def test_torch_export_builder_return_graph_surgeon_optimization_when_enabled():
25+
def test_torch_export_onnx_builder_return_graph_surgeon_optimization_when_enabled():
2326
config = CommonConfig(
2427
framework=Framework.TORCH,
2528
dataloader=[{"input_name": [idx]} for idx in range(10)],
@@ -34,12 +37,12 @@ def test_torch_export_builder_return_graph_surgeon_optimization_when_enabled():
3437
models_config = {
3538
Format.ONNX: [ONNXModelConfig(opset=DEFAULT_ONNX_OPSET, graph_surgeon_optimization=True)],
3639
}
37-
pipeline = torch_export_builder(config=config, models_config=models_config)
40+
pipeline = torch_export_onnx_builder(config=config, models_config=models_config)
3841
assert len(pipeline.execution_units) == 2
3942
assert pipeline.execution_units[-1].command == GraphSurgeonOptimize
4043

4144

42-
def test_torch_export_builder_does_not_return_graph_surgeon_optimization_when_disabled():
45+
def test_torch_export_onnx_builder_does_not_return_graph_surgeon_optimization_when_disabled():
4346
config = CommonConfig(
4447
framework=Framework.TORCH,
4548
dataloader=[{"input_name": [idx]} for idx in range(10)],
@@ -54,7 +57,7 @@ def test_torch_export_builder_does_not_return_graph_surgeon_optimization_when_di
5457
models_config = {
5558
Format.ONNX: [ONNXModelConfig(opset=DEFAULT_ONNX_OPSET, graph_surgeon_optimization=False)],
5659
}
57-
pipeline = torch_export_builder(config=config, models_config=models_config)
60+
pipeline = torch_export_onnx_builder(config=config, models_config=models_config)
5861
assert len(pipeline.execution_units) == 1
5962

6063

0 commit comments

Comments
 (0)