@@ -93,16 +93,7 @@ def activate_impl(self):
9393 self ._input_module_device = get_module_device (self .model ) or torch .device ("cpu" )
9494 self ._loaded_model = self .model
9595 self ._loaded_model .to (self .device ).eval ()
96-
97- # TODO: Consider better handling for controlling autocast behavior
98- try :
99- if hasattr (self ._loaded_model , "parameters" ) and self ._autocast_dtype is None :
100- param_dtype = next (self ._loaded_model .parameters ()).dtype
101- if param_dtype in [torch .bfloat16 , torch .int8 , torch .uint8 ]:
102- self ._autocast = False
103- LOGGER .warning (f"Model has { param_dtype } parameters. Disabling autocast." )
104- except StopIteration :
105- LOGGER .warning ("Model has no parameters." )
96+ self ._adjust_autocast ()
10697
10798 def deactivate_impl (self ):
10899 """Deactivation implementation."""
@@ -233,6 +224,17 @@ def _to_torch_tensor(self, value, dtype):
233224 raise ValueError (f"Unsupported type { type (value )} " )
234225 return value
235226
227+ def _adjust_autocast (self ):
228+ # TODO: Consider better handling for controlling autocast behavior
229+ try :
230+ if hasattr (self ._loaded_model , "parameters" ) and self ._autocast_dtype is None :
231+ param_dtype = next (self ._loaded_model .parameters ()).dtype
232+ if param_dtype in [torch .bfloat16 , torch .int8 , torch .uint8 ]:
233+ self ._autocast = False
234+ LOGGER .warning (f"Model has { param_dtype } parameters. Disabling autocast." )
235+ except StopIteration :
236+ LOGGER .warning ("Model has no parameters." )
237+
236238
237239class _BaseTorchScriptRunner (_BaseTorchRunner ):
238240 """Base runner for inference of TorchScript models."""
@@ -424,6 +426,7 @@ def activate_impl(self):
424426 exported_program = torch .export .load (str (self ._model ))
425427 self ._loaded_model = exported_program .module ()
426428 self ._loaded_model .to (self .device )
429+ self ._adjust_autocast ()
427430
428431
429432class TorchExportedProgramCPURunner (_BaseTorchExportedProgramRunner ):
0 commit comments