|
28 | 28 | from model_navigator.pipelines.pipeline import Pipeline |
29 | 29 | from model_navigator.runners.onnx import OnnxrtCPURunner, OnnxrtCUDARunner |
30 | 30 | from model_navigator.runners.tensorflow import TensorFlowSavedModelCPURunner, TensorFlowSavedModelCUDARunner |
31 | | -from model_navigator.runners.torch import TorchCPURunner, TorchCUDARunner |
| 31 | +from model_navigator.runners.torch import TorchCompileCPURunner, TorchCompileCUDARunner, TorchCPURunner, TorchCUDARunner |
32 | 32 | from model_navigator.utils.config_helpers import do_find_device_max_batch_size |
33 | 33 |
|
34 | 34 |
|
@@ -77,22 +77,24 @@ def find_device_max_batch_size_builder( |
77 | 77 | def _find_max_batch_size_config_for_torch(config: CommonConfig, models_config: Dict[Format, List[ModelConfig]]): |
78 | 78 | configurations = [] |
79 | 79 | for model_cfg in models_config.get(Format.TORCH, []): |
80 | | - runner_cls = { |
81 | | - DeviceKind.CUDA: TorchCUDARunner, |
82 | | - DeviceKind.CPU: TorchCPURunner, |
| 80 | + runners_cls = { |
| 81 | + DeviceKind.CUDA: [TorchCUDARunner, TorchCompileCUDARunner], |
| 82 | + DeviceKind.CPU: [TorchCPURunner, TorchCompileCPURunner], |
83 | 83 | }[config.target_device] |
84 | 84 |
|
85 | | - if model_cfg.format != runner_cls.format(): |
86 | | - raise ModelNavigatorRuntimeError( |
87 | | - f"Model config format `{model_cfg.format}` does not match `{runner_cls.format()}`." |
| 85 | + for runner_cls in runners_cls: |
| 86 | + if model_cfg.format != runner_cls.format(): |
| 87 | + raise ModelNavigatorRuntimeError( |
| 88 | + f"Model config format `{model_cfg.format}` does not match `{runner_cls.format()}`." |
| 89 | + ) |
| 90 | + mbs_config = FindMaxBatchSizeConfig( |
| 91 | + format=Format.TORCH, |
| 92 | + model=config.model, |
| 93 | + runner_cls=runner_cls, |
| 94 | + reproduction_scripts_dir=pathlib.Path(model_cfg.key), |
88 | 95 | ) |
89 | | - mbs_config = FindMaxBatchSizeConfig( |
90 | | - format=Format.TORCH, |
91 | | - model=config.model, |
92 | | - runner_cls=runner_cls, |
93 | | - reproduction_scripts_dir=pathlib.Path(model_cfg.key), |
94 | | - ) |
95 | | - configurations.append(mbs_config) |
| 96 | + configurations.append(mbs_config) |
| 97 | + |
96 | 98 | for model_cfg in models_config.get(Format.ONNX, []): |
97 | 99 | runner_cls = { |
98 | 100 | DeviceKind.CUDA: OnnxrtCUDARunner, |
|
0 commit comments