Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,16 @@ jobs:
pip install torch==${{ matrix.torch }} torchvision torchaudio
fi

- name: Cache pip
if: ${{ matrix.torch != 'main' && matrix.transformers != 'main' }}
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
restore-keys: |
${{ runner.os }}-pip-
${{ runner.os }}-
# two slow
#- name: Cache pip
# if: ${{ matrix.torch != 'main' && matrix.transformers != 'main' }}
# uses: actions/cache@v4
# with:
# path: ~/.cache/pip
# key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
# restore-keys: |
# ${{ runner.os }}-pip-
# ${{ runner.os }}-

- name: pip freeze
run: python -m pip freeze
Expand Down
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.4
+++++

* :pr:`337`: fixes extract_subset_of_nodes
* :pr:`336`: implements versioned onnx plugs

0.8.3
Expand Down
Binary file not shown.
Binary file not shown.
40 changes: 39 additions & 1 deletion _unittests/ut_helpers/test_onnx_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import unittest
from typing import Any, Dict, List
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from onnx import TensorProto, FunctionProto, ValueInfoProto
Expand Down Expand Up @@ -475,7 +477,7 @@ def _mkv_(name):

def test_onnx_dtype_name(self):
for k in dir(TensorProto):
if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL"}:
if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL", "DEFAULT"}:
self.assertEqual(k, onnx_dtype_name(getattr(TensorProto, k)))
self.assertRaise(lambda: onnx_dtype_name(1000), ValueError)
self.assertEqual(onnx_dtype_name(1000, exc=False), "UNEXPECTED")
Expand Down Expand Up @@ -532,6 +534,42 @@ def _type_rank_fn(name):
check_model(new_model)
self.check_ort(new_model)

def test_extract_subset_of_nodes_bigger(self):
model = onnx.load(
os.path.join(
os.path.dirname(__file__), "data", "test_sbs_mha_split_every_piece.onnx"
)
)
nodes = extract_subset_of_nodes(
model=model,
name="scaled_dot_product_attention",
node_index=16,
cut_points={
"linear",
"linear_1",
"linear_2",
"output_0",
"scaled_dot_product_attention",
"transpose_2",
"view_2",
"x",
},
)
self.assertEqual(
[
"Mul",
"Reshape",
"Transpose",
"Mul",
"Reshape",
"Transpose",
"FusedMatMul",
"Softmax",
"MatMul",
],
[n.op_type for n in nodes],
)


if __name__ == "__main__":
unittest.main(verbosity=2)
76 changes: 76 additions & 0 deletions _unittests/ut_torch_onnx/test_sbs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest
import pandas
import onnx
Expand Down Expand Up @@ -777,6 +778,81 @@ def forward(self, query, key, value, seq_lens):
df.to_excel(self.get_dump_file("test_sbs_with_loops.xlsx"))
# self.clean_dump()

@hide_stdout()
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
def test_sbs_mha_split_every_piece(self):
torch = self.torch

class Model(self.torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int):
super(Model, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads

assert embed_dim % num_heads == 0, (
f"embed_dim % num_heads =! 0 -> "
f"{embed_dim} % {num_heads} = {embed_dim % num_heads}"
)

self.W_q = torch.nn.Linear(embed_dim, embed_dim)
self.W_k = torch.nn.Linear(embed_dim, embed_dim)
self.W_v = torch.nn.Linear(embed_dim, embed_dim)

def split_heads(self, t, seq_len):
return t.view(t.shape[0], seq_len, self.num_heads, self.head_dim).transpose(
1, 2
)

def forward(self, x):
q = self.split_heads(self.W_q(x), x.shape[1])
k = self.split_heads(self.W_k(x), x.shape[1])
v = self.split_heads(self.W_v(x), x.shape[1])
return (
torch.nn.functional.scaled_dot_product_attention(q, k, v)
.transpose(1, 2)
.reshape(x.shape[0], x.shape[1], self.embed_dim)
)

embed_dim = 16
num_heads = 4
seq_len = 10
batch_size = 2
inputs = dict(x=torch.randn(batch_size, seq_len, embed_dim))
model = Model(embed_dim, num_heads)
model(**inputs)
ds = dict(x={0: "batch", 1: "seqlen"})

ep = self.torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
)
self.dump_text("test_sbs_mha_split_every_piece.ep", str(ep))
filename = self.get_dump_file("test_sbs_mha_split_every_piece.onnx")
to_onnx(ep, exporter="custom", filename=filename)
replay = self.get_dump_folder("test_sbs_mha_split_every_piece_replay")
onx = onnx.load(filename)
results = list(
run_aligned(
ep,
onx,
kwargs=inputs,
run_cls=OnnxruntimeEvaluator,
verbose=11,
use_tensor=True,
run_onnx_with_torch_inputs=True,
replay_configuration=ReplayConfiguration(
dump_folder=replay, selected_op_types={"MatMul"}, threshold=2**20
),
),
)
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
df.to_excel(self.get_dump_file("test_sbs_mha_split_every_piece.xlsx"))
max_abs = df["err_abs"].max()
self.assertLess(max_abs, 1e-5)
# self.clean_dump()
subonnx = onnx.load(os.path.join(replay, "scaled_dot_product_attention", "model.onnx"))
self.assertEqual(len(subonnx.graph.input), 3)


if __name__ == "__main__":
unittest.main(verbosity=2)
7 changes: 7 additions & 0 deletions onnx_diagnostic/ext_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,13 @@ def dump_onnx(self, name: str, proto: Any, folder: Optional[str] = None) -> str:
f.write(proto.SerializeToString())
return fullname

