Skip to content

Commit 5054c7a

Browse files
authored
Add to result heaps during profile (#503)
* Moved preamble for compile_and_sort into __init__ * Adding to heaps during profile * Fixing L0 issues * Renaming add_run_config_measurement to _add_rcm_to_results * Moving add_rcm from model to metrics * Getting closer. Still have unit tests failing * All unit tests passing * Refactoring __init__ * Fixing L0_state_manager * Creating empty model list if calling result_manager from report * Revert change * Fixing minor nits * Refactoring init * Fixing report manager unit test * Fixing result manager unit test * Add public get methods to RCRC * Adding back in exception if model is not found in checkpoint * Adding in exception for models present in checkpoint. Updated report manager units tests * Mocked check for models. All unit tests passing * Fixing directory check issue in report manager unit testing * Creating and removing test directory * Fixing report manager unit testing the right way! * Fixes based on Tim's review * Removing uneeded method * Changing Mocks to be CCRs
1 parent 2bafb26 commit 5054c7a

File tree

8 files changed

+192
-113
lines changed

8 files changed

+192
-113
lines changed

model_analyzer/analyzer.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,6 @@ def profile(self, client: TritonClient, gpus: List[GPUDevice], mode: str,
113113
logger.info("")
114114

115115
if not self._config.skip_summary_reports:
116-
# TODO: TMA-792: This won't be needed once the Results class is used in profile
117-
self._analyze_models()
118-
119116
self._create_summary_tables(verbose)
120117
self._create_summary_reports(mode)
121118
logger.info(self._get_report_command_help_string())
@@ -150,7 +147,6 @@ def analyze(self, mode: str, verbose: bool):
150147
self._result_manager)
151148

152149
# Create result tables, put top results and get stats
153-
self._result_manager.compile_and_sort_results()
154150
self._report_manager.create_summaries()
155151
self._report_manager.export_summaries()
156152

@@ -252,13 +248,6 @@ def _profile_models(self):
252248
finally:
253249
self._state_manager.save_checkpoint()
254250

255-
def _analyze_models(self):
256-
# TODO: TMA-792: Until we get rid of analysis we need to copy some values from profile
257-
self._config._fields["analysis_models"] = self._config._fields[
258-
"profile_models"]
259-
260-
self._result_manager.compile_and_sort_results()
261-
262251
def _create_summary_tables(self, verbose: bool):
263252
self._result_table_manager = ResultTableManager(self._config,
264253
self._result_manager)

model_analyzer/result/result_manager.py

Lines changed: 104 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,14 @@
1919

2020
from .result_heap import ResultHeap
2121
from .run_config_result_comparator import RunConfigResultComparator
22+
from .run_config_measurement import RunConfigMeasurement
2223
from .run_config_result import RunConfigResult
2324
from .results import Results
2425

26+
from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
27+
from model_analyzer.config.input.config_command_analyze import ConfigCommandAnalyze
28+
from model_analyzer.config.input.config_command_report import ConfigCommandReport
29+
2530
from collections import defaultdict
2631

2732

@@ -35,7 +40,7 @@ def __init__(self, config, state_manager):
3540
"""
3641
Parameters
3742
----------
38-
config :ConfigCommandProfile
43+
config :ConfigCommandProfile/ConfigCommandReport
3944
the model analyzer config
4045
state_manager: AnalyzerStateManager
4146
The object that allows control and update of state
@@ -44,13 +49,15 @@ def __init__(self, config, state_manager):
4449
self._config = config
4550
self._state_manager = state_manager
4651

47-
if state_manager.starting_fresh_run():
48-
self._init_state()
49-
5052
# Data structures for sorting results
5153
self._per_model_sorted_results = defaultdict(ResultHeap)
5254
self._across_model_sorted_results = ResultHeap()
5355

56+
if state_manager.starting_fresh_run():
57+
self._init_state()
58+
59+
self._complete_setup()
60+
5461
def get_model_names(self):
5562
"""
5663
Returns a list of model names that have sorted results
@@ -98,46 +105,31 @@ def add_server_data(self, data):
98105
self._state_manager.set_state_variable('ResultManager.server_only_data',
99106
data)
100107

101-
def add_run_config_measurement(self, run_config, run_config_measurement):
108+
def add_run_config_measurement(
109+
self, run_config, run_config_measurement: RunConfigMeasurement):
102110
"""
103-
This function adds model inference
104-
measurements to the required result
105-
106-
Parameters
107-
----------
108-
run_config : RunConfig
109-
Contains the parameters used to generate the measurment
110-
run_config_measurement: RunConfigMeasurement
111-
the measurement to be added
111+
Add measurement to individual result heap,
112+
global result heap and results class
112113
"""
114+
model_name = run_config.models_name()
113115

114-
# Get reference to results state and modify it
115-
results = self._state_manager.get_state_variable(
116-
'ResultManager.results')
117-
118-
results.add_run_config_measurement(run_config, run_config_measurement)
119-
120-
# Use set_state_variable to record that state may have been changed
121-
self._state_manager.set_state_variable(name='ResultManager.results',
122-
value=results)
116+
run_config_result = RunConfigResult(
117+
model_name=model_name,
118+
run_config=run_config,
119+
comparator=self._run_comparators[model_name],
120+
constraints=self._run_constraints[model_name])
123121

124-
def compile_and_sort_results(self):
125-
"""
126-
Collects objectives and constraints for
127-
each model, constructs results from the
128-
measurements obtained, and sorts and
129-
filters them according to constraints
130-
and objectives.
131-
"""
122+
run_config_measurement.set_metric_weightings(
123+
self._run_comparators[model_name].get_metric_weights())
132124

133-
self._create_concurrent_analysis_model_name()
125+
run_config_measurement.set_model_config_weighting(
126+
self._run_comparators[model_name].get_model_weights())
134127

135-
if self._analyzing_models_concurrently():
136-
self._setup_for_concurrent_analysis()
137-
else:
138-
self._setup_for_sequential_analysis()
128+
self._add_rcm_to_results(run_config, run_config_measurement)
129+
run_config_result.add_run_config_measurement(run_config_measurement)
139130

140-
self._add_results_to_heaps()
131+
self._per_model_sorted_results[model_name].add_result(run_config_result)
132+
self._across_model_sorted_results.add_result(run_config_result)
141133

142134
def get_model_configs_run_config_measurements(self, model_variants_name):
143135
"""
@@ -252,6 +244,44 @@ def _init_state(self):
252244
self._state_manager.set_state_variable('ResultManager.server_only_data',
253245
{})
254246

247+
def _complete_setup(self):
248+
# The Report subcommand can init, but nothing needs to be done
249+
if isinstance(self._config, ConfigCommandProfile):
250+
self._complete_profile_setup()
251+
elif isinstance(self._config, ConfigCommandAnalyze):
252+
self._complete_analyze_setup()
253+
elif isinstance(self._config, ConfigCommandReport):
254+
pass
255+
else:
256+
raise TritonModelAnalyzerException(
257+
f"Expected config of type ConfigCommandProfile/ConfigCommandAnalyze/ConfigCommandReport,"
258+
f" got {type(self._config)}.")
259+
260+
def _complete_profile_setup(self):
261+
#TODO: TMA-792: Until we get rid of analysis we need this
262+
self._config._fields["analysis_models"] = self._config._fields[
263+
"profile_models"]
264+
265+
self._create_concurrent_analysis_model_name()
266+
267+
if self._config.run_config_profile_models_concurrently_enable:
268+
self._setup_for_concurrent_analysis()
269+
else:
270+
self._setup_for_sequential_analysis()
271+
272+
self._add_results_to_heaps()
273+
274+
def _complete_analyze_setup(self):
275+
self._create_concurrent_analysis_model_name()
276+
277+
if self._analyzing_models_concurrently():
278+
self._setup_for_concurrent_analysis()
279+
else:
280+
self._setup_for_sequential_analysis()
281+
282+
self._check_for_models_in_checkpoint()
283+
self._add_results_to_heaps()
284+
255285
def _create_concurrent_analysis_model_name(self):
256286
analysis_model_names = [
257287
model.model_name() for model in self._config.analysis_models
@@ -310,6 +340,39 @@ def _setup_for_sequential_analysis(self):
310340
for model in self._config.analysis_models
311341
}
312342

343+
def _add_rcm_to_results(self, run_config, run_config_measurement):
344+
"""
345+
This function adds model inference
346+
measurements to the required result
347+
348+
Parameters
349+
----------
350+
run_config : RunConfig
351+
Contains the parameters used to generate the measurment
352+
run_config_measurement: RunConfigMeasurement
353+
the measurement to be added
354+
"""
355+
356+
# Get reference to results state and modify it
357+
results = self._state_manager.get_state_variable(
358+
'ResultManager.results')
359+
360+
results.add_run_config_measurement(run_config, run_config_measurement)
361+
362+
# Use set_state_variable to record that state may have been changed
363+
self._state_manager.set_state_variable(name='ResultManager.results',
364+
value=results)
365+
366+
def _check_for_models_in_checkpoint(self):
367+
results = self._state_manager.get_state_variable(
368+
'ResultManager.results')
369+
370+
for model_name in self._analysis_model_names:
371+
if not results.get_model_measurements_dict(model_name):
372+
raise TritonModelAnalyzerException(
373+
f"The model {model_name} was not found in the loaded checkpoint."
374+
)
375+
313376
def _add_results_to_heaps(self):
314377
"""
315378
Construct and add results to individual result heaps
@@ -321,10 +384,9 @@ def _add_results_to_heaps(self):
321384
for model_name in self._analysis_model_names:
322385
model_measurements = results.get_model_measurements_dict(model_name)
323386

387+
# Only add in models that exist in the checkpoint
324388
if not model_measurements:
325-
raise TritonModelAnalyzerException(
326-
f"The model {model_name} was not found in the loaded checkpoint."
327-
)
389+
continue
328390

329391
for (run_config,
330392
run_config_measurements) in model_measurements.values():
@@ -336,10 +398,10 @@ def _add_results_to_heaps(self):
336398

337399
for run_config_measurement in run_config_measurements.values():
338400
run_config_measurement.set_metric_weightings(
339-
self._run_comparators[model_name]._metric_weights)
401+
self._run_comparators[model_name].get_metric_weights())
340402

