|
18 | 18 | from model_navigator.configuration import ( |
19 | 19 | DEFAULT_TENSORRT_PRECISION, |
20 | 20 | DEFAULT_TENSORRT_PRECISION_MODE, |
| 21 | + AutocastType, |
21 | 22 | CustomConfigForFormat, |
22 | 23 | Format, |
23 | 24 | JitType, |
@@ -112,10 +113,34 @@ def test_torch_config_has_valid_name_and_format(): |
112 | 113 |
|
113 | 114 |
|
114 | 115 | def test_torch_config_defaults_reset_values_to_initial(): |
115 | | - config = TorchConfig(inference_mode=False, autocast=False, custom_args={"key": "value"}) |
| 116 | + config = TorchConfig( |
| 117 | + inference_mode=False, autocast=False, autocast_dtype=AutocastType.FP16, custom_args={"key": "value"} |
| 118 | + ) |
116 | 119 | config.defaults() |
117 | 120 | assert config.autocast is True |
118 | 121 | assert config.inference_mode is True |
| 122 | + assert config.autocast_dtype is AutocastType.DEVICE |
| 123 | + assert config.custom_args is None |
| 124 | + |
| 125 | + |
| 126 | +def test_torch_config_initialize_values_from_dict_correctly(): |
| 127 | + config = TorchConfig.from_dict({ |
| 128 | + "inference_mode": False, |
| 129 | + "autocast": False, |
| 130 | + "autocast_dtype": "torch.float16", |
| 131 | + "custom_args": {"key": "value"}, |
| 132 | + }) |
| 133 | + |
| 134 | + assert config.autocast is False |
| 135 | + assert config.inference_mode is False |
| 136 | + assert config.autocast_dtype is AutocastType.FP16 |
| 137 | + assert config.custom_args == {"key": "value"} |
| 138 | + |
| 139 | + config = TorchConfig.from_dict({}) |
| 140 | + |
| 141 | + assert config.autocast is True |
| 142 | + assert config.inference_mode is True |
| 143 | + assert config.autocast_dtype is AutocastType.DEVICE |
119 | 144 | assert config.custom_args is None |
120 | 145 |
|
121 | 146 |
|
|
0 commit comments