Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/check-urls.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ jobs:
print_all: false
timeout: 2
retry_count# : 2
exclude_urls: https://hal.archives-,ouvertes.fr/hal-00990252/document,http://badge.fury.io/py/onnx-diagnostic,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://medium.com/@msouza.os/llm-from-scratch-with-pytorch-9f21808c6319,https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L5965,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311,https://www.linux.org/
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/,https://codecov.io/,https://huggingface.co/,https://www.linux.org/
exclude_urls: https://hal.archives-,ouvertes.fr/hal-00990252/document,http://badge.fury.io/py/onnx-diagnostic,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://medium.com/@msouza.os/llm-from-scratch-with-pytorch-9f21808c6319,https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L5965,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311,https://www.linux.org/,https://docs.scipy.org/doc/scipy/
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/,https://codecov.io/,https://huggingface.co/,https://www.linux.org/,https://docs.scipy.org/doc/scipy/
# force_pass : true
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Change Logs
+++++

* :pr:`330`: fixes access rope_parameters for ``transformers>=5``
* :pr:`329`: supports lists with OnnxruntimeEvaluator
* :pr:`326`: use ConcatFromSequence in LoopMHA with the loop
* :pr:`325`: adds plug for LoopMHA, extends the unit tests to measure the discrepancies
* :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator
Expand Down
4 changes: 3 additions & 1 deletion _unittests/ut_helpers/test_onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,9 @@ def test_enumerate_results_loop(self):
new_axis=0,
),
],
)
),
ir_version=10,
opset_imports=[oh.make_opsetid("", 22)],
)
res = list(enumerate_results(model, "slice_start", verbose=2))
self.assertEqual(len(res), 2)
Expand Down
118 changes: 118 additions & 0 deletions _unittests/ut_reference/test_onnxruntime_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


TFLOAT = onnx.TensorProto.FLOAT
TINT64 = onnx.TensorProto.INT64


class TestOnnxruntimeEvaluator(ExtTestCase):
Expand Down Expand Up @@ -319,6 +320,123 @@ def test_function_proto_with_kwargs(self):
got = sess.run(None, feeds)
self.assertEqualArray(expected, got[0], atol=1e-5)

@hide_stdout()
def test_ort_eval_loop_seq(self):
x = np.array([1, 2, 3, 4, 5]).astype(np.float32)
_mkv_ = oh.make_tensor_value_info
model = oh.make_model(
graph=oh.make_graph(
name="loop_test",
inputs=[
oh.make_tensor_value_info("trip_count", TINT64, ["a"]),
oh.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []),
],
outputs=[oh.make_tensor_value_info("res", TFLOAT, [])],
nodes=[
oh.make_node("SequenceEmpty", [], ["seq_empty"], dtype=TFLOAT),
oh.make_node(
"Loop",
inputs=["trip_count", "cond", "seq_empty"],
outputs=["seq_res"],
body=oh.make_graph(
[
oh.make_node(
"Identity", inputs=["cond_in"], outputs=["cond_out"]
),
oh.make_node(
"Constant",
inputs=[],
outputs=["x"],
value=oh.make_tensor(
name="const_tensor_x",
data_type=TFLOAT,
dims=x.shape,
vals=x.flatten().astype(float),
),
),
oh.make_node(
"Constant",
inputs=[],
outputs=["one"],
value=oh.make_tensor(
name="const_tensor_one",
data_type=TINT64,
dims=(),
vals=[1],
),
),
oh.make_node(
"Constant",
inputs=[],
outputs=["slice_start"],
value=oh.make_tensor(
name="const_tensor_zero",
data_type=TINT64,
dims=(1,),
vals=[0],
),
),
oh.make_node(
"Add", inputs=["iter_count", "one"], outputs=["end"]
),
oh.make_node(
"Constant",
inputs=[],
outputs=["axes"],
value=oh.make_tensor(
name="const_tensor_axes",
data_type=TINT64,
dims=(1,),
vals=[0],
),
),
oh.make_node(
"Unsqueeze", inputs=["end", "axes"], outputs=["slice_end"]
),
oh.make_node(
"Slice",
inputs=["x", "slice_start", "slice_end"],
outputs=["slice_out"],
),
oh.make_node(
"SequenceInsert",
inputs=["seq_in", "slice_out"],
outputs=["seq_out"],
),
],
"loop_body",
[
_mkv_("iter_count", TINT64, []),
_mkv_("cond_in", onnx.TensorProto.BOOL, []),
oh.make_tensor_sequence_value_info("seq_in", TFLOAT, None),
],
[
_mkv_("cond_out", onnx.TensorProto.BOOL, []),
oh.make_tensor_sequence_value_info("seq_out", TFLOAT, None),
],
),
),
oh.make_node(
"ConcatFromSequence",
inputs=["seq_res"],
outputs=["res"],
axis=0,
new_axis=0,
),
],
),
ir_version=10,
opset_imports=[oh.make_opsetid("", 22)],
)
ev = OnnxruntimeEvaluator(model, verbose=10)
feeds = dict(trip_count=torch.tensor([3], dtype=torch.int64), cond=torch.tensor(True))
got = ev.run(None, feeds)
self.assertEqual((6,), got[0].shape)
self.assertEqualArray(
torch.tensor([1.0, 1.0, 2.0, 1.0, 2.0, 3.0], dtype=torch.float32), got[0]
)
self.assertIsInstance(got[0], torch.Tensor)


