Skip to content

Commit 01ac364

Browse files
committed
Move to autocast adjustment to separate method
1 parent 9e13ebc commit 01ac364

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

model_navigator/runners/torch.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,7 @@ def activate_impl(self):
9393
self._input_module_device = get_module_device(self.model) or torch.device("cpu")
9494
self._loaded_model = self.model
9595
self._loaded_model.to(self.device).eval()
96-
97-
# TODO: Consider better handling for controlling autocast behavior
98-
try:
99-
if hasattr(self._loaded_model, "parameters") and self._autocast_dtype is None:
100-
param_dtype = next(self._loaded_model.parameters()).dtype
101-
if param_dtype in [torch.bfloat16, torch.int8, torch.uint8]:
102-
self._autocast = False
103-
LOGGER.warning(f"Model has {param_dtype} parameters. Disabling autocast.")
104-
except StopIteration:
105-
LOGGER.warning("Model has no parameters.")
96+
self._adjust_autocast()
10697

10798
def deactivate_impl(self):
10899
"""Deactivation implementation."""
@@ -233,6 +224,17 @@ def _to_torch_tensor(self, value, dtype):
233224
raise ValueError(f"Unsupported type {type(value)}")
234225
return value
235226

227+
def _adjust_autocast(self):
228+
# TODO: Consider better handling for controlling autocast behavior
229+
try:
230+
if hasattr(self._loaded_model, "parameters") and self._autocast_dtype is None:
231+
param_dtype = next(self._loaded_model.parameters()).dtype
232+
if param_dtype in [torch.bfloat16, torch.int8, torch.uint8]:
233+
self._autocast = False
234+
LOGGER.warning(f"Model has {param_dtype} parameters. Disabling autocast.")
235+
except StopIteration:
236+
LOGGER.warning("Model has no parameters.")
237+
236238

237239
class _BaseTorchScriptRunner(_BaseTorchRunner):
238240
"""Base runner for inference of TorchScript models."""
@@ -424,6 +426,7 @@ def activate_impl(self):
424426
exported_program = torch.export.load(str(self._model))
425427
self._loaded_model = exported_program.module()
426428
self._loaded_model.to(self.device)
429+
self._adjust_autocast()
427430

428431

429432
class TorchExportedProgramCPURunner(_BaseTorchExportedProgramRunner):

0 commit comments

Comments
 (0)