Skip to content

Commit 10608c0

Browse files
committed
add shadowing
1 parent 6705737 commit 10608c0

File tree

2 files changed

+120
-1
lines changed

2 files changed

+120
-1
lines changed

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from_array_extended,
1818
tensor_statistics,
1919
enumerate_results,
20+
shadowing_names,
2021
)
2122

2223

@@ -400,7 +401,64 @@ def test_enumerate_results_loop(self):
400401
)
401402
)
402403
res = list(enumerate_results(model, "slice_start", verbose=2))
403-
print(res)
404+
self.assertEqual(len(res), 2)
405+
406+
def test_shadowing_names(self):
407+
def _mkv_(name):
408+
value_info_proto = ValueInfoProto()
409+
value_info_proto.name = name
410+
return value_info_proto
411+
412+
model = oh.make_model(
413+
oh.make_graph(
414+
[
415+
oh.make_node("ReduceSum", ["X"], ["Xred"]),
416+
oh.make_node("Add", ["X", "two"], ["X0"]),
417+
oh.make_node("Add", ["X0", "zero"], ["X00"]),
418+
oh.make_node("CastLike", ["one", "Xred"], ["one_c"]),
419+
oh.make_node("Greater", ["Xred", "one_c"], ["cond"]),
420+
oh.make_node("Identity", ["two"], ["three"]),
421+
oh.make_node(
422+
"If",
423+
["cond"],
424+
["Z_c"],
425+
then_branch=oh.make_graph(
426+
[
427+
# shadowing
428+
oh.make_node("Constant", [], ["three"], value_floats=[2.1]),
429+
oh.make_node("Add", ["X00", "three"], ["Y"]),
430+
],
431+
"then",
432+
[],
433+
[_mkv_("Y")],
434+
),
435+
else_branch=oh.make_graph(
436+
[
437+
# not shadowing
438+
oh.make_node("Sub", ["X0", "three"], ["Y"]),
439+
],
440+
"else",
441+
[],
442+
[_mkv_("Y")],
443+
),
444+
),
445+
oh.make_node("CastLike", ["Z_c", "X"], ["Z"]),
446+
],
447+
"test",
448+
[
449+
oh.make_tensor_value_info("X", TensorProto.FLOAT, ["N"]),
450+
oh.make_tensor_value_info("one", TensorProto.FLOAT, ["N"]),
451+
],
452+
[oh.make_tensor_value_info("Z", TensorProto.UNDEFINED, ["N"])],
453+
[
454+
onh.from_array(np.array([0], dtype=np.float32), name="zero"),
455+
onh.from_array(np.array([2], dtype=np.float32), name="two"),
456+
],
457+
),
458+
opset_imports=[oh.make_operatorsetid("", 18)],
459+
ir_version=10,
460+
)
461+
self.assertEqual({"three"}, shadowing_names(model))
404462

405463

406464
if __name__ == "__main__":

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,3 +1119,64 @@ def enumerate_results(
11191119
yield r
11201120
if verbose:
11211121
print(f"[enumerate_results] {indent}done")
1122+
1123+
1124+
def shadowing_names(
1125+
proto: Union[FunctionProto, GraphProto, ModelProto, Sequence[NodeProto]],
1126+
verbose: int = 0,
1127+
existing: Optional[Set[str]] = None,
1128+
shadow_context: Optional[Set[str]] = None,
1129+
) -> Set[str]:
1130+
"""Returns the shadowing names."""
1131+
if isinstance(proto, ModelProto):
1132+
return shadowing_names(proto.graph)
1133+
if isinstance(proto, GraphProto):
1134+
assert (
1135+
existing is None and shadow_context is None
1136+
), "existing must be None if nodes is None"
1137+
return shadowing_names(
1138+
proto.node,
1139+
verbose=verbose,
1140+
existing=set(i.name for i in proto.initializer)
1141+
| set(i.name for i in proto.sparse_initializer)
1142+
| set(i.name for i in proto.input if i.name),
1143+
shadow_context=set(),
1144+
)
1145+
if isinstance(proto, FunctionProto):
1146+
assert (
1147+
existing is None and shadow_context is None
1148+
), "existing must be None if nodes is None"
1149+
return shadowing_names(
1150+
proto.node,
1151+
verbose=verbose,
1152+
existing=set(i for i in proto.input if i),
1153+
shadow_context=set(),
1154+
)
1155+
1156+
assert (
1157+
existing is not None and shadow_context is not None
1158+
), "existing must not be None if nodes is not None"
1159+
shadow = set()
1160+
shadow_context = shadow_context.copy()
1161+
existing = existing.copy()
1162+
for node in proto:
1163+
not_empty = set(n for n in node.input if n)
1164+
intersection = not_empty & existing
1165+
assert len(intersection) == len(not_empty), (
1166+
f"One input in {not_empty}, node={pretty_onnx(node)} "
1167+
f"was not found in {existing}"
1168+
)
1169+
for att in node.attribute:
1170+
if att.type == AttributeProto.GRAPH:
1171+
g = att.g
1172+
shadow |= set(i.name for i in g.input) & shadow_context
1173+
shadow |= set(i.name for i in g.initializer) & shadow_context
1174+
shadow |= set(i.name for i in g.sparse_initializer) & shadow_context
1175+
shadow |= shadowing_names(
1176+
g.node, verbose=verbose, existing=existing, shadow_context=existing
1177+
)
1178+
1179+
not_empty = set(n for n in node.output if n)
1180+
shadow |= not_empty & shadow_context
1181+
existing |= not_empty
1182+
return shadow

0 commit comments

Comments
 (0)