if __name__ == "__main__":
unittest.main(verbosity=2)
4 changes: 3 additions & 1 deletion _unittests/ut_reference/test_torch_onnx_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,9 @@ def test_loop(self):
new_axis=0,
),
],
)
),
ir_version=10,
opset_imports=[oh.make_opsetid("", 22)],
)
self._finalize_test(
model, torch.tensor(5, dtype=torch.int64), torch.tensor(1, dtype=torch.bool)
Expand Down
120 changes: 120 additions & 0 deletions _unittests/ut_torch_onnx/test_sbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
ignore_errors,
requires_cuda,
)
from onnx_diagnostic.helpers.rt_helper import make_feeds
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import patch_qwen2_5
from onnx_diagnostic.torch_onnx.sbs import run_aligned
from onnx_diagnostic.torch_onnx.sbs_dataclasses import RunAlignedRecord, ReplayConfiguration
from onnx_diagnostic.export.api import to_onnx
Expand Down Expand Up @@ -671,6 +673,124 @@ def forward(self, x):
)
self.clean_dump()

@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
@hide_stdout()
def test_sbs_with_loops(self):
import torch
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
PLUGS_Qwen25,
)
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
qwen_sdpa_attention_loopmha_versatile,
)

class Model(torch.nn.Module):
def forward(self, query, key, value, seq_lens):
rg1 = torch.arange(4, dtype=torch.int32).unsqueeze(0)
rg0 = torch.arange(4, dtype=torch.int32).unsqueeze(1)
mask = (rg0 <= rg1).flatten().reshape((1, -1, 1, 1)).to(query.dtype)
qs = query * mask
ks = key * mask
vs = value * mask
attn_output = qwen_sdpa_attention_loopmha_versatile(
qs,
ks,
vs,
seq_lens,
0.11,
16,
(
onnx.TensorProto.FLOAT
if query.dtype == torch.float32
else (
onnx.TensorProto.FLOAT16
if query.dtype == torch.float16
else onnx.TensorProto.BFLOAT16
)
),
)
red = attn_output.mean(dim=-1, keepdim=True)
return attn_output - red

model = Model()
inputs = (
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
torch.tensor(
[
0,
64,
128,
192,
256,
304,
368,
432,
496,
560,
608,
672,
736,
800,
864,
912,
976,
1040,
1104,
1168,
1216,
1232,
1248,
1264,
1280,
1292,
],
dtype=torch.int64,
),
)
expected = model(*inputs)
ds = ({2: "seq_length"}, {2: "seq_length"}, {2: "seq_length"}, {0: "num_patches"})
onnx_file = self.get_dump_file("test_sbs_with_loops.onnx")
ep_file = self.get_dump_file("test_sbs_with_loops")
to_onnx(
model,
inputs,
dynamic_shapes=ds,
filename=onnx_file,
save_ep=(ep_file, 2**28),
exporter="custom",
onnx_plugs=PLUGS_Qwen25,
target_opset=22,
)
input_file = ep_file + ".input.pt"
ep_file = ep_file + ".ep.pt2"
self.assertExists(onnx_file)
self.assertExists(ep_file)
self.assertExists(input_file)
sess = self.check_ort(onnx_file)
input_names = [i.name for i in sess.get_inputs()]
feeds = make_feeds(input_names, inputs, use_numpy=True)
got = sess.run(None, feeds)
self.assertEqualArray(expected, got[0], atol=1e-3)
# sbs
ep = torch.export.load(ep_file)
onx = onnx.load(onnx_file)
kwargs = make_feeds(input_names, inputs, use_numpy=False)
results = list(
run_aligned(
ep,
onx,
kwargs=kwargs,
run_cls=OnnxruntimeEvaluator,
verbose=11,
use_tensor=True,
),
)
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
df.to_excel(self.get_dump_file("test_sbs_with_loops.xlsx"))
# self.clean_dump()


if __name__ == "__main__":
unittest.main(verbosity=2)
5 changes: 4 additions & 1 deletion onnx_diagnostic/ext_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,10 @@ def check_ort(
) -> "onnxruntime.InferenceSession": # noqa: F821
from onnxruntime import InferenceSession

return InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
return InferenceSession(
onx if isinstance(onx, str) else onx.SerializeToString(),
providers=["CPUExecutionProvider"],
)

def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
"""In the name"""
Expand Down
5 changes: 4 additions & 1 deletion onnx_diagnostic/helpers/ort_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def __init__(
session_options,
providers=providers,
)
except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
except (
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
) as e:
if isinstance(sess, onnx.ModelProto):
debug_path = "_debug_InferenceSession_last_failure.onnx"
onnx.save(
Expand Down
Loading
Loading