|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | from typing import List, Union, Optional |
| 16 | +from copy import deepcopy |
16 | 17 | import sys |
17 | 18 | from model_analyzer.constants import LOGGER_NAME, PA_ERROR_LOG_FILENAME |
18 | 19 | from .model_manager import ModelManager |
|
32 | 33 | from model_analyzer.state.analyzer_state_manager import AnalyzerStateManager |
33 | 34 | from model_analyzer.triton.server.server import TritonServer |
34 | 35 |
|
| 36 | +from model_analyzer.cli.cli import CLI |
| 37 | + |
35 | 38 | from model_analyzer.config.generate.base_model_config_generator import BaseModelConfigGenerator |
36 | 39 |
|
37 | 40 | from .triton.client.client import TritonClient |
@@ -126,17 +129,9 @@ def profile(self, client: TritonClient, gpus: List[GPUDevice], mode: str, |
126 | 129 | if not self._config.skip_summary_reports: |
127 | 130 | self._create_summary_tables(verbose) |
128 | 131 | self._create_summary_reports(mode) |
| 132 | + self._create_detailed_reports(mode) |
129 | 133 |
|
130 | | - # TODO-TMA-650: Detailed reporting not supported for multi-model |
131 | | - if not self._config.run_config_profile_models_concurrently_enable: |
132 | | - for model in self._config.profile_models: |
133 | | - logger.info( |
134 | | - self._get_report_command_help_string( |
135 | | - model.model_name())) |
136 | | - |
137 | | - if self._metrics_manager.encountered_perf_analyzer_error(): |
138 | | - logger.warning(f"Perf Analyzer encountered an error when profiling one or more configurations. " \ |
139 | | - f"See {self._config.export_path}/{PA_ERROR_LOG_FILENAME} for further details.\n") |
| 134 | + self._check_for_perf_analyzer_errors() |
140 | 135 |
|
141 | 136 | def report(self, mode: str) -> None: |
142 | 137 | """ |
@@ -280,18 +275,35 @@ def _get_num_profiled_configs(self): |
280 | 275 | ]) |
281 | 276 |
|
282 | 277 | def _get_report_command_help_string(self, model_name: str) -> str: |
283 | | - top_3_model_config_names = self._get_top_n_model_config_names( |
284 | | - n=3, model_name=model_name) |
| 278 | + top_n_model_config_names = self._get_top_n_model_config_names( |
| 279 | + n=self._config.num_configs_per_model, model_name=model_name) |
285 | 280 | return ( |
286 | 281 | f'To generate detailed reports for the ' |
287 | | - f'{len(top_3_model_config_names)} best {model_name} configurations, run ' |
288 | | - f'`{self._get_report_command_string(top_3_model_config_names)}`') |
| 282 | + f'{len(top_n_model_config_names)} best {model_name} configurations, run ' |
| 283 | + f'`{self._get_report_command_string(top_n_model_config_names)}`') |
| 284 | + |
| 285 | + def _run_report_command(self, model_name: str, mode: str) -> None: |
| 286 | + top_n_model_config_names = self._get_top_n_model_config_names( |
| 287 | + n=self._config.num_configs_per_model, model_name=model_name) |
| 288 | + top_n_string = ','.join(top_n_model_config_names) |
| 289 | + logger.info( |
| 290 | + f'Generating detailed reports for the best configurations {top_n_string}:' |
| 291 | + ) |
| 292 | + |
| 293 | + # [1:] removes 'model-analyzer' from the args |
| 294 | + args = self._get_report_command_string(top_n_model_config_names).split( |
| 295 | + ' ')[1:] |
| 296 | + |
| 297 | + original_profile_config = deepcopy(self._config) |
| 298 | + self._config = self._create_report_config(args) |
| 299 | + self.report(mode) |
| 300 | + self._config = original_profile_config |
289 | 301 |
|
290 | 302 | def _get_report_command_string(self, |
291 | | - top_3_model_config_names: List[str]) -> str: |
| 303 | + top_n_model_config_names: List[str]) -> str: |
292 | 304 | report_command_string = (f'model-analyzer report ' |
293 | 305 | f'--report-model-configs ' |
294 | | - f'{",".join(top_3_model_config_names)}') |
| 306 | + f'{",".join(top_n_model_config_names)}') |
295 | 307 |
|
296 | 308 | if self._config.export_path is not None: |
297 | 309 | report_command_string += (f' --export-path ' |
@@ -336,3 +348,26 @@ def _multiple_models_in_report_model_config(self) -> bool: |
336 | 348 | ] |
337 | 349 |
|
338 | 350 | return len(set(model_names)) > 1 |
| 351 | + |
| 352 | + def _check_for_perf_analyzer_errors(self) -> None: |
| 353 | + if self._metrics_manager.encountered_perf_analyzer_error(): |
| 354 | + logger.warning(f"Perf Analyzer encountered an error when profiling one or more configurations. " \ |
| 355 | + f"See {self._config.export_path}/{PA_ERROR_LOG_FILENAME} for further details.\n") |
| 356 | + |
| 357 | + def _create_detailed_reports(self, mode: str) -> None: |
| 358 | + # TODO-TMA-650: Detailed reporting not supported for multi-model |
| 359 | + if not self._config.run_config_profile_models_concurrently_enable: |
| 360 | + for model in self._config.profile_models: |
| 361 | + if not self._config.skip_detailed_reports: |
| 362 | + self._run_report_command(model.model_name(), mode) |
| 363 | + else: |
| 364 | + logger.info( |
| 365 | + self._get_report_command_help_string( |
| 366 | + model.model_name())) |
| 367 | + |
| 368 | + def _create_report_config(self, args: list) -> ConfigCommandReport: |
| 369 | + config = ConfigCommandReport() |
| 370 | + cli = CLI() |
| 371 | + cli.add_subcommand(cmd='report', help="", config=config) |
| 372 | + cli.parse(args) |
| 373 | + return config |
0 commit comments