Skip to content

Commit db573e3

Browse files
committed
fix
1 parent b3c411f commit db573e3

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -691,15 +691,17 @@ def forward(self, input_ids):
691691

692692
def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
693693
"""Applies torch.to if applicable. Goes recursively."""
694-
if isinstance(value, (torch.nn.Module, torch.Tensor)):
694+
if isinstance(value, (torch.nn.Module, torch.Tensor)) and value.__class__.__name__ not in {
695+
"DynamicCache",
696+
"EncoderDecoderCache",
697+
}:
695698
if (
696699
(
697700
isinstance(to_value, torch.dtype)
698701
or to_value in {"float16", "bfloat16", "float32", "float64"}
699702
)
700703
and hasattr(value, "dtype")
701704
and value.dtype in {torch.int32, torch.int64, torch.int8, torch.int16}
702-
and value.__class__.__name__ not in {"DynamicCache", "EncoderDecoderCache"}
703705
):
704706
# int vector should not be changed.
705707
return value
@@ -712,8 +714,6 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
712714
return {to_any(t, to_value) for t in value}
713715
if isinstance(value, dict):
714716
return {k: to_any(t, to_value) for k, t in value.items()}
715-
if hasattr(value, "to"):
716-
return value.to(to_value)
717717
if value.__class__.__name__ == "DynamicCache":
718718
return make_dynamic_cache(
719719
list(
@@ -733,6 +733,9 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
733733
new_args = to_any(args, to_value)
734734
return torch.utils._pytree.tree_unflatten(new_args, spec)
735735

736+
if hasattr(value, "to"):
737+
return value.to(to_value)
738+
736739
assert "Cache" not in value.__class__.__name__, (
737740
f"Class {value.__class__.__name__!r} should be registered "
738741
f"to be able to change the type in every tensor it contains."

0 commit comments

Comments
 (0)