Skip to content

Commit a6cfc23

Browse files
committed
add garbe
1 parent 43b14f0 commit a6cfc23

File tree

2 files changed

+80
-6
lines changed

2 files changed

+80
-6
lines changed

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,23 @@ def test_constant_bool_input(self):
173173
self.assertEqual(got.dtype, torch.bool)
174174
self.assertEqual(got[0], True)
175175

176+
def test_ort_eval_loop(self):
177+
model = torch.nn.EmbeddingBag(num_embeddings=49157, embedding_dim=32, mode="sum")
178+
a = torch.tensor([[39906, 39906]]).long()
179+
example_args = (a,)
180+
model_eval = model.eval()
181+
expected = model(*example_args)
182+
183+
onx = to_onnx(model_eval, example_args, optimize=True)
184+
self.assertIn("Loop", set(n.op_type for n in onx.graph.node))
185+
186+
ref = OnnxruntimeEvaluator(onx, verbose=10)
187+
feeds = dict(
188+
zip([i.name for i in onx.graph.input], [t.detach().numpy() for t in example_args])
189+
)
190+
got = ref.run(None, feeds)
191+
self.assertEqualArray(expected, got[0])
192+
176193

177194
if __name__ == "__main__":
178195
unittest.main(verbosity=2)

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(
135135
)
136136
if local_functions:
137137
self.local_functions.update(local_functions)
138+
self.garbage_collector = self._build_garbage_collector() if self.rt_nodes_ else {}
138139

139140
@property
140141
def input_names(self) -> List[str]:
@@ -242,7 +243,7 @@ def run(
242243
self._log(2, " +I %s: %s", k, v)
243244
results[k] = v
244245

245-
for node in self.rt_nodes_ or []:
246+
for i_node, node in enumerate(self.rt_nodes_ or []):
246247
self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output)
247248
for i in node.input:
248249
if i != "" and i not in results:
@@ -266,6 +267,7 @@ def run(
266267
self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
267268
assert isinstance(name, str), f"unexpected type for name {type(name)}"
268269
results[name] = value
270+
self._clean_unused_inplace(i_node, node, results)
269271

270272
if intermediate:
271273
return results
@@ -280,6 +282,50 @@ def run(
280282
)
281283
return [results[name] for name in output_names if name != ""]
282284

285+
def _build_garbage_collector(self) -> Dict[str, int]:
286+
"""
287+
Memorizes the results not needed anymore for every node.
288+
Returns a dictionary with the last node using the results.
289+
"""
290+
needed = {}
291+
for i, node in enumerate(self.rt_nodes_):
292+
for name in node.input:
293+
needed[name] = i
294+
if node.op_type in {"Scan", "If", "Loop"}:
295+
hidden = self._get_hidden_node_inputs(node)
296+
for name in hidden:
297+
needed[name] = i
298+
if isinstance(self.proto, ModelProto):
299+
for o in self.proto.graph.output:
300+
needed[o.name] = len(self.rt_nodes_)
301+
elif isinstance(self.proto, GraphProto):
302+
for o in self.proto.output:
303+
needed[o.name] = len(self.rt_nodes_)
304+
elif isinstance(self.proto, FunctionProto):
305+
for o in self.proto.output:
306+
needed[o] = len(self.rt_nodes_)
307+
return needed
308+
309+
def _clean_unused_inplace(self, i_node: int, node: NodeProto, results: Dict[str, Any]):
310+
"""
311+
Cleans all results not needed anymore. Some models requires to clean the memory
312+
to be able to run.
313+
"""
314+
if not self.garbage_collector:
315+
return
316+
for name in node.input:
317+
if self.garbage_collector[name] == i_node:
318+
if self.verbose:
319+
print(f" - deletes: {name}")
320+
del results[name]
321+
if node.op_type in {"Scan", "If", "Loop"}:
322+
hidden = self._get_hidden_node_inputs(node)
323+
for name in hidden:
324+
if self.garbage_collector[name] == i_node and name in results:
325+
if self.verbose:
326+
print(f" - deletes: {name}")
327+
del results[name]
328+
283329
def _make_model_proto(
284330
self,
285331
nodes: Sequence[NodeProto],
@@ -304,6 +350,8 @@ def _make_model_proto(
304350
else:
305351
onx.opset_import.append(oh.make_opsetid("", onnx_opset_version()))
306352

353+
# That helps fixing bugs.
354+
onx = shi.infer_shapes(onx)
307355
return onx
308356

309357
@classmethod
@@ -327,6 +375,17 @@ def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
327375
memo |= set(node.output)
328376
return hidden
329377

378+
@classmethod
379+
def _get_hidden_node_inputs(self, node: NodeProto) -> Set[str]:
380+
"""Calls multiple _get_hidden_inputs on every attribute."""
381+
if node.op_type not in {"Loop", "Scan", "If"}:
382+
return set()
383+
hidden = set()
384+
for att in node.attribute:
385+
if att.type == AttributeProto.GRAPH:
386+
hidden |= self._get_hidden_inputs(att.g)
387+
return hidden - (hidden & set(node.input))
388+
330389
def _get_sess(
331390
self, node: Union[ModelProto, NodeProto], inputs: List[Any]
332391
) -> Tuple[ModelProto, _InferenceSession]:
@@ -360,8 +419,6 @@ def _get_sess(
360419
voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output]
361420

362421
onx = self._make_model_proto([node], vinputs, voutputs)
363-
# That helps fixing bugs.
364-
onx = shi.infer_shapes(onx)
365422

366423
cls = (
367424
InferenceSessionForNumpy
@@ -496,11 +553,11 @@ def _get_sess_scan(
496553
assert g, f"Missing attribute {branch!r}"
497554
vinputs = self._get_sess_init_subgraph(node, inputs, context, g)
498555

556+
begin = 0 if node.op_type == "Scan" else 1
499557
voutputs = []
500-
for name, goutput in zip(node.output, g.output):
501-
b = goutput.SerializeToString()
558+
for name, _goutput in zip(node.output, g.output[begin:]):
502559
v = ValueInfoProto()
503-
v.ParseFromString(b)
560+
# v.ParseFromString(goutput.SerializeToString())
504561
v.name = name
505562
voutputs.append(v)
506563

0 commit comments

Comments
 (0)