Skip to content

Commit 48df450

Browse files
committed
add unit test
1 parent ed88bcd commit 48df450

File tree

2 files changed

+124
-1
lines changed

2 files changed

+124
-1
lines changed

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
ignore_errors,
99
requires_cuda,
1010
)
11+
from onnx_diagnostic.helpers.rt_helper import make_feeds
1112
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
1213
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
14+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import patch_qwen2_5
1315
from onnx_diagnostic.torch_onnx.sbs import run_aligned
1416
from onnx_diagnostic.torch_onnx.sbs_dataclasses import RunAlignedRecord, ReplayConfiguration
1517
from onnx_diagnostic.export.api import to_onnx
@@ -671,6 +673,124 @@ def forward(self, x):
671673
)
672674
self.clean_dump()
673675

676+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
677+
@hide_stdout()
678+
def test_sbs_with_loops(self):
679+
import torch
680+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
681+
PLUGS_Qwen25,
682+
)
683+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
684+
qwen_sdpa_attention_loopmha_versatile,
685+
)
686+
687+
class Model(torch.nn.Module):
688+
def forward(self, query, key, value, seq_lens):
689+
rg1 = torch.arange(4, dtype=torch.int32).unsqueeze(0)
690+
rg0 = torch.arange(4, dtype=torch.int32).unsqueeze(1)
691+
mask = (rg0 <= rg1).flatten().reshape((1, -1, 1, 1)).to(query.dtype)
692+
qs = query * mask
693+
ks = key * mask
694+
vs = value * mask
695+
attn_output = qwen_sdpa_attention_loopmha_versatile(
696+
qs,
697+
ks,
698+
vs,
699+
seq_lens,
700+
0.11,
701+
16,
702+
(
703+
onnx.TensorProto.FLOAT
704+
if query.dtype == torch.float32
705+
else (
706+
onnx.TensorProto.FLOAT16
707+
if query.dtype == torch.float16
708+
else onnx.TensorProto.BFLOAT16
709+
)
710+
),
711+
)
712+
red = attn_output.mean(dim=-1, keepdim=True)
713+
return attn_output - red
714+
715+
model = Model()
716+
inputs = (
717+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
718+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
719+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
720+
torch.tensor(
721+
[
722+
0,
723+
64,
724+
128,
725+
192,
726+
256,
727+
304,
728+
368,
729+
432,
730+
496,
731+
560,
732+
608,
733+
672,
734+
736,
735+
800,
736+
864,
737+
912,
738+
976,
739+
1040,
740+
1104,
741+
1168,
742+
1216,
743+
1232,
744+
1248,
745+
1264,
746+
1280,
747+
1292,
748+
],
749+
dtype=torch.int64,
750+
),
751+
)
752+
expected = model(*inputs)
753+
ds = ({2: "seq_length"}, {2: "seq_length"}, {2: "seq_length"}, {0: "num_patches"})
754+
onnx_file = self.get_dump_file("test_sbs_with_loops.onnx")
755+
ep_file = self.get_dump_file("test_sbs_with_loops")
756+
to_onnx(
757+
model,
758+
inputs,
759+
dynamic_shapes=ds,
760+
filename=onnx_file,
761+
save_ep=(ep_file, 2**28),
762+
exporter="custom",
763+
onnx_plugs=PLUGS_Qwen25,
764+
target_opset=22,
765+
)
766+
input_file = ep_file + ".input.pt"
767+
ep_file = ep_file + ".ep.pt2"
768+
self.assertExists(onnx_file)
769+
self.assertExists(ep_file)
770+
self.assertExists(input_file)
771+
sess = self.check_ort(onnx_file)
772+
input_names = [i.name for i in sess.get_inputs()]
773+
feeds = make_feeds(input_names, inputs, use_numpy=True)
774+
got = sess.run(None, feeds)
775+
self.assertEqualArray(expected, got[0], atol=1e-3)
776+
# sbs
777+
ep = torch.export.load(ep_file)
778+
onx = onnx.load(onnx_file)
779+
kwargs = make_feeds(input_names, inputs, use_numpy=False)
780+
results = list(
781+
run_aligned(
782+
ep,
783+
onx,
784+
kwargs=kwargs,
785+
run_cls=OnnxruntimeEvaluator,
786+
verbose=11,
787+
use_tensor=True,
788+
),
789+
)
790+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
791+
df.to_excel(self.get_dump_file("test_sbs_with_loops.xlsx"))
792+
# self.clean_dump()
793+
674794

675795
if __name__ == "__main__":
676796
unittest.main(verbosity=2)

onnx_diagnostic/ext_test_case.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,10 @@ def check_ort(
11111111
) -> "onnxruntime.InferenceSession": # noqa: F821
11121112
from onnxruntime import InferenceSession
11131113

1114-
return InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
1114+
return InferenceSession(
1115+
onx if isinstance(onx, str) else onx.SerializeToString(),
1116+
providers=["CPUExecutionProvider"],
1117+
)
11151118

11161119
def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
11171120
"""In the name"""

0 commit comments

Comments
 (0)