Skip to content

Commit acf085f

Browse files
authored
Optuna Search Mode (Alpha) Release (#896)
* Adding cli option for optuna search (#867) * Adding cli option for optuna search * Changed RCS description * Class to hold info about parameters (#868) * Initial code for ConfigParameters class * Fixing codeql issue * Fixes based on review * Connect up parameter description class (#869) * Added hooks for creating search parameters with some basic unit testing * Adding more unit testing * Cleaning up codeql * Adding story ref for TODO * Changes based on review comments * Refactored ConfigParameters * Renaming to SearchParameter(s) * Moving unit testing into SearchParameters test class * Fix codeql issues * Creating Optuna RCG factory (#878) * Creating optuna RCG factory * fixing codeql issues * Removing metrics manager * Fixing mypy failure * Optuna Search Class (#877) * Base Optuna class plus unit testing * codeql fixes * more codeql fixes * Removing metrics manager * Removing metrics manager from Optuna RCG unit test * Removing client from quick/optuna RCGs * Changing gpus to gpu_count in quick/optuna RCGs * Removing magic number * Fixing codeql issue * Fixing optuna version * Adding todo comment about client batch size support * Using SearchParameters in OptunaRCG (#881) * Using SearchParameters in OptunaRCG * Fixing search parameter unit tests * Removing debug line * Changes based on PR * Adding call for default parameters * Added todo for dynamic batching * Add Percentage Search Space to Optuna (#882) * Added method for calculating total possible configurations * Added min/max percentage of search space to CLI * Connected up in optuna RCG * Added in support to cap optuna search based on a strict number of trials (#884) * Adding support for concurrency formula as an option in Optuna search (#885) * Fixing merge confilct * Adding --use-concurrency-formula to unit testing * Add Debug info to Optuna (#889) * Adding debug info + bug fixes * Fixes based on PR * Optuna Early Exit (#890) * Add logic to enable early exit along with CLI hooks. * Changes based on PR * Check that model supports dynamic batching when creating param_combo (#891) * Adding option to disable concurrency sweeping (#893) * Adding support for client batch size (#892) * Adding support for client batch size * Fixes based on PR * Removing redundant keys() * Fixing codeQL issue * Attempt to fix unittest issue * Removing 3.8 testing
1 parent ef12a85 commit acf085f

21 files changed

+2114
-81
lines changed

.github/workflows/python-package.yaml

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,21 @@ jobs:
3939
fail-fast: false
4040
matrix:
4141
os: ["ubuntu-22.04"]
42-
python-version: ["3.8", "3.11"]
42+
python-version: ["3.11"]
4343
env:
4444
SKIP_GPU_TESTS: 1
4545

4646
steps:
47-
- uses: actions/checkout@v3
48-
- name: Set up Python ${{ matrix.python-version }}
49-
uses: actions/setup-python@v3
50-
with:
51-
python-version: ${{ matrix.python-version }}
52-
- name: Install dependencies
53-
run: |
54-
python -m pip install --upgrade pip
55-
python -m pip install -e .
56-
- name: Test with unittest
57-
run: |
58-
pip install unittest-parallel
59-
python3 -m unittest_parallel -v -s ./tests -t .
47+
- uses: actions/checkout@v3
48+
- name: Set up Python ${{ matrix.python-version }}
49+
uses: actions/setup-python@v3
50+
with:
51+
python-version: ${{ matrix.python-version }}
52+
- name: Install dependencies
53+
run: |
54+
python -m pip install --upgrade pip
55+
python -m pip install -e .
56+
- name: Test with unittest
57+
run: |
58+
pip install unittest-parallel
59+
python3 -m unittest_parallel -v -s ./tests -t .

model_analyzer/analyzer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
import logging
1818
import sys
1919
from copy import deepcopy
20-
from typing import List, Optional, Union
20+
from typing import Dict, List, Optional, Union
2121

2222
from model_analyzer.cli.cli import CLI
2323
from model_analyzer.config.generate.base_model_config_generator import (
2424
BaseModelConfigGenerator,
2525
)
26+
from model_analyzer.config.generate.search_parameters import SearchParameters
2627
from model_analyzer.constants import LOGGER_NAME, PA_ERROR_LOG_FILENAME
2728
from model_analyzer.state.analyzer_state_manager import AnalyzerStateManager
2829
from model_analyzer.triton.server.server import TritonServer
@@ -82,6 +83,8 @@ def __init__(
8283
constraint_manager=self._constraint_manager,
8384
)
8485

86+
self._search_parameters: Dict[str, SearchParameters] = {}
87+
8588
def profile(
8689
self, client: TritonClient, gpus: List[GPUDevice], mode: str, verbose: bool
8790
) -> None:
@@ -115,6 +118,7 @@ def profile(
115118

116119
self._create_metrics_manager(client, gpus)
117120
self._create_model_manager(client, gpus)
121+
self._populate_search_parameters()
118122

119123
if self._config.triton_launch_mode == "remote":
120124
self._warn_if_other_models_loaded_on_remote_server(client)
@@ -200,6 +204,7 @@ def _create_model_manager(self, client, gpus):
200204
metrics_manager=self._metrics_manager,
201205
state_manager=self._state_manager,
202206
constraint_manager=self._constraint_manager,
207+
search_parameters=self._search_parameters,
203208
)
204209

205210
def _get_server_only_metrics(self, client, gpus):
@@ -414,3 +419,9 @@ def _warn_if_other_models_loaded_on_remote_server(self, client):
414419
f"A model not being profiled ({model_name}) is loaded on the remote Tritonserver. "
415420
"This could impact the profile results."
416421
)
422+
423+
def _populate_search_parameters(self):
424+
for model in self._config.profile_models:
425+
self._search_parameters[model.model_name()] = SearchParameters(
426+
self._config, model.parameters(), model.model_config_parameters()
427+
)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import logging
18+
from copy import deepcopy
19+
from typing import Dict, Generator, List, Optional
20+
21+
from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec
22+
from model_analyzer.config.generate.model_variant_name_manager import (
23+
ModelVariantNameManager,
24+
)
25+
from model_analyzer.config.generate.optuna_run_config_generator import (
26+
OptunaRunConfigGenerator,
27+
)
28+
from model_analyzer.config.generate.search_parameters import SearchParameters
29+
from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
30+
from model_analyzer.config.run.run_config import RunConfig
31+
from model_analyzer.constants import LOGGER_NAME
32+
from model_analyzer.result.parameter_search import ParameterSearch
33+
from model_analyzer.result.result_manager import ResultManager
34+
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
35+
36+
from .config_generator_interface import ConfigGeneratorInterface
37+
38+
logger = logging.getLogger(LOGGER_NAME)
39+
40+
41+
class OptunaPlusConcurrencySweepRunConfigGenerator(ConfigGeneratorInterface):
42+
"""
43+
First run OptunaConfigGenerator for an Optuna search, then use
44+
ParameterSearch for a concurrency sweep + binary search of the default
45+
and Top N results
46+
"""
47+
48+
def __init__(
49+
self,
50+
config: ConfigCommandProfile,
51+
gpu_count: int,
52+
models: List[ModelProfileSpec],
53+
result_manager: ResultManager,
54+
model_variant_name_manager: ModelVariantNameManager,
55+
search_parameters: Dict[str, SearchParameters],
56+
):
57+
"""
58+
Parameters
59+
----------
60+
config: ConfigCommandProfile
61+
Profile configuration information
62+
gpu_count: Number of gpus in the system
63+
models: List of ModelProfileSpec
64+
List of models to profile
65+
result_manager: ResultManager
66+
The object that handles storing and sorting the results from the perf analyzer
67+
model_variant_name_manager: ModelVariantNameManager
68+
Maps model variants to config names
69+
search_parameters: SearchParameters
70+
The object that handles the users configuration search parameters
71+
"""
72+
self._config = config
73+
self._gpu_count = gpu_count
74+
self._models = models
75+
self._result_manager = result_manager
76+
self._model_variant_name_manager = model_variant_name_manager
77+
self._search_parameters = search_parameters
78+
79+
def set_last_results(
80+
self, measurements: List[Optional[RunConfigMeasurement]]
81+
) -> None:
82+
self._last_measurement = measurements[-1]
83+
self._rcg.set_last_results(measurements)
84+
85+
def get_configs(self) -> Generator[RunConfig, None, None]:
86+
"""
87+
Returns
88+
-------
89+
RunConfig
90+
The next RunConfig generated by this class
91+
"""
92+
93+
logger.info("")
94+
logger.info("Starting Optuna mode search to find optimal configs")
95+
logger.info("")
96+
yield from self._execute_optuna_search()
97+
logger.info("")
98+
if self._config.concurrency_sweep_disable:
99+
logger.info("Done with Optuna mode search.")
100+
else:
101+
logger.info(
102+
"Done with Optuna mode search. Gathering concurrency sweep measurements for reports"
103+
)
104+
logger.info("")
105+
yield from self._sweep_concurrency_over_top_results()
106+
logger.info("")
107+
logger.info("Done gathering concurrency sweep measurements for reports")
108+
logger.info("")
109+
110+
def _execute_optuna_search(self) -> Generator[RunConfig, None, None]:
111+
self._rcg: ConfigGeneratorInterface = self._create_optuna_run_config_generator()
112+
113+
yield from self._rcg.get_configs()
114+
115+
def _create_optuna_run_config_generator(self) -> OptunaRunConfigGenerator:
116+
return OptunaRunConfigGenerator(
117+
config=self._config,
118+
gpu_count=self._gpu_count,
119+
models=self._models,
120+
model_variant_name_manager=self._model_variant_name_manager,
121+
search_parameters=self._search_parameters,
122+
)
123+
124+
def _sweep_concurrency_over_top_results(self) -> Generator[RunConfig, None, None]:
125+
for model_name in self._result_manager.get_model_names():
126+
top_results = self._result_manager.top_n_results(
127+
model_name=model_name,
128+
n=self._config.num_configs_per_model,
129+
include_default=True,
130+
)
131+
132+
for result in top_results:
133+
run_config = deepcopy(result.run_config())
134+
parameter_search = ParameterSearch(self._config)
135+
for concurrency in parameter_search.search_parameters():
136+
run_config = self._set_concurrency(run_config, concurrency)
137+
yield run_config
138+
parameter_search.add_run_config_measurement(self._last_measurement)
139+
140+
def _set_concurrency(self, run_config: RunConfig, concurrency: int) -> RunConfig:
141+
for model_run_config in run_config.model_run_configs():
142+
perf_config = model_run_config.perf_config()
143+
perf_config.update_config({"concurrency-range": concurrency})
144+
145+
return run_config

0 commit comments

Comments
 (0)