@@ -47,6 +47,35 @@ def _forward_(
4747 return res
4848
4949
50+ _additional_stolen_objects = {}
51+
52+
53+ def steal_append (name : str , obj : Any ):
54+ """
55+ When outside a forward method, it is still possible to add
56+ a python object which contains tensors and dump after the execution
57+ of the model.
58+
59+ .. code-block:: python
60+
61+ steal_append("quantize", [t1, t2])
62+
63+ The same code can executed multiple times, then
64+ the name can extended with a number.
65+ """
66+ if name in _additional_stolen_objects :
67+ i = 1
68+ n = f"{ name } _{ i } "
69+ while n in _additional_stolen_objects :
70+ i += 1
71+ n = f"{ name } _{ i } "
72+ print (f"-- stolen { name !r} renamed in { n !r} : { string_type (obj , with_shape = True )} " )
73+ _additional_stolen_objects [n ] = obj
74+ else :
75+ print (f"-- stolen { name !r} : { string_type (obj , with_shape = True )} " )
76+ _additional_stolen_objects [name ] = obj
77+
78+
5079@contextlib .contextmanager
5180def steal_forward (
5281 model : Union [
@@ -111,7 +140,11 @@ def forward(self, x, y):
111140 print("input", k, args, kwargs)
112141 else:
113142 print("output", k, v)
143+
144+ Function :func:`steal_append` can be used to dump more tensors.
114145 """
146+ # We clear the cache.
147+ _additional_stolen_objects .clear ()
115148 assert not submodules or isinstance (
116149 model , torch .nn .Module
117150 ), f"submodules can only be True if model is a module but is is { type (model )} ."
@@ -147,6 +180,10 @@ def forward(self, x, y):
147180 for f in keep_model_forward .values ():
148181 f [0 ].forward = f [1 ]
149182 if dump_file :
183+ # Let's add the cached tensor
184+ storage .update (_additional_stolen_objects )
185+ # We clear the cache.
186+ _additional_stolen_objects .clear ()
150187 proto = create_onnx_model_from_input_tensors (storage )
151188 onnx .save (
152189 proto ,
0 commit comments