Skip to content

Commit 9013d66

Browse files
committed
Use torch.compile in heuristic search for max batch size
1 parent dc2632e commit 9013d66

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
enables dynamic shapes user setup for TorchTensorRT compilation
2222
- new: autocast_dtype added Torch runner configuration to set the dtype for autocast
2323
- new: New version of Onnx Runtime 1.20 for python version >= 3.10
24+
- new: Use `torch.compile` path in heuristic search for max batch size
2425
- change: Removed TensorFlow dependencies for `nav.jax.optimize`
2526
- change: Removed PyTorch dependencies from `nav.profile`
2627
- change: Collect all Python packages in status instead of filtered list

model_navigator/pipelines/builders/find_device_max_batch_size.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from model_navigator.pipelines.pipeline import Pipeline
2929
from model_navigator.runners.onnx import OnnxrtCPURunner, OnnxrtCUDARunner
3030
from model_navigator.runners.tensorflow import TensorFlowSavedModelCPURunner, TensorFlowSavedModelCUDARunner
31-
from model_navigator.runners.torch import TorchCPURunner, TorchCUDARunner
31+
from model_navigator.runners.torch import TorchCompileCPURunner, TorchCompileCUDARunner, TorchCPURunner, TorchCUDARunner
3232
from model_navigator.utils.config_helpers import do_find_device_max_batch_size
3333

3434

@@ -77,22 +77,24 @@ def find_device_max_batch_size_builder(
7777
def _find_max_batch_size_config_for_torch(config: CommonConfig, models_config: Dict[Format, List[ModelConfig]]):
7878
configurations = []
7979
for model_cfg in models_config.get(Format.TORCH, []):
80-
runner_cls = {
81-
DeviceKind.CUDA: TorchCUDARunner,
82-
DeviceKind.CPU: TorchCPURunner,
80+
runners_cls = {
81+
DeviceKind.CUDA: [TorchCUDARunner, TorchCompileCUDARunner],
82+
DeviceKind.CPU: [TorchCPURunner, TorchCompileCPURunner],
8383
}[config.target_device]
8484

85-
if model_cfg.format != runner_cls.format():
86-
raise ModelNavigatorRuntimeError(
87-
f"Model config format `{model_cfg.format}` does not match `{runner_cls.format()}`."
85+
for runner_cls in runners_cls:
86+
if model_cfg.format != runner_cls.format():
87+
raise ModelNavigatorRuntimeError(
88+
f"Model config format `{model_cfg.format}` does not match `{runner_cls.format()}`."
89+
)
90+
mbs_config = FindMaxBatchSizeConfig(
91+
format=Format.TORCH,
92+
model=config.model,
93+
runner_cls=runner_cls,
94+
reproduction_scripts_dir=pathlib.Path(model_cfg.key),
8895
)
89-
mbs_config = FindMaxBatchSizeConfig(
90-
format=Format.TORCH,
91-
model=config.model,
92-
runner_cls=runner_cls,
93-
reproduction_scripts_dir=pathlib.Path(model_cfg.key),
94-
)
95-
configurations.append(mbs_config)
96+
configurations.append(mbs_config)
97+
9698
for model_cfg in models_config.get(Format.ONNX, []):
9799
runner_cls = {
98100
DeviceKind.CUDA: OnnxrtCUDARunner,

tests/unit/base/test_pipelines_builders_find_device_max_batch_size.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from model_navigator.pipelines.builders.find_device_max_batch_size import find_device_max_batch_size_builder
3232
from model_navigator.runners.onnx import OnnxrtCUDARunner
3333
from model_navigator.runners.tensorflow import TensorFlowSavedModelCUDARunner
34-
from model_navigator.runners.torch import TorchCUDARunner
34+
from model_navigator.runners.torch import TorchCompileCUDARunner, TorchCUDARunner
3535

3636

3737
def test_find_device_max_batch_size_builder_return_execution_unit_when_torch_framework_is_used():
@@ -84,12 +84,15 @@ def test_find_device_max_batch_size_builder_return_execution_unit_when_torch_fra
8484
assert len(pipeline.execution_units) == 1
8585

8686
execution_unit = pipeline.execution_units[0]
87-
assert len(execution_unit.kwargs["configurations"]) == 2
87+
assert len(execution_unit.kwargs["configurations"]) == 3
8888

8989
configuration = execution_unit.kwargs["configurations"][0]
9090
assert configuration.runner_cls == TorchCUDARunner
9191

9292
configuration = execution_unit.kwargs["configurations"][1]
93+
assert configuration.runner_cls == TorchCompileCUDARunner
94+
95+
configuration = execution_unit.kwargs["configurations"][2]
9396
assert configuration.runner_cls == OnnxrtCUDARunner
9497

9598

0 commit comments

Comments
 (0)