Skip to content

Commit 3a64053

Browse files
committed
steal_append
1 parent 1a00215 commit 3a64053

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
to_numpy,
1111
is_torchdynamo_exporting,
1212
model_statistics,
13+
steal_append,
1314
steal_forward,
1415
replace_string_by_dynamic,
1516
to_any,
@@ -145,6 +146,36 @@ def forward(self, x, y):
145146
self.assertEqualAny(restored["main", 0, "O"], res1)
146147
self.assertEqualAny(restored["main", 0, "O"], res2)
147148

149+
@hide_stdout()
150+
def test_steal_forward_dump_file_steal_append(self):
151+
class SubModel(torch.nn.Module):
152+
def forward(self, x):
153+
return x * x
154+
155+
class Model(torch.nn.Module):
156+
def __init__(self):
157+
super().__init__()
158+
self.s1 = SubModel()
159+
self.s2 = SubModel()
160+
161+
def forward(self, x, y):
162+
sx = self.s1(x)
163+
steal_append("sx", sx)
164+
return sx + self.s2(y)
165+
166+
inputs = torch.rand(3, 4), torch.rand(3, 4)
167+
model = Model()
168+
dump_file = self.get_dump_file("test_steal_forward_dump_file.onnx")
169+
with steal_forward(model, dump_file=dump_file):
170+
model(*inputs)
171+
model(*inputs)
172+
self.assertExists(dump_file)
173+
restored = create_input_tensors_from_onnx_model(dump_file)
174+
self.assertEqual(
175+
{("", 1, "I"), ("", 1, "O"), "sx", ("", 0, "O"), "sx_1", ("", 0, "I")},
176+
set(restored),
177+
)
178+
148179
@hide_stdout()
149180
def test_steal_forward_submodules(self):
150181
class SubModel(torch.nn.Module):

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,35 @@ def _forward_(
4747
return res
4848

4949

50+
_additional_stolen_objects = {}
51+
52+
53+
def steal_append(name: str, obj: Any):
54+
"""
55+
When outside a forward method, it is still possible to add
56+
a python object which contains tensors and dump after the execution
57+
of the model.
58+
59+
.. code-block:: python
60+
61+
steal_append("quantize", [t1, t2])
62+
63+
The same code can executed multiple times, then
64+
the name can extended with a number.
65+
"""
66+
if name in _additional_stolen_objects:
67+
i = 1
68+
n = f"{name}_{i}"
69+
while n in _additional_stolen_objects:
70+
i += 1
71+
n = f"{name}_{i}"
72+
print(f"-- stolen {name!r} renamed in {n!r}: {string_type(obj, with_shape=True)}")
73+
_additional_stolen_objects[n] = obj
74+
else:
75+
print(f"-- stolen {name!r}: {string_type(obj, with_shape=True)}")
76+
_additional_stolen_objects[name] = obj
77+
78+
5079
@contextlib.contextmanager
5180
def steal_forward(
5281
model: Union[
@@ -111,7 +140,11 @@ def forward(self, x, y):
111140
print("input", k, args, kwargs)
112141
else:
113142
print("output", k, v)
143+
144+
Function :func:`steal_append` can be used to dump more tensors.
114145
"""
146+
# We clear the cache.
147+
_additional_stolen_objects.clear()
115148
assert not submodules or isinstance(
116149
model, torch.nn.Module
117150
), f"submodules can only be True if model is a module but is is {type(model)}."
@@ -147,6 +180,10 @@ def forward(self, x, y):
147180
for f in keep_model_forward.values():
148181
f[0].forward = f[1]
149182
if dump_file:
183+
# Let's add the cached tensor
184+
storage.update(_additional_stolen_objects)
185+
# We clear the cache.
186+
_additional_stolen_objects.clear()
150187
proto = create_onnx_model_from_input_tensors(storage)
151188
onnx.save(
152189
proto,

0 commit comments

Comments
 (0)