@@ -691,15 +691,17 @@ def forward(self, input_ids):
691691
692692def to_any (value : Any , to_value : Union [torch .dtype , torch .device , str ]) -> Any :
693693 """Applies torch.to if applicable. Goes recursively."""
694- if isinstance (value , (torch .nn .Module , torch .Tensor )):
694+ if isinstance (value , (torch .nn .Module , torch .Tensor )) and value .__class__ .__name__ not in {
695+ "DynamicCache" ,
696+ "EncoderDecoderCache" ,
697+ }:
695698 if (
696699 (
697700 isinstance (to_value , torch .dtype )
698701 or to_value in {"float16" , "bfloat16" , "float32" , "float64" }
699702 )
700703 and hasattr (value , "dtype" )
701704 and value .dtype in {torch .int32 , torch .int64 , torch .int8 , torch .int16 }
702- and value .__class__ .__name__ not in {"DynamicCache" , "EncoderDecoderCache" }
703705 ):
704706 # int vector should not be changed.
705707 return value
@@ -712,8 +714,6 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
712714 return {to_any (t , to_value ) for t in value }
713715 if isinstance (value , dict ):
714716 return {k : to_any (t , to_value ) for k , t in value .items ()}
715- if hasattr (value , "to" ):
716- return value .to (to_value )
717717 if value .__class__ .__name__ == "DynamicCache" :
718718 return make_dynamic_cache (
719719 list (
@@ -733,6 +733,9 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
733733 new_args = to_any (args , to_value )
734734 return torch .utils ._pytree .tree_unflatten (new_args , spec )
735735
736+ if hasattr (value , "to" ):
737+ return value .to (to_value )
738+
736739 assert "Cache" not in value .__class__ .__name__ , (
737740 f"Class { value .__class__ .__name__ !r} should be registered "
738741 f"to be able to change the type in every tensor it contains."
0 commit comments