@@ -135,6 +135,7 @@ def __init__(
135135 )
136136 if local_functions :
137137 self .local_functions .update (local_functions )
138+ self .garbage_collector = self ._build_garbage_collector () if self .rt_nodes_ else {}
138139
139140 @property
140141 def input_names (self ) -> List [str ]:
@@ -242,7 +243,7 @@ def run(
242243 self ._log (2 , " +I %s: %s" , k , v )
243244 results [k ] = v
244245
245- for node in self .rt_nodes_ or []:
246+ for i_node , node in enumerate ( self .rt_nodes_ or []) :
246247 self ._log (1 , "%s(%s) -> %s" , node .op_type , node .input , node .output )
247248 for i in node .input :
248249 if i != "" and i not in results :
@@ -266,6 +267,7 @@ def run(
266267 self ._log (2 , " + %s: %s" , name , value ) # type: ignore[arg-type]
267268 assert isinstance (name , str ), f"unexpected type for name { type (name )} "
268269 results [name ] = value
270+ self ._clean_unused_inplace (i_node , node , results )
269271
270272 if intermediate :
271273 return results
@@ -280,6 +282,50 @@ def run(
280282 )
281283 return [results [name ] for name in output_names if name != "" ]
282284
285+ def _build_garbage_collector (self ) -> Dict [str , int ]:
286+ """
287+ Memorizes the results not needed anymore for every node.
288+ Returns a dictionary with the last node using the results.
289+ """
290+ needed = {}
291+ for i , node in enumerate (self .rt_nodes_ ):
292+ for name in node .input :
293+ needed [name ] = i
294+ if node .op_type in {"Scan" , "If" , "Loop" }:
295+ hidden = self ._get_hidden_node_inputs (node )
296+ for name in hidden :
297+ needed [name ] = i
298+ if isinstance (self .proto , ModelProto ):
299+ for o in self .proto .graph .output :
300+ needed [o .name ] = len (self .rt_nodes_ )
301+ elif isinstance (self .proto , GraphProto ):
302+ for o in self .proto .output :
303+ needed [o .name ] = len (self .rt_nodes_ )
304+ elif isinstance (self .proto , FunctionProto ):
305+ for o in self .proto .output :
306+ needed [o ] = len (self .rt_nodes_ )
307+ return needed
308+
309+ def _clean_unused_inplace (self , i_node : int , node : NodeProto , results : Dict [str , Any ]):
310+ """
311+ Cleans all results not needed anymore. Some models requires to clean the memory
312+ to be able to run.
313+ """
314+ if not self .garbage_collector :
315+ return
316+ for name in node .input :
317+ if self .garbage_collector [name ] == i_node :
318+ if self .verbose :
319+ print (f" - deletes: { name } " )
320+ del results [name ]
321+ if node .op_type in {"Scan" , "If" , "Loop" }:
322+ hidden = self ._get_hidden_node_inputs (node )
323+ for name in hidden :
324+ if self .garbage_collector [name ] == i_node and name in results :
325+ if self .verbose :
326+ print (f" - deletes: { name } " )
327+ del results [name ]
328+
283329 def _make_model_proto (
284330 self ,
285331 nodes : Sequence [NodeProto ],
@@ -304,6 +350,8 @@ def _make_model_proto(
304350 else :
305351 onx .opset_import .append (oh .make_opsetid ("" , onnx_opset_version ()))
306352
353+ # That helps fixing bugs.
354+ onx = shi .infer_shapes (onx )
307355 return onx
308356
309357 @classmethod
@@ -327,6 +375,17 @@ def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
327375 memo |= set (node .output )
328376 return hidden
329377
378+ @classmethod
379+ def _get_hidden_node_inputs (self , node : NodeProto ) -> Set [str ]:
380+ """Calls multiple _get_hidden_inputs on every attribute."""
381+ if node .op_type not in {"Loop" , "Scan" , "If" }:
382+ return set ()
383+ hidden = set ()
384+ for att in node .attribute :
385+ if att .type == AttributeProto .GRAPH :
386+ hidden |= self ._get_hidden_inputs (att .g )
387+ return hidden - (hidden & set (node .input ))
388+
330389 def _get_sess (
331390 self , node : Union [ModelProto , NodeProto ], inputs : List [Any ]
332391 ) -> Tuple [ModelProto , _InferenceSession ]:
@@ -360,8 +419,6 @@ def _get_sess(
360419 voutputs = [oh .make_value_info (o , TypeProto ()) for o in node .output ]
361420
362421 onx = self ._make_model_proto ([node ], vinputs , voutputs )
363- # That helps fixing bugs.
364- onx = shi .infer_shapes (onx )
365422
366423 cls = (
367424 InferenceSessionForNumpy
@@ -496,11 +553,11 @@ def _get_sess_scan(
496553 assert g , f"Missing attribute { branch !r} "
497554 vinputs = self ._get_sess_init_subgraph (node , inputs , context , g )
498555
556+ begin = 0 if node .op_type == "Scan" else 1
499557 voutputs = []
500- for name , goutput in zip (node .output , g .output ):
501- b = goutput .SerializeToString ()
558+ for name , _goutput in zip (node .output , g .output [begin :]):
502559 v = ValueInfoProto ()
503- v .ParseFromString (b )
560+ # v.ParseFromString(goutput.SerializeToString() )
504561 v .name = name
505562 voutputs .append (v )
506563
0 commit comments