@@ -47,6 +47,42 @@ def _forward_(
4747 return res
4848
4949
50+ _steal_forward_status = [False ]
51+ _additional_stolen_objects = {}
52+
53+
54+ def is_stealing () -> bool :
55+ """Returns true if :func:`steal_forward` was yielded."""
56+ return _steal_forward_status [0 ]
57+
58+
59+ def steal_append (name : str , obj : Any ):
60+ """
61+ When outside a forward method, it is still possible to add
62+ a python object which contains tensors and dump after the execution
63+ of the model.
64+
65+ .. code-block:: python
66+
67+ steal_append("quantize", [t1, t2])
68+
69+ The same code can executed multiple times, then
70+ the name can extended with a number.
71+ """
72+ if is_stealing ():
73+ if name in _additional_stolen_objects :
74+ i = 1
75+ n = f"{ name } _{ i } "
76+ while n in _additional_stolen_objects :
77+ i += 1
78+ n = f"{ name } _{ i } "
79+ print (f"-- stolen { name !r} renamed in { n !r} : { string_type (obj , with_shape = True )} " )
80+ _additional_stolen_objects [n ] = obj
81+ else :
82+ print (f"-- stolen { name !r} : { string_type (obj , with_shape = True )} " )
83+ _additional_stolen_objects [name ] = obj
84+
85+
5086@contextlib .contextmanager
5187def steal_forward (
5288 model : Union [
@@ -111,7 +147,14 @@ def forward(self, x, y):
111147 print("input", k, args, kwargs)
112148 else:
113149 print("output", k, v)
150+
151+ Function :func:`steal_append` can be used to dump more tensors.
152+ When inside the context, func:`is_stealing` returns True, False otherwise.
114153 """
154+ assert not is_stealing (), "steal_forward was already called."
155+ # We clear the cache.
156+ _steal_forward_status [0 ] = True
157+ _additional_stolen_objects .clear ()
115158 assert not submodules or isinstance (
116159 model , torch .nn .Module
117160 ), f"submodules can only be True if model is a module but is is { type (model )} ."
@@ -144,9 +187,15 @@ def forward(self, x, y):
144187 try :
145188 yield
146189 finally :
190+ _steal_forward_status [0 ] = False
147191 for f in keep_model_forward .values ():
148192 f [0 ].forward = f [1 ]
149193 if dump_file :
194+ # Let's add the cached tensor
195+ assert storage is not None , "storage cannot be None but mypy is confused here."
196+ storage .update (_additional_stolen_objects )
197+ # We clear the cache.
198+ _additional_stolen_objects .clear ()
150199 proto = create_onnx_model_from_input_tensors (storage )
151200 onnx .save (
152201 proto ,
0 commit comments