Skip to content

Commit 78bf24d

Browse files
Support for passing runner configuration to nav.profile
1 parent 9013d66 commit 78bf24d

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ limitations under the License.
3131
- fix: Removed option from ExportOption removed from Torch 2.5
3232
- fix: Improved preprocessing stage in Torch based runners
3333
- fix: Warn when using autocast with bfloat16 in Torch
34+
- fix: Pass runner configuration to runners in nav.profile
3435

3536
## 0.12.0
3637

model_navigator/configuration/model/model_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _from_dict(cls, data_dict: Dict):
195195
@staticmethod
196196
def _parse_string(parse_func: Callable, val: Optional[str] = None):
197197
"""Parses string with parse_func or returns None if val not provided."""
198-
if val:
198+
if val is not None:
199199
return parse_func(val)
200200
else:
201201
return None

model_navigator/package/package.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636
from model_navigator.configuration.common_config import CommonConfig
3737
from model_navigator.configuration.device import get_device_kind_from_device_string
38+
from model_navigator.configuration.runner.runner_config import RunnerConfig
3839
from model_navigator.core.logger import LOGGER
3940
from model_navigator.core.workspace import Workspace
4041
from model_navigator.exceptions import (
@@ -180,8 +181,12 @@ def get_runner(
180181
The optimal runner for the optimized model.
181182
"""
182183
runtime_result = self.get_best_runtime(strategies=strategies, include_source=include_source, inplace=inplace)
183-
184184
model_config = runtime_result.model_status.model_config
185+
186+
runner_config = None
187+
if hasattr(runtime_result.model_status.model_config, "runner_config"):
188+
runner_config = runtime_result.model_status.model_config.runner_config # pytype: disable=attribute-error
189+
185190
runner_status = runtime_result.runner_status
186191

187192
if not is_source_format(model_config.format) and not (self.workspace.path / model_config.path).exists():
@@ -199,7 +204,12 @@ def get_runner(
199204
)
200205

201206
return self._get_runner(
202-
model_config.key, runner_status.runner_name, return_type=return_type, device=device, inplace=inplace
207+
model_config.key,
208+
runner_status.runner_name,
209+
return_type=return_type,
210+
device=device,
211+
inplace=inplace,
212+
runner_config=runner_config,
203213
)
204214

205215
def get_best_model_status(
@@ -239,7 +249,13 @@ def is_empty(self) -> bool:
239249
return True
240250

241251
def _get_runner(
242-
self, model_key: str, runner_name: str, device: str, return_type: TensorType, inplace: bool = False
252+
self,
253+
model_key: str,
254+
runner_name: str,
255+
device: str,
256+
return_type: TensorType,
257+
inplace: bool = False,
258+
runner_config: Optional[RunnerConfig] = None,
243259
) -> NavigatorRunner:
244260
"""Load runner.
245261
@@ -249,6 +265,7 @@ def _get_runner(
249265
return_type: Type of the runner output.
250266
device: Device on which the model has been executed
251267
inplace: Indicate if runner is in inplace mode.
268+
runner_config: Runner configuration.
252269
253270
Raises:
254271
ModelNavigatorNotFoundError when no runner found for provided constraints.
@@ -266,15 +283,23 @@ def _get_runner(
266283
else:
267284
model = self.workspace.path / model_config.path
268285

286+
if runner_config is None:
287+
runner_config = {}
288+
269289
device_kind = get_device_kind_from_device_string(device)
270290
LOGGER.info(f"Creating model `{model_key}` on runner `{runner_name}` and device `{device}`")
291+
# TODO: implement better handling for redundant device argument in _get_runner and runner_config
292+
runner_config_dict = runner_config.to_dict(parse=True) if runner_config else {}
293+
runner_config_dict["device"] = device
294+
271295
return get_runner(runner_name, device_kind)(
272296
model=model,
273297
input_metadata=self.status.input_metadata,
274298
output_metadata=self.status.output_metadata,
275299
return_type=return_type,
276-
device=device,
300+
# device=device, # TODO: remove redundant device argument and use runner_config
277301
inplace=inplace,
302+
**runner_config_dict,
278303
) # pytype: disable=not-instantiable
279304

280305
def get_best_runtime(

0 commit comments

Comments
 (0)