Skip to content

Commit 11d65c3

Browse files
committed
if code coverage
1 parent 525ae26 commit 11d65c3

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

_unittests/ut_reference/test_ort_evaluator.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import ml_dtypes
55
from onnx import ModelProto, TensorProto
66
from onnx.checker import check_model
7+
import onnx
78
import onnx.helper as oh
89
import onnx.numpy_helper as onh
910
import torch
@@ -248,6 +249,71 @@ def test_init_torch_bfloat16(self):
248249
self.assertIsInstance(got[0], (torch.Tensor, np.ndarray))
249250
self.assertEqualArray(expected[0], got[0])
250251

252+
@hide_stdout()
253+
def test_if(self):
254+
255+
def _mkv_(name):
256+
value_info_proto = onnx.ValueInfoProto()
257+
value_info_proto.name = name
258+
return value_info_proto
259+
260+
model = oh.make_model(
261+
oh.make_graph(
262+
[
263+
oh.make_node("ReduceSum", ["X"], ["Xred"]),
264+
oh.make_node("Add", ["X", "two"], ["X0"]),
265+
oh.make_node("Add", ["X0", "zero"], ["X00"]),
266+
oh.make_node("CastLike", ["one", "Xred"], ["one_c"]),
267+
oh.make_node("Greater", ["Xred", "one_c"], ["cond"]),
268+
oh.make_node(
269+
"If",
270+
["cond"],
271+
["Z_c"],
272+
then_branch=oh.make_graph(
273+
[
274+
oh.make_node("Constant", [], ["two"], value_floats=[2.1]),
275+
oh.make_node("Add", ["X00", "two"], ["Y"]),
276+
],
277+
"then",
278+
[],
279+
[_mkv_("Y")],
280+
),
281+
else_branch=oh.make_graph(
282+
[
283+
oh.make_node("Constant", [], ["two"], value_floats=[2.2]),
284+
oh.make_node("Sub", ["X0", "two"], ["Y"]),
285+
],
286+
"else",
287+
[],
288+
[_mkv_("Y")],
289+
),
290+
),
291+
oh.make_node("CastLike", ["Z_c", "X"], ["Z"]),
292+
],
293+
"test",
294+
[
295+
oh.make_tensor_value_info("X", TensorProto.FLOAT, ["N"]),
296+
oh.make_tensor_value_info("one", TensorProto.FLOAT, ["N"]),
297+
],
298+
[oh.make_tensor_value_info("Z", TensorProto.UNDEFINED, ["N"])],
299+
[
300+
onh.from_array(np.array([0], dtype=np.float32), name="zero"),
301+
onh.from_array(np.array([2], dtype=np.float32), name="two"),
302+
],
303+
),
304+
opset_imports=[oh.make_operatorsetid("", 18)],
305+
ir_version=10,
306+
)
307+
feeds = {
308+
"X": np.array([1, 2, 3], dtype=np.float32),
309+
"one": np.array([1], dtype=np.float32),
310+
}
311+
ref = ExtendedReferenceEvaluator(model, verbose=10)
312+
expected = ref.run(None, feeds)[0]
313+
sess = OnnxruntimeEvaluator(model, verbose=10)
314+
got = sess.run(None, feeds)[0]
315+
self.assertEqualArray(expected[0], got[0])
316+
251317

252318
if __name__ == "__main__":
253319
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)