@@ -47,9 +47,15 @@ def _forward_(
4747 return res
4848
4949
50+ _steal_forward_status = [False ]
5051_additional_stolen_objects = {}
5152
5253
54+ def is_stealing () -> bool :
55+ """Returns true if :func:`steal_forward` was yielded."""
56+ return _steal_forward_status [0 ]
57+
58+
5359def steal_append (name : str , obj : Any ):
5460 """
5561 When outside a forward method, it is still possible to add
@@ -63,17 +69,18 @@ def steal_append(name: str, obj: Any):
6369 The same code can executed multiple times, then
6470 the name can extended with a number.
6571 """
66- if name in _additional_stolen_objects :
67- i = 1
68- n = f"{ name } _{ i } "
69- while n in _additional_stolen_objects :
70- i += 1
72+ if is_stealing ():
73+ if name in _additional_stolen_objects :
74+ i = 1
7175 n = f"{ name } _{ i } "
72- print (f"-- stolen { name !r} renamed in { n !r} : { string_type (obj , with_shape = True )} " )
73- _additional_stolen_objects [n ] = obj
74- else :
75- print (f"-- stolen { name !r} : { string_type (obj , with_shape = True )} " )
76- _additional_stolen_objects [name ] = obj
76+ while n in _additional_stolen_objects :
77+ i += 1
78+ n = f"{ name } _{ i } "
79+ print (f"-- stolen { name !r} renamed in { n !r} : { string_type (obj , with_shape = True )} " )
80+ _additional_stolen_objects [n ] = obj
81+ else :
82+ print (f"-- stolen { name !r} : { string_type (obj , with_shape = True )} " )
83+ _additional_stolen_objects [name ] = obj
7784
7885
7986@contextlib .contextmanager
@@ -142,8 +149,11 @@ def forward(self, x, y):
142149 print("output", k, v)
143150
144151 Function :func:`steal_append` can be used to dump more tensors.
152+ When inside the context, func:`is_stealing` returns True, False otherwise.
145153 """
154+ assert not is_stealing (), "steal_forward was already called."
146155 # We clear the cache.
156+ _steal_forward_status [0 ] = True
147157 _additional_stolen_objects .clear ()
148158 assert not submodules or isinstance (
149159 model , torch .nn .Module
@@ -177,6 +187,7 @@ def forward(self, x, y):
177187 try :
178188 yield
179189 finally :
190+ _steal_forward_status [0 ] = False
180191 for f in keep_model_forward .values ():
181192 f [0 ].forward = f [1 ]
182193 if dump_file :
0 commit comments