Skip to content

Commit 65157ab

Browse files
trivedivivekfacebook-github-bot
authored andcommitted
Improving 4bit quant mat mul performance by shifting position of -8 operation. (pytorch#15436)
Summary: This diff introduces performance improvements in the 4-bit quant matrix multiplication operation by adjusting the position of the -8 operation. Resulting in overall reduction in math operation performed during shader runtime. Differential Revision: D85721578
1 parent 30d7cae commit 65157ab

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ void main() {
6464

6565
T sums[TILE_ROWS * TILE_TXCOLS * 4];
6666

67+
$if QUANT_NBITS == 4:
68+
// accumulate mat1 elements sums so -8 bias can be applied to it later
69+
T mat1_accum[TILE_ROWS];
70+
$for r in range(TILE_ROWS):
71+
mat1_accum[${r}] = T(0.0);
72+
6773
for (int r = 0; r < TILE_ROWS; ++r) {
6874
$for c in range(TILE_TXCOLS):
6975
$for j in range(4):
@@ -86,6 +92,11 @@ void main() {
8692
VEC4_T mat1_vec4 = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0));
8793
$for j in range(4):
8894
mat1[i * 4 + ${j}] = mat1_vec4[${j}];
95+
96+
$if QUANT_NBITS == 4:
97+
// Apply -8 * mat1 bias here rather then below, to effectively reduce overall number of math operations performed during runtime.
98+
// Accumulate mat1 element sum, this will be multiplied with -8 later for converting 4 bit data to a signed number.
99+
mat1_accum[i] += mat1[i * 4 + 0] + mat1[i * 4 + 1] + mat1[i * 4 + 2] + mat1[i * 4 + 3];
89100
}
90101

91102
$if WEIGHT_STORAGE == "buffer":
@@ -109,13 +120,13 @@ void main() {
109120
packed_weight_tex = texelFetch(
110121
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
111122

112-
qmat2_vec4 = (VEC4_T(packed_weight_tex >> 4) - 8.0);
123+
qmat2_vec4 = VEC4_T(packed_weight_tex >> 4);
113124
qmat2[${c} * 4 * TILE_TXCOLS + 0] = qmat2_vec4.x;
114125
qmat2[${c} * 4 * TILE_TXCOLS + 1] = qmat2_vec4.y;
115126
qmat2[${c} * 4 * TILE_TXCOLS + 2] = qmat2_vec4.z;
116127
qmat2[${c} * 4 * TILE_TXCOLS + 3] = qmat2_vec4.w;
117128

118-
qmat2_vec4 = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
129+
qmat2_vec4 = VEC4_T(packed_weight_tex & 0x0F);
119130
qmat2[${c} * 4 * TILE_TXCOLS + 4] = qmat2_vec4.x;
120131
qmat2[${c} * 4 * TILE_TXCOLS + 5] = qmat2_vec4.y;
121132
qmat2[${c} * 4 * TILE_TXCOLS + 6] = qmat2_vec4.z;
@@ -156,10 +167,16 @@ void main() {
156167
for (int r = 0; r < TILE_ROWS; ++r) {
157168
VEC4_T scaled_sums;
158169
$for c in range(TILE_TXCOLS):
159-
scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] * scales[${c}].x;
160-
scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] * scales[${c}].y;
161-
scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] * scales[${c}].z;
162-
scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] * scales[${c}].w;
170+
$if QUANT_NBITS == 4:
171+
scaled_sums.x = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] + mat1_accum[r] * -8.0) * scales[${c}].x;
172+
scaled_sums.y = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] + mat1_accum[r] * -8.0) * scales[${c}].y;
173+
scaled_sums.z = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] + mat1_accum[r] * -8.0) * scales[${c}].z;
174+
scaled_sums.w = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] + mat1_accum[r] * -8.0) * scales[${c}].w;
175+
$else:
176+
scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] * scales[${c}].x;
177+
scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] * scales[${c}].y;
178+
scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] * scales[${c}].z;
179+
scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] * scales[${c}].w;
163180

164181
$if OUT_STORAGE == "buffer":
165182
if (out_row + r < out_sizes.y) {

0 commit comments

Comments
 (0)