diff --git a/model_analyzer/config/generate/quick_run_config_generator.py b/model_analyzer/config/generate/quick_run_config_generator.py index 5454765bf..348b2939f 100755 --- a/model_analyzer/config/generate/quick_run_config_generator.py +++ b/model_analyzer/config/generate/quick_run_config_generator.py @@ -513,7 +513,20 @@ def _get_next_perf_analyzer_config( concurrency = self._calculate_concurrency(dimension_values) - perf_config_params = {"batch-size": 1, "concurrency-range": concurrency} + # FIXME -- this path isn't catching the default config + + model_params = model.parameters() + batch_sizes = ( + model_params.get("batch_sizes", [1]) + if isinstance(model_params, dict) + else [1] + ) + assert len(batch_sizes) == 1 + + perf_config_params = { + "batch-size": batch_sizes[0], + "concurrency-range": concurrency, + } perf_analyzer_config.update_config(perf_config_params) perf_analyzer_config.update_config(model.perf_analyzer_flags()) diff --git a/model_analyzer/config/generate/run_config_generator_factory.py b/model_analyzer/config/generate/run_config_generator_factory.py index 0cdcddeb6..7ab2e752b 100755 --- a/model_analyzer/config/generate/run_config_generator_factory.py +++ b/model_analyzer/config/generate/run_config_generator_factory.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from math import log2 from typing import List from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec @@ -165,9 +166,7 @@ def _create_search_config( if model.is_ensemble(): continue - dims = RunConfigGeneratorFactory._get_dimensions_for_model( - model.supports_batching() - ) + dims = RunConfigGeneratorFactory._get_dimensions_for_model(model) dimensions.add_dimensions(index, dims) index += 1 @@ -178,17 +177,32 @@ def _create_search_config( return search_config @staticmethod - def _get_dimensions_for_model(is_batching_supported: bool) -> List[SearchDimension]: - if is_batching_supported: - return RunConfigGeneratorFactory._get_batching_supported_dimensions() + def _get_dimensions_for_model(model: ModelProfileSpec) -> List[SearchDimension]: + model_params = model.parameters() + batch_sizes = ( + model_params.get("batch_sizes", [1]) + if isinstance(model_params, dict) + else [1] + ) + + assert len(batch_sizes) == 1 + if model.supports_batching(): + return RunConfigGeneratorFactory._get_batching_supported_dimensions( + batch_sizes[0] + ) else: return RunConfigGeneratorFactory._get_batching_not_supported_dimensions() @staticmethod - def _get_batching_supported_dimensions() -> List[SearchDimension]: + def _get_batching_supported_dimensions( + client_batch_size: int, + ) -> List[SearchDimension]: + min_dimension = int(log2(client_batch_size)) return [ SearchDimension( - f"max_batch_size", SearchDimension.DIMENSION_TYPE_EXPONENTIAL + f"max_batch_size", + SearchDimension.DIMENSION_TYPE_EXPONENTIAL, + min=min_dimension, ), SearchDimension(f"instance_count", SearchDimension.DIMENSION_TYPE_LINEAR), ] diff --git a/model_analyzer/config/input/config_command.py b/model_analyzer/config/input/config_command.py index 59d3e87ce..ff60c0ee2 100755 --- a/model_analyzer/config/input/config_command.py +++ b/model_analyzer/config/input/config_command.py @@ -233,7 +233,9 @@ def _check_quick_search_no_global_list_values( concurrency = self._get_config_value("concurrency", args, yaml_config) batch_sizes = self._get_config_value("batch_sizes", args, yaml_config) - if concurrency or batch_sizes: + if concurrency or ( + batch_sizes and isinstance(batch_sizes, list) and len(batch_sizes) > 1 + ): raise TritonModelAnalyzerException( f"\nProfiling of models in quick search mode is not supported with lists of concurrencies or batch sizes." "\nPlease use brute search mode or remove concurrency/batch sizes list." @@ -259,9 +261,10 @@ def _check_per_model_parameters(self, profile_models: Dict) -> None: if not "parameters" in model: continue - if ( - "concurrency" in model["parameters"] - or "batch size" in model["parameters"] + if "concurrency" in model["parameters"] or ( + "batch_sizes" in model["parameters"] + and isinstance(model["parameters"]["batch_sizes"], list) + and len(model["parameters"]["batch_sizes"]) > 1 ): raise TritonModelAnalyzerException( f"\nProfiling of models in quick search mode is not supported with lists of concurrencies or batch sizes."