@@ -305,7 +305,7 @@ def forward(self, input_ids):
305305
306306def to_any (value : Any , to_value : Union [torch .dtype , torch .device ]) -> Any :
307307 """
308- Applies torch.to is applicables .
308+ Applies torch.to is applicable .
309309 Goes recursively.
310310 """
311311 if isinstance (value , (torch .nn .Module , torch .Tensor )):
@@ -329,6 +329,10 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
329329 )
330330 )
331331 )
332+ if value .__class__ in torch .utils ._pytree .SUPPORTED_NODES :
333+ args , spec = torch .utils ._pytree .tree_flatten (value )
334+ new_args = to_any (args , to_value )
335+ return torch .utils ._pytree .tree_unflatten (new_args , spec )
332336
333337 assert not isinstance (value , Iterable ), f"Unsupported type { type (value )} "
334338 return value
@@ -361,6 +365,11 @@ def torch_deepcopy(value: Any) -> Any:
361365 torch_deepcopy (value .self_attention_cache ),
362366 torch_deepcopy (value .cross_attention_cache ),
363367 )
368+ if value .__class__ in torch .utils ._pytree .SUPPORTED_NODES :
369+ args , spec = torch .utils ._pytree .tree_flatten (value )
370+ new_args = torch_deepcopy (args )
371+ return torch .utils ._pytree .tree_unflatten (new_args , spec )
372+
364373 # We should have a code using serialization, deserialization assuming a model
365374 # cannot be exported without them.
366375 raise NotImplementedError (f"torch_deepcopy not implemented for type { type (value )} " )
0 commit comments