3535)
3636from model_navigator .configuration .common_config import CommonConfig
3737from model_navigator .configuration .device import get_device_kind_from_device_string
38+ from model_navigator .configuration .runner .runner_config import RunnerConfig
3839from model_navigator .core .logger import LOGGER
3940from model_navigator .core .workspace import Workspace
4041from model_navigator .exceptions import (
@@ -180,8 +181,12 @@ def get_runner(
180181 The optimal runner for the optimized model.
181182 """
182183 runtime_result = self .get_best_runtime (strategies = strategies , include_source = include_source , inplace = inplace )
183-
184184 model_config = runtime_result .model_status .model_config
185+
186+ runner_config = None
187+ if hasattr (runtime_result .model_status .model_config , "runner_config" ):
188+ runner_config = runtime_result .model_status .model_config .runner_config # pytype: disable=attribute-error
189+
185190 runner_status = runtime_result .runner_status
186191
187192 if not is_source_format (model_config .format ) and not (self .workspace .path / model_config .path ).exists ():
@@ -199,7 +204,12 @@ def get_runner(
199204 )
200205
201206 return self ._get_runner (
202- model_config .key , runner_status .runner_name , return_type = return_type , device = device , inplace = inplace
207+ model_config .key ,
208+ runner_status .runner_name ,
209+ return_type = return_type ,
210+ device = device ,
211+ inplace = inplace ,
212+ runner_config = runner_config ,
203213 )
204214
205215 def get_best_model_status (
@@ -239,7 +249,13 @@ def is_empty(self) -> bool:
239249 return True
240250
241251 def _get_runner (
242- self , model_key : str , runner_name : str , device : str , return_type : TensorType , inplace : bool = False
252+ self ,
253+ model_key : str ,
254+ runner_name : str ,
255+ device : str ,
256+ return_type : TensorType ,
257+ inplace : bool = False ,
258+ runner_config : Optional [RunnerConfig ] = None ,
243259 ) -> NavigatorRunner :
244260 """Load runner.
245261
@@ -249,6 +265,7 @@ def _get_runner(
249265 return_type: Type of the runner output.
250266 device: Device on which the model has been executed
251267 inplace: Indicate if runner is in inplace mode.
268+ runner_config: Runner configuration.
252269
253270 Raises:
254271 ModelNavigatorNotFoundError when no runner found for provided constraints.
@@ -266,15 +283,23 @@ def _get_runner(
266283 else :
267284 model = self .workspace .path / model_config .path
268285
286+ if runner_config is None :
287+ runner_config = {}
288+
269289 device_kind = get_device_kind_from_device_string (device )
270290 LOGGER .info (f"Creating model `{ model_key } ` on runner `{ runner_name } ` and device `{ device } `" )
291+ # TODO: implement better handling for redundant device argument in _get_runner and runner_config
292+ runner_config_dict = runner_config .to_dict (parse = True ) if runner_config else {}
293+ runner_config_dict ["device" ] = device
294+
271295 return get_runner (runner_name , device_kind )(
272296 model = model ,
273297 input_metadata = self .status .input_metadata ,
274298 output_metadata = self .status .output_metadata ,
275299 return_type = return_type ,
276- device = device ,
300+ # device=device, # TODO: remove redundant device argument and use runner_config
277301 inplace = inplace ,
302+ ** runner_config_dict ,
278303 ) # pytype: disable=not-instantiable
279304
280305 def get_best_runtime (
0 commit comments