Commit 9a7fb42
Arm backend: Fix torch.matmul() failures for 2D tensor inputs (pytorch#14624)
- ConvertMmToBmmPass converts an MM node to BMM nodes, turns input and
output tensors from rank-2 to rank-3 via unsqueeze/squeeze, and inserts
q-dq before and after BMM node when necessary.
- After ConvertMmToBmmPass:
```
x -> q -> dq -> unsqueeze -> q_2 -> dq_2 ->
\
bmm -> q_4 -> dq_4
/
y -> q_1 -> dq_1 -> unsqueeze -> q_3 -> dq_3 ->
```
- Therefore, if the original matmul was 2D, the bmm already has DQ nodes
on its inputs and Q node on its output. If AnnotateDecomposedMatmulPass
(pytorch#10654) is still applied in this case, it produces illegal sequences
such as: x -> q -> unsqueeze -> q_2 (invalid)
- Fix by checking whether the BMM is already surrounded by DQ nodes on
its inputs and Q nodes on its output.
Change-Id: I9949d59b0b4a96fa34a88b0734014567ea6f24cc
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218
Signed-off-by: Yufeng Shi <[email protected]>
Co-authored-by: Oscar Andersson <[email protected]>1 parent 75ebd05 commit 9a7fb42
File tree
2 files changed
+14
-2
lines changed- backends/arm
- _passes
- test/ops
2 files changed
+14
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
73 | 73 | | |
74 | 74 | | |
75 | 75 | | |
76 | | - | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
77 | 80 | | |
78 | 81 | | |
79 | 82 | | |
| |||
99 | 102 | | |
100 | 103 | | |
101 | 104 | | |
102 | | - | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
103 | 108 | | |
104 | 109 | | |
105 | 110 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | 27 | | |
27 | 28 | | |
| |||
32 | 33 | | |
33 | 34 | | |
34 | 35 | | |
| 36 | + | |
35 | 37 | | |
36 | 38 | | |
37 | 39 | | |
| |||
42 | 44 | | |
43 | 45 | | |
44 | 46 | | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
45 | 52 | | |
46 | 53 | | |
47 | 54 | | |
| |||
0 commit comments