Skip to content

Commit ea91d48

Browse files
committed
Fixed initialization of defaults for autocast_dtype
1 parent 49a4bab commit ea91d48

File tree

4 files changed

+20
-2
lines changed

4 files changed

+20
-2
lines changed

model_navigator/configuration/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,10 @@ class TorchConfig(CustomConfigForFormat):
708708
inference_mode: bool = True
709709
custom_args: Optional[Dict[str, Any]] = None
710710

711+
def __post_init__(self):
712+
"""Post initialization to handle correctly enums."""
713+
self.autocast_dtype: AutocastType = AutocastType(self.autocast_dtype)
714+
711715
@property
712716
def format(self) -> Format:
713717
"""Returns Format.TORCH.
@@ -762,6 +766,7 @@ def __post_init__(self) -> None:
762766
"""Parse dataclass enums."""
763767
jit_type = (self.jit_type,) if not isinstance(self.jit_type, (list, tuple)) else self.jit_type
764768
self.jit_type = tuple(JitType(j) for j in jit_type)
769+
self.autocast_dtype: AutocastType = AutocastType(self.autocast_dtype)
765770

766771
@property
767772
def format(self) -> Format:
@@ -800,6 +805,10 @@ class TorchExportConfig(CustomConfigForFormat):
800805
autocast_dtype: AutocastType = AutocastType.DEVICE
801806
inference_mode: bool = True
802807

808+
def __post_init__(self):
809+
"""Post initialization to handle correctly enums."""
810+
self.autocast_dtype: AutocastType = AutocastType(self.autocast_dtype)
811+
803812
@property
804813
def format(self) -> Format:
805814
"""Returns Format.

tests/functional/L1_Inplace_Optimize_wo_MaxBS_Search_torch_HuggingFace_Whisper/test.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ THIS_SCRIPT_PATH="$(realpath --relative-to="$(pwd)" "$0")"
1919
TEST_MODULE="$(dirname "${THIS_SCRIPT_PATH}"|sed 's/\//./g').test"
2020

2121

22+
pip install --upgrade setuptools
2223
pip install "transformers<=4.38.2" datasets librosa --upgrade
23-
pip install "numpy<=1.24.0"
24+
pip install "numpy<2"
2425

2526
export NAVIGATOR_CONSOLE_OUTPUT=LOGS
2627
export NAVIGATOR_LOG_LEVEL=DEBUG

tests/unit/base/test_api_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ def test_torch_config_has_valid_name_and_format():
112112
assert config.custom_args is None
113113

114114

115+
def test_torch_config_handle_autocast_dtype_mapping():
116+
config = TorchConfig(autocast_dtype=None) # pytype: disable=wrong-arg-types
117+
assert config.autocast_dtype is AutocastType.DEVICE
118+
119+
config = TorchConfig(autocast_dtype="torch.float16") # pytype: disable=wrong-arg-types
120+
assert config.autocast_dtype is AutocastType.FP16
121+
122+
115123
def test_torch_config_defaults_reset_values_to_initial():
116124
config = TorchConfig(
117125
inference_mode=False, autocast=False, autocast_dtype=AutocastType.FP16, custom_args={"key": "value"}

tests/unit/base/test_custom_configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_default_values_for_torch_script_custom_config():
7575
assert config.format == Format.TORCHSCRIPT
7676

7777

78-
def test_default_values_for_torch_tensorrt_custom_confg():
78+
def test_default_values_for_torch_tensorrt_custom_config():
7979
torch_tensorrt_config = TorchTensorRTConfig()
8080
assert torch_tensorrt_config.precision == (
8181
TensorRTPrecision.FP32,

0 commit comments

Comments
 (0)