@@ -689,9 +689,15 @@ def forward(self, input_ids):
689689 raise NotImplementedError (f"cls_name={ cls_name } " )
690690
691691
692- def to_any (value : Any , to_value : Union [torch .dtype , torch .device ]) -> Any :
692+ def to_any (value : Any , to_value : Union [torch .dtype , torch .device , str ]) -> Any :
693693 """Applies torch.to if applicable. Goes recursively."""
694694 if isinstance (value , (torch .nn .Module , torch .Tensor )):
695+ if (
696+ isinstance (to_value , torch .dtype )
697+ or to_value in {"float16" , "bfloat16" , "float32" , "float64" }
698+ ) and value .dtype in {torch .int32 , torch .int64 , torch .int8 , torch .int16 }:
699+ # int vector should not be changed.
700+ return value
695701 return value .to (to_value )
696702 if isinstance (value , list ):
697703 return [to_any (t , to_value ) for t in value ]
@@ -712,11 +718,20 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
712718 )
713719 )
714720 )
721+ if value .__class__ .__name__ == "EncoderDecoderCache" :
722+ return make_encoder_decoder_cache (
723+ to_any (value .self_attention_cache , to_value ),
724+ to_any (value .cross_attention_cache , to_value ),
725+ )
715726 if value .__class__ in torch .utils ._pytree .SUPPORTED_NODES :
716727 args , spec = torch .utils ._pytree .tree_flatten (value )
717728 new_args = to_any (args , to_value )
718729 return torch .utils ._pytree .tree_unflatten (new_args , spec )
719730
731+ assert "Cache" not in value .__class__ .__name__ , (
732+ f"Class { value .__class__ .__name__ !r} should be registered "
733+ f"to be able to change the type in every tensor it contains."
734+ )
720735 assert not isinstance (value , Iterable ), f"Unsupported type { type (value )} "
721736 return value
722737
0 commit comments