Skip to content

Commit 6660706

Browse files
committed
if
1 parent 8fb4e3b commit 6660706

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,65 @@ def test_local_function(self):
897897
atol=1e-6,
898898
)
899899

900+
def test_if(self):
901+
def _mkv_(name):
902+
value_info_proto = onnx.ValueInfoProto()
903+
value_info_proto.name = name
904+
return value_info_proto
905+
906+
model = oh.make_model(
907+
oh.make_graph(
908+
[
909+
oh.make_node("ReduceSum", ["X"], ["Xred"]),
910+
oh.make_node("Add", ["X", "two"], ["X0"]),
911+
oh.make_node("Add", ["X0", "zero"], ["X00"]),
912+
oh.make_node("CastLike", ["one", "Xred"], ["one_c"]),
913+
oh.make_node("Greater", ["Xred", "one_c"], ["cond"]),
914+
oh.make_node(
915+
"If",
916+
["cond"],
917+
["Z_c"],
918+
then_branch=oh.make_graph(
919+
[
920+
oh.make_node("Constant", [], ["t2"], value_floats=[2.1]),
921+
oh.make_node("Add", ["X00", "t2"], ["Y"]),
922+
],
923+
"then",
924+
[],
925+
[_mkv_("Y")],
926+
),
927+
else_branch=oh.make_graph(
928+
[
929+
oh.make_node("Constant", [], ["t2"], value_floats=[2.2]),
930+
oh.make_node("Sub", ["X0", "t2"], ["Y"]),
931+
],
932+
"else",
933+
[],
934+
[_mkv_("Y")],
935+
),
936+
),
937+
oh.make_node("CastLike", ["Z_c", "X"], ["Z"]),
938+
],
939+
"test",
940+
[
941+
oh.make_tensor_value_info("X", TFLOAT, ["N"]),
942+
oh.make_tensor_value_info("one", TFLOAT, ["N"]),
943+
],
944+
[oh.make_tensor_value_info("Z", onnx.TensorProto.UNDEFINED, ["N"])],
945+
[
946+
onh.from_array(np.array([0], dtype=np.float32), name="zero"),
947+
onh.from_array(np.array([2], dtype=np.float32), name="two"),
948+
],
949+
),
950+
opset_imports=[oh.make_operatorsetid("", 18)],
951+
ir_version=10,
952+
)
953+
self._finalize_test(
954+
model,
955+
torch.tensor([1, 2, 3], dtype=torch.float32),
956+
torch.tensor([1], dtype=torch.float32),
957+
)
958+
900959

901960
if __name__ == "__main__":
902961
unittest.main(verbosity=2)

onnx_diagnostic/reference/torch_ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from .generator_ops import Range_11
1919
from .nn_ops import LayerNormalization_17, Softmax_13, Tanh_6
20-
from .other_ops import Cast_6, Concat_1, Transpose_1, Trilu_14, Where_9
20+
from .other_ops import Cast_6, CastLike_15, Concat_1, Transpose_1, Trilu_14, Where_9
2121
from .reduce_ops import ReduceMax_18, ReduceMean_18, ReduceMin_18, ReduceSum_18
2222
from .shape_ops import (
2323
ConstantOfShape_9,

onnx_diagnostic/reference/torch_ops/other_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,25 @@ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
1313
to = self.get_attribute_int(node, "to", 0)
1414
assert isinstance(to, int), f"Unexpected value for attribute to={to!r}"
1515
self.to = onnx_dtype_to_torch_dtype(to)
16+
self.saturate = self.get_attribute_int(node, "saturate", 1)
17+
assert self.saturate == 1, f"saturate={self.saturate} not implemented for Cast"
1618

1719
def run(self, data: OpRunValue) -> OpRunValue:
1820
return OpRunValue(data.tensor.to(self.to))
1921

2022

23+
class CastLike_15(OpRun):
24+
"Cast"
25+
26+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
27+
super().__init__(node, version)
28+
self.saturate = self.get_attribute_int(node, "saturate", 1)
29+
assert self.saturate == 1, f"saturate={self.saturate} not implemented for CastLike"
30+
31+
def run(self, data: OpRunValue, like: OpRunValue) -> OpRunValue:
32+
return OpRunValue(data.tensor.to(like.tensor.dtype))
33+
34+
2135
class Concat_1(OpRun):
2236
"Concat"
2337

0 commit comments

Comments
 (0)