Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.5.0
+++++

* :pr:`96`: implements ``is_stealing``, ``steal_append`` to complement ``steal_forward``
* :pr:`95`: fix Scan implementation for ``OnnxruntimeEvaluator``
* :pr:`93`: introduce patched expression to get around annoying export issues
* :pr:`92`: support errors distribution in max_diff
* :pr:`91`: enable strings in ``guess_dynamic_shapes``
Expand Down
33 changes: 32 additions & 1 deletion _unittests/ut_helpers/test_torch_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
to_numpy,
is_torchdynamo_exporting,
model_statistics,
steal_append,
steal_forward,
replace_string_by_dynamic,
to_any,
Expand Down Expand Up @@ -145,6 +146,36 @@ def forward(self, x, y):
self.assertEqualAny(restored["main", 0, "O"], res1)
self.assertEqualAny(restored["main", 0, "O"], res2)

@hide_stdout()
def test_steal_forward_dump_file_steal_append(self):
class SubModel(torch.nn.Module):
def forward(self, x):
return x * x

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.s1 = SubModel()
self.s2 = SubModel()

def forward(self, x, y):
sx = self.s1(x)
steal_append("sx", sx)
return sx + self.s2(y)

inputs = torch.rand(3, 4), torch.rand(3, 4)
model = Model()
dump_file = self.get_dump_file("test_steal_forward_dump_file.onnx")
with steal_forward(model, dump_file=dump_file):
model(*inputs)
model(*inputs)
self.assertExists(dump_file)
restored = create_input_tensors_from_onnx_model(dump_file)
self.assertEqual(
{("", 1, "I"), ("", 1, "O"), "sx", ("", 0, "O"), "sx_1", ("", 0, "I")},
set(restored),
)

@hide_stdout()
def test_steal_forward_submodules(self):
class SubModel(torch.nn.Module):
Expand Down Expand Up @@ -173,7 +204,7 @@ def forward(self, x, y):
else:
print("output", k, v)
print(string_type(restored, with_shape=True))
l1, l2 = 151, 160
l1, l2 = 182, 191
self.assertEqual(
[
(f"-Model-{l2}", 0, "I"),
Expand Down
49 changes: 49 additions & 0 deletions onnx_diagnostic/helpers/torch_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,42 @@ def _forward_(
return res


_steal_forward_status = [False]
_additional_stolen_objects = {}


def is_stealing() -> bool:
"""Returns true if :func:`steal_forward` was yielded."""
return _steal_forward_status[0]


def steal_append(name: str, obj: Any):
"""
When outside a forward method, it is still possible to add
a python object which contains tensors and dump after the execution
of the model.

.. code-block:: python

steal_append("quantize", [t1, t2])

The same code can executed multiple times, then
the name can extended with a number.
"""
if is_stealing():
if name in _additional_stolen_objects:
i = 1
n = f"{name}_{i}"
while n in _additional_stolen_objects:
i += 1
n = f"{name}_{i}"
print(f"-- stolen {name!r} renamed in {n!r}: {string_type(obj, with_shape=True)}")
_additional_stolen_objects[n] = obj
else:
print(f"-- stolen {name!r}: {string_type(obj, with_shape=True)}")
_additional_stolen_objects[name] = obj


@contextlib.contextmanager
def steal_forward(
model: Union[
Expand Down Expand Up @@ -111,7 +147,14 @@ def forward(self, x, y):
print("input", k, args, kwargs)
else:
print("output", k, v)

Function :func:`steal_append` can be used to dump more tensors.
When inside the context, func:`is_stealing` returns True, False otherwise.
"""
assert not is_stealing(), "steal_forward was already called."
# We clear the cache.
_steal_forward_status[0] = True
_additional_stolen_objects.clear()
assert not submodules or isinstance(
model, torch.nn.Module
), f"submodules can only be True if model is a module but is is {type(model)}."
Expand Down Expand Up @@ -144,9 +187,15 @@ def forward(self, x, y):
try:
yield
finally:
_steal_forward_status[0] = False
for f in keep_model_forward.values():
f[0].forward = f[1]
if dump_file:
# Let's add the cached tensor
assert storage is not None, "storage cannot be None but mypy is confused here."
storage.update(_additional_stolen_objects)
# We clear the cache.
_additional_stolen_objects.clear()
proto = create_onnx_model_from_input_tensors(storage)
onnx.save(
proto,
Expand Down
Loading