@@ -689,9 +689,22 @@ def forward(self, input_ids):
689689 raise NotImplementedError (f"cls_name={ cls_name } " )
690690
691691
692- def to_any (value : Any , to_value : Union [torch .dtype , torch .device ]) -> Any :
692+ def to_any (value : Any , to_value : Union [torch .dtype , torch .device , str ]) -> Any :
693693 """Applies torch.to if applicable. Goes recursively."""
694- if isinstance (value , (torch .nn .Module , torch .Tensor )):
694+ if isinstance (value , (torch .nn .Module , torch .Tensor )) and value .__class__ .__name__ not in {
695+ "DynamicCache" ,
696+ "EncoderDecoderCache" ,
697+ }:
698+ if (
699+ (
700+ isinstance (to_value , torch .dtype )
701+ or to_value in {"float16" , "bfloat16" , "float32" , "float64" }
702+ )
703+ and hasattr (value , "dtype" )
704+ and value .dtype in {torch .int32 , torch .int64 , torch .int8 , torch .int16 }
705+ ):
706+ # int vector should not be changed.
707+ return value
695708 return value .to (to_value )
696709 if isinstance (value , list ):
697710 return [to_any (t , to_value ) for t in value ]
@@ -701,8 +714,6 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
701714 return {to_any (t , to_value ) for t in value }
702715 if isinstance (value , dict ):
703716 return {k : to_any (t , to_value ) for k , t in value .items ()}
704- if hasattr (value , "to" ):
705- return value .to (to_value )
706717 if value .__class__ .__name__ == "DynamicCache" :
707718 return make_dynamic_cache (
708719 list (
@@ -712,11 +723,23 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
712723 )
713724 )
714725 )
726+ if value .__class__ .__name__ == "EncoderDecoderCache" :
727+ return make_encoder_decoder_cache (
728+ to_any (value .self_attention_cache , to_value ),
729+ to_any (value .cross_attention_cache , to_value ),
730+ )
715731 if value .__class__ in torch .utils ._pytree .SUPPORTED_NODES :
716732 args , spec = torch .utils ._pytree .tree_flatten (value )
717733 new_args = to_any (args , to_value )
718734 return torch .utils ._pytree .tree_unflatten (new_args , spec )
719735
736+ if hasattr (value , "to" ):
737+ return value .to (to_value )
738+
739+ assert "Cache" not in value .__class__ .__name__ , (
740+ f"Class { value .__class__ .__name__ !r} should be registered "
741+ f"to be able to change the type in every tensor it contains."
742+ )
720743 assert not isinstance (value , Iterable ), f"Unsupported type { type (value )} "
721744 return value
722745
0 commit comments