Skip to content

Commit 265cbae

Browse files
authored
Add Optuna Seed to Checkpoint (#913)
* Adding seed * Adding checkpointing of seed to optuna * Fixing issue if checkpoint doesn't contain optuna seed
1 parent 020ecb6 commit 265cbae

File tree

6 files changed

+62
-14
lines changed

6 files changed

+62
-14
lines changed

model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from model_analyzer.result.parameter_search import ParameterSearch
3333
from model_analyzer.result.result_manager import ResultManager
3434
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
35+
from model_analyzer.state.analyzer_state_manager import AnalyzerStateManager
3536

3637
from .config_generator_interface import ConfigGeneratorInterface
3738

@@ -48,6 +49,7 @@ class OptunaPlusConcurrencySweepRunConfigGenerator(ConfigGeneratorInterface):
4849
def __init__(
4950
self,
5051
config: ConfigCommandProfile,
52+
state_manager: AnalyzerStateManager,
5153
gpu_count: int,
5254
models: List[ModelProfileSpec],
5355
composing_models: List[ModelProfileSpec],
@@ -61,6 +63,8 @@ def __init__(
6163
----------
6264
config: ConfigCommandProfile
6365
Profile configuration information
66+
state_manager: AnalyzerStateManager
67+
The object that allows control and update of checkpoint state
6468
gpu_count: Number of gpus in the system
6569
models: List of ModelProfileSpec
6670
List of models to profile
@@ -76,6 +80,7 @@ def __init__(
7680
The object that handles the users configuration search parameters for composing models
7781
"""
7882
self._config = config
83+
self._state_manager = state_manager
7984
self._gpu_count = gpu_count
8085
self._models = models
8186
self._composing_models = composing_models
@@ -123,6 +128,7 @@ def _execute_optuna_search(self) -> Generator[RunConfig, None, None]:
123128
def _create_optuna_run_config_generator(self) -> OptunaRunConfigGenerator:
124129
return OptunaRunConfigGenerator(
125130
config=self._config,
131+
state_manager=self._state_manager,
126132
gpu_count=self._gpu_count,
127133
models=self._models,
128134
composing_models=self._composing_models,

model_analyzer/config/generate/optuna_run_config_generator.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import logging
18+
from random import randint
1819
from sys import maxsize
1920
from typing import Any, Dict, Generator, List, Optional, TypeAlias, Union
2021

@@ -43,6 +44,7 @@
4344
from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException
4445
from model_analyzer.perf_analyzer.perf_config import PerfAnalyzerConfig
4546
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
47+
from model_analyzer.state.analyzer_state_manager import AnalyzerStateManager
4648
from model_analyzer.triton.model.model_config import ModelConfig
4749
from model_analyzer.triton.model.model_config_variant import ModelConfigVariant
4850

@@ -82,19 +84,22 @@ class OptunaRunConfigGenerator(ConfigGeneratorInterface):
8284
def __init__(
8385
self,
8486
config: ConfigCommandProfile,
87+
state_manager: AnalyzerStateManager,
8588
gpu_count: int,
8689
models: List[ModelProfileSpec],
8790
composing_models: List[ModelProfileSpec],
8891
model_variant_name_manager: ModelVariantNameManager,
8992
search_parameters: Dict[str, SearchParameters],
9093
composing_search_parameters: Dict[str, SearchParameters],
91-
seed: Optional[int] = None,
94+
user_seed: Optional[int] = None,
9295
):
9396
"""
9497
Parameters
9598
----------
9699
config: ConfigCommandProfile
97100
Profile configuration information
101+
state_manager: AnalyzerStateManager
102+
The object that allows control and update of checkpoint state
98103
gpu_count: Number of gpus in the system
99104
models: List of ModelProfileSpec
100105
List of models to profile
@@ -105,8 +110,11 @@ def __init__(
105110
The object that handles the users configuration search parameters
106111
composing_search_parameters: SearchParameters
107112
The object that handles the users configuration search parameters for composing models
113+
user_seed: int
114+
The seed to use. If not provided, one will be generated (fresh run) or read from checkpoint
108115
"""
109116
self._config = config
117+
self._state_manager = state_manager
110118
self._gpu_count = gpu_count
111119
self._models = models
112120
self._composing_models = composing_models
@@ -132,10 +140,9 @@ def __init__(
132140

133141
self._done = False
134142

135-
if seed is not None:
136-
self._sampler = optuna.samplers.TPESampler(seed=seed)
137-
else:
138-
self._sampler = optuna.samplers.TPESampler()
143+
self._seed = self._create_seed(user_seed)
144+
145+
self._sampler = optuna.samplers.TPESampler(seed=self._seed)
139146

140147
self._study_name = ",".join([model.model_name() for model in self._models])
141148

@@ -145,6 +152,24 @@ def __init__(
145152
sampler=self._sampler,
146153
)
147154

155+
self._init_state()
156+
157+
def _get_seed(self) -> int:
158+
return self._state_manager.get_state_variable("OptunaRunConfigGenerator.seed")
159+
160+
def _create_seed(self, user_seed: Optional[int]) -> int:
161+
if self._state_manager.starting_fresh_run():
162+
seed = randint(0, 10000) if user_seed is None else user_seed
163+
else:
164+
seed = self._get_seed() if user_seed is None else user_seed
165+
166+
return seed
167+
168+
def _init_state(self) -> None:
169+
self._state_manager.set_state_variable(
170+
"OptunaRunConfigGenerator.seed", self._seed
171+
)
172+
148173
def _is_done(self) -> bool:
149174
return self._done
150175

model_analyzer/config/generate/run_config_generator_factory.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from model_analyzer.device.gpu_device import GPUDevice
3030
from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException
3131
from model_analyzer.result.result_manager import ResultManager
32+
from model_analyzer.state.analyzer_state_manager import AnalyzerStateManager
3233
from model_analyzer.triton.client.client import TritonClient
3334
from model_analyzer.triton.model.model_config import ModelConfig
3435

@@ -55,6 +56,7 @@ class RunConfigGeneratorFactory:
5556
@staticmethod
5657
def create_run_config_generator(
5758
command_config: ConfigCommandProfile,
59+
state_manager: AnalyzerStateManager,
5860
gpus: List[GPUDevice],
5961
models: List[ConfigModelProfileSpec],
6062
client: TritonClient,
@@ -68,6 +70,8 @@ def create_run_config_generator(
6870
----------
6971
command_config: ConfigCommandProfile
7072
The Model Analyzer config file for the profile step
73+
state_manager: AnalyzerStateManager
74+
The object that allows control and update of checkpoint state
7175
gpus: List of GPUDevices
7276
models: list of ConfigModelProfileSpec
7377
The models to generate RunConfigs for
@@ -107,6 +111,7 @@ def create_run_config_generator(
107111
if command_config.run_config_search_mode == "optuna":
108112
return RunConfigGeneratorFactory._create_optuna_plus_concurrency_sweep_run_config_generator(
109113
command_config=command_config,
114+
state_manager=state_manager,
110115
gpu_count=len(gpus),
111116
models=new_models,
112117
composing_models=composing_models,
@@ -159,6 +164,7 @@ def _create_brute_plus_binary_parameter_search_run_config_generator(
159164
@staticmethod
160165
def _create_optuna_plus_concurrency_sweep_run_config_generator(
161166
command_config: ConfigCommandProfile,
167+
state_manager: AnalyzerStateManager,
162168
gpu_count: int,
163169
models: List[ModelProfileSpec],
164170
composing_models: List[ModelProfileSpec],
@@ -169,6 +175,7 @@ def _create_optuna_plus_concurrency_sweep_run_config_generator(
169175
) -> ConfigGeneratorInterface:
170176
return OptunaPlusConcurrencySweepRunConfigGenerator(
171177
config=command_config,
178+
state_manager=state_manager,
172179
gpu_count=gpu_count,
173180
composing_models=composing_models,
174181
models=models,

model_analyzer/model_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def run_models(self, models: List[ConfigModelProfileSpec]) -> None:
136136

137137
rcg = RunConfigGeneratorFactory.create_run_config_generator(
138138
command_config=self._config,
139+
state_manager=self._state_manager,
139140
gpus=self._gpus,
140141
models=models,
141142
client=self._client,

model_analyzer/state/analyzer_state.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ def from_dict(cls, state_dict):
5959
# GPU data
6060
state._state_dict["MetricsManager.gpus"] = state_dict["MetricsManager.gpus"]
6161

62+
# Optuna Seed
63+
state._state_dict["OptunaRunConfigGenerator.seed"] = state_dict.get(
64+
"OptunaRunConfigGenerator.seed", 0
65+
)
66+
6267
return state
6368

6469
def get(self, name):

tests/test_optuna_run_config_generator.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,14 @@ def setUp(self):
6868

6969
self._rcg = OptunaRunConfigGenerator(
7070
config=config,
71+
state_manager=MagicMock(),
7172
gpu_count=1,
7273
models=self._mock_models,
7374
composing_models=[],
7475
model_variant_name_manager=ModelVariantNameManager(),
7576
search_parameters={"add_sub": search_parameters},
7677
composing_search_parameters={},
77-
seed=100,
78+
user_seed=100,
7879
)
7980

8081
def test_max_number_of_configs_to_search_percentage(self):
@@ -204,7 +205,7 @@ def test_create_objective_based_run_config(self):
204205

205206
self.assertEqual(model_config.to_dict()["name"], self._test_config_dict["name"])
206207

207-
# These values are the result of using a fixed seed of 100
208+
# These values are the result of using a fixed user_seed of 100
208209
self.assertEqual(model_config.to_dict()["maxBatchSize"], 16)
209210
self.assertEqual(model_config.to_dict()["instanceGroup"][0]["count"], 2)
210211
self.assertEqual(
@@ -232,13 +233,14 @@ def test_create_run_config_with_concurrency_formula(self):
232233

233234
rcg = OptunaRunConfigGenerator(
234235
config=config,
236+
state_manager=MagicMock(),
235237
gpu_count=1,
236238
models=self._mock_models,
237239
composing_models=[],
238240
model_variant_name_manager=ModelVariantNameManager(),
239241
search_parameters={"add_sub": search_parameters},
240242
composing_search_parameters={},
241-
seed=100,
243+
user_seed=100,
242244
)
243245

244246
trial = rcg._study.ask()
@@ -250,7 +252,7 @@ def test_create_run_config_with_concurrency_formula(self):
250252

251253
self.assertEqual(model_config.to_dict()["name"], self._test_config_dict["name"])
252254

253-
# These values are the result of using a fixed seed of 100
255+
# These values are the result of using a fixed user_seed of 100
254256
self.assertEqual(model_config.to_dict()["maxBatchSize"], 16)
255257
self.assertEqual(model_config.to_dict()["instanceGroup"][0]["count"], 2)
256258
self.assertEqual(
@@ -291,6 +293,7 @@ def test_create_run_bls_config(self):
291293
)
292294
rcg = OptunaRunConfigGenerator(
293295
config=config,
296+
state_manager=MagicMock(),
294297
gpu_count=1,
295298
models=[bls_model],
296299
composing_models=[add_model, sub_model],
@@ -300,7 +303,7 @@ def test_create_run_bls_config(self):
300303
"add": add_search_parameters,
301304
"sub": sub_search_parameters,
302305
},
303-
seed=100,
306+
user_seed=100,
304307
)
305308

306309
trial = rcg._study.ask()
@@ -315,7 +318,7 @@ def test_create_run_bls_config(self):
315318
sub_model_config = run_config.model_run_configs()[0].composing_configs()[1]
316319
perf_config = run_config.model_run_configs()[0].perf_config()
317320

318-
# BLS (Top Level Model) + PA Config (Seed=100)
321+
# BLS (Top Level Model) + PA Config (user_seed=100)
319322
# =====================================================================
320323
self.assertEqual(bls_model_config.to_dict()["name"], "bls")
321324
self.assertEqual(bls_model_config.to_dict()["instanceGroup"][0]["count"], 3)
@@ -364,6 +367,7 @@ def test_create_run_multi_model_config(self):
364367
)
365368
rcg = OptunaRunConfigGenerator(
366369
config=config,
370+
state_manager=MagicMock(),
367371
gpu_count=1,
368372
models=[add_model, vgg_model],
369373
composing_models=[],
@@ -373,7 +377,7 @@ def test_create_run_multi_model_config(self):
373377
"vgg19_libtorch": vgg_search_parameters,
374378
},
375379
composing_search_parameters={},
376-
seed=100,
380+
user_seed=100,
377381
)
378382

379383
trial = rcg._study.ask()
@@ -388,7 +392,7 @@ def test_create_run_multi_model_config(self):
388392
add_perf_config = run_config.model_run_configs()[0].perf_config()
389393
vgg_perf_config = run_config.model_run_configs()[0].perf_config()
390394

391-
# ADD_SUB + PA Config (Seed=100)
395+
# ADD_SUB + PA Config (user_seed=100)
392396
# =====================================================================
393397
self.assertEqual(add_model_config.to_dict()["name"], "add_sub")
394398
self.assertEqual(add_model_config.to_dict()["maxBatchSize"], 16)
@@ -400,7 +404,7 @@ def test_create_run_multi_model_config(self):
400404
self.assertEqual(add_perf_config["batch-size"], DEFAULT_BATCH_SIZES)
401405
self.assertEqual(add_perf_config["concurrency-range"], 16)
402406

403-
# VGG19_LIBTORCH + PA Config (Seed=100)
407+
# VGG19_LIBTORCH + PA Config (user_seed=100)
404408
# =====================================================================
405409
self.assertEqual(vgg_model_config.to_dict()["name"], "vgg19_libtorch")
406410
self.assertEqual(vgg_model_config.to_dict()["instanceGroup"][0]["count"], 4)

0 commit comments

Comments
 (0)