Skip to content

Commit 228062b

Browse files
authored
Truncate model configuration name in plots (#757)
* Truncate longer model_config_name * Update lengths * Update truncate_model_config_name() * Add report_utils.py and test case * Fix Pre-commit errors
1 parent 7f06f57 commit 228062b

File tree

3 files changed

+105
-0
lines changed

3 files changed

+105
-0
lines changed

model_analyzer/plots/simple_plot.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from model_analyzer.perf_analyzer.perf_config import PerfAnalyzerConfig
2323
from model_analyzer.record.metrics_manager import MetricsManager
24+
from model_analyzer.reports.report_utils import truncate_model_config_name
2425

2526

2627
class SimplePlot:
@@ -156,6 +157,8 @@ def plot_data_and_constraints(self, constraints):
156157
list(t) for t in zip(*sorted(zip(data["x_data"], data["y_data"])))
157158
)
158159

160+
model_config_name = truncate_model_config_name(model_config_name)
161+
159162
if self._monotonic:
160163
filtered_x, filtered_y = [x_data[0]], [y_data[0]]
161164
for i in range(1, len(x_data)):
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
18+
def truncate_model_config_name(model_config_name):
19+
"""
20+
Truncates the model configuration name if its length exceeds the threshold length.
21+
ex: long_model_name_config_4 --> long_mod..._config_4
22+
Parameters
23+
----------
24+
model_config_name: string
25+
Returns
26+
-------
27+
string
28+
The truncated model configuration name,
29+
or the original name if it is shorter than the threshold length.
30+
"""
31+
max_model_config_name_len = 35
32+
33+
if len(model_config_name) > max_model_config_name_len:
34+
config_name = model_config_name[model_config_name.rfind("config_") :]
35+
36+
return (
37+
model_config_name[: (max_model_config_name_len - len(config_name) - 3)]
38+
+ "..."
39+
+ config_name
40+
)
41+
42+
return model_config_name

tests/test_report_utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import unittest
18+
from unittest.mock import patch
19+
20+
from model_analyzer.reports.report_utils import truncate_model_config_name
21+
22+
from .common import test_result_collector as trc
23+
24+
25+
class TestReportUtils(trc.TestResultCollector):
26+
"""
27+
Tests the report_utils functions
28+
"""
29+
30+
def tearDown(self):
31+
patch.stopall()
32+
33+
def test_truncate_model_config_name(self):
34+
"""
35+
Test the behavior of the truncate_longer_model_config_name function.
36+
"""
37+
38+
# Test: Shorter model config name (below 35 characters threshold)
39+
model_config_name = "ensemble_model_23"
40+
result = truncate_model_config_name(model_config_name)
41+
self.assertEqual(model_config_name, result)
42+
43+
# Test: Model config name ends with 'config_#'
44+
model_config_name = "long_pytorch_platform_handler_config_10"
45+
result = truncate_model_config_name(model_config_name)
46+
self.assertEqual(result, "long_pytorch_platform_h...config_10")
47+
48+
# Test: Model config name ends with 'config_default'.
49+
model_config_name = "long_pytorch_platform_handler_config_default"
50+
result = truncate_model_config_name(model_config_name)
51+
self.assertEqual(result, "long_pytorch_platf...config_default")
52+
53+
# Test: Model config name includes the "config" keyword in the model name
54+
model_config_name = "long_config_pytorch_platform_handler_config_128"
55+
result = truncate_model_config_name(model_config_name)
56+
self.assertEqual(result, "long_config_pytorch_pl...config_128")
57+
58+
59+
if __name__ == "__main__":
60+
unittest.main()

0 commit comments

Comments
 (0)