Skip to content

Commit c4637b6

Browse files
committed
fix loops
1 parent 7571ae0 commit c4637b6

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

_unittests/ut_export/test_control_flow.py renamed to _unittests/ut_export/test_control_flow_onnx.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from onnx_diagnostic.export.api import to_onnx
1313

1414

15-
class TestControlFlow(ExtTestCase):
15+
class TestControlFlowOnnx(ExtTestCase):
1616
@never_test()
1717
def test_loop_one_research(self):
1818
class Model(torch.nn.Module):
@@ -70,7 +70,7 @@ def body(i, x):
7070
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
7171
)
7272
self.assertIn(
73-
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlow_test_loop_one_custom_L_Model_forward_L_body_",
73+
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_one_custom_L_Model_forward_L_body_",
7474
str(ep),
7575
)
7676

@@ -104,7 +104,7 @@ def body(i, x):
104104
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
105105
)
106106
self.assertIn(
107-
"torch.ops.onnx_higher_ops.loop_for_TestControlFlow_test_loop_one_custom_different_opset_L_Model_forward_L_body_",
107+
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_one_custom_different_opset_L_Model_forward_L_body_",
108108
str(ep),
109109
)
110110

@@ -142,7 +142,7 @@ def body(i, x):
142142
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
143143
)
144144
self.assertIn(
145-
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlow_test_loop_two_custom_L_Model_forward_L_body_",
145+
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_two_custom_L_Model_forward_L_body_",
146146
str(ep),
147147
)
148148

@@ -177,7 +177,7 @@ def body(i, x):
177177
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
178178
)
179179
self.assertIn(
180-
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlow_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_",
180+
"torch.ops.onnx_higher_ops.loop_for_onnx_TestControlFlowOnnx_test_loop_two_custom_reduction_dim_L_Model_forward_L_body_",
181181
str(ep),
182182
)
183183

onnx_diagnostic/export/control_flow_onnx.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,17 @@ def convert_custom_loop_into_onnx(
265265
nodes, graph.name, inputs, graph_outputs, graph.initializer, graph.sparse_initializer
266266
)
267267

268-
sequences = [g.op.SequenceEmpty() for _ in outputs]
268+
itypes = [
269+
graph.output[i].type.sequence_type.elem_type.tensor_type.elem_type
270+
for i in range(1, len(graph.output))
271+
]
272+
assert len(outputs) == len(
273+
itypes
274+
), f"Length mismatch between outputs={outputs} and graph.output={graph.output}"
275+
assert (
276+
0 not in itypes
277+
), f"Undefined types are not allowed in itype={itypes}, graph.output={graph.output}"
278+
sequences = [g.op.SequenceEmpty(dtype=itype) for itype in itypes]
269279

270280
outloop = [g.unique_name(f"loop_for_onnx{i}") for i in range(len(sequences))]
271281

@@ -285,8 +295,10 @@ def convert_custom_loop_into_onnx(
285295
]
286296
if not sts:
287297
for i, o in enumerate(outputs):
288-
g.set_type(o, graph.output[i].type.tensor_type.elem_type)
289-
g.set_rank(o, len(graph.output[i].type.tensor_type.shape.dims))
298+
g.set_type(o, graph.output[i].type.sequence_type.elem_type.tensor_type.elem_type)
299+
g.set_rank(
300+
o, len(graph.output[i].type.sequence_type.elem_type.tensor_type.shape.dims)
301+
)
290302
return outputs if len(outputs) > 1 else outputs[0]
291303

292304

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def append_output_sequence(
159159
"""
160160
if not tensors:
161161
# empty list
162-
self.nodes.append(oh.make_node("SequenceEmpty", [], [name]))
162+
self.nodes.append(
163+
oh.make_node("SequenceEmpty", [], [name], dtype=TensorProto.FLOAT)
164+
)
163165
tensor_type_proto = oh.make_tensor_type_proto(
164166
elem_type=TensorProto.FLOAT, shape=None
165167
)

0 commit comments

Comments
 (0)