Skip to content

Commit 5af535e

Browse files
authored
Add Binary Parameter Search to Brute (#681)
* Initial changes - still need to add parameter support * Moved checking for request rate * ConcurrencySearch to ParameterSearch * Adding request rate binary search to brute * Fixng QL errors * Making is_request_rate a class member function * Fixing check of when BCS can occur
1 parent 9d86539 commit 5af535e

8 files changed

+342
-111
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Optional, Generator, Dict
16+
17+
from .config_generator_interface import ConfigGeneratorInterface
18+
19+
from model_analyzer.config.generate.brute_run_config_generator import BruteRunConfigGenerator
20+
from model_analyzer.config.generate.model_variant_name_manager import ModelVariantNameManager
21+
from model_analyzer.config.run.run_config import RunConfig
22+
from model_analyzer.triton.client.client import TritonClient
23+
from model_analyzer.device.gpu_device import GPUDevice
24+
from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
25+
from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec
26+
from model_analyzer.result.result_manager import ResultManager
27+
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
28+
from model_analyzer.result.parameter_search import ParameterSearch
29+
30+
from model_analyzer.constants import LOGGER_NAME
31+
32+
from copy import deepcopy
33+
34+
import logging
35+
36+
logger = logging.getLogger(LOGGER_NAME)
37+
38+
39+
class BrutePlusBinaryParameterSearchRunConfigGenerator(ConfigGeneratorInterface
40+
):
41+
"""
42+
First run BruteRunConfigGenerator for a brute search, then for
43+
automatic searches use ParameterSearch to perform a binary search
44+
"""
45+
46+
def __init__(self, config: ConfigCommandProfile, gpus: List[GPUDevice],
47+
models: List[ModelProfileSpec], client: TritonClient,
48+
result_manager: ResultManager,
49+
model_variant_name_manager: ModelVariantNameManager):
50+
"""
51+
Parameters
52+
----------
53+
config: ConfigCommandProfile
54+
Profile configuration information
55+
gpus: List of GPUDevices
56+
models: List of ModelProfileSpec
57+
List of models to profile
58+
client: TritonClient
59+
result_manager: ResultManager
60+
The object that handles storing and sorting the results from the perf analyzer
61+
model_variant_name_manager: ModelVariantNameManager
62+
Maps model variants to config names
63+
"""
64+
self._config = config
65+
self._gpus = gpus
66+
self._models = models
67+
self._client = client
68+
self._result_manager = result_manager
69+
self._model_variant_name_manager = model_variant_name_manager
70+
71+
def set_last_results(
72+
self, measurements: List[Optional[RunConfigMeasurement]]) -> None:
73+
self._last_measurement = measurements[-1]
74+
self._rcg.set_last_results(measurements)
75+
76+
def get_configs(self) -> Generator[RunConfig, None, None]:
77+
"""
78+
Returns
79+
-------
80+
RunConfig
81+
The next RunConfig generated by this class
82+
"""
83+
84+
logger.info("")
85+
logger.info("Starting brute mode search")
86+
logger.info("")
87+
yield from self._execute_brute_search()
88+
logger.info("")
89+
logger.info("Done with brute mode search.")
90+
logger.info("")
91+
92+
if self._can_binary_search_top_results():
93+
yield from self._binary_search_over_top_results()
94+
logger.info("")
95+
logger.info(
96+
"Done gathering concurrency sweep measurements for reports")
97+
logger.info("")
98+
99+
def _execute_brute_search(self) -> Generator[RunConfig, None, None]:
100+
self._rcg: ConfigGeneratorInterface = self._create_brute_run_config_generator(
101+
)
102+
103+
yield from self._rcg.get_configs()
104+
105+
def _create_brute_run_config_generator(self) -> BruteRunConfigGenerator:
106+
return BruteRunConfigGenerator(
107+
config=self._config,
108+
gpus=self._gpus,
109+
models=self._models,
110+
client=self._client,
111+
model_variant_name_manager=self._model_variant_name_manager)
112+
113+
def _can_binary_search_top_results(self) -> bool:
114+
for model in self._models:
115+
if model.parameters()['concurrency'] or model.parameters(
116+
)['request_rate']:
117+
return False
118+
119+
return True
120+
121+
def _binary_search_over_top_results(
122+
self) -> Generator[RunConfig, None, None]:
123+
for model_name in self._result_manager.get_model_names():
124+
top_results = self._result_manager.top_n_results(
125+
model_name=model_name,
126+
n=self._config.num_configs_per_model,
127+
include_default=True)
128+
129+
for result in top_results:
130+
run_config = deepcopy(result.run_config())
131+
model_parameters = self._get_model_parameters(model_name)
132+
parameter_search = ParameterSearch(
133+
config=self._config,
134+
model_parameters=model_parameters,
135+
skip_parameter_sweep=True)
136+
for parameter in parameter_search.search_parameters():
137+
run_config = self._set_parameter(run_config,
138+
model_parameters,
139+
parameter)
140+
yield run_config
141+
parameter_search.add_run_config_measurement(
142+
self._last_measurement)
143+
144+
def _get_model_parameters(self, model_name: str) -> Dict:
145+
for model in self._models:
146+
if model_name == model.model_name():
147+
return model.parameters()
148+
149+
return {}
150+
151+
def _set_parameter(self, run_config: RunConfig, model_parameters: Dict,
152+
parameter: int) -> RunConfig:
153+
for model_run_config in run_config.model_run_configs():
154+
perf_config = model_run_config.perf_config()
155+
if self._config.is_request_rate_specified(model_parameters):
156+
perf_config.update_config({'request-rate-range': parameter})
157+
else:
158+
perf_config.update_config({'concurrency-range': parameter})
159+
160+
return run_config

model_analyzer/config/generate/perf_analyzer_config_generator.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .config_generator_interface import ConfigGeneratorInterface
2020
from .generator_utils import GeneratorUtils as utils
2121

22-
from model_analyzer.constants import LOGGER_NAME, THROUGHPUT_MINIMUM_GAIN, THROUGHPUT_MINIMUM_CONSECUTIVE_CONCURRENCY_TRIES, THROUGHPUT_MINIMUM_CONSECUTIVE_BATCH_SIZE_TRIES
22+
from model_analyzer.constants import LOGGER_NAME, THROUGHPUT_MINIMUM_GAIN, THROUGHPUT_MINIMUM_CONSECUTIVE_PARAMETER_TRIES, THROUGHPUT_MINIMUM_CONSECUTIVE_BATCH_SIZE_TRIES
2323
from model_analyzer.perf_analyzer.perf_config import PerfAnalyzerConfig
2424
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
2525

@@ -90,7 +90,7 @@ def __init__(self, cli_config: ConfigCommandProfile, model_name: str,
9090
@staticmethod
9191
def throughput_gain_valid_helper(
9292
throughputs: List[Optional[RunConfigMeasurement]],
93-
min_tries: int = THROUGHPUT_MINIMUM_CONSECUTIVE_CONCURRENCY_TRIES,
93+
min_tries: int = THROUGHPUT_MINIMUM_CONSECUTIVE_PARAMETER_TRIES,
9494
min_gain: float = THROUGHPUT_MINIMUM_GAIN) -> bool:
9595
if len(throughputs) < min_tries:
9696
return True
@@ -159,17 +159,11 @@ def _create_parameter_list(self) -> List[int]:
159159
# The two possible parameters are request rate or concurrency
160160
# Concurrency is the default and will be used unless the user specifies
161161
# request rate, either as a model parameter or a config option
162-
if self._config_specifies_request_rate():
162+
if self._cli_config.is_request_rate_specified(self._model_parameters):
163163
return self._create_request_rate_list()
164164
else:
165165
return self._create_concurrency_list()
166166

167-
def _config_specifies_request_rate(self) -> bool:
168-
return self._model_parameters['request_rate'] or \
169-
self._cli_config.request_rate_search_enable or \
170-
self._cli_config.get_config()['run_config_search_min_request_rate'].is_set_by_user() or \
171-
self._cli_config.get_config()['run_config_search_max_request_rate'].is_set_by_user()
172-
173167
def _create_request_rate_list(self) -> List[int]:
174168
if self._model_parameters['request_rate']:
175169
return sorted(self._model_parameters['request_rate'])
@@ -205,7 +199,8 @@ def _generate_perf_configs(self) -> None:
205199

206200
new_perf_config.update_config(params)
207201

208-
if self._config_specifies_request_rate():
202+
if self._cli_config.is_request_rate_specified(
203+
self._model_parameters):
209204
new_perf_config.update_config(
210205
{'request-rate-range': parameter})
211206
else:
@@ -259,7 +254,8 @@ def _done_walking_parameters(self) -> bool:
259254
if self._early_exit_enable and not self._parameter_throughput_gain_valid(
260255
):
261256
if not self._parameter_warning_printed:
262-
if self._config_specifies_request_rate():
257+
if self._cli_config.is_request_rate_specified(
258+
self._model_parameters):
263259
logger.info(
264260
"No longer increasing request rate as throughput has plateaued"
265261
)
@@ -292,7 +288,7 @@ def _parameter_throughput_gain_valid(self) -> bool:
292288
""" Check if any of the last X parameter results resulted in valid gain """
293289
return PerfAnalyzerConfigGenerator.throughput_gain_valid_helper(
294290
throughputs=self._parameter_results,
295-
min_tries=THROUGHPUT_MINIMUM_CONSECUTIVE_CONCURRENCY_TRIES,
291+
min_tries=THROUGHPUT_MINIMUM_CONSECUTIVE_PARAMETER_TRIES,
296292
min_gain=THROUGHPUT_MINIMUM_GAIN)
297293

298294
def _batch_size_throughput_gain_valid(self) -> bool:

model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec
2828
from model_analyzer.result.result_manager import ResultManager
2929
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
30-
from model_analyzer.result.concurrency_search import ConcurrencySearch
30+
from model_analyzer.result.parameter_search import ParameterSearch
3131

3232
from model_analyzer.constants import LOGGER_NAME
3333

@@ -42,7 +42,8 @@
4242
class QuickPlusConcurrencySweepRunConfigGenerator(ConfigGeneratorInterface):
4343
"""
4444
First run QuickRunConfigGenerator for a hill climbing search, then use
45-
Brute for a concurrency sweep of the default and Top N results
45+
ParameterSearch for a concurrency sweep + binary search of the default
46+
and Top N results
4647
"""
4748

4849
def __init__(self, search_config: SearchConfig,
@@ -68,8 +69,6 @@ def __init__(self, search_config: SearchConfig,
6869
The object that handles storing and sorting the results from the perf analyzer
6970
model_variant_name_manager: ModelVariantNameManager
7071
Maps model variants to config names
71-
72-
model_variant_name_manager: ModelVariantNameManager
7372
"""
7473
self._search_config = search_config
7574
self._config = config
@@ -133,11 +132,11 @@ def _sweep_concurrency_over_top_results(
133132

134133
for result in top_results:
135134
run_config = deepcopy(result.run_config())
136-
concurrency_search = ConcurrencySearch(self._config)
137-
for concurrency in concurrency_search.search_concurrencies():
135+
parameter_search = ParameterSearch(self._config)
136+
for concurrency in parameter_search.search_parameters():
138137
run_config = self._set_concurrency(run_config, concurrency)
139138
yield run_config
140-
concurrency_search.add_run_config_measurement(
139+
parameter_search.add_run_config_measurement(
141140
self._last_measurement)
142141

143142
def _set_concurrency(self, run_config: RunConfig,

model_analyzer/config/generate/run_config_generator_factory.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from model_analyzer.result.result_manager import ResultManager
2626
from .brute_run_config_generator import BruteRunConfigGenerator
2727
from .quick_plus_concurrency_sweep_run_config_generator import QuickPlusConcurrencySweepRunConfigGenerator
28+
from .brute_plus_binary_parameter_search_run_config_generator import BrutePlusBinaryParameterSearchRunConfigGenerator
2829
from .search_dimensions import SearchDimensions
2930
from .search_dimension import SearchDimension
3031
from .search_config import SearchConfig
@@ -84,28 +85,31 @@ def create_run_config_generator(
8485
result_manager=result_manager,
8586
model_variant_name_manager=model_variant_name_manager)
8687
elif (command_config.run_config_search_mode == "brute"):
87-
return RunConfigGeneratorFactory._create_brute_run_config_generator(
88+
return RunConfigGeneratorFactory._create_brute_plus_binary_parameter_search_run_config_generator(
8889
command_config=command_config,
8990
gpus=gpus,
9091
models=new_models,
9192
client=client,
93+
result_manager=result_manager,
9294
model_variant_name_manager=model_variant_name_manager)
9395
else:
9496
raise TritonModelAnalyzerException(
9597
f"Unexpected search mode {command_config.run_config_search_mode}"
9698
)
9799

98100
@staticmethod
99-
def _create_brute_run_config_generator(
101+
def _create_brute_plus_binary_parameter_search_run_config_generator(
100102
command_config: ConfigCommandProfile, gpus: List[GPUDevice],
101103
models: List[ModelProfileSpec], client: TritonClient,
104+
result_manager: ResultManager,
102105
model_variant_name_manager: ModelVariantNameManager
103106
) -> ConfigGeneratorInterface:
104-
return BruteRunConfigGenerator(
107+
return BrutePlusBinaryParameterSearchRunConfigGenerator(
105108
config=command_config,
106109
gpus=gpus,
107110
models=models,
108111
client=client,
112+
result_manager=result_manager,
109113
model_variant_name_manager=model_variant_name_manager)
110114

111115
@staticmethod

model_analyzer/config/input/config_command_profile.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1239,4 +1239,13 @@ def _are_models_using_request_rate(self) -> bool:
12391239
raise TritonModelAnalyzerException("Parameters in all profiled models must use request-rate-range. "\
12401240
"Model Analyzer does not support mixing concurrency-range and request-rate-range.")
12411241
else:
1242-
return model_using_request_rate
1242+
return model_using_request_rate
1243+
1244+
def is_request_rate_specified(self, model_parameters: dict) -> bool:
1245+
"""
1246+
Returns true if either the model or the config specified request rate
1247+
"""
1248+
return 'request_rate' in model_parameters and model_parameters['request_rate'] or \
1249+
self.request_rate_search_enable or \
1250+
self.get_config()['run_config_search_min_request_rate'].is_set_by_user() or \
1251+
self.get_config()['run_config_search_max_request_rate'].is_set_by_user()

model_analyzer/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
# Run Search
3232
THROUGHPUT_MINIMUM_GAIN = 0.05
33-
THROUGHPUT_MINIMUM_CONSECUTIVE_CONCURRENCY_TRIES = 4
33+
THROUGHPUT_MINIMUM_CONSECUTIVE_PARAMETER_TRIES = 4
3434
THROUGHPUT_MINIMUM_CONSECUTIVE_BATCH_SIZE_TRIES = 4
3535

3636
# Quick search algorithm constants

0 commit comments

Comments
 (0)