Skip to content

Commit db3b67c

Browse files
committed
Handles sequences
1 parent a37a6d5 commit db3b67c

File tree

16 files changed

+381
-12
lines changed

16 files changed

+381
-12
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.reference.torch_ops.control_flow_ops
3+
====================================================
4+
5+
.. automodule:: onnx_diagnostic.reference.torch_ops.control_flow_ops
6+
:members:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.reference.torch_ops.generator_ops
3+
=================================================
4+
5+
.. automodule:: onnx_diagnostic.reference.torch_ops.generator_ops
6+
:members:

_doc/api/reference/torch_ops/index.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ onnx_diagnostic.reference.torch_ops
99

1010
access_ops
1111
binary_ops
12+
control_flow_ops
13+
generator_ops
14+
nn_ops
15+
other_ops
16+
reduce_ops
17+
sequence_ops
18+
shape_ops
19+
unary_ops
1220

1321
OpRun
1422
+++++
@@ -22,6 +30,12 @@ OpRunValue
2230
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRunValue
2331
:members:
2432

33+
OpRunValueSequence
34+
++++++++++++++++++
35+
36+
.. autoclass:: onnx_diagnostic.reference.torch_ops.OpRunValueSequence
37+
:members:
38+
2539
OpRunFunction
2640
+++++++++++++
2741

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.reference.torch_ops.nn_ops
3+
==========================================
4+
5+
.. automodule:: onnx_diagnostic.reference.torch_ops.nn_ops
6+
:members:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.reference.torch_ops.other_ops
3+
=============================================
4+
5+
.. automodule:: onnx_diagnostic.reference.torch_ops.other_ops
6+
:members:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.reference.torch_ops.reduce_ops
3+
==============================================
4+
5+
.. automodule:: onnx_diagnostic.reference.torch_ops.reduce_ops
6+
:members:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.reference.torch_ops.sequence_ops
3+
================================================
4+
5+
.. automodule:: onnx_diagnostic.reference.torch_ops.sequence_ops
6+
:members:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.reference.torch_ops.shape_ops
3+
=============================================
4+
5+
.. automodule:: onnx_diagnostic.reference.torch_ops.shape_ops
6+
:members:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.reference.torch_ops.unary_ops
3+
=============================================
4+
5+
.. automodule:: onnx_diagnostic.reference.torch_ops.unary_ops
6+
:members:

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,119 @@ def _mkv_(name):
956956
torch.tensor([1], dtype=torch.float32),
957957
)
958958

959+
def test_loop(self):
960+
x = np.array([1, 2, 3, 4, 5]).astype(np.float32)
961+
962+
model = oh.make_model(
963+
graph=oh.make_graph(
964+
name="loop_test",
965+
inputs=[
966+
oh.make_tensor_value_info("trip_count", TINT64, ["a"]),
967+
oh.make_tensor_value_info("cond", onnx.TensorProto.BOOL, [1]),
968+
],
969+
outputs=[oh.make_tensor_value_info("res", TFLOAT, [])],
970+
nodes=[
971+
oh.make_node("SequenceEmpty", [], ["seq_empty"], dtype=TFLOAT),
972+
oh.make_node(
973+
"Loop",
974+
inputs=["trip_count", "cond", "seq_empty"],
975+
outputs=["seq_res"],
976+
body=oh.make_graph(
977+
[
978+
oh.make_node(
979+
"Identity", inputs=["cond_in"], outputs=["cond_out"]
980+
),
981+
oh.make_node(
982+
"Constant",
983+
inputs=[],
984+
outputs=["x"],
985+
value=oh.make_tensor(
986+
name="const_tensor_x",
987+
data_type=TFLOAT,
988+
dims=x.shape,
989+
vals=x.flatten().astype(float),
990+
),
991+
),
992+
oh.make_node(
993+
"Constant",
994+
inputs=[],
995+
outputs=["one"],
996+
value=oh.make_tensor(
997+
name="const_tensor_one",
998+
data_type=TINT64,
999+
dims=(),
1000+
vals=[1],
1001+
),
1002+
),
1003+
oh.make_node(
1004+
"Constant",
1005+
inputs=[],
1006+
outputs=["slice_start"],
1007+
value=oh.make_tensor(
1008+
name="const_tensor_zero",
1009+
data_type=TINT64,
1010+
dims=(1,),
1011+
vals=[0],
1012+
),
1013+
),
1014+
oh.make_node(
1015+
"Add", inputs=["iter_count", "one"], outputs=["end"]
1016+
),
1017+
oh.make_node(
1018+
"Constant",
1019+
inputs=[],
1020+
outputs=["axes"],
1021+
value=oh.make_tensor(
1022+
name="const_tensor_axes",
1023+
data_type=TINT64,
1024+
dims=(1,),
1025+
vals=[0],
1026+
),
1027+
),
1028+
oh.make_node(
1029+
"Unsqueeze", inputs=["end", "axes"], outputs=["slice_end"]
1030+
),
1031+
oh.make_node(
1032+
"Slice",
1033+
inputs=["x", "slice_start", "slice_end"],
1034+
outputs=["slice_out"],
1035+
),
1036+
oh.make_node(
1037+
"SequenceInsert",
1038+
inputs=["seq_in", "slice_out"],
1039+
outputs=["seq_out"],
1040+
),
1041+
],
1042+
"loop_body",
1043+
[
1044+
oh.make_tensor_value_info("iter_count", TINT64, []),
1045+
oh.make_tensor_value_info(
1046+
"cond_in", onnx.TensorProto.BOOL, []
1047+
),
1048+
oh.make_tensor_sequence_value_info("seq_in", TFLOAT, None),
1049+
],
1050+
[
1051+
oh.make_tensor_value_info(
1052+
"cond_out", onnx.TensorProto.BOOL, []
1053+
),
1054+
oh.make_tensor_sequence_value_info("seq_out", TFLOAT, None),
1055+
],
1056+
),
1057+
),
1058+
oh.make_node(
1059+
"ConcatFromSequence",
1060+
inputs=["seq_res"],
1061+
outputs=["res"],
1062+
axis=0,
1063+
new_axis=0,
1064+
),
1065+
],
1066+
)
1067+
)
1068+
self._finalize_test(
1069+
model, torch.tensor(5, dtype=torch.int64), torch.tensor(1, dtype=torch.bool)
1070+
)
1071+
9591072

9601073
if __name__ == "__main__":
9611074
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)