Skip to content

Commit 50f551d

Browse files
trivedivivekfacebook-github-bot
authored andcommitted
Reformulating matrix multiplication scale equation to reduce math ops and improve power and performance. (pytorch#6437)
Summary: This diff simplifies the the matrix multiplication scale equation in q_linear op. The existing equation in q_linear op is: ``` for i in K / 4 sums[c] = mat1_tex . (qmat2(c) scales[c]) out += sums ``` where c = [0, 4), out, sums, mat1_tex and qmat2 are vectors and scales is a scalar. The dot product is associative with respect to scalar multiplication as mentioned in https://en.wikipedia.org/wiki/Dot_product ie. (ac1).(bc2) = c1c2(a.b) Thus, the multiplication can be rearranged as: ``` for i in K / 4 sums[c] = (mat1_tex . qmat2(c)) scales[c] out += sums ``` Using distributive property of multiplication ie. ab + ac + ad ... = a(b + c+ d...) the code can be further simplified to: ``` for i in K / 4 sums[c] = mat1_tex . qmat2(c) out += sums out *= scale ``` This rearrangement significantly reduces redundant multiplications. Reviewed By: SS-JIA Differential Revision: D64479405
1 parent f6778d5 commit 50f551d

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,22 +102,20 @@ VEC4_T q_8w_linear(const ivec3 out_pos, const int K) {
102102

103103
for (int i = 0; i < K; i += 4) {
104104
const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);
105-
106105
const VEC4_T sums = VEC4_T(
107-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos) * scales.x),
108-
dot(mat1_tex,
109-
load_texel(t_qmat2, qmat2_pos + u16vec3(0, 1, 0)) * scales.y),
110-
dot(mat1_tex,
111-
load_texel(t_qmat2, qmat2_pos + u16vec3(0, 2, 0)) * scales.z),
112-
dot(mat1_tex,
113-
load_texel(t_qmat2, qmat2_pos + u16vec3(0, 3, 0)) * scales.w));
106+
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos)),
107+
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 1, 0))),
108+
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 2, 0))),
109+
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 3, 0))));
114110

115111
outtex += sums;
116112

117113
mat1_pos.x++;
118114
qmat2_pos.x++;
119115
}
120116

117+
outtex *= scales;
118+
121119
return outtex;
122120
}
123121

0 commit comments

Comments
 (0)