@@ -67,9 +67,12 @@ def _create_parameters(
6767 "max_token_count" : max_token_count ,
6868 }
6969
70- def _evaluate_config (self , args , yaml_content , subcommand = "profile" ):
70+ def _evaluate_config (
71+ self , args , yaml_content , subcommand = "profile" , numba_available = True
72+ ):
7173 mock_numba = MockNumba (
72- mock_paths = ["model_analyzer.config.input.config_command_profile" ]
74+ mock_paths = ["model_analyzer.config.input.config_command_profile" ],
75+ is_available = numba_available ,
7376 )
7477
7578 mock_config = MockConfig (args , yaml_content )
@@ -109,6 +112,7 @@ def _assert_equality_of_model_configs(self, model_configs, expected_model_config
109112 for model_config , expected_model_config in zip (
110113 model_configs , expected_model_configs
111114 ):
115+ self .assertEqual (expected_model_config .cpu_only (), model_config .cpu_only ())
112116 self .assertEqual (
113117 expected_model_config .model_name (), model_config .model_name ()
114118 )
@@ -1384,6 +1388,32 @@ def test_autofill(self):
13841388 ]
13851389 self ._assert_equality_of_model_configs (model_configs , expected_model_configs )
13861390
1391+ # Test autofill CPU_ONLY. It will only be false if no local gpus are available AND we are not in remote mode
1392+ yaml_content = """
1393+ profile_models:
1394+ - vgg_16_graphdef
1395+ """
1396+ for launch_mode in ["remote" , "c_api" , "docker" , "local" ]:
1397+ for local_gpus_available in [True , False ]:
1398+ new_args = args .copy ()
1399+ new_args .extend (["--triton-launch-mode" , launch_mode ])
1400+ config = self ._evaluate_config (
1401+ new_args , yaml_content , numba_available = local_gpus_available
1402+ )
1403+ model_configs = config .get_all_config ()["profile_models" ]
1404+ expected_cpu_only = not local_gpus_available and launch_mode != "remote"
1405+ expected_model_configs = [
1406+ ConfigModelProfileSpec (
1407+ "vgg_16_graphdef" ,
1408+ cpu_only = expected_cpu_only ,
1409+ parameters = self ._create_parameters (batch_sizes = [1 ]),
1410+ objectives = {"perf_throughput" : 10 },
1411+ )
1412+ ]
1413+ self ._assert_equality_of_model_configs (
1414+ model_configs , expected_model_configs
1415+ )
1416+
13871417 def test_config_shorthands (self ):
13881418 """
13891419 test flags like --latency-budget
0 commit comments