diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 42bc5036..cb22f19b 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.5.0 +++++ +* :pr:`96`: implements ``is_stealing``, ``steal_append`` to complement ``steal_forward`` +* :pr:`95`: fix Scan implementation for ``OnnxruntimeEvaluator`` * :pr:`93`: introduce patched expression to get around annoying export issues * :pr:`92`: support errors distribution in max_diff * :pr:`91`: enable strings in ``guess_dynamic_shapes`` diff --git a/_unittests/ut_helpers/test_torch_test_helper.py b/_unittests/ut_helpers/test_torch_test_helper.py index 6416188d..dc169d03 100644 --- a/_unittests/ut_helpers/test_torch_test_helper.py +++ b/_unittests/ut_helpers/test_torch_test_helper.py @@ -10,6 +10,7 @@ to_numpy, is_torchdynamo_exporting, model_statistics, + steal_append, steal_forward, replace_string_by_dynamic, to_any, @@ -145,6 +146,36 @@ def forward(self, x, y): self.assertEqualAny(restored["main", 0, "O"], res1) self.assertEqualAny(restored["main", 0, "O"], res2) + @hide_stdout() + def test_steal_forward_dump_file_steal_append(self): + class SubModel(torch.nn.Module): + def forward(self, x): + return x * x + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.s1 = SubModel() + self.s2 = SubModel() + + def forward(self, x, y): + sx = self.s1(x) + steal_append("sx", sx) + return sx + self.s2(y) + + inputs = torch.rand(3, 4), torch.rand(3, 4) + model = Model() + dump_file = self.get_dump_file("test_steal_forward_dump_file.onnx") + with steal_forward(model, dump_file=dump_file): + model(*inputs) + model(*inputs) + self.assertExists(dump_file) + restored = create_input_tensors_from_onnx_model(dump_file) + self.assertEqual( + {("", 1, "I"), ("", 1, "O"), "sx", ("", 0, "O"), "sx_1", ("", 0, "I")}, + set(restored), + ) + @hide_stdout() def test_steal_forward_submodules(self): class SubModel(torch.nn.Module): @@ -173,7 +204,7 @@ def forward(self, x, y): else: print("output", k, v) print(string_type(restored, with_shape=True)) - l1, l2 = 151, 160 + l1, l2 = 182, 191 self.assertEqual( [ (f"-Model-{l2}", 0, "I"), diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index 81e713b6..85b3cdec 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -47,6 +47,42 @@ def _forward_( return res +_steal_forward_status = [False] +_additional_stolen_objects = {} + + +def is_stealing() -> bool: + """Returns true if :func:`steal_forward` was yielded.""" + return _steal_forward_status[0] + + +def steal_append(name: str, obj: Any): + """ + When outside a forward method, it is still possible to add + a python object which contains tensors and dump after the execution + of the model. + + .. code-block:: python + + steal_append("quantize", [t1, t2]) + + The same code can executed multiple times, then + the name can extended with a number. + """ + if is_stealing(): + if name in _additional_stolen_objects: + i = 1 + n = f"{name}_{i}" + while n in _additional_stolen_objects: + i += 1 + n = f"{name}_{i}" + print(f"-- stolen {name!r} renamed in {n!r}: {string_type(obj, with_shape=True)}") + _additional_stolen_objects[n] = obj + else: + print(f"-- stolen {name!r}: {string_type(obj, with_shape=True)}") + _additional_stolen_objects[name] = obj + + @contextlib.contextmanager def steal_forward( model: Union[ @@ -111,7 +147,14 @@ def forward(self, x, y): print("input", k, args, kwargs) else: print("output", k, v) + + Function :func:`steal_append` can be used to dump more tensors. + When inside the context, func:`is_stealing` returns True, False otherwise. """ + assert not is_stealing(), "steal_forward was already called." + # We clear the cache. + _steal_forward_status[0] = True + _additional_stolen_objects.clear() assert not submodules or isinstance( model, torch.nn.Module ), f"submodules can only be True if model is a module but is is {type(model)}." @@ -144,9 +187,15 @@ def forward(self, x, y): try: yield finally: + _steal_forward_status[0] = False for f in keep_model_forward.values(): f[0].forward = f[1] if dump_file: + # Let's add the cached tensor + assert storage is not None, "storage cannot be None but mypy is confused here." + storage.update(_additional_stolen_objects) + # We clear the cache. + _additional_stolen_objects.clear() proto = create_onnx_model_from_input_tensors(storage) onnx.save( proto,