Skip to content

Commit 85960d6

Browse files
committed
dump stolen
1 parent ab387bf commit 85960d6

File tree

3 files changed

+118
-17
lines changed

3 files changed

+118
-17
lines changed

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
make_mamba_cache,
2121
make_sliding_window_cache,
2222
)
23+
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
2324

2425
TFLOAT = onnx.TensorProto.FLOAT
2526

@@ -88,6 +89,61 @@ def forward(self, x, y):
8889
):
8990
model(*inputs)
9091

92+
@hide_stdout()
93+
def test_steal_forward_dump_file(self):
94+
class SubModel(torch.nn.Module):
95+
def forward(self, x):
96+
return x * x
97+
98+
class Model(torch.nn.Module):
99+
def __init__(self):
100+
super().__init__()
101+
self.s1 = SubModel()
102+
self.s2 = SubModel()
103+
104+
def forward(self, x, y):
105+
return self.s1(x) + self.s2(y)
106+
107+
inputs = torch.rand(3, 4), torch.rand(3, 4)
108+
model = Model()
109+
dump_file = self.get_dump_file("test_steal_forward_dump_file.onnx")
110+
with steal_forward(
111+
[
112+
(
113+
"main",
114+
model,
115+
),
116+
(" s1", model.s1),
117+
(" s2", model.s2),
118+
],
119+
dump_file=dump_file,
120+
):
121+
res1 = model(*inputs)
122+
res2 = model(*inputs)
123+
self.assertExists(dump_file)
124+
restored = create_input_tensors_from_onnx_model(dump_file)
125+
self.assertEqual(
126+
[
127+
("main", 0, "I"),
128+
("main", 0, "O"),
129+
("main", 1, "I"),
130+
("main", 1, "O"),
131+
("s1", 0, "I"),
132+
("s1", 0, "O"),
133+
("s1", 1, "I"),
134+
("s1", 1, "O"),
135+
("s2", 0, "I"),
136+
("s2", 0, "O"),
137+
("s2", 1, "I"),
138+
("s2", 1, "O"),
139+
],
140+
sorted(restored),
141+
)
142+
self.assertEqualAny(restored["main", 0, "I"], (inputs, {}))
143+
self.assertEqualAny(restored["main", 1, "I"], (inputs, {}))
144+
self.assertEqualAny(restored["main", 0, "O"], res1)
145+
self.assertEqualAny(restored["main", 0, "O"], res2)
146+
91147
def test_replace_string_by_dynamic(self):
92148
example = {
93149
"input_ids": {0: "batch_size", 1: "sequence_length"},

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
44
import numpy as np
5-
from onnx import GraphProto, ModelProto, TensorProto
5+
from onnx import GraphProto, ModelProto, NodeProto, TensorProto
66
import onnx.helper as oh
77
import torch
88
from .onnx_helper import dtype_to_tensor_dtype, tensor_dtype_to_np_dtype, from_array_extended
@@ -34,10 +34,7 @@ def proto_from_array(
3434
)
3535

3636
# arr.contiguous() is slow after a transpose, maybe there is a way to optimize this.
37-
if arr.is_contiguous():
38-
arr_cpu = arr.cpu()
39-
else:
40-
arr_cpu = arr.contiguous().cpu()
37+
arr_cpu = arr.cpu() if arr.is_contiguous() else arr.contiguous().cpu()
4138

4239
numel = torch.numel(arr_cpu)
4340
element_size = arr_cpu.element_size()
@@ -91,10 +88,10 @@ class MiniOnnxBuilder:
9188
"""
9289

9390
def __init__(self, target_opset: int = 18, ir_version: int = 10, sep: str = "___"):
94-
self.initializers_dict = {}
95-
self.inputs = []
96-
self.outputs = []
97-
self.nodes = []
91+
self.initializers_dict: Dict[str, Any] = {}
92+
self.inputs: List[Any] = []
93+
self.outputs: List[Any] = []
94+
self.nodes: List[NodeProto] = []
9895
self.opsets = {"": target_opset}
9996
self.ir_version = ir_version
10097
self.torch = torch
@@ -270,7 +267,7 @@ def _build_initializers(
270267

271268
return initializer
272269

273-
res = []
270+
res: List[TensorProto] = []
274271
for k, v in init_dict.items():
275272
if isinstance(v, TensorProto):
276273
res.append(v)
@@ -354,12 +351,19 @@ def _flatten_iterator(obj: Any, sep: str) -> Iterator:
354351
f"Key {k!r} cannot contain '{sep}'. "
355352
f"It would interfere with the serialization."
356353
)
354+
355+
def _mk(k):
356+
if isinstance(k, tuple):
357+
# this assumes the tuple contains simple types
358+
return f"(({','.join(map(str,k))}))"
359+
return str(k)
360+
357361
if i == len(obj) - 1:
358362
for p, o in _flatten_iterator(v, sep):
359-
yield f"dict._{k}{sep}{p}", o
363+
yield f"dict._{_mk(k)}{sep}{p}", o
360364
else:
361365
for p, o in _flatten_iterator(v, sep):
362-
yield f"dict_{k}{sep}{p}", o
366+
yield f"dict_{_mk(k)}{sep}{p}", o
363367
elif obj.__class__.__name__ == "DynamicCache":
364368
# transformers
365369
import transformers
@@ -420,7 +424,7 @@ def _unflatten(
420424
pos: int = 0,
421425
level: int = 0,
422426
device: str = "cpu",
423-
) -> Tuple[int, Tuple[Any, ...]]:
427+
) -> Tuple[int, Any]:
424428
"""Unflattens a list of outputs flattened with :func:`flatten_iterator`."""
425429
name = names[pos]
426430
spl = name.split(sep)
@@ -465,7 +469,7 @@ def _unflatten(
465469

466470
if end:
467471
if prefix.startswith("dict"):
468-
ty = dict
472+
ty: type = dict
469473
elif prefix.startswith("list"):
470474
ty = list
471475
elif prefix.startswith("tuple"):
@@ -479,12 +483,30 @@ def _unflatten(
479483
break
480484
pos = next_pos
481485

486+
def _tryint(s):
487+
try:
488+
return int(s)
489+
except (ValueError, TypeError):
490+
if s in {"True", "False"}:
491+
return s == "True"
492+
return s
493+
482494
def _make(ty: type, res: Any) -> Any:
483495
if ty.__name__ == "DynamicCache":
484496
r = ty()
485497
for k, v in res:
486498
setattr(r, k, v)
487499
return r
500+
if ty is dict:
501+
d = {}
502+
for k, v in res:
503+
if k.startswith("((") and k.endswith("))"):
504+
spl = k[2:-2].split(",")
505+
key = tuple(_tryint(s) for s in spl)
506+
else:
507+
key = _tryint(k)
508+
d[key] = v
509+
return d
488510
return ty(res)
489511

490512
return next_pos, (

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Iterable
33
from typing import Any, Callable, List, Optional, Tuple, Union
44
import numpy as np
5+
import onnx
56
import torch
67
from .helper import string_type
78
from .cache_helper import (
@@ -10,9 +11,12 @@
1011
make_sliding_window_cache,
1112
make_mamba_cache,
1213
)
14+
from .mini_onnx_builder import create_onnx_model_from_input_tensors
1315

1416

15-
def _forward_(*args, _f=None, _fprint=string_type, _prefix="", _context=None, **kwargs):
17+
def _forward_(
18+
*args, _f=None, _fprint=string_type, _prefix="", _context=None, _storage=None, **kwargs
19+
):
1620
assert _f is not None, "_f cannot be None"
1721
assert _context is not None, "_context cannot be None"
1822
indent = " " * (len(_prefix) - len(_prefix.lstrip()))
@@ -28,10 +32,16 @@ def _forward_(*args, _f=None, _fprint=string_type, _prefix="", _context=None, **
2832
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
2933
# torch.compiler.is_exporting requires torch>=2.7
3034
print(f"{indent} <- args={_fprint(args, **kws)} --- kwargs={_fprint(kwargs, **kws)}")
35+
if _storage is not None:
36+
it = _context["iteration"]
37+
key = (_prefix, it)
38+
_storage[(*key, "I")] = (torch_deepcopy(args), torch_deepcopy(kwargs))
3139
res = _f(*args, **kwargs)
3240
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
3341
print(f"{indent} -> {_fprint(res, **kws)}")
3442
print(f"{indent}-{_prefix}.")
43+
if _storage is not None:
44+
_storage[(*key, "O")] = torch_deepcopy(res)
3545
_context["iteration"] += 1
3646
return res
3747

@@ -43,6 +53,7 @@ def steal_forward(
4353
List[Union[torch.nn.Module, Tuple[str, torch.nn.Module]]],
4454
],
4555
fprint: Callable = string_type,
56+
dump_file: Optional[str] = None,
4657
**kwargs,
4758
):
4859
"""
@@ -56,26 +67,38 @@ def steal_forward(
5667
:func:`onnx_diagnostic.helpers.string_type`
5768
:param kwargs: additional parameters sent to :func:`onnx_diagnostic.helpers.string_type`
5869
or any other function defined by ``fprint``
70+
:param dump_file: dumps stolen inputs and outputs in an onnx model,
71+
they can be restored with :func:`create_input_tensors_from_onnx_model
72+
<onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
5973
"""
6074
context = dict(iteration=0, **kwargs)
6175
if "with_shape" not in context and fprint == string_type:
6276
context["with_shape"] = True
6377
if not isinstance(model, list):
6478
model = [model]
6579
keep_model_forward = {}
80+
storage = {} if dump_file else None
6681
for mt in model:
6782
name, m = mt if isinstance(mt, tuple) else ("", mt)
6883
keep_model_forward[id(m)] = (m, m.forward)
6984
c = context.copy()
7085
c["class_name"] = m.__class__.__name__
71-
m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, **kws: _forward_(
72-
*args, _f=_f, _fprint=_fp, _context=_c, _prefix=_p, **kws
86+
m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, _s=storage, **kws: _forward_( # noqa: E501
87+
*args, _f=_f, _fprint=_fp, _context=_c, _prefix=_p, _storage=_s, **kws
7388
)
7489
try:
7590
yield
7691
finally:
7792
for f in keep_model_forward.values():
7893
f[0].forward = f[1]
94+
if dump_file:
95+
proto = create_onnx_model_from_input_tensors(storage)
96+
onnx.save(
97+
proto,
98+
dump_file,
99+
save_as_external_data=False,
100+
all_tensors_to_one_file=True,
101+
)
79102

80103

81104
def is_torchdynamo_exporting() -> bool:

0 commit comments

Comments
 (0)