1919
2020from .result_heap import ResultHeap
2121from .run_config_result_comparator import RunConfigResultComparator
22+ from .run_config_measurement import RunConfigMeasurement
2223from .run_config_result import RunConfigResult
2324from .results import Results
2425
26+ from model_analyzer .config .input .config_command_profile import ConfigCommandProfile
27+ from model_analyzer .config .input .config_command_analyze import ConfigCommandAnalyze
28+ from model_analyzer .config .input .config_command_report import ConfigCommandReport
29+
2530from collections import defaultdict
2631
2732
@@ -35,7 +40,7 @@ def __init__(self, config, state_manager):
3540 """
3641 Parameters
3742 ----------
38- config :ConfigCommandProfile
43+ config :ConfigCommandProfile/ConfigCommandReport
3944 the model analyzer config
4045 state_manager: AnalyzerStateManager
4146 The object that allows control and update of state
@@ -44,13 +49,15 @@ def __init__(self, config, state_manager):
4449 self ._config = config
4550 self ._state_manager = state_manager
4651
47- if state_manager .starting_fresh_run ():
48- self ._init_state ()
49-
5052 # Data structures for sorting results
5153 self ._per_model_sorted_results = defaultdict (ResultHeap )
5254 self ._across_model_sorted_results = ResultHeap ()
5355
56+ if state_manager .starting_fresh_run ():
57+ self ._init_state ()
58+
59+ self ._complete_setup ()
60+
5461 def get_model_names (self ):
5562 """
5663 Returns a list of model names that have sorted results
@@ -98,46 +105,31 @@ def add_server_data(self, data):
98105 self ._state_manager .set_state_variable ('ResultManager.server_only_data' ,
99106 data )
100107
101- def add_run_config_measurement (self , run_config , run_config_measurement ):
108+ def add_run_config_measurement (
109+ self , run_config , run_config_measurement : RunConfigMeasurement ):
102110 """
103- This function adds model inference
104- measurements to the required result
105-
106- Parameters
107- ----------
108- run_config : RunConfig
109- Contains the parameters used to generate the measurment
110- run_config_measurement: RunConfigMeasurement
111- the measurement to be added
111+ Add measurement to individual result heap,
112+ global result heap and results class
112113 """
114+ model_name = run_config .models_name ()
113115
114- # Get reference to results state and modify it
115- results = self ._state_manager .get_state_variable (
116- 'ResultManager.results' )
117-
118- results .add_run_config_measurement (run_config , run_config_measurement )
119-
120- # Use set_state_variable to record that state may have been changed
121- self ._state_manager .set_state_variable (name = 'ResultManager.results' ,
122- value = results )
116+ run_config_result = RunConfigResult (
117+ model_name = model_name ,
118+ run_config = run_config ,
119+ comparator = self ._run_comparators [model_name ],
120+ constraints = self ._run_constraints [model_name ])
123121
124- def compile_and_sort_results (self ):
125- """
126- Collects objectives and constraints for
127- each model, constructs results from the
128- measurements obtained, and sorts and
129- filters them according to constraints
130- and objectives.
131- """
122+ run_config_measurement .set_metric_weightings (
123+ self ._run_comparators [model_name ].get_metric_weights ())
132124
133- self ._create_concurrent_analysis_model_name ()
125+ run_config_measurement .set_model_config_weighting (
126+ self ._run_comparators [model_name ].get_model_weights ())
134127
135- if self ._analyzing_models_concurrently ():
136- self ._setup_for_concurrent_analysis ()
137- else :
138- self ._setup_for_sequential_analysis ()
128+ self ._add_rcm_to_results (run_config , run_config_measurement )
129+ run_config_result .add_run_config_measurement (run_config_measurement )
139130
140- self ._add_results_to_heaps ()
131+ self ._per_model_sorted_results [model_name ].add_result (run_config_result )
132+ self ._across_model_sorted_results .add_result (run_config_result )
141133
142134 def get_model_configs_run_config_measurements (self , model_variants_name ):
143135 """
@@ -252,6 +244,44 @@ def _init_state(self):
252244 self ._state_manager .set_state_variable ('ResultManager.server_only_data' ,
253245 {})
254246
247+ def _complete_setup (self ):
248+ # The Report subcommand can init, but nothing needs to be done
249+ if isinstance (self ._config , ConfigCommandProfile ):
250+ self ._complete_profile_setup ()
251+ elif isinstance (self ._config , ConfigCommandAnalyze ):
252+ self ._complete_analyze_setup ()
253+ elif isinstance (self ._config , ConfigCommandReport ):
254+ pass
255+ else :
256+ raise TritonModelAnalyzerException (
257+ f"Expected config of type ConfigCommandProfile/ConfigCommandAnalyze/ConfigCommandReport,"
258+ f" got { type (self ._config )} ." )
259+
260+ def _complete_profile_setup (self ):
261+ #TODO: TMA-792: Until we get rid of analysis we need this
262+ self ._config ._fields ["analysis_models" ] = self ._config ._fields [
263+ "profile_models" ]
264+
265+ self ._create_concurrent_analysis_model_name ()
266+
267+ if self ._config .run_config_profile_models_concurrently_enable :
268+ self ._setup_for_concurrent_analysis ()
269+ else :
270+ self ._setup_for_sequential_analysis ()
271+
272+ self ._add_results_to_heaps ()
273+
274+ def _complete_analyze_setup (self ):
275+ self ._create_concurrent_analysis_model_name ()
276+
277+ if self ._analyzing_models_concurrently ():
278+ self ._setup_for_concurrent_analysis ()
279+ else :
280+ self ._setup_for_sequential_analysis ()
281+
282+ self ._check_for_models_in_checkpoint ()
283+ self ._add_results_to_heaps ()
284+
255285 def _create_concurrent_analysis_model_name (self ):
256286 analysis_model_names = [
257287 model .model_name () for model in self ._config .analysis_models
@@ -310,6 +340,39 @@ def _setup_for_sequential_analysis(self):
310340 for model in self ._config .analysis_models
311341 }
312342
343+ def _add_rcm_to_results (self , run_config , run_config_measurement ):
344+ """
345+ This function adds model inference
346+ measurements to the required result
347+
348+ Parameters
349+ ----------
350+ run_config : RunConfig
351+ Contains the parameters used to generate the measurment
352+ run_config_measurement: RunConfigMeasurement
353+ the measurement to be added
354+ """
355+
356+ # Get reference to results state and modify it
357+ results = self ._state_manager .get_state_variable (
358+ 'ResultManager.results' )
359+
360+ results .add_run_config_measurement (run_config , run_config_measurement )
361+
362+ # Use set_state_variable to record that state may have been changed
363+ self ._state_manager .set_state_variable (name = 'ResultManager.results' ,
364+ value = results )
365+
366+ def _check_for_models_in_checkpoint (self ):
367+ results = self ._state_manager .get_state_variable (
368+ 'ResultManager.results' )
369+
370+ for model_name in self ._analysis_model_names :
371+ if not results .get_model_measurements_dict (model_name ):
372+ raise TritonModelAnalyzerException (
373+ f"The model { model_name } was not found in the loaded checkpoint."
374+ )
375+
313376 def _add_results_to_heaps (self ):
314377 """
315378 Construct and add results to individual result heaps
@@ -321,10 +384,9 @@ def _add_results_to_heaps(self):
321384 for model_name in self ._analysis_model_names :
322385 model_measurements = results .get_model_measurements_dict (model_name )
323386
387+ # Only add in models that exist in the checkpoint
324388 if not model_measurements :
325- raise TritonModelAnalyzerException (
326- f"The model { model_name } was not found in the loaded checkpoint."
327- )
389+ continue
328390
329391 for (run_config ,
330392 run_config_measurements ) in model_measurements .values ():
@@ -336,10 +398,10 @@ def _add_results_to_heaps(self):
336398
337399 for run_config_measurement in run_config_measurements .values ():
338400 run_config_measurement .set_metric_weightings (
339- self ._run_comparators [model_name ]._metric_weights )
401+ self ._run_comparators [model_name ].get_metric_weights () )
340402
341403 run_config_measurement .set_model_config_weighting (
342- self ._run_comparators [model_name ]._model_weights )
404+ self ._run_comparators [model_name ].get_model_weights () )
343405
344406 run_config_result .add_run_config_measurement (
345407 run_config_measurement )
0 commit comments