Skip to content

Commit b4f4d54

Browse files
committed
msg
1 parent e9c25fb commit b4f4d54

File tree

6 files changed

+27
-2
lines changed

6 files changed

+27
-2
lines changed

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,8 @@ def to_tensor(tensor: onnx.TensorProto, base_dir: str = "") -> torch.Tensor:
878878

879879
if tensor.HasField("raw_data"):
880880
raw_data = tensor.raw_data
881+
if len(raw_data) == 0:
882+
return torch.tensor([], dtype=torch_dtype).reshape(dims)
881883
if sys.byteorder == "big":
882884
# Convert endian from little to big
883885
raw_data = torch.frombuffer(raw_data, dtype=torch_dtype).byteswap().tobytes()

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,14 @@ def __init__(
117117
self.runtime_info = first_used_last_used(proto, constant_as_initializer=True)
118118
self.last_used: List[List[str]] = [[] for _ in self.kernels]
119119
for name, info in self.runtime_info.items():
120-
assert isinstance(info.last_used, int), f"Missing field last_used in {info!r}"
121-
if not info.is_output and not info.is_initializer:
120+
assert isinstance(info.last_used, int) or info.is_input, (
121+
f"Missing field last_used in {info!r}, last_used={info.last_used!r}, "
122+
f"This may mean the node is unused and it should be removed."
123+
)
124+
if info.last_used is None:
125+
# Not used.
126+
self.last_used[0].append(name)
127+
elif not info.is_output and not info.is_initializer:
122128
self.last_used[info.last_used].append(name)
123129

124130
@property

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
Abs_1,
3434
Cos_1,
3535
Exp_1,
36+
Identity_1,
3637
Log_1,
3738
Neg_1,
3839
Not_1,

onnx_diagnostic/reference/torch_ops/unary_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ def run(self, x: OpRunValue) -> OpRunValue:
2323
return OpRunValue(x.tensor.exp())
2424

2525

26+
class Identity_1(OpRun):
27+
"Identity"
28+
29+
def run(self, x: OpRunValue) -> OpRunValue:
30+
return OpRunValue(x.tensor)
31+
32+
2633
class Log_1(OpRun):
2734
"""Log"""
2835

onnx_diagnostic/torch_onnx/runtime_info.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ def is_output(self) -> bool:
113113
"Tells if it is an output."
114114
return self.kind == RuntimeValueKind.OUTPUT
115115

116+
@property
117+
def is_input(self) -> bool:
118+
"Tells if it is an input."
119+
return self.kind == RuntimeValueKind.INPUT
120+
116121
@property
117122
def is_initializer(self) -> bool:
118123
"Tells if it is an initializer."

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ disable_error_code = ["union-attr"]
4848
module = ["onnx_diagnostic.reference.torch_ops.*"]
4949
disable_error_code = ["override"]
5050

51+
[[tool.mypy.overrides]]
52+
module = ["onnx_diagnostic.reference.torch_ops.control_flow"]
53+
disable_error_code = ["name-defined"]
54+
5155
[[tool.mypy.overrides]]
5256
module = ["onnx_diagnostic.reference.torch_ops._op_run"]
5357
disable_error_code = ["name-defined"]

0 commit comments

Comments
 (0)