Skip to content

Commit 929ea50

Browse files
committed
fix stealing
1 parent 3a64053 commit 929ea50

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`96`: implements ``is_stealing``, ``steal_append`` to complement ``steal_forward``
8+
* :pr:`95`: fix Scan implementation for ``OnnxruntimeEvaluator``
79
* :pr:`93`: introduce patched expression to get around annoying export issues
810
* :pr:`92`: support errors distribution in max_diff
911
* :pr:`91`: enable strings in ``guess_dynamic_shapes``

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,15 @@ def _forward_(
4747
return res
4848

4949

50+
_steal_forward_status = [False]
5051
_additional_stolen_objects = {}
5152

5253

54+
def is_stealing() -> bool:
55+
"""Returns true if :func:`steal_forward` was yielded."""
56+
return _steal_forward_status[0]
57+
58+
5359
def steal_append(name: str, obj: Any):
5460
"""
5561
When outside a forward method, it is still possible to add
@@ -63,17 +69,18 @@ def steal_append(name: str, obj: Any):
6369
The same code can executed multiple times, then
6470
the name can extended with a number.
6571
"""
66-
if name in _additional_stolen_objects:
67-
i = 1
68-
n = f"{name}_{i}"
69-
while n in _additional_stolen_objects:
70-
i += 1
72+
if is_stealing():
73+
if name in _additional_stolen_objects:
74+
i = 1
7175
n = f"{name}_{i}"
72-
print(f"-- stolen {name!r} renamed in {n!r}: {string_type(obj, with_shape=True)}")
73-
_additional_stolen_objects[n] = obj
74-
else:
75-
print(f"-- stolen {name!r}: {string_type(obj, with_shape=True)}")
76-
_additional_stolen_objects[name] = obj
76+
while n in _additional_stolen_objects:
77+
i += 1
78+
n = f"{name}_{i}"
79+
print(f"-- stolen {name!r} renamed in {n!r}: {string_type(obj, with_shape=True)}")
80+
_additional_stolen_objects[n] = obj
81+
else:
82+
print(f"-- stolen {name!r}: {string_type(obj, with_shape=True)}")
83+
_additional_stolen_objects[name] = obj
7784

7885

7986
@contextlib.contextmanager
@@ -142,8 +149,11 @@ def forward(self, x, y):
142149
print("output", k, v)
143150
144151
Function :func:`steal_append` can be used to dump more tensors.
152+
When inside the context, func:`is_stealing` returns True, False otherwise.
145153
"""
154+
assert not is_stealing(), "steal_forward was already called."
146155
# We clear the cache.
156+
_steal_forward_status[0] = True
147157
_additional_stolen_objects.clear()
148158
assert not submodules or isinstance(
149159
model, torch.nn.Module
@@ -177,6 +187,7 @@ def forward(self, x, y):
177187
try:
178188
yield
179189
finally:
190+
_steal_forward_status[0] = False
180191
for f in keep_model_forward.values():
181192
f[0].forward = f[1]
182193
if dump_file:

0 commit comments

Comments
 (0)