Skip to content

Commit e7e63b1

Browse files
committed
post-shadow
1 parent bb4aa54 commit e7e63b1

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,14 @@ def _mkv_(name):
458458
opset_imports=[oh.make_operatorsetid("", 18)],
459459
ir_version=10,
460460
)
461-
self.assertEqual({"three"}, shadowing_names(model))
461+
self.assertEqual(
462+
(
463+
{"three"},
464+
set(),
465+
{"cond", "Z", "X0", "Z_c", "three", "one_c", "Xred", "X00", "Y"},
466+
),
467+
shadowing_names(model),
468+
)
462469

463470

464471
if __name__ == "__main__":

onnx_diagnostic/_command_lines_parser.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ def _cmd_find(argv: List[Any]):
218218
args = parser.parse_args(argv[1:])
219219
if args.names == "SHADOW":
220220
onx = onnx.load(args.input, load_external_data=False)
221-
print(f"shadowing names: {shadowing_names(onx)}")
221+
s, ps = shadowing_names(onx)[:2]
222+
print(f"shadowing names: {s}")
223+
print(f"post-shadowing names: {ps}")
222224
elif args.v2:
223225
onx = onnx.load(args.input, load_external_data=False)
224226
res = list(

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,8 +1126,12 @@ def shadowing_names(
11261126
verbose: int = 0,
11271127
existing: Optional[Set[str]] = None,
11281128
shadow_context: Optional[Set[str]] = None,
1129-
) -> Set[str]:
1130-
"""Returns the shadowing names."""
1129+
post_shadow_context: Optional[Set[str]] = None,
1130+
) -> Tuple[Set[str], Set[str], Set[str]]:
1131+
"""
1132+
Returns the shadowing names, the names created in the main graph
1133+
after they were created in a subgraphs and the names created by the nodes.
1134+
"""
11311135
if isinstance(proto, ModelProto):
11321136
return shadowing_names(proto.graph)
11331137
if isinstance(proto, GraphProto):
@@ -1141,6 +1145,7 @@ def shadowing_names(
11411145
| set(i.name for i in proto.sparse_initializer)
11421146
| set(i.name for i in proto.input if i.name),
11431147
shadow_context=set(),
1148+
post_shadow_context=set(),
11441149
)
11451150
if isinstance(proto, FunctionProto):
11461151
assert (
@@ -1151,6 +1156,7 @@ def shadowing_names(
11511156
verbose=verbose,
11521157
existing=set(i for i in proto.input if i),
11531158
shadow_context=set(),
1159+
post_shadow_context=set(),
11541160
)
11551161

11561162
assert (
@@ -1159,6 +1165,8 @@ def shadowing_names(
11591165
shadow = set()
11601166
shadow_context = shadow_context.copy()
11611167
existing = existing.copy()
1168+
created = set()
1169+
post_shadow = set()
11621170
for node in proto:
11631171
not_empty = set(n for n in node.input if n)
11641172
intersection = not_empty & existing
@@ -1172,11 +1180,15 @@ def shadowing_names(
11721180
shadow |= set(i.name for i in g.input) & shadow_context
11731181
shadow |= set(i.name for i in g.initializer) & shadow_context
11741182
shadow |= set(i.name for i in g.sparse_initializer) & shadow_context
1175-
shadow |= shadowing_names(
1183+
s, ps, c = shadowing_names(
11761184
g.node, verbose=verbose, existing=existing, shadow_context=existing
11771185
)
1186+
shadow |= s
1187+
created |= c
11781188

11791189
not_empty = set(n for n in node.output if n)
1190+
post_shadow |= not_empty & created
11801191
shadow |= not_empty & shadow_context
11811192
existing |= not_empty
1182-
return shadow
1193+
created |= not_empty
1194+
return shadow, post_shadow, created

0 commit comments

Comments
 (0)