Skip to content

Commit a2f6914

Browse files
authored
Misc fixes (#799)
* Fix tests so that artifacts show up on failure * Fix bug in extract_model_name_from_variant_name * Discard auto-generated preferred_batch_size * prep for ensemble fix. Boyscout some types * fix type checking * PR feedback
1 parent dcc2424 commit a2f6914

File tree

7 files changed

+41
-8
lines changed

7 files changed

+41
-8
lines changed

model_analyzer/config/generate/base_model_config_generator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,13 @@ def extract_model_name_from_variant_name(variant_name: str) -> str:
309309
Removes '_config_#/default' from the variant name and returns
310310
the model name, eg. model_name_config_10 -> model_name
311311
"""
312-
return variant_name[: variant_name.find("_config_")]
312+
model_name = variant_name
313+
config_index = variant_name.find("_config_")
314+
315+
if config_index != -1:
316+
model_name = variant_name[:config_index]
317+
318+
return model_name
313319

314320
@staticmethod
315321
def create_original_config_from_variant(variant_config: ModelConfig) -> ModelConfig:

model_analyzer/config/run/run_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from typing import List
18+
1719
from model_analyzer.config.run.model_run_config import ModelRunConfig
1820

1921

@@ -33,15 +35,15 @@ def __init__(self, triton_env):
3335
"""
3436

3537
self._triton_env = triton_env
36-
self._model_run_configs = []
38+
self._model_run_configs: List[ModelRunConfig] = []
3739

3840
def add_model_run_config(self, model_run_config):
3941
"""
4042
Add a ModelRunConfig to this RunConfig
4143
"""
4244
self._model_run_configs.append(model_run_config)
4345

44-
def model_run_configs(self):
46+
def model_run_configs(self) -> List[ModelRunConfig]:
4547
"""
4648
Returns the list of ModelRunConfigs to run concurrently
4749
"""

model_analyzer/record/metrics_manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,11 @@ def _create_remote_mode_model_variant(
426426
# Ignore if the dir already exists
427427
pass
428428

429-
def _load_model_variants(self, run_config):
429+
def _load_model_variants(self, run_config: RunConfig) -> bool:
430430
"""
431431
Loads all model variants in the client
432432
"""
433+
# TODO TMA-1487: Make BLS and ensemble both load all composing model variants first
433434
for mrc in run_config.model_run_configs():
434435
if not self._load_model_variant(variant_config=mrc.model_config_variant()):
435436
return False
@@ -444,7 +445,7 @@ def _load_model_variants(self, run_config):
444445

445446
return True
446447

447-
def _load_model_variant(self, variant_config):
448+
def _load_model_variant(self, variant_config: ModelConfigVariant) -> bool:
448449
"""
449450
Conditionally loads a model variant in the client
450451
"""
@@ -458,7 +459,7 @@ def _load_model_variant(self, variant_config):
458459
retval = self._do_load_model_variant(variant_config)
459460
return retval
460461

461-
def _do_load_model_variant(self, variant_config):
462+
def _do_load_model_variant(self, variant_config: ModelConfigVariant) -> bool:
462463
"""
463464
Loads a model variant in the client
464465
"""

model_analyzer/triton/model/model_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,16 @@ def create_model_config_dict(config, client, gpus, model_repository, model_name)
9292
config = ModelConfig._get_default_config_from_server(
9393
config, client, gpus, model_name
9494
)
95+
96+
# An auto-completed triton model config will set preferred_batch_size
97+
# to a default value. We do not want to keep and honor that
98+
# value when we are searching, so we discard it here
99+
if (
100+
"dynamic_batching" in config
101+
and "preferred_batch_size" in config["dynamic_batching"]
102+
):
103+
del config["dynamic_batching"]["preferred_batch_size"]
104+
95105
else:
96106
ModelConfig._check_default_config_exceptions(config, model_path)
97107

qa/L0_multi_model_profile/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
source ../common/util.sh
17-
create_logs_dir
17+
create_logs_dir "L0_multi_model_profile"
1818

1919
# Set test parameters
2020
MODEL_ANALYZER="`which model-analyzer`"

qa/L0_perf_analyzer/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
source ../common/util.sh
17-
create_logs_dir
17+
create_logs_dir "L0_perf_analyzer"
1818

1919
# Set test parameters
2020
MODEL_ANALYZER="`which model-analyzer`"

tests/test_model_config_generator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,20 @@ def test_early_exit_on_manual(self):
807807
early_exit_enable=True)
808808
# yapf: enable
809809

810+
def test_extract_model_name_from_variant_name(self):
811+
input_output_pairs = {}
812+
input_output_pairs[
813+
"onnx_int32_int32_int32_config_default"
814+
] = "onnx_int32_int32_int32"
815+
input_output_pairs["onnx_int32_int32_int32_config_2"] = "onnx_int32_int32_int32"
816+
input_output_pairs["onnx_int32_int32_int32"] = "onnx_int32_int32_int32"
817+
818+
for variant_name, expected_model_name in input_output_pairs.items():
819+
model_name = BaseModelConfigGenerator.extract_model_name_from_variant_name(
820+
variant_name
821+
)
822+
self.assertEqual(model_name, expected_model_name)
823+
810824
def _run_and_test_model_config_generator(
811825
self,
812826
yaml_str,

0 commit comments

Comments
 (0)