def dump_text(self, name: str, text: str, folder: Optional[str] = None) -> str:
"""Dumps text in a file."""
fullname = self.get_dump_file(name, folder=folder)
with open(fullname, "w") as f:
f.write(text)
return fullname

def assertExists(self, name):
"""Checks the existing of a file."""
if not os.path.exists(name):
Expand Down
31 changes: 22 additions & 9 deletions onnx_diagnostic/helpers/onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
print(onnx_dtype_name(7))
"""
for k in dir(TensorProto):
if k.upper() == k and k != "EXTERNAL":
if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL", "DEFAULT"}:
v = getattr(TensorProto, k)
if v == itype:
return k
Expand Down Expand Up @@ -1219,11 +1219,14 @@ def extract_subset_of_nodes(
if name in node.output:
node_index = i
break
assert (
node_index is not None
and node_index < len(model.graph.node)
and name in model.graph.node[node_index].output
), f"node_index is still empty or wrong for result {name!r}"
assert node_index is not None and node_index < len(model.graph.node), (
f"node_index={node_index} (n_nodes={len(model.graph.node)}) "
f"is still empty or wrong for result {name!r}"
)
assert name in model.graph.node[node_index].output, (
f"Unable to find {name!r} in {model.graph.node[node_index].output}, "
f"node={pretty_onnx(model.graph.node[node_index])}"
)
if cut_points is None:
cut_points = {n.name for n in model.graph.input} | {
n.name for n in model.graph.initializer
Expand All @@ -1236,16 +1239,26 @@ def extract_subset_of_nodes(
current_node_index = node_index
current_input_index = 0
intermediate = {name}
cut_points -= {name}
inputs = set(k for k in node.input if k)
while not (inputs <= cut_points) and current_node_index >= 0:
node = model.graph.node[current_node_index]
if current_input_index == 0:
if current_input_index == 0 or not node.input:
needs = [o for o in node.output if o in intermediate and o not in cut_points]
if needs:
selected.add(current_node_index)
if not node.input:
current_node_index -= 1
current_input_index = 0
continue
else:
current_node_index -= 1
current_input_index = 0
continue
assert current_input_index < len(node.input), (
f"current_input_index={current_input_index} but node.input={node.input}, "
f"node={pretty_onnx(node)}"
)
res = node.input[current_input_index]
if res not in cut_points:
intermediate.add(res)
Expand Down Expand Up @@ -1290,8 +1303,8 @@ def _mkv_(name, itype, irank):
oh.make_graph(
nodes,
"submodel",
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)],
[_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)],
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known) if n],
[_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names) if n],
),
ir_version=ir_version,
opset_imports=opset_imports,
Expand Down
9 changes: 5 additions & 4 deletions onnx_diagnostic/torch_onnx/sbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ def _preparation_with_fx_graph(
assert len(torch_input_names) < len(onx.graph.input), (
f"torch_input_names={torch_input_names!r}, "
f"onnx_input_names={[n.name for n in onx.graph.input]}, "
f"node.name={node.name!r} cannot be an input"
f"node.name={node.name!r} cannot be an input, "
f"placeholders_to_state_dict={sorted(placeholders_to_state_dict)}"
)
assert node.name not in skip_mapping_torch_onnx, (
f"{node.name!r} is ambiguous, cannot be mapped due to "
Expand Down Expand Up @@ -772,9 +773,9 @@ def forward(self, x):
# preparation with ep.graph.nodes
ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers(), **ep.tensor_constants)}
placeholders_to_state_dict = {
**{f"p_{name.replace('.', '_')}": name for name in ep.state_dict},
**{f"b_{name.replace('.', '_')}": name for name, _ in ep.named_buffers()},
**{f"c_{name.replace('.', '_')}": name for name in ep.tensor_constants},
**{f"p_{name.replace('.', '_').lower()}": name for name in ep.state_dict},
**{f"b_{name.replace('.', '_').lower()}": name for name, _ in ep.named_buffers()},
**{f"c_{name.replace('.', '_').lower()}": name for name in ep.tensor_constants},
}
skip_mapping_torch_onnx = _duplicated_values(placeholders_to_state_dict)
placeholders = {}
Expand Down
15 changes: 13 additions & 2 deletions onnx_diagnostic/torch_onnx/sbs_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,17 @@ def dump(
:return: the folder created to dump everything
"""
if verbose:
print(f"[ReplayConfiguration.dump] extract subset of node for {name!r}")
print(
f"[ReplayConfiguration.dump] extract subset of nodes for "
f"{name!r} (onnx_id_node={onnx_id_node})"
)
if verbose >= 10:
print(f"[ReplayConfiguration.dump] onnx_results={sorted(onnx_results)}")
print(f"[ReplayConfiguration.dump] torch_results={sorted(torch_results)}")
print(
f"[ReplayConfiguration.dump] onnx_name_to_ep_name="
f"{sorted(onnx_name_to_ep_name)}"
)
nodes = extract_subset_of_nodes(
model=model,
name=name,
Expand All @@ -253,7 +263,8 @@ def dump(
if not nodes:
if verbose:
print(
f"[ReplayConfiguration.dump] could not extract subset of node for {name!r}"
f"[ReplayConfiguration.dump] could not extract subset of "
f"nodes for {name!r}"
)
return None
if verbose:
Expand Down
Loading