Skip to content

Commit 71e7a90

Browse files
committed
Clean up code around load args
1 parent 0ebae08 commit 71e7a90

File tree

4 files changed

+55
-40
lines changed

4 files changed

+55
-40
lines changed

model_analyzer/config/generate/model_profile_spec.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ConfigModelProfileSpec,
2323
)
2424
from model_analyzer.device.gpu_device import GPUDevice
25+
from model_analyzer.perf_analyzer.perf_config import PerfAnalyzerConfig
2526
from model_analyzer.triton.client.client import TritonClient
2627
from model_analyzer.triton.model.model_config import ModelConfig
2728

@@ -72,3 +73,12 @@ def supports_dynamic_batching(self) -> bool:
7273
def is_ensemble(self) -> bool:
7374
"""Returns true if the model is an ensemble"""
7475
return "ensemble_scheduling" in self._default_model_config
76+
77+
def is_load_specified(self) -> bool:
78+
"""
79+
Returns true if the model's PA config has specified any of the
80+
inference load args (such as concurrency). Else returns false
81+
"""
82+
load_args = PerfAnalyzerConfig.get_inference_load_args()
83+
pa_flags = self.perf_analyzer_flags()
84+
return any(e in pa_flags for e in load_args)

model_analyzer/config/generate/quick_run_config_generator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,8 +512,7 @@ def _get_next_perf_analyzer_config(
512512

513513
perf_analyzer_config.update_config_from_profile_config(model_name, self._config)
514514

515-
# FIXME 1772 -- use new method in perf_config
516-
if not "request-intervals" in model.perf_analyzer_flags():
515+
if not model.is_load_specified():
517516
concurrency = self._calculate_concurrency(dimension_values)
518517
perf_config_params = {"concurrency-range": concurrency}
519518
perf_analyzer_config.update_config(perf_config_params)

model_analyzer/perf_analyzer/perf_config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,13 @@ class PerfAnalyzerConfig:
9696
"collect-metrics",
9797
]
9898

99+
# Only one of these args can be sent to PA, as each one controls the inference load in a different way
100+
inference_load_args = [
101+
"concurrency-range",
102+
"request-rate-range",
103+
"request-intervals",
104+
]
105+
99106
def __init__(self):
100107
"""
101108
Construct a PerfAnalyzerConfig
@@ -160,6 +167,16 @@ def additive_keys(cls):
160167

161168
return cls.additive_args[:]
162169

170+
@classmethod
171+
def get_inference_load_args(cls):
172+
"""
173+
Returns
174+
-------
175+
list of str
176+
The Perf Analyzer args that control the inference load
177+
"""
178+
return cls.inference_load_args
179+
163180
def update_config(self, params=None):
164181
"""
165182
Allows setting values from a params dict

model_analyzer/plots/detailed_plot.py

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from matplotlib import patches as mpatches
2323

2424
from model_analyzer.constants import LOGGER_NAME
25+
from model_analyzer.perf_analyzer.perf_config import PerfAnalyzerConfig
2526
from model_analyzer.record.metrics_manager import MetricsManager
2627

2728
logging.getLogger("matplotlib").setLevel(logging.ERROR)
@@ -119,42 +120,15 @@ def add_run_config_measurement(self, run_config_measurement):
119120
"""
120121

121122
# TODO-TMA-568: This needs to be updated because there will be multiple model configs
122-
if (
123-
"concurrency-range" in run_config_measurement.model_specific_pa_params()[0]
124-
and run_config_measurement.model_specific_pa_params()[0][
125-
"concurrency-range"
126-
]
127-
):
128-
self._data["concurrency"].append(
129-
run_config_measurement.model_specific_pa_params()[0][
130-
"concurrency-range"
131-
]
132-
)
133-
134-
if (
135-
"request-rate-range" in run_config_measurement.model_specific_pa_params()[0]
136-
and run_config_measurement.model_specific_pa_params()[0][
137-
"request-rate-range"
138-
]
139-
):
140-
self._data["request_rate"].append(
141-
run_config_measurement.model_specific_pa_params()[0][
142-
"request-rate-range"
143-
]
144-
)
145-
146-
# FIXME 1772 -- clean this up??
147-
if (
148-
"request-intervals" in run_config_measurement.model_specific_pa_params()[0]
149-
and run_config_measurement.model_specific_pa_params()[0][
150-
"request-intervals"
151-
]
152-
):
153-
self._data["request-intervals"].append(
154-
run_config_measurement.model_specific_pa_params()[0][
155-
"request-intervals"
156-
]
157-
)
123+
for load_arg in PerfAnalyzerConfig.get_inference_load_args():
124+
if (
125+
load_arg in run_config_measurement.model_specific_pa_params()[0]
126+
and run_config_measurement.model_specific_pa_params()[0][load_arg]
127+
):
128+
data_key = self._get_data_key_from_load_arg(load_arg)
129+
self._data[data_key].append(
130+
run_config_measurement.model_specific_pa_params()[0][load_arg]
131+
)
158132

159133
self._data["perf_throughput"].append(
160134
run_config_measurement.get_non_gpu_metric_value(tag="perf_throughput")
@@ -177,9 +151,9 @@ def plot_data(self):
177151
"""
178152

179153
# Update the x-axis plot title
180-
if "request-intervals" in self._data and self._data["request-intervals"][0]:
154+
if "request_intervals" in self._data and self._data["request_intervals"][0]:
181155
self._ax_latency.set_xlabel("Request Intervals File")
182-
sort_indices_key = "request-intervals"
156+
sort_indices_key = "request_intervals"
183157
elif "request_rate" in self._data and self._data["request_rate"][0]:
184158
self._ax_latency.set_xlabel("Client Request Rate")
185159
sort_indices_key = "request_rate"
@@ -274,3 +248,18 @@ def save(self, filepath):
274248
"""
275249

276250
self._fig.savefig(os.path.join(filepath, self._name))
251+
252+
def _get_data_key_from_load_arg(self, load_arg):
253+
"""
254+
Gets the key into _data corresponding with the input load arg
255+
256+
For example, the load arg "request-rate-range" has the key "request_rate"
257+
"""
258+
# Check if '-range' exists at the end of the input string and remove it
259+
if load_arg.endswith("-range"):
260+
load_arg = load_arg[:-6]
261+
262+
# Replace any '-' with '_' in the remaining string
263+
data_key = load_arg.replace("-", "_")
264+
265+
return data_key

0 commit comments

Comments
 (0)