Skip to content

Commit 9d86539

Browse files
authored
Add Binary Concurrency Search to Quick Search (#679)
* Config Search Class (#675) * Framework of class and testing env created * Testing and logic for objective saturation * Adding check to ensure measurements are added * Scaffolding completed for binary search * Binary search code and testing * Fixing type checking * Refactoring * Full sweep before bcs * Minor refactoring * Changes based on Tim's review * Adding TMA to fixme * Adding config option of max BCS steps (#676) * Adding new config option for max BCS steps * Adding documentation * Changing config name * Integrating concurrency search into quick search (#677) * Replacing magic numbers with constants * Adding constraints to L0 quick search. Added checks for cases when PA returns no result (#678) * Fixing QL import error * Fixing another QL error * Using config's max binary search steps instead of default
1 parent ac56b06 commit 9d86539

File tree

9 files changed

+494
-21
lines changed

9 files changed

+494
-21
lines changed

docs/config.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ bls_composing_models: <comma-delimited-string-list>
212212
# Maximum request rate used for the automatic/quick config search
213213
[ run_config_search_max_request_rate: <int> | default: 8092 ]
214214
215+
# Maximum number of steps taken during a binary search
216+
[ run_config_search_max_binary_search_steps: <int> | default: 5 ]
217+
215218
# Disables automatic config search
216219
[ run_config_search_disable: <bool> | default: false ]
217220

model_analyzer/config/generate/perf_analyzer_config_generator.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,14 @@ def _done_walking_parameters(self) -> bool:
259259
if self._early_exit_enable and not self._parameter_throughput_gain_valid(
260260
):
261261
if not self._parameter_warning_printed:
262-
logger.info(
263-
"No longer increasing concurrency as throughput has plateaued"
264-
)
262+
if self._config_specifies_request_rate():
263+
logger.info(
264+
"No longer increasing request rate as throughput has plateaued"
265+
)
266+
else:
267+
logger.info(
268+
"No longer increasing concurrency as throughput has plateaued"
269+
)
265270
self._parameter_warning_printed = True
266271
return True
267272
return False

model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec
2828
from model_analyzer.result.result_manager import ResultManager
2929
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
30+
from model_analyzer.result.concurrency_search import ConcurrencySearch
3031

3132
from model_analyzer.constants import LOGGER_NAME
3233

@@ -130,26 +131,14 @@ def _sweep_concurrency_over_top_results(
130131
n=self._config.num_configs_per_model,
131132
include_default=True)
132133

133-
for count, result in enumerate(top_results):
134+
for result in top_results:
134135
run_config = deepcopy(result.run_config())
135-
136-
max_concurrency_index = int(
137-
log2(self._config.run_config_search_max_concurrency))
138-
139-
run_config_measurements = []
140-
for concurrency in (
141-
2**i for i in range(0, max_concurrency_index + 1)):
136+
concurrency_search = ConcurrencySearch(self._config)
137+
for concurrency in concurrency_search.search_concurrencies():
142138
run_config = self._set_concurrency(run_config, concurrency)
143139
yield run_config
144-
145-
run_config_measurements.append(self._last_measurement)
146-
147-
if not PerfAnalyzerConfigGenerator.throughput_gain_valid_helper(
148-
throughputs=run_config_measurements):
149-
logger.info(
150-
"Terminating concurrency sweep - throughput is decreasing"
151-
)
152-
break
140+
concurrency_search.add_run_config_measurement(
141+
self._last_measurement)
153142

154143
def _set_concurrency(self, run_config: RunConfig,
155144
concurrency: int) -> RunConfig:

model_analyzer/config/input/config_command_profile.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
DEFAULT_PERF_OUTPUT_FLAG, DEFAULT_RUN_CONFIG_MAX_CONCURRENCY, DEFAULT_RUN_CONFIG_MIN_CONCURRENCY, \
3737
DEFAULT_RUN_CONFIG_MAX_REQUEST_RATE, DEFAULT_RUN_CONFIG_MIN_REQUEST_RATE, \
3838
DEFAULT_RUN_CONFIG_PROFILE_MODELS_CONCURRENTLY_ENABLE, DEFAULT_RUN_CONFIG_SEARCH_MODE, \
39-
DEFAULT_REQUEST_RATE_SEARCH_ENABLE, \
39+
DEFAULT_RUN_CONFIG_MAX_BINARY_SEARCH_STEPS, DEFAULT_REQUEST_RATE_SEARCH_ENABLE, \
4040
DEFAULT_RUN_CONFIG_MAX_INSTANCE_COUNT, DEFAULT_RUN_CONFIG_MIN_INSTANCE_COUNT, \
4141
DEFAULT_RUN_CONFIG_MAX_MODEL_BATCH_SIZE, DEFAULT_RUN_CONFIG_MIN_MODEL_BATCH_SIZE, \
4242
DEFAULT_RUN_CONFIG_SEARCH_DISABLE, DEFAULT_TRITON_DOCKER_IMAGE, DEFAULT_TRITON_GRPC_ENDPOINT, \
@@ -596,6 +596,15 @@ def _add_run_search_configs(self):
596596
description=
597597
"Value for the model's max_batch_size that run config search will start from."
598598
))
599+
self._add_config(
600+
ConfigField(
601+
'run_config_search_max_binary_search_steps',
602+
flags=['--run-config-search-max-binary-search-steps'],
603+
field_type=ConfigPrimitive(int),
604+
default_value=DEFAULT_RUN_CONFIG_MAX_BINARY_SEARCH_STEPS,
605+
description=
606+
"Maximum number of steps take during the binary concurrency search."
607+
))
599608
self._add_config(
600609
ConfigField(
601610
'run_config_search_mode',

model_analyzer/config/input/config_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
DEFAULT_RUN_CONFIG_MIN_INSTANCE_COUNT = 1
4848
DEFAULT_RUN_CONFIG_MIN_MODEL_BATCH_SIZE = 1
4949
DEFAULT_RUN_CONFIG_MAX_MODEL_BATCH_SIZE = 128
50+
DEFAULT_RUN_CONFIG_MAX_BINARY_SEARCH_STEPS = 5
5051
DEFAULT_RUN_CONFIG_SEARCH_DISABLE = False
5152
DEFAULT_RUN_CONFIG_SEARCH_MODE = 'brute'
5253
DEFAULT_RUN_CONFIG_PROFILE_MODELS_CONCURRENTLY_ENABLE = False
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Optional, Generator
16+
17+
from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException
18+
from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
19+
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
20+
21+
from math import log2
22+
23+
import logging
24+
from model_analyzer.constants import LOGGER_NAME, THROUGHPUT_MINIMUM_GAIN, THROUGHPUT_MINIMUM_CONSECUTIVE_CONCURRENCY_TRIES
25+
26+
logger = logging.getLogger(LOGGER_NAME)
27+
28+
29+
class ConcurrencySearch():
30+
"""
31+
Generates the next concurrency value to use when searching through
32+
RunConfigMeasurements for the best value (according to the users objective)
33+
- Will sweep from by powers of two from min to max concurrency
34+
- If the user specifies a constraint, the algorithm will perform a binary search
35+
around the boundary if the constraint is violated
36+
37+
Invariant: It is necessary for the user to add new measurements as they are taken
38+
"""
39+
40+
def __init__(self, config: ConfigCommandProfile) -> None:
41+
"""
42+
Parameters
43+
----------
44+
config: ConfigCommandProfile
45+
Profile configuration information
46+
"""
47+
self._min_concurrency_index = int(
48+
log2(config.run_config_search_min_concurrency))
49+
self._max_concurrency_index = int(
50+
log2(config.run_config_search_max_concurrency))
51+
self._max_binary_search_steps = config.run_config_search_max_binary_search_steps
52+
53+
self._run_config_measurements: List[Optional[RunConfigMeasurement]] = []
54+
self._concurrencies: List[int] = []
55+
self._last_failing_concurrency = 0
56+
self._last_passing_concurrency = 0
57+
58+
def add_run_config_measurement(
59+
self,
60+
run_config_measurement: Optional[RunConfigMeasurement]) -> None:
61+
"""
62+
Adds a new RunConfigMeasurement
63+
Invariant: Assumed that RCMs are added in the same order they are measured
64+
"""
65+
self._run_config_measurements.append(run_config_measurement)
66+
67+
def search_concurrencies(self) -> Generator[int, None, None]:
68+
"""
69+
First performs a concurrency sweep, and then, if necessary, perform
70+
a binary concurrency search around the point where the constraint
71+
violated
72+
"""
73+
yield from self._perform_concurrency_sweep()
74+
75+
if self._was_constraint_violated():
76+
yield from self._perform_binary_concurrency_search()
77+
78+
def _perform_concurrency_sweep(self) -> Generator[int, None, None]:
79+
for concurrency in (2**i for i in range(
80+
self._min_concurrency_index, self._max_concurrency_index + 1)):
81+
if self._should_continue_concurrency_sweep():
82+
self._concurrencies.append(concurrency)
83+
yield concurrency
84+
else:
85+
logger.info(
86+
"Terminating concurrency sweep - throughput is decreasing")
87+
return
88+
89+
def _should_continue_concurrency_sweep(self) -> bool:
90+
self._check_measurement_count()
91+
92+
if not self._are_minimum_tries_reached():
93+
return True
94+
else:
95+
return not self._has_objective_gain_saturated()
96+
97+
def _check_measurement_count(self) -> None:
98+
if len(self._run_config_measurements) != len(self._concurrencies):
99+
raise TritonModelAnalyzerException(f"Internal Measurement count: {self._concurrencies}, doesn't match number " \
100+
f"of measurements added: {len(self._run_config_measurements)}.")
101+
102+
def _are_minimum_tries_reached(self) -> bool:
103+
if len(self._run_config_measurements
104+
) < THROUGHPUT_MINIMUM_CONSECUTIVE_CONCURRENCY_TRIES:
105+
return False
106+
else:
107+
return True
108+
109+
def _has_objective_gain_saturated(self) -> bool:
110+
gain = self._calculate_gain()
111+
return gain < THROUGHPUT_MINIMUM_GAIN
112+
113+
def _calculate_gain(self) -> float:
114+
first_rcm = self._run_config_measurements[
115+
-THROUGHPUT_MINIMUM_CONSECUTIVE_CONCURRENCY_TRIES]
116+
117+
best_rcm = self._get_best_rcm()
118+
119+
# These cover the cases where we don't get a result from PA
120+
if not first_rcm and not best_rcm:
121+
return 0
122+
if not first_rcm:
123+
return 1
124+
elif not best_rcm:
125+
return -1
126+
else:
127+
gain = first_rcm.compare_measurements(best_rcm)
128+
129+
return gain
130+
131+
def _get_best_rcm(self) -> Optional[RunConfigMeasurement]:
132+
# Need to remove entries (None) with no result from PA before sorting
133+
pruned_rcms = [
134+
rcm for rcm in self._run_config_measurements[
135+
-THROUGHPUT_MINIMUM_CONSECUTIVE_CONCURRENCY_TRIES:] if rcm
136+
]
137+
best_rcm = max(pruned_rcms) if pruned_rcms else None
138+
139+
return best_rcm
140+
141+
def _was_constraint_violated(self) -> bool:
142+
for i in range(len(self._run_config_measurements) - 1, 1, -1):
143+
if self._at_constraint_failure_boundary(i):
144+
self._last_failing_concurrency = self._concurrencies[i]
145+
self._last_passing_concurrency = self._concurrencies[i - 1]
146+
return True
147+
148+
if self._run_config_measurements[
149+
0] and not self._run_config_measurements[
150+
0].is_passing_constraints():
151+
self._last_failing_concurrency = self._concurrencies[i]
152+
self._last_passing_concurrency = 0
153+
return True
154+
else:
155+
return False
156+
157+
def _at_constraint_failure_boundary(self, index: int) -> bool:
158+
if not self._run_config_measurements[
159+
index] or not self._run_config_measurements[index - 1]:
160+
return False
161+
162+
at_failure_boundary = not self._run_config_measurements[ # type: ignore
163+
index].is_passing_constraints() and self._run_config_measurements[
164+
index - # type: ignore
165+
1].is_passing_constraints()
166+
167+
return at_failure_boundary
168+
169+
def _perform_binary_concurrency_search(self) -> Generator[int, None, None]:
170+
# This is needed because we are going to restart the search from the
171+
# concurrency that failed - so we expect this to be at the end of the list
172+
self._concurrencies.append(self._last_failing_concurrency)
173+
174+
for i in range(0, self._max_binary_search_steps):
175+
concurrency = self._determine_next_binary_concurrency()
176+
177+
if concurrency != self._concurrencies[-1]:
178+
self._concurrencies.append(concurrency)
179+
yield concurrency
180+
181+
def _determine_next_binary_concurrency(self) -> int:
182+
if not self._run_config_measurements[-1]:
183+
return 0
184+
185+
if self._run_config_measurements[-1].is_passing_constraints():
186+
self._last_passing_concurrency = self._concurrencies[-1]
187+
concurrency = int(
188+
(self._last_failing_concurrency + self._concurrencies[-1]) / 2)
189+
else:
190+
self._last_failing_concurrency = self._concurrencies[-1]
191+
concurrency = int(
192+
(self._last_passing_concurrency + self._concurrencies[-1]) / 2)
193+
194+
return concurrency

qa/L0_quick_search/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ MODEL_ANALYZER_ARGS="$MODEL_ANALYZER_ARGS --triton-metrics-url http://localhost:
5252
MODEL_ANALYZER_ARGS="$MODEL_ANALYZER_ARGS --output-model-repository-path $OUTPUT_MODEL_REPOSITORY --override-output-model-repository"
5353
MODEL_ANALYZER_ARGS="$MODEL_ANALYZER_ARGS -e $EXPORT_PATH --filename-server-only=$FILENAME_SERVER_ONLY"
5454
MODEL_ANALYZER_ARGS="$MODEL_ANALYZER_ARGS --filename-model-inference=$FILENAME_INFERENCE_MODEL --filename-model-gpu=$FILENAME_GPU_MODEL"
55+
MODEL_ANALYZER_ARGS="$MODEL_ANALYZER_ARGS --latency-budget 10"
5556
MODEL_ANALYZER_ARGS="$MODEL_ANALYZER_ARGS --skip-summary-reports"
5657
MODEL_ANALYZER_SUBCOMMAND="profile"
5758
run_analyzer

tests/test_cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def get_test_options():
8585
OptionStruct("int", "profile", "--run-config-search-max-model-batch-size", None, "100", "128"),
8686
OptionStruct("int", "profile", "--run-config-search-min-instance-count", None, "2", "1"),
8787
OptionStruct("int", "profile", "--run-config-search-max-instance-count", None, "10", "5"),
88+
OptionStruct("int", "profile", "--run-config-search-max-binary-search-steps", None, "10", "5"),
8889
OptionStruct("float", "profile", "--monitoring-interval", "-i", "10.0", "1.0"),
8990
OptionStruct("float", "profile", "--perf-analyzer-cpu-util", None, "10.0", str(psutil.cpu_count() * 80.0)),
9091
OptionStruct("int", "profile", "--num-configs-per-model", None, "10", "3"),

0 commit comments

Comments
 (0)