Skip to content

Commit c130899

Browse files
committed
Fix empty outputs for OnnxruntimeEvaluator
1 parent 5d397f5 commit c130899

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,31 @@ def test_skip_layer_normalization(self):
259259
got = rt.run(None, feeds)
260260
self.assertEqualAny(expected, got, atol=1e-4)
261261

262+
@hide_stdout()
263+
def test_skip_simplified_layer_normalization(self):
264+
node = oh.make_node(
265+
"SkipSimplifiedLayerNormalization",
266+
["x", "skip", "beta", "gamma"],
267+
["Z", "", "", "bias"],
268+
epsilon=1.0e-5,
269+
domain="com.microsoft",
270+
)
271+
feeds = dict(
272+
x=self._range(2, 3, 8),
273+
skip=self._range(2, 3, 8, bias=3),
274+
beta=self._range(8, bias=1),
275+
gamma=self._range(8, bias=2),
276+
)
277+
rt = OnnxruntimeEvaluator(node, verbose=10, opsets={"": 22})
278+
got = rt.run(None, feeds)
279+
self.assertEqual(len(got), 2)
280+
self.assertIsInstance(got[0], np.ndarray)
281+
self.assertIsInstance(got[1], np.ndarray)
282+
self.assertEqual(got[0].shape, feeds["x"].shape)
283+
self.assertEqual(got[0].dtype, feeds["x"].dtype)
284+
self.assertEqual(got[1].shape, feeds["x"].shape)
285+
self.assertEqual(got[1].dtype, feeds["x"].dtype)
286+
262287

263288
if __name__ == "__main__":
264289
unittest.main(verbosity=2)

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,11 @@ def run(
278278
outputs = self._run_local(node, inputs, results)
279279
else:
280280
outputs = self._run(node, inputs, results)
281-
for name, value in zip(node.output, outputs):
282-
if name == "":
283-
continue
281+
node_output = [o for o in node.output if o]
282+
assert len(node_output) == len(
283+
outputs
284+
), f"Length mismatch between node output={node.output} and outputs={outputs}"
285+
for name, value in zip(node_output, outputs):
284286
self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
285287
assert isinstance(name, str), f"unexpected type for name {type(name)}"
286288
results[name] = value
@@ -384,6 +386,11 @@ def _make_model_proto(
384386
onx = shi.infer_shapes(onx)
385387
return onx
386388

389+
def _make_model_outputs(
390+
self, node: NodeProto, inputs: List[ValueInfoProto]
391+
) -> Tuple[List[NodeProto], List[ValueInfoProto]]:
392+
return [], [oh.make_value_info(o, TypeProto()) for o in node.output if o]
393+
387394
@classmethod
388395
def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
389396
"""
@@ -424,6 +431,7 @@ def _get_sess(
424431
onx = node
425432
else:
426433
assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node"
434+
prenodes = []
427435
if node.op_type == "Constant":
428436
# We force the type to be a boolean.
429437
ref = ExtendedReferenceEvaluator(node)
@@ -447,9 +455,9 @@ def _get_sess(
447455
vinputs.append(value)
448456

449457
# no need to run shape inference
450-
voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output]
458+
prenodes, voutputs = self._make_model_outputs(node, vinputs)
451459

452-
onx = self._make_model_proto([node], vinputs, voutputs)
460+
onx = self._make_model_proto([*prenodes, node], vinputs, voutputs)
453461
if node.op_type in {"Shape", "Size"}:
454462
on_cpu = True
455463

0 commit comments

Comments
 (0)