341403
run_config_measurement.set_model_config_weighting(
342-
self._run_comparators[model_name]._model_weights)
404+
self._run_comparators[model_name].get_model_weights())
343405

344406
run_config_result.add_run_config_measurement(
345407
run_config_measurement)

model_analyzer/result/run_config_result_comparator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ def __init__(self, metric_objectives_list):
4242
# TODO-TMA-571: Need to add support for model weighting
4343
self._model_weights.append(1)
4444

45+
def get_metric_weights(self):
46+
return self._metric_weights
47+
48+
def get_model_weights(self):
49+
return self._model_weights
50+
4551
def is_better_than(self, run_config_result1, run_config_result2):
4652
"""
4753
Aggregates and compares the score for two RunConfigResults

tests/test_analyzer.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,14 @@ def mock_get_state_variable(self, name):
5757
def mock_get_list_of_models(self):
5858
return ['model1']
5959

60-
@patch.multiple(
61-
f'{AnalyzerStateManager.__module__}.AnalyzerStateManager',
62-
get_state_variable=mock_get_state_variable,
63-
exiting=lambda _: False
64-
)
65-
@patch.multiple(
66-
f'{Analyzer.__module__}.Analyzer',
67-
_create_metrics_manager=MagicMock(),
68-
_create_model_manager=MagicMock(),
69-
_get_server_only_metrics=MagicMock(),
70-
_analyze_models=MagicMock(),
71-
_profile_models=MagicMock()
72-
)
60+
@patch.multiple(f'{AnalyzerStateManager.__module__}.AnalyzerStateManager',
61+
get_state_variable=mock_get_state_variable,
62+
exiting=lambda _: False)
63+
@patch.multiple(f'{Analyzer.__module__}.Analyzer',
64+
_create_metrics_manager=MagicMock(),
65+
_create_model_manager=MagicMock(),
66+
_get_server_only_metrics=MagicMock(),
67+
_profile_models=MagicMock())
7368
def test_profile_skip_summary_reports(self, **mocks):
7469
"""
7570
Tests when the skip_summary_reports config option is turned on,
@@ -81,18 +76,16 @@ def test_profile_skip_summary_reports(self, **mocks):
8176
args = [
8277
'model-analyzer', 'profile', '--model-repository', '/tmp',
8378
'--profile-models', 'model1', '--config-file', '/tmp/my_config.yml',
84-
'--checkpoint-directory', '/tmp/my_checkpoints', '--skip-summary-reports'
79+
'--checkpoint-directory', '/tmp/my_checkpoints',
80+
'--skip-summary-reports'
8581
]
8682
config = evaluate_mock_config(args, '', subcommand="profile")
8783
state_manager = AnalyzerStateManager(config, None)
8884
analyzer = Analyzer(config,
8985
None,
9086
state_manager,
9187
checkpoint_required=False)
92-
analyzer.profile(client=None,
93-
gpus=None,
94-
mode=None,
95-
verbose=False)
88+
analyzer.profile(client=None, gpus=None, mode=None, verbose=False)
9689

