|
8 | 8 | ignore_errors, |
9 | 9 | requires_cuda, |
10 | 10 | ) |
| 11 | +from onnx_diagnostic.helpers.rt_helper import make_feeds |
11 | 12 | from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator |
12 | 13 | from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str |
| 14 | +from onnx_diagnostic.torch_export_patches.patches.patch_transformers import patch_qwen2_5 |
13 | 15 | from onnx_diagnostic.torch_onnx.sbs import run_aligned |
14 | 16 | from onnx_diagnostic.torch_onnx.sbs_dataclasses import RunAlignedRecord, ReplayConfiguration |
15 | 17 | from onnx_diagnostic.export.api import to_onnx |
@@ -671,6 +673,124 @@ def forward(self, x): |
671 | 673 | ) |
672 | 674 | self.clean_dump() |
673 | 675 |
|
| 676 | + @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") |
| 677 | + @hide_stdout() |
| 678 | + def test_sbs_with_loops(self): |
| 679 | + import torch |
| 680 | + from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( |
| 681 | + PLUGS_Qwen25, |
| 682 | + ) |
| 683 | + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( |
| 684 | + qwen_sdpa_attention_loopmha_versatile, |
| 685 | + ) |
| 686 | + |
| 687 | + class Model(torch.nn.Module): |
| 688 | + def forward(self, query, key, value, seq_lens): |
| 689 | + rg1 = torch.arange(4, dtype=torch.int32).unsqueeze(0) |
| 690 | + rg0 = torch.arange(4, dtype=torch.int32).unsqueeze(1) |
| 691 | + mask = (rg0 <= rg1).flatten().reshape((1, -1, 1, 1)).to(query.dtype) |
| 692 | + qs = query * mask |
| 693 | + ks = key * mask |
| 694 | + vs = value * mask |
| 695 | + attn_output = qwen_sdpa_attention_loopmha_versatile( |
| 696 | + qs, |
| 697 | + ks, |
| 698 | + vs, |
| 699 | + seq_lens, |
| 700 | + 0.11, |
| 701 | + 16, |
| 702 | + ( |
| 703 | + onnx.TensorProto.FLOAT |
| 704 | + if query.dtype == torch.float32 |
| 705 | + else ( |
| 706 | + onnx.TensorProto.FLOAT16 |
| 707 | + if query.dtype == torch.float16 |
| 708 | + else onnx.TensorProto.BFLOAT16 |
| 709 | + ) |
| 710 | + ), |
| 711 | + ) |
| 712 | + red = attn_output.mean(dim=-1, keepdim=True) |
| 713 | + return attn_output - red |
| 714 | + |
| 715 | + model = Model() |
| 716 | + inputs = ( |
| 717 | + torch.rand((1, 16, 1292, 80), dtype=torch.float16), |
| 718 | + torch.rand((1, 16, 1292, 80), dtype=torch.float16), |
| 719 | + torch.rand((1, 16, 1292, 80), dtype=torch.float16), |
| 720 | + torch.tensor( |
| 721 | + [ |
| 722 | + 0, |
| 723 | + 64, |
| 724 | + 128, |
| 725 | + 192, |
| 726 | + 256, |
| 727 | + 304, |
| 728 | + 368, |
| 729 | + 432, |
| 730 | + 496, |
| 731 | + 560, |
| 732 | + 608, |
| 733 | + 672, |
| 734 | + 736, |
| 735 | + 800, |
| 736 | + 864, |
| 737 | + 912, |
| 738 | + 976, |
| 739 | + 1040, |
| 740 | + 1104, |
| 741 | + 1168, |
| 742 | + 1216, |
| 743 | + 1232, |
| 744 | + 1248, |
| 745 | + 1264, |
| 746 | + 1280, |
| 747 | + 1292, |
| 748 | + ], |
| 749 | + dtype=torch.int64, |
| 750 | + ), |
| 751 | + ) |
| 752 | + expected = model(*inputs) |
| 753 | + ds = ({2: "seq_length"}, {2: "seq_length"}, {2: "seq_length"}, {0: "num_patches"}) |
| 754 | + onnx_file = self.get_dump_file("test_sbs_with_loops.onnx") |
| 755 | + ep_file = self.get_dump_file("test_sbs_with_loops") |
| 756 | + to_onnx( |
| 757 | + model, |
| 758 | + inputs, |
| 759 | + dynamic_shapes=ds, |
| 760 | + filename=onnx_file, |
| 761 | + save_ep=(ep_file, 2**28), |
| 762 | + exporter="custom", |
| 763 | + onnx_plugs=PLUGS_Qwen25, |
| 764 | + target_opset=22, |
| 765 | + ) |
| 766 | + input_file = ep_file + ".input.pt" |
| 767 | + ep_file = ep_file + ".ep.pt2" |
| 768 | + self.assertExists(onnx_file) |
| 769 | + self.assertExists(ep_file) |
| 770 | + self.assertExists(input_file) |
| 771 | + sess = self.check_ort(onnx_file) |
| 772 | + input_names = [i.name for i in sess.get_inputs()] |
| 773 | + feeds = make_feeds(input_names, inputs, use_numpy=True) |
| 774 | + got = sess.run(None, feeds) |
| 775 | + self.assertEqualArray(expected, got[0], atol=1e-3) |
| 776 | + # sbs |
| 777 | + ep = torch.export.load(ep_file) |
| 778 | + onx = onnx.load(onnx_file) |
| 779 | + kwargs = make_feeds(input_names, inputs, use_numpy=False) |
| 780 | + results = list( |
| 781 | + run_aligned( |
| 782 | + ep, |
| 783 | + onx, |
| 784 | + kwargs=kwargs, |
| 785 | + run_cls=OnnxruntimeEvaluator, |
| 786 | + verbose=11, |
| 787 | + use_tensor=True, |
| 788 | + ), |
| 789 | + ) |
| 790 | + df = pandas.DataFrame(list(results)).dropna(axis=1, how="all") |
| 791 | + df.to_excel(self.get_dump_file("test_sbs_with_loops.xlsx")) |
| 792 | + # self.clean_dump() |
| 793 | + |
674 | 794 |
|
675 | 795 | if __name__ == "__main__": |
676 | 796 | unittest.main(verbosity=2) |
0 commit comments