Skip to content

Commit 329474b

Browse files
committed
more robust
1 parent 63abbbb commit 329474b

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
validate_model,
1616
filter_inputs,
1717
run_ort_fusion,
18+
empty,
1819
)
1920
from onnx_diagnostic.tasks import supported_tasks
2021

@@ -32,6 +33,9 @@ def test_get_inputs_for_task(self):
3233
self.assertIn("dynamic_shapes", data)
3334
copy.deepcopy(data["inputs"])
3435

36+
def test_empty(self):
37+
self.assertFalse(empty("float16"))
38+
3539
@hide_stdout()
3640
def test_validate_model(self):
3741
mid = "arnir0/Tiny-LLM"

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,9 +689,15 @@ def forward(self, input_ids):
689689
raise NotImplementedError(f"cls_name={cls_name}")
690690

691691

692-
def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
692+
def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
693693
"""Applies torch.to if applicable. Goes recursively."""
694694
if isinstance(value, (torch.nn.Module, torch.Tensor)):
695+
if (
696+
isinstance(to_value, torch.dtype)
697+
or to_value in {"float16", "bfloat16", "float32", "float64"}
698+
) and value.dtype in {torch.int32, torch.int64, torch.int8, torch.int16}:
699+
# int vector should not be changed.
700+
return value
695701
return value.to(to_value)
696702
if isinstance(value, list):
697703
return [to_any(t, to_value) for t in value]
@@ -712,11 +718,20 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
712718
)
713719
)
714720
)
721+
if value.__class__.__name__ == "EncoderDecoderCache":
722+
return make_encoder_decoder_cache(
723+
to_any(value.self_attention_cache, to_value),
724+
to_any(value.cross_attention_cache, to_value),
725+
)
715726
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
716727
args, spec = torch.utils._pytree.tree_flatten(value)
717728
new_args = to_any(args, to_value)
718729
return torch.utils._pytree.tree_unflatten(new_args, spec)
719730

731+
assert "Cache" not in value.__class__.__name__, (
732+
f"Class {value.__class__.__name__!r} should be registered "
733+
f"to be able to change the type in every tensor it contains."
734+
)
720735
assert not isinstance(value, Iterable), f"Unsupported type {type(value)}"
721736
return value
722737

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
def empty(value: Any) -> bool:
2424
"""Tells if the value is empty."""
2525
if isinstance(value, (str, list, dict, tuple, set)):
26-
return bool(value)
26+
return not bool(value)
2727
if value is None:
2828
return True
2929
return False

0 commit comments

Comments
 (0)