9790
path = os.getcwd()
9891
self.assertFalse(os.path.exists(os.path.join(path, "plots")))
@@ -155,6 +148,9 @@ def mock_top_n_results(self, model_name=None, n=-1):
155148
RunConfigResult("fake_model_name", rc3, MagicMock())
156149
]
157150

151+
def mock_check_for_models_in_checkpoint(self):
152+
return True
153+
158154
@patch(
159155
'model_analyzer.config.input.config_command_analyze.file_path_validator',
160156
lambda _: ConfigStatus(status=CONFIG_PARSER_SUCCESS))
@@ -163,6 +159,9 @@ def mock_top_n_results(self, model_name=None, n=-1):
163159
lambda _: None)
164160
@patch('model_analyzer.result.result_manager.ResultManager.top_n_results',
165161
mock_top_n_results)
162+
@patch(
163+
'model_analyzer.result.result_manager.ResultManager._check_for_models_in_checkpoint',
164+
mock_check_for_models_in_checkpoint)
166165
def test_get_report_command_help_string(self):
167166
"""
168167
Tests that the member function returning the report command help string

tests/test_plot_manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,6 @@ def _create_single_model_result_manager(self):
103103

104104
self._single_model_config = config
105105

106-
self._single_model_result_manager.compile_and_sort_results()
107-
108106
def _create_multi_model_result_manager(self):
109107
args = [
110108
'model-analyzer', 'analyze', '-f', 'config.yml',
@@ -123,8 +121,6 @@ def _create_multi_model_result_manager(self):
123121

124122
self._multi_model_config = config
125123

126-
self._multi_model_result_manager.compile_and_sort_results()
127-
128124
def _plot_manager_to_dict(self, plot_manager):
129125
plot_manager_dict = {}
130126
plot_manager_dict['_simple_plots'] = {}

0 commit comments

Comments
 (0)