Skip to content

Commit 49a4bab

Browse files
committed
Fixed parsing TorchConfig from dict
1 parent 431618e commit 49a4bab

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

model_navigator/configuration/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,16 @@ def defaults(self) -> None:
729729
self.inference_mode = True
730730
self.custom_args = None
731731

732+
@classmethod
733+
def from_dict(cls, config_dict: Dict[str, Any]) -> "TorchConfig":
734+
"""Instantiate TorchConfig from a dictionary."""
735+
return cls(
736+
autocast=config_dict.get("autocast", True),
737+
autocast_dtype=AutocastType(config_dict.get("autocast_dtype", AutocastType.DEVICE)),
738+
inference_mode=config_dict.get("inference_mode", True),
739+
custom_args=config_dict.get("custom_args"),
740+
)
741+
732742

733743
@dataclasses.dataclass
734744
class TorchScriptConfig(CustomConfigForFormat):

tests/unit/base/test_api_config.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from model_navigator.configuration import (
1919
DEFAULT_TENSORRT_PRECISION,
2020
DEFAULT_TENSORRT_PRECISION_MODE,
21+
AutocastType,
2122
CustomConfigForFormat,
2223
Format,
2324
JitType,
@@ -112,10 +113,34 @@ def test_torch_config_has_valid_name_and_format():
112113

113114

114115
def test_torch_config_defaults_reset_values_to_initial():
115-
config = TorchConfig(inference_mode=False, autocast=False, custom_args={"key": "value"})
116+
config = TorchConfig(
117+
inference_mode=False, autocast=False, autocast_dtype=AutocastType.FP16, custom_args={"key": "value"}
118+
)
116119
config.defaults()
117120
assert config.autocast is True
118121
assert config.inference_mode is True
122+
assert config.autocast_dtype is AutocastType.DEVICE
123+
assert config.custom_args is None
124+
125+
126+
def test_torch_config_initialize_values_from_dict_correctly():
127+
config = TorchConfig.from_dict({
128+
"inference_mode": False,
129+
"autocast": False,
130+
"autocast_dtype": "torch.float16",
131+
"custom_args": {"key": "value"},
132+
})
133+
134+
assert config.autocast is False
135+
assert config.inference_mode is False
136+
assert config.autocast_dtype is AutocastType.FP16
137+
assert config.custom_args == {"key": "value"}
138+
139+
config = TorchConfig.from_dict({})
140+
141+
assert config.autocast is True
142+
assert config.inference_mode is True
143+
assert config.autocast_dtype is AutocastType.DEVICE
119144
assert config.custom_args is None
120145

121146

0 commit comments

Comments
 (0)