Skip to content

Commit 9b18b46

Browse files
authored
Fixes extract_subset_of_nodes (#337)
* Fix extract_sub_model * changelogs * handle empty input
1 parent d311955 commit 9b18b46

File tree

10 files changed

+173
-25
lines changed

10 files changed

+173
-25
lines changed

.github/workflows/ci.yml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,16 @@ jobs:
106106
pip install torch==${{ matrix.torch }} torchvision torchaudio
107107
fi
108108
109-
- name: Cache pip
110-
if: ${{ matrix.torch != 'main' && matrix.transformers != 'main' }}
111-
uses: actions/cache@v4
112-
with:
113-
path: ~/.cache/pip
114-
key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
115-
restore-keys: |
116-
${{ runner.os }}-pip-
117-
${{ runner.os }}-
109+
# two slow
110+
#- name: Cache pip
111+
# if: ${{ matrix.torch != 'main' && matrix.transformers != 'main' }}
112+
# uses: actions/cache@v4
113+
# with:
114+
# path: ~/.cache/pip
115+
# key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
116+
# restore-keys: |
117+
# ${{ runner.os }}-pip-
118+
# ${{ runner.os }}-
118119

119120
- name: pip freeze
120121
run: python -m pip freeze

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.4
55
+++++
66

7+
* :pr:`337`: fixes extract_subset_of_nodes
78
* :pr:`336`: implements versioned onnx plugs
89

910
0.8.3
10.2 KB
Binary file not shown.
Binary file not shown.

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import os
12
import unittest
23
from typing import Any, Dict, List
34
import numpy as np
5+
import onnx
46
import onnx.helper as oh
57
import onnx.numpy_helper as onh
68
from onnx import TensorProto, FunctionProto, ValueInfoProto
@@ -475,7 +477,7 @@ def _mkv_(name):
475477

476478
def test_onnx_dtype_name(self):
477479
for k in dir(TensorProto):
478-
if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL"}:
480+
if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL", "DEFAULT"}:
479481
self.assertEqual(k, onnx_dtype_name(getattr(TensorProto, k)))
480482
self.assertRaise(lambda: onnx_dtype_name(1000), ValueError)
481483
self.assertEqual(onnx_dtype_name(1000, exc=False), "UNEXPECTED")
@@ -532,6 +534,42 @@ def _type_rank_fn(name):
532534
check_model(new_model)
533535
self.check_ort(new_model)
534536

537+
def test_extract_subset_of_nodes_bigger(self):
538+
model = onnx.load(
539+
os.path.join(
540+
os.path.dirname(__file__), "data", "test_sbs_mha_split_every_piece.onnx"
541+
)
542+
)
543+
nodes = extract_subset_of_nodes(
544+
model=model,
545+
name="scaled_dot_product_attention",
546+
node_index=16,
547+
cut_points={
548+
"linear",
549+
"linear_1",
550+
"linear_2",
551+
"output_0",
552+
"scaled_dot_product_attention",
553+
"transpose_2",
554+
"view_2",
555+
"x",
556+
},
557+
)
558+
self.assertEqual(
559+
[
560+
"Mul",
561+
"Reshape",
562+
"Transpose",
563+
"Mul",
564+
"Reshape",
565+
"Transpose",
566+
"FusedMatMul",
567+
"Softmax",
568+
"MatMul",
569+
],
570+
[n.op_type for n in nodes],
571+
)
572+
535573

536574
if __name__ == "__main__":
537575
unittest.main(verbosity=2)

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import unittest
23
import pandas
34
import onnx
@@ -777,6 +778,81 @@ def forward(self, query, key, value, seq_lens):
777778
df.to_excel(self.get_dump_file("test_sbs_with_loops.xlsx"))
778779
# self.clean_dump()
779780

781+
@hide_stdout()
782+
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
783+
def test_sbs_mha_split_every_piece(self):
784+
torch = self.torch
785+
786+
class Model(self.torch.nn.Module):
787+
def __init__(self, embed_dim: int, num_heads: int):
788+
super(Model, self).__init__()
789+
self.embed_dim = embed_dim
790+
self.num_heads = num_heads
791+
self.head_dim = embed_dim // num_heads
792+
793+
assert embed_dim % num_heads == 0, (
794+
f"embed_dim % num_heads =! 0 -> "
795+
f"{embed_dim} % {num_heads} = {embed_dim % num_heads}"
796+
)
797+
798+
self.W_q = torch.nn.Linear(embed_dim, embed_dim)
799+
self.W_k = torch.nn.Linear(embed_dim, embed_dim)
800+
self.W_v = torch.nn.Linear(embed_dim, embed_dim)
801+
802+
def split_heads(self, t, seq_len):
803+
return t.view(t.shape[0], seq_len, self.num_heads, self.head_dim).transpose(
804+
1, 2
805+
)
806+
807+
def forward(self, x):
808+
q = self.split_heads(self.W_q(x), x.shape[1])
809+
k = self.split_heads(self.W_k(x), x.shape[1])
810+
v = self.split_heads(self.W_v(x), x.shape[1])
811+
return (
812+
torch.nn.functional.scaled_dot_product_attention(q, k, v)
813+
.transpose(1, 2)
814+
.reshape(x.shape[0], x.shape[1], self.embed_dim)
815+
)
816+
817+
embed_dim = 16
818+
num_heads = 4
819+
seq_len = 10
820+
batch_size = 2
821+
inputs = dict(x=torch.randn(batch_size, seq_len, embed_dim))
822+
model = Model(embed_dim, num_heads)
823+
model(**inputs)
824+
ds = dict(x={0: "batch", 1: "seqlen"})
825+
826+
ep = self.torch.export.export(
827+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
828+
)
829+
self.dump_text("test_sbs_mha_split_every_piece.ep", str(ep))
830+
filename = self.get_dump_file("test_sbs_mha_split_every_piece.onnx")
831+
to_onnx(ep, exporter="custom", filename=filename)
832+
replay = self.get_dump_folder("test_sbs_mha_split_every_piece_replay")
833+
onx = onnx.load(filename)
834+
results = list(
835+
run_aligned(
836+
ep,
837+
onx,
838+
kwargs=inputs,
839+
run_cls=OnnxruntimeEvaluator,
840+
verbose=11,
841+
use_tensor=True,
842+
run_onnx_with_torch_inputs=True,
843+
replay_configuration=ReplayConfiguration(
844+
dump_folder=replay, selected_op_types={"MatMul"}, threshold=2**20
845+
),
846+
),
847+
)
848+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
849+
df.to_excel(self.get_dump_file("test_sbs_mha_split_every_piece.xlsx"))
850+
max_abs = df["err_abs"].max()
851+
self.assertLess(max_abs, 1e-5)
852+
# self.clean_dump()
853+
subonnx = onnx.load(os.path.join(replay, "scaled_dot_product_attention", "model.onnx"))
854+
self.assertEqual(len(subonnx.graph.input), 3)
855+
780856

781857
if __name__ == "__main__":
782858
unittest.main(verbosity=2)

onnx_diagnostic/ext_test_case.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,13 @@ def dump_onnx(self, name: str, proto: Any, folder: Optional[str] = None) -> str:
845845
f.write(proto.SerializeToString())
846846
return fullname
847847

848+
def dump_text(self, name: str, text: str, folder: Optional[str] = None) -> str:
849+
"""Dumps text in a file."""
850+
fullname = self.get_dump_file(name, folder=folder)
851+
with open(fullname, "w") as f:
852+
f.write(text)
853+
return fullname
854+
848855
def assertExists(self, name):
849856
"""Checks the existing of a file."""
850857
if not os.path.exists(name):

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
332332
print(onnx_dtype_name(7))
333333
"""
334334
for k in dir(TensorProto):
335-
if k.upper() == k and k != "EXTERNAL":
335+
if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL", "DEFAULT"}:
336336
v = getattr(TensorProto, k)
337337
if v == itype:
338338
return k
@@ -1219,11 +1219,14 @@ def extract_subset_of_nodes(
12191219
if name in node.output:
12201220
node_index = i
12211221
break
1222-
assert (
1223-
node_index is not None
1224-
and node_index < len(model.graph.node)
1225-
and name in model.graph.node[node_index].output
1226-
), f"node_index is still empty or wrong for result {name!r}"
1222+
assert node_index is not None and node_index < len(model.graph.node), (
1223+
f"node_index={node_index} (n_nodes={len(model.graph.node)}) "
1224+
f"is still empty or wrong for result {name!r}"
1225+
)
1226+
assert name in model.graph.node[node_index].output, (
1227+
f"Unable to find {name!r} in {model.graph.node[node_index].output}, "
1228+
f"node={pretty_onnx(model.graph.node[node_index])}"
1229+
)
12271230
if cut_points is None:
12281231
cut_points = {n.name for n in model.graph.input} | {
12291232
n.name for n in model.graph.initializer
@@ -1236,16 +1239,26 @@ def extract_subset_of_nodes(
12361239
current_node_index = node_index
12371240
current_input_index = 0
12381241
intermediate = {name}
1242+
cut_points -= {name}
12391243
inputs = set(k for k in node.input if k)
12401244
while not (inputs <= cut_points) and current_node_index >= 0:
12411245
node = model.graph.node[current_node_index]
1242-
if current_input_index == 0:
1246+
if current_input_index == 0 or not node.input:
12431247
needs = [o for o in node.output if o in intermediate and o not in cut_points]
12441248
if needs:
12451249
selected.add(current_node_index)
1250+
if not node.input:
1251+
current_node_index -= 1
1252+
current_input_index = 0
1253+
continue
12461254
else:
12471255
current_node_index -= 1
1256+
current_input_index = 0
12481257
continue
1258+
assert current_input_index < len(node.input), (
1259+
f"current_input_index={current_input_index} but node.input={node.input}, "
1260+
f"node={pretty_onnx(node)}"
1261+
)
12491262
res = node.input[current_input_index]
12501263
if res not in cut_points:
12511264
intermediate.add(res)
@@ -1290,8 +1303,8 @@ def _mkv_(name, itype, irank):
12901303
oh.make_graph(
12911304
nodes,
12921305
"submodel",
1293-
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)],
1294-
[_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)],
1306+
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known) if n],
1307+
[_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names) if n],
12951308
),
12961309
ir_version=ir_version,
12971310
opset_imports=opset_imports,

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,8 @@ def _preparation_with_fx_graph(
381381
assert len(torch_input_names) < len(onx.graph.input), (
382382
f"torch_input_names={torch_input_names!r}, "
383383
f"onnx_input_names={[n.name for n in onx.graph.input]}, "
384-
f"node.name={node.name!r} cannot be an input"
384+
f"node.name={node.name!r} cannot be an input, "
385+
f"placeholders_to_state_dict={sorted(placeholders_to_state_dict)}"
385386
)
386387
assert node.name not in skip_mapping_torch_onnx, (
387388
f"{node.name!r} is ambiguous, cannot be mapped due to "
@@ -772,9 +773,9 @@ def forward(self, x):
772773
# preparation with ep.graph.nodes
773774
ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers(), **ep.tensor_constants)}
774775
placeholders_to_state_dict = {
775-
**{f"p_{name.replace('.', '_')}": name for name in ep.state_dict},
776-
**{f"b_{name.replace('.', '_')}": name for name, _ in ep.named_buffers()},
777-
**{f"c_{name.replace('.', '_')}": name for name in ep.tensor_constants},
776+
**{f"p_{name.replace('.', '_').lower()}": name for name in ep.state_dict},
777+
**{f"b_{name.replace('.', '_').lower()}": name for name, _ in ep.named_buffers()},
778+
**{f"c_{name.replace('.', '_').lower()}": name for name in ep.tensor_constants},
778779
}
779780
skip_mapping_torch_onnx = _duplicated_values(placeholders_to_state_dict)
780781
placeholders = {}

onnx_diagnostic/torch_onnx/sbs_dataclasses.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,17 @@ def dump(
243243
:return: the folder created to dump everything
244244
"""
245245
if verbose:
246-
print(f"[ReplayConfiguration.dump] extract subset of node for {name!r}")
246+
print(
247+
f"[ReplayConfiguration.dump] extract subset of nodes for "
248+
f"{name!r} (onnx_id_node={onnx_id_node})"
249+
)
250+
if verbose >= 10:
251+
print(f"[ReplayConfiguration.dump] onnx_results={sorted(onnx_results)}")
252+
print(f"[ReplayConfiguration.dump] torch_results={sorted(torch_results)}")
253+
print(
254+
f"[ReplayConfiguration.dump] onnx_name_to_ep_name="
255+
f"{sorted(onnx_name_to_ep_name)}"
256+
)
247257
nodes = extract_subset_of_nodes(
248258
model=model,
249259
name=name,
@@ -253,7 +263,8 @@ def dump(
253263
if not nodes:
254264
if verbose:
255265
print(
256-
f"[ReplayConfiguration.dump] could not extract subset of node for {name!r}"
266+
f"[ReplayConfiguration.dump] could not extract subset of "
267+
f"nodes for {name!r}"
257268
)
258269
return None
259270
if verbose:

0 commit comments

Comments
 (0)