Skip to content

Commit 814a5cf

Browse files
committed
assert
1 parent 3ddec0a commit 814a5cf

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -693,9 +693,13 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
693693
"""Applies torch.to if applicable. Goes recursively."""
694694
if isinstance(value, (torch.nn.Module, torch.Tensor)):
695695
if (
696-
isinstance(to_value, torch.dtype)
697-
or to_value in {"float16", "bfloat16", "float32", "float64"}
698-
) and value.dtype in {torch.int32, torch.int64, torch.int8, torch.int16}:
696+
(
697+
isinstance(to_value, torch.dtype)
698+
or to_value in {"float16", "bfloat16", "float32", "float64"}
699+
)
700+
and hasattr(value, "dtype")
701+
and value.dtype in {torch.int32, torch.int64, torch.int8, torch.int16}
702+
):
699703
# int vector should not be changed.
700704
return value
701705
return value.to(to_value)

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,9 @@ def torch_export_patches(
188188
if rewrite:
189189
from .patch_module import torch_export_rewrite
190190

191-
with torch_export_rewrite( # type: ignore[var-annotated]
191+
with torch_export_rewrite(
192192
rewrite=rewrite, dump_rewriting=dump_rewriting, verbose=verbose
193-
), torch_export_patches(
193+
), torch_export_patches( # type: ignore[var-annotated]
194194
patch_sympy=patch_sympy,
195195
patch_torch=patch_torch,
196196
patch_transformers=patch_transformers,

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,21 @@ def _rewrite_if(
177177
else_ret = else_exprs[0]
178178
then_exprs = [n for n in node.body if not isinstance(n, ast.Return)]
179179
else_exprs = [n for n in node.orelse if not isinstance(n, ast.Return)]
180-
assert type(then_ret.value) is type(else_ret.value), (
181-
f"Inconsistencies return then value={then_ret.value}, "
182-
f"else value={else_ret.value}"
180+
is_tuple_or_list = (
181+
isinstance(then_ret, (ast.Tuple, ast.List)),
182+
isinstance(else_ret, (ast.Tuple, ast.List)),
183183
)
184-
if isinstance(then_ret.value, (ast.Tuple, ast.list)):
185-
assert len(then_ret.value.elts) == len(else_ret.value.elts), (
184+
assert len(set(is_tuple_or_list)) == 1, (
185+
f"is_tuple_or_list={is_tuple_or_list}, inconsistencies return "
186+
f"then value={then_ret}, "
187+
f"else value={else_ret}"
188+
)
189+
if is_tuple_or_list[0]:
190+
assert len(then_ret.elts) == len(else_ret.elts), (
186191
f"Unexpected number of elements on both branches, "
187-
f"then:{then_ret.value.elts}, else:{else_ret.value.elts}"
192+
f"then:{then_ret.elts}, else:{else_ret.elts}"
188193
)
189-
n_returned_values = len(then_ret.value.elts)
194+
n_returned_values = len(then_ret.elts)
190195
else:
191196
n_returned_values = 0
192197
else:

0 commit comments

Comments
 (0)