2222from matplotlib import patches as mpatches
2323
2424from model_analyzer .constants import LOGGER_NAME
25+ from model_analyzer .perf_analyzer .perf_config import PerfAnalyzerConfig
2526from model_analyzer .record .metrics_manager import MetricsManager
2627
2728logging .getLogger ("matplotlib" ).setLevel (logging .ERROR )
@@ -119,42 +120,15 @@ def add_run_config_measurement(self, run_config_measurement):
119120 """
120121
121122 # TODO-TMA-568: This needs to be updated because there will be multiple model configs
122- if (
123- "concurrency-range" in run_config_measurement .model_specific_pa_params ()[0 ]
124- and run_config_measurement .model_specific_pa_params ()[0 ][
125- "concurrency-range"
126- ]
127- ):
128- self ._data ["concurrency" ].append (
129- run_config_measurement .model_specific_pa_params ()[0 ][
130- "concurrency-range"
131- ]
132- )
133-
134- if (
135- "request-rate-range" in run_config_measurement .model_specific_pa_params ()[0 ]
136- and run_config_measurement .model_specific_pa_params ()[0 ][
137- "request-rate-range"
138- ]
139- ):
140- self ._data ["request_rate" ].append (
141- run_config_measurement .model_specific_pa_params ()[0 ][
142- "request-rate-range"
143- ]
144- )
145-
146- # FIXME 1772 -- clean this up??
147- if (
148- "request-intervals" in run_config_measurement .model_specific_pa_params ()[0 ]
149- and run_config_measurement .model_specific_pa_params ()[0 ][
150- "request-intervals"
151- ]
152- ):
153- self ._data ["request-intervals" ].append (
154- run_config_measurement .model_specific_pa_params ()[0 ][
155- "request-intervals"
156- ]
157- )
123+ for load_arg in PerfAnalyzerConfig .get_inference_load_args ():
124+ if (
125+ load_arg in run_config_measurement .model_specific_pa_params ()[0 ]
126+ and run_config_measurement .model_specific_pa_params ()[0 ][load_arg ]
127+ ):
128+ data_key = self ._get_data_key_from_load_arg (load_arg )
129+ self ._data [data_key ].append (
130+ run_config_measurement .model_specific_pa_params ()[0 ][load_arg ]
131+ )
158132
159133 self ._data ["perf_throughput" ].append (
160134 run_config_measurement .get_non_gpu_metric_value (tag = "perf_throughput" )
@@ -177,9 +151,9 @@ def plot_data(self):
177151 """
178152
179153 # Update the x-axis plot title
180- if "request-intervals " in self ._data and self ._data ["request-intervals " ][0 ]:
154+ if "request_intervals " in self ._data and self ._data ["request_intervals " ][0 ]:
181155 self ._ax_latency .set_xlabel ("Request Intervals File" )
182- sort_indices_key = "request-intervals "
156+ sort_indices_key = "request_intervals "
183157 elif "request_rate" in self ._data and self ._data ["request_rate" ][0 ]:
184158 self ._ax_latency .set_xlabel ("Client Request Rate" )
185159 sort_indices_key = "request_rate"
@@ -274,3 +248,18 @@ def save(self, filepath):
274248 """
275249
276250 self ._fig .savefig (os .path .join (filepath , self ._name ))
251+
252+ def _get_data_key_from_load_arg (self , load_arg ):
253+ """
254+ Gets the key into _data corresponding with the input load arg
255+
256+ For example, the load arg "request-rate-range" has the key "request_rate"
257+ """
258+ # Check if '-range' exists at the end of the input string and remove it
259+ if load_arg .endswith ("-range" ):
260+ load_arg = load_arg [:- 6 ]
261+
262+ # Replace any '-' with '_' in the remaining string
263+ data_key = load_arg .replace ("-" , "_" )
264+
265+ return data_key
0 commit comments