Skip to content

Commit 4a05d09

Browse files
authored
Added deepcopy to config_command (#505)
1 parent 8169200 commit 4a05d09

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

model_analyzer/config/input/config_command.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import yaml
1818
from .yaml_config_validator import YamlConfigValidator
1919

20+
from copy import deepcopy
21+
2022

2123
class ConfigCommand:
2224
"""
@@ -167,5 +169,13 @@ def get_all_config(self):
167169

168170
return config_dict
169171

172+
def __deepcopy__(self, memo):
173+
cls = self.__class__
174+
result = cls.__new__(cls)
175+
memo[id(self)] = result
176+
for k, v in self.__dict__.items():
177+
setattr(result, k, deepcopy(v, memo))
178+
return result
179+
170180
def __getattr__(self, name):
171181
return self._fields[name].value()

tests/test_config.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
from model_analyzer.constants import \
4545
CONFIG_PARSER_FAILURE
4646

47+
from copy import deepcopy
48+
4749
from unittest.mock import patch
4850

4951

@@ -1711,6 +1713,31 @@ def test_path_validation(self):
17111713
self._evaluate_config(args, yaml_content, subcommand='profile')
17121714
self.mock_os.set_os_path_exists_return_value(True)
17131715

1716+
def test_copy(self):
1717+
"""
1718+
Test that deepcopy works correctly
1719+
"""
1720+
args = [
1721+
'model-analyzer', 'profile', '--model-repository', 'cli_repository',
1722+
'-f', 'path-to-config-file', '--profile-models', 'vgg11'
1723+
]
1724+
yaml_content = 'model_repository: yaml_repository'
1725+
configA = self._evaluate_config(args, yaml_content)
1726+
1727+
configB = deepcopy(configA)
1728+
1729+
self._assert_equality_of_model_configs(
1730+
configA.get_all_config()['profile_models'],
1731+
configB.get_all_config()['profile_models'])
1732+
1733+
self.assertEqual(configA.run_config_search_mode,
1734+
configB.run_config_search_mode)
1735+
1736+
configB.run_config_search_mode = 'quick'
1737+
1738+
self.assertNotEqual(configA.run_config_search_mode,
1739+
configB.run_config_search_mode)
1740+
17141741

17151742
if __name__ == '__main__':
17161743
unittest.main()

0 commit comments

Comments
 (0)