|
| 1 | +import os |
1 | 2 | import unittest |
2 | 3 | import pandas |
3 | 4 | import onnx |
@@ -777,6 +778,81 @@ def forward(self, query, key, value, seq_lens): |
777 | 778 | df.to_excel(self.get_dump_file("test_sbs_with_loops.xlsx")) |
778 | 779 | # self.clean_dump() |
779 | 780 |
|
| 781 | + @hide_stdout() |
| 782 | + @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) |
| 783 | + def test_sbs_mha_split_every_piece(self): |
| 784 | + torch = self.torch |
| 785 | + |
| 786 | + class Model(self.torch.nn.Module): |
| 787 | + def __init__(self, embed_dim: int, num_heads: int): |
| 788 | + super(Model, self).__init__() |
| 789 | + self.embed_dim = embed_dim |
| 790 | + self.num_heads = num_heads |
| 791 | + self.head_dim = embed_dim // num_heads |
| 792 | + |
| 793 | + assert embed_dim % num_heads == 0, ( |
| 794 | + f"embed_dim % num_heads =! 0 -> " |
| 795 | + f"{embed_dim} % {num_heads} = {embed_dim % num_heads}" |
| 796 | + ) |
| 797 | + |
| 798 | + self.W_q = torch.nn.Linear(embed_dim, embed_dim) |
| 799 | + self.W_k = torch.nn.Linear(embed_dim, embed_dim) |
| 800 | + self.W_v = torch.nn.Linear(embed_dim, embed_dim) |
| 801 | + |
| 802 | + def split_heads(self, t, seq_len): |
| 803 | + return t.view(t.shape[0], seq_len, self.num_heads, self.head_dim).transpose( |
| 804 | + 1, 2 |
| 805 | + ) |
| 806 | + |
| 807 | + def forward(self, x): |
| 808 | + q = self.split_heads(self.W_q(x), x.shape[1]) |
| 809 | + k = self.split_heads(self.W_k(x), x.shape[1]) |
| 810 | + v = self.split_heads(self.W_v(x), x.shape[1]) |
| 811 | + return ( |
| 812 | + torch.nn.functional.scaled_dot_product_attention(q, k, v) |
| 813 | + .transpose(1, 2) |
| 814 | + .reshape(x.shape[0], x.shape[1], self.embed_dim) |
| 815 | + ) |
| 816 | + |
| 817 | + embed_dim = 16 |
| 818 | + num_heads = 4 |
| 819 | + seq_len = 10 |
| 820 | + batch_size = 2 |
| 821 | + inputs = dict(x=torch.randn(batch_size, seq_len, embed_dim)) |
| 822 | + model = Model(embed_dim, num_heads) |
| 823 | + model(**inputs) |
| 824 | + ds = dict(x={0: "batch", 1: "seqlen"}) |
| 825 | + |
| 826 | + ep = self.torch.export.export( |
| 827 | + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) |
| 828 | + ) |
| 829 | + self.dump_text("test_sbs_mha_split_every_piece.ep", str(ep)) |
| 830 | + filename = self.get_dump_file("test_sbs_mha_split_every_piece.onnx") |
| 831 | + to_onnx(ep, exporter="custom", filename=filename) |
| 832 | + replay = self.get_dump_folder("test_sbs_mha_split_every_piece_replay") |
| 833 | + onx = onnx.load(filename) |
| 834 | + results = list( |
| 835 | + run_aligned( |
| 836 | + ep, |
| 837 | + onx, |
| 838 | + kwargs=inputs, |
| 839 | + run_cls=OnnxruntimeEvaluator, |
| 840 | + verbose=11, |
| 841 | + use_tensor=True, |
| 842 | + run_onnx_with_torch_inputs=True, |
| 843 | + replay_configuration=ReplayConfiguration( |
| 844 | + dump_folder=replay, selected_op_types={"MatMul"}, threshold=2**20 |
| 845 | + ), |
| 846 | + ), |
| 847 | + ) |
| 848 | + df = pandas.DataFrame(list(results)).dropna(axis=1, how="all") |
| 849 | + df.to_excel(self.get_dump_file("test_sbs_mha_split_every_piece.xlsx")) |
| 850 | + max_abs = df["err_abs"].max() |
| 851 | + self.assertLess(max_abs, 1e-5) |
| 852 | + # self.clean_dump() |
| 853 | + subonnx = onnx.load(os.path.join(replay, "scaled_dot_product_attention", "model.onnx")) |
| 854 | + self.assertEqual(len(subonnx.graph.input), 3) |
| 855 | + |
780 | 856 |
|
781 | 857 | if __name__ == "__main__": |
782 | 858 | unittest.main(verbosity=2) |
0 commit comments