Skip to content

Commit 1237844

Browse files
authored
Creating new generator class for concurrency sweeping (#921)
* Creating new generator class for concurrency sweeping * Removing unused import
1 parent 08590a2 commit 1237844

File tree

3 files changed

+83
-48
lines changed

3 files changed

+83
-48
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import logging
18+
from copy import deepcopy
19+
from typing import Generator, List, Optional
20+
21+
from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
22+
from model_analyzer.config.run.run_config import RunConfig
23+
from model_analyzer.constants import LOGGER_NAME
24+
from model_analyzer.result.parameter_search import ParameterSearch
25+
from model_analyzer.result.result_manager import ResultManager
26+
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
27+
28+
logger = logging.getLogger(LOGGER_NAME)
29+
30+
31+
class ConcurrencySweeper:
32+
"""
33+
Sweeps concurrency for the top-N model configs
34+
"""
35+
36+
def __init__(
37+
self,
38+
config: ConfigCommandProfile,
39+
result_manager: ResultManager,
40+
):
41+
self._config = config
42+
self._result_manager = result_manager
43+
self._last_measurement: Optional[RunConfigMeasurement] = None
44+
45+
def set_last_results(
46+
self, measurements: List[Optional[RunConfigMeasurement]]
47+
) -> None:
48+
self._last_measurement = measurements[-1]
49+
50+
def get_configs(self) -> Generator[RunConfig, None, None]:
51+
"""
52+
A generator which creates RunConfigs based on sweeping
53+
concurrency over the top-N models
54+
"""
55+
for model_name in self._result_manager.get_model_names():
56+
top_results = self._result_manager.top_n_results(
57+
model_name=model_name,
58+
n=self._config.num_configs_per_model,
59+
include_default=True,
60+
)
61+
62+
for result in top_results:
63+
run_config = deepcopy(result.run_config())
64+
parameter_search = ParameterSearch(self._config)
65+
for concurrency in parameter_search.search_parameters():
66+
run_config = self._create_run_config(run_config, concurrency)
67+
yield run_config
68+
parameter_search.add_run_config_measurement(self._last_measurement)
69+
70+
def _create_run_config(self, run_config: RunConfig, concurrency: int) -> RunConfig:
71+
for model_run_config in run_config.model_run_configs():
72+
perf_config = model_run_config.perf_config()
73+
perf_config.update_config({"concurrency-range": concurrency})
74+
75+
return run_config

model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from copy import deepcopy
1919
from typing import Dict, Generator, List, Optional
2020

21+
from model_analyzer.config.generate.concurrency_sweeper import ConcurrencySweeper
2122
from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec
2223
from model_analyzer.config.generate.model_variant_name_manager import (
2324
ModelVariantNameManager,
@@ -115,7 +116,9 @@ def get_configs(self) -> Generator[RunConfig, None, None]:
115116
"Done with Optuna mode search. Gathering concurrency sweep measurements for reports"
116117
)
117118
logger.info("")
118-
yield from self._sweep_concurrency_over_top_results()
119+
yield from ConcurrencySweeper(
120+
config=self._config, result_manager=self._result_manager
121+
).get_configs()
119122
logger.info("")
120123
logger.info("Done gathering concurrency sweep measurements for reports")
121124
logger.info("")
@@ -136,26 +139,3 @@ def _create_optuna_run_config_generator(self) -> OptunaRunConfigGenerator:
136139
search_parameters=self._search_parameters,
137140
composing_search_parameters=self._composing_search_parameters,
138141
)
139-
140-
def _sweep_concurrency_over_top_results(self) -> Generator[RunConfig, None, None]:
141-
for model_name in self._result_manager.get_model_names():
142-
top_results = self._result_manager.top_n_results(
143-
model_name=model_name,
144-
n=self._config.num_configs_per_model,
145-
include_default=True,
146-
)
147-
148-
for result in top_results:
149-
run_config = deepcopy(result.run_config())
150-
parameter_search = ParameterSearch(self._config)
151-
for concurrency in parameter_search.search_parameters():
152-
run_config = self._set_concurrency(run_config, concurrency)
153-
yield run_config
154-
parameter_search.add_run_config_measurement(self._last_measurement)
155-
156-
def _set_concurrency(self, run_config: RunConfig, concurrency: int) -> RunConfig:
157-
for model_run_config in run_config.model_run_configs():
158-
perf_config = model_run_config.perf_config()
159-
perf_config.update_config({"concurrency-range": concurrency})
160-
161-
return run_config

model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from copy import deepcopy
1919
from typing import Generator, List, Optional
2020

21+
from model_analyzer.config.generate.concurrency_sweeper import ConcurrencySweeper
2122
from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec
2223
from model_analyzer.config.generate.model_variant_name_manager import (
2324
ModelVariantNameManager,
@@ -106,7 +107,9 @@ def get_configs(self) -> Generator[RunConfig, None, None]:
106107
"Done with quick mode search. Gathering concurrency sweep measurements for reports"
107108
)
108109
logger.info("")
109-
yield from self._sweep_concurrency_over_top_results()
110+
yield from ConcurrencySweeper(
111+
config=self._config, result_manager=self._result_manager
112+
).get_configs()
110113
logger.info("")
111114
logger.info("Done gathering concurrency sweep measurements for reports")
112115
logger.info("")
@@ -125,26 +128,3 @@ def _create_quick_run_config_generator(self) -> QuickRunConfigGenerator:
125128
composing_models=self._composing_models,
126129
model_variant_name_manager=self._model_variant_name_manager,
127130
)
128-
129-
def _sweep_concurrency_over_top_results(self) -> Generator[RunConfig, None, None]:
130-
for model_name in self._result_manager.get_model_names():
131-
top_results = self._result_manager.top_n_results(
132-
model_name=model_name,
133-
n=self._config.num_configs_per_model,
134-
include_default=True,
135-
)
136-
137-
for result in top_results:
138-
run_config = deepcopy(result.run_config())
139-
parameter_search = ParameterSearch(self._config)
140-
for concurrency in parameter_search.search_parameters():
141-
run_config = self._set_concurrency(run_config, concurrency)
142-
yield run_config
143-
parameter_search.add_run_config_measurement(self._last_measurement)
144-
145-
def _set_concurrency(self, run_config: RunConfig, concurrency: int) -> RunConfig:
146-
for model_run_config in run_config.model_run_configs():
147-
perf_config = model_run_config.perf_config()
148-
perf_config.update_config({"concurrency-range": concurrency})
149-
150-
return run_config

0 commit comments

Comments
 (0)