1- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
1+ # Copyright (c) 2021-22 , NVIDIA CORPORATION. All rights reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Dict , List
15+ from typing import Dict , List , Optional , Any
1616from model_analyzer .model_analyzer_exceptions \
1717 import TritonModelAnalyzerException
1818import yaml
19+ from argparse import Namespace
1920from .yaml_config_validator import YamlConfigValidator
2021
2122from copy import deepcopy
@@ -76,7 +77,7 @@ def _load_config_file(self, file_path):
7677 config = yaml .safe_load (config_file )
7778 return config
7879
79- def set_config_values (self , args ) :
80+ def set_config_values (self , args : Namespace ) -> None :
8081 """
8182 Set the config values. This function sets all the values for the
8283 config. CLI arguments have the highest priority, then YAML config
@@ -94,41 +95,183 @@ def set_config_values(self, args):
9495 this exception
9596 """
9697
97- # Config file has been specified
98+ yaml_config = self ._load_yaml_config (args )
99+ self ._check_for_illegal_config_settings (args , yaml_config )
100+ self ._set_field_values (args , yaml_config )
101+ self ._preprocess_and_verify_arguments ()
102+ self ._autofill_values ()
103+
104+ def _load_yaml_config (self , args : Namespace ) -> Optional [Dict [str , List ]]:
98105 if 'config_file' in args :
99106 yaml_config = self ._load_config_file (args .config_file )
100107 YamlConfigValidator .validate (yaml_config )
101108 else :
102109 yaml_config = None
103110
111+ return yaml_config
112+
113+ def _check_for_illegal_config_settings (
114+ self , args : Namespace , yaml_config : Optional [Dict [str ,
115+ List ]]) -> None :
116+ self ._check_for_duplicate_profile_models_option (args , yaml_config )
117+ self ._check_for_multi_model_incompatability (args , yaml_config )
118+ self ._check_for_quick_search_incompatability (args , yaml_config )
119+
120+ def _set_field_values (self , args : Namespace ,
121+ yaml_config : Optional [Dict [str , List ]]) -> None :
104122 for key , value in self ._fields .items ():
105123 self ._fields [key ].set_name (key )
106- if key in args :
107- self ._check_for_duplicate_profile_models_option (
108- yaml_config , key )
109- self ._fields [key ].set_value (getattr (args , key ))
110- elif yaml_config is not None and key in yaml_config :
111- self ._fields [key ].set_value (yaml_config [key ])
124+ config_value = self ._get_config_value (key , args , yaml_config )
125+
126+ if config_value :
127+ self ._fields [key ].set_value (config_value )
112128 elif value .default_value () is not None :
113129 self ._fields [key ].set_value (value .default_value ())
114130 elif value .required ():
115131 flags = ', ' .join (value .flags ())
116132 raise TritonModelAnalyzerException (
117133 f'Config for { value .name ()} is not specified. You need to specify it using the YAML config file or using the { flags } flags in CLI.'
118134 )
119- self ._preprocess_and_verify_arguments ()
120- self ._autofill_values ()
121135
122- def _check_for_duplicate_profile_models_option (self ,
123- yaml_config : Dict [str , List ],
124- key : str ) -> None :
125- if yaml_config is not None and key in yaml_config and key == 'profile_models' :
136+ def _get_config_value (
137+ self , key : str , args : Namespace ,
138+ yaml_config : Optional [Dict [str , List ]]) -> Optional [Any ]:
139+ if key in args :
140+ return getattr (args , key )
141+ elif yaml_config is not None and key in yaml_config :
142+ return yaml_config [key ]
143+ else :
144+ return None
145+
146+ def _check_for_duplicate_profile_models_option (
147+ self , args : Namespace , yaml_config : Optional [Dict [str ,
148+ List ]]) -> None :
149+ key_in_args = 'profile_models' in args
150+ key_in_yaml = yaml_config is not None and 'profile_models' in yaml_config
151+
152+ if key_in_args and key_in_yaml :
126153 raise TritonModelAnalyzerException (
127154 f'\n The profile model option is specified on both '
128155 'the CLI (--profile-models) and in the YAML config file.'
129156 '\n Please remove the option from one of the locations and try again'
130157 )
131158
159+ def _check_for_multi_model_incompatability (
160+ self , args : Namespace , yaml_config : Optional [Dict [str ,
161+ List ]]) -> None :
162+ if not self ._get_config_value (
163+ 'run_config_profile_models_concurrently_enable' , args ,
164+ yaml_config ):
165+ return
166+
167+ self ._check_multi_model_search_mode_incompatability (args , yaml_config )
168+
169+ def _check_multi_model_search_mode_incompatability (
170+ self , args : Namespace , yaml_config : Optional [Dict [str ,
171+ List ]]) -> None :
172+ if self ._get_config_value ('run_config_search_mode' , args ,
173+ yaml_config ) != 'quick' :
174+ raise TritonModelAnalyzerException (
175+ f'\n Concurrent profiling of models is only supported in quick search mode.'
176+ '\n Please use quick search mode or disable concurrent model profiling.'
177+ )
178+
179+ def _check_for_quick_search_incompatability (
180+ self , args : Namespace , yaml_config : Optional [Dict [str ,
181+ List ]]) -> None :
182+ if self ._get_config_value ('run_config_search_mode' , args ,
183+ yaml_config ) != 'quick' :
184+ return
185+
186+ self ._check_no_search_disable (args , yaml_config )
187+ self ._check_no_search_values (args , yaml_config )
188+ self ._check_no_global_list_values (args , yaml_config )
189+ self ._check_no_per_model_list_values (args , yaml_config )
190+
191+ def _check_no_search_disable (
192+ self , args : Namespace , yaml_config : Optional [Dict [str ,
193+ List ]]) -> None :
194+ if self ._get_config_value ('run_config_search_disable' , args ,
195+ yaml_config ):
196+ raise TritonModelAnalyzerException (
197+ f'\n Disabling of run config search is not supported in quick search mode.'
198+ '\n Please use brute search mode or remove --run-config-search-disable.'
199+ )
200+
201+ def _check_no_search_values (self , args : Namespace ,
202+ yaml_config : Optional [Dict [str , List ]]) -> None :
203+ max_concurrency = self ._get_config_value (
204+ 'run_config_search_max_concurrency' , args , yaml_config )
205+ min_concurrency = self ._get_config_value (
206+ 'run_config_search_min_concurrency' , args , yaml_config )
207+ max_instance = self ._get_config_value (
208+ 'run_config_search_max_instance_count' , args , yaml_config )
209+ min_instance = self ._get_config_value (
210+ 'run_config_search_min_instance_count' , args , yaml_config )
211+ max_batch_size = self ._get_config_value (
212+ 'run_config_search_max_model_batch_size' , args , yaml_config )
213+ min_batch_size = self ._get_config_value (
214+ 'run_config_search_min_model_batch_size' , args , yaml_config )
215+
216+ if max_concurrency or min_concurrency :
217+ raise TritonModelAnalyzerException (
218+ f'\n Profiling of models in quick search mode is not supported with min/max concurrency search values.'
219+ '\n Please use brute search mode or remove concurrency search values.'
220+ )
221+ if max_instance or min_instance :
222+ raise TritonModelAnalyzerException (
223+ f'\n Profiling of models in quick search mode is not supported with min/max instance search values.'
224+ '\n Please use brute search mode or remove instance search values.'
225+ )
226+ if max_batch_size or min_batch_size :
227+ raise TritonModelAnalyzerException (
228+ f'\n Profiling of models in quick search mode is not supported with min/max batch size search values.'
229+ '\n Please use brute search mode or remove batch size search values.'
230+ )
231+
232+ def _check_no_global_list_values (
233+ self , args : Namespace , yaml_config : Optional [Dict [str ,
234+ List ]]) -> None :
235+ concurrency = self ._get_config_value ('concurrency' , args , yaml_config )
236+ batch_sizes = self ._get_config_value ('batch_sizes' , args , yaml_config )
237+
238+ if concurrency or batch_sizes :
239+ raise TritonModelAnalyzerException (
240+ f'\n Profiling of models in quick search mode is not supported with lists of concurrencies or batch sizes.'
241+ '\n Please use brute search mode or remove concurrency/batch sizes list.'
242+ )
243+
244+ def _check_no_per_model_list_values (
245+ self , args : Namespace , yaml_config : Optional [Dict [str ,
246+ List ]]) -> None :
247+ profile_models = self ._get_config_value ('profile_models' , args ,
248+ yaml_config )
249+
250+ if not profile_models or type (profile_models ) is str or type (
251+ profile_models ) is list :
252+ return
253+
254+ for model in profile_models .values ():
255+ if not 'parameters' in model :
256+ continue
257+
258+ if 'concurrency' in model ['parameters' ] or 'batch size' in model [
259+ 'parameters' ]:
260+ raise TritonModelAnalyzerException (
261+ f'\n Profiling of models in quick search mode is not supported with lists of concurrencies or batch sizes.'
262+ '\n Please use brute search mode or remove concurrency/batch sizes list.'
263+ )
264+
265+ for model in profile_models .values ():
266+ if not 'model_config_parameters' in model :
267+ continue
268+
269+ if 'max_batch_size' in model ['model_config_parameters' ]:
270+ raise TritonModelAnalyzerException (
271+ f'\n Profiling of models in quick search mode is not supported with lists max batch sizes.'
272+ '\n Please use brute search mode or remov max batch size list.'
273+ )
274+
132275 def _preprocess_and_verify_arguments (self ):
133276 """
134277 Enforces some rules on the config.
0 commit comments