Skip to content

Commit 08b1cdf

Browse files
authored
Implements steal_append to dump additional tensors (#96)
* steal_append * fix stealing * fix mypy
1 parent 1a00215 commit 08b1cdf

File tree

3 files changed

+83
-1
lines changed

3 files changed

+83
-1
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`96`: implements ``is_stealing``, ``steal_append`` to complement ``steal_forward``
8+
* :pr:`95`: fix Scan implementation for ``OnnxruntimeEvaluator``
79
* :pr:`93`: introduce patched expression to get around annoying export issues
810
* :pr:`92`: support errors distribution in max_diff
911
* :pr:`91`: enable strings in ``guess_dynamic_shapes``

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
to_numpy,
1111
is_torchdynamo_exporting,
1212
model_statistics,
13+
steal_append,
1314
steal_forward,
1415
replace_string_by_dynamic,
1516
to_any,
@@ -145,6 +146,36 @@ def forward(self, x, y):
145146
self.assertEqualAny(restored["main", 0, "O"], res1)
146147
self.assertEqualAny(restored["main", 0, "O"], res2)
147148

149+
@hide_stdout()
150+
def test_steal_forward_dump_file_steal_append(self):
151+
class SubModel(torch.nn.Module):
152+
def forward(self, x):
153+
return x * x
154+
155+
class Model(torch.nn.Module):
156+
def __init__(self):
157+
super().__init__()
158+
self.s1 = SubModel()
159+
self.s2 = SubModel()
160+
161+
def forward(self, x, y):
162+
sx = self.s1(x)
163+
steal_append("sx", sx)
164+
return sx + self.s2(y)
165+
166+
inputs = torch.rand(3, 4), torch.rand(3, 4)
167+
model = Model()
168+
dump_file = self.get_dump_file("test_steal_forward_dump_file.onnx")
169+
with steal_forward(model, dump_file=dump_file):
170+
model(*inputs)
171+
model(*inputs)
172+
self.assertExists(dump_file)
173+
restored = create_input_tensors_from_onnx_model(dump_file)
174+
self.assertEqual(
175+
{("", 1, "I"), ("", 1, "O"), "sx", ("", 0, "O"), "sx_1", ("", 0, "I")},
176+
set(restored),
177+
)
178+
148179
@hide_stdout()
149180
def test_steal_forward_submodules(self):
150181
class SubModel(torch.nn.Module):
@@ -173,7 +204,7 @@ def forward(self, x, y):
173204
else:
174205
print("output", k, v)
175206
print(string_type(restored, with_shape=True))
176-
l1, l2 = 151, 160
207+
l1, l2 = 182, 191
177208
self.assertEqual(
178209
[
179210
(f"-Model-{l2}", 0, "I"),

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,42 @@ def _forward_(
4747
return res
4848

4949

50+
_steal_forward_status = [False]
51+
_additional_stolen_objects = {}
52+
53+
54+
def is_stealing() -> bool:
55+
"""Returns true if :func:`steal_forward` was yielded."""
56+
return _steal_forward_status[0]
57+
58+
59+
def steal_append(name: str, obj: Any):
60+
"""
61+
When outside a forward method, it is still possible to add
62+
a python object which contains tensors and dump after the execution
63+
of the model.
64+
65+
.. code-block:: python
66+
67+
steal_append("quantize", [t1, t2])
68+
69+
The same code can executed multiple times, then
70+
the name can extended with a number.
71+
"""
72+
if is_stealing():
73+
if name in _additional_stolen_objects:
74+
i = 1
75+
n = f"{name}_{i}"
76+
while n in _additional_stolen_objects:
77+
i += 1
78+
n = f"{name}_{i}"
79+
print(f"-- stolen {name!r} renamed in {n!r}: {string_type(obj, with_shape=True)}")
80+
_additional_stolen_objects[n] = obj
81+
else:
82+
print(f"-- stolen {name!r}: {string_type(obj, with_shape=True)}")
83+
_additional_stolen_objects[name] = obj
84+
85+
5086
@contextlib.contextmanager
5187
def steal_forward(
5288
model: Union[
@@ -111,7 +147,14 @@ def forward(self, x, y):
111147
print("input", k, args, kwargs)
112148
else:
113149
print("output", k, v)
150+
151+
Function :func:`steal_append` can be used to dump more tensors.
152+
When inside the context, func:`is_stealing` returns True, False otherwise.
114153
"""
154+
assert not is_stealing(), "steal_forward was already called."
155+
# We clear the cache.
156+
_steal_forward_status[0] = True
157+
_additional_stolen_objects.clear()
115158
assert not submodules or isinstance(
116159
model, torch.nn.Module
117160
), f"submodules can only be True if model is a module but is is {type(model)}."
@@ -144,9 +187,15 @@ def forward(self, x, y):
144187
try:
145188
yield
146189
finally:
190+
_steal_forward_status[0] = False
147191
for f in keep_model_forward.values():
148192
f[0].forward = f[1]
149193
if dump_file:
194+
# Let's add the cached tensor
195+
assert storage is not None, "storage cannot be None but mypy is confused here."
196+
storage.update(_additional_stolen_objects)
197+
# We clear the cache.
198+
_additional_stolen_objects.clear()
150199
proto = create_onnx_model_from_input_tensors(storage)
151200
onnx.save(
152201
proto,

0 commit comments

Comments
 (0)