Skip to content

Commit 7f09a68

Browse files
authored
ggml-cpu : optimize RVV q2_k and q3_k kernels (ggml-org#16887)
1 parent aa37417 commit 7f09a68

File tree

1 file changed

+108
-49
lines changed

1 file changed

+108
-49
lines changed

ggml/src/ggml-cpu/arch/riscv/quants.c

Lines changed: 108 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -580,16 +580,19 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
580580
const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
581581
uint8_t *patmp = atmp;
582582
int vsums;
583-
int tmp;
583+
int tmp, t1, t2, t3, t4, t5, t6, t7;
584584
__asm__ __volatile__(
585585
"vsetivli zero, 16, e8, m1\n\t"
586586
"vmv.v.x v8, zero\n\t"
587+
"lb zero, 15(%[sc])\n\t"
587588
"vle8.v v1, (%[sc])\n\t"
589+
"vle8.v v2, (%[bsums])\n\t"
590+
"addi %[tmp], %[bsums], 16\n\t"
588591
"vand.vi v0, v1, 0xF\n\t"
589592
"vsrl.vi v1, v1, 4\n\t"
593+
"vle8.v v3, (%[tmp])\n\t"
590594
"vse8.v v0, (%[scale])\n\t"
591595
"vsetivli zero, 16, e16, m2\n\t"
592-
"vle16.v v2, (%[bsums])\n\t"
593596
"vzext.vf2 v0, v1\n\t"
594597
"vwmul.vv v4, v0, v2\n\t"
595598
"vsetivli zero, 16, e32, m4\n\t"
@@ -608,46 +611,89 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
608611

609612
for (int j = 0; j < QK_K/128; ++j) {
610613
__asm__ __volatile__(
611-
"vsetvli zero, %[vl32], e8, m2\n\t"
614+
"lb zero, 31(%[q2])\n\t"
615+
"addi %[tmp], %[q2], 16\n\t"
616+
"addi %[t1], %[q8], 16\n\t"
617+
"vsetivli zero, 16, e8, m1\n\t"
612618
"vle8.v v0, (%[q2])\n\t"
619+
"vle8.v v1, (%[tmp])\n\t"
613620
"vsrl.vi v2, v0, 2\n\t"
621+
"vsrl.vi v3, v1, 2\n\t"
614622
"vsrl.vi v4, v0, 4\n\t"
623+
"addi %[tmp], %[q8], 32\n\t"
624+
"vle8.v v8, (%[q8])\n\t"
625+
"vle8.v v9, (%[t1])\n\t"
626+
"addi %[t1], %[t1], 32\n\t"
627+
"vsrl.vi v5, v1, 4\n\t"
615628
"vsrl.vi v6, v0, 6\n\t"
629+
"vsrl.vi v7, v1, 6\n\t"
630+
"vle8.v v10, (%[tmp])\n\t"
631+
"vle8.v v11, (%[t1])\n\t"
632+
"addi %[tmp], %[tmp], 32\n\t"
633+
"addi %[t1], %[t1], 32\n\t"
616634
"vand.vi v0, v0, 0x3\n\t"
635+
"vand.vi v1, v1, 0x3\n\t"
617636
"vand.vi v2, v2, 0x3\n\t"
637+
"vle8.v v12, (%[tmp])\n\t"
638+
"vle8.v v13, (%[t1])\n\t"
639+
"addi %[tmp], %[tmp], 32\n\t"
640+
"addi %[t1], %[t1], 32\n\t"
641+
"vand.vi v3, v3, 0x3\n\t"
618642
"vand.vi v4, v4, 0x3\n\t"
619-
"vsetvli zero, %[vl128], e8, m8\n\t"
620-
"vle8.v v8, (%[q8])\n\t"
621-
"vsetvli zero, %[vl64], e8, m4\n\t"
643+
"vand.vi v5, v5, 0x3\n\t"
644+
"vle8.v v14, (%[tmp])\n\t"
645+
"vle8.v v15, (%[t1])\n\t"
622646
"vwmul.vv v16, v0, v8\n\t"
647+
"vwmul.vv v18, v1, v9\n\t"
648+
"vwmul.vv v20, v2, v10\n\t"
649+
"vwmul.vv v22, v3, v11\n\t"
623650
"vwmul.vv v24, v4, v12\n\t"
624-
"vsetivli zero, 16, e16, m2\n\t"
651+
"vwmul.vv v26, v5, v13\n\t"
652+
"vwmul.vv v28, v6, v14\n\t"
653+
"vwmul.vv v30, v7, v15\n\t"
654+
"vsetivli zero, 8, e16, m1\n\t"
625655
"vmv.v.x v0, zero\n\t"
626-
"vwredsum.vs v10, v16, v0\n\t"
656+
"lbu %[tmp], 0(%[scale])\n\t"
657+
"vwredsum.vs v8, v16, v0\n\t"
627658
"vwredsum.vs v9, v18, v0\n\t"
628-
"vwredsum.vs v8, v20, v0\n\t"
629-
"vwredsum.vs v7, v22, v0\n\t"
630-
"vwredsum.vs v11, v24, v0\n\t"
631-
"vwredsum.vs v12, v26, v0\n\t"
632-
"vwredsum.vs v13, v28, v0\n\t"
633-
"vwredsum.vs v14, v30, v0\n\t"
659+
"lbu %[t1], 1(%[scale])\n\t"
660+
"vwredsum.vs v10, v20, v0\n\t"
661+
"vwredsum.vs v11, v22, v0\n\t"
662+
"lbu %[t2], 2(%[scale])\n\t"
663+
"vwredsum.vs v12, v24, v0\n\t"
664+
"vwredsum.vs v13, v26, v0\n\t"
665+
"lbu %[t3], 3(%[scale])\n\t"
666+
"vwredsum.vs v14, v28, v0\n\t"
667+
"vwredsum.vs v15, v30, v0\n\t"
668+
"lbu %[t4], 4(%[scale])\n\t"
669+
"vwredsum.vs v8, v17, v8\n\t"
670+
"vwredsum.vs v9, v19, v9\n\t"
671+
"lbu %[t5], 5(%[scale])\n\t"
672+
"vwredsum.vs v10, v21, v10\n\t"
673+
"vwredsum.vs v11, v23, v11\n\t"
674+
"lbu %[t6], 6(%[scale])\n\t"
675+
"vwredsum.vs v12, v25, v12\n\t"
676+
"vwredsum.vs v13, v27, v13\n\t"
677+
"lbu %[t7], 7(%[scale])\n\t"
678+
"vwredsum.vs v14, v29, v14\n\t"
679+
"vwredsum.vs v15, v31, v15\n\t"
634680
"vsetivli zero, 4, e32, m1\n\t"
635-
"vslideup.vi v10, v9, 1\n\t"
636-
"vslideup.vi v8, v7, 1\n\t"
637-
"vslideup.vi v11, v12, 1\n\t"
638-
"vslideup.vi v13, v14, 1\n\t"
639-
"vslideup.vi v10, v8, 2\n\t"
640-
"vslideup.vi v11, v13, 2\n\t"
641-
"vsetivli zero, 8, e32, m2\n\t"
642-
"vle8.v v15, (%[scale])\n\t"
643-
"vzext.vf4 v12, v15\n\t"
644-
"vmul.vv v10, v10, v12\n\t"
645-
"vredsum.vs v0, v10, v0\n\t"
681+
"vmul.vx v0, v8, %[tmp]\n\t"
682+
"vmul.vx v1, v9, %[t1]\n\t"
683+
"vmacc.vx v0, %[t2], v10\n\t"
684+
"vmacc.vx v1, %[t3], v11\n\t"
685+
"vmacc.vx v0, %[t4], v12\n\t"
686+
"vmacc.vx v1, %[t5], v13\n\t"
687+
"vmacc.vx v0, %[t6], v14\n\t"
688+
"vmacc.vx v1, %[t7], v15\n\t"
646689
"vmv.x.s %[tmp], v0\n\t"
647-
"add %[isum], %[isum], %[tmp]"
648-
: [tmp] "=&r" (tmp), [isum] "+&r" (isum)
690+
"vmv.x.s %[t1], v1\n\t"
691+
"add %[isum], %[isum], %[tmp]\n\t"
692+
"add %[isum], %[isum], %[t1]"
693+
: [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
694+
, [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
695+
, [isum] "+&r" (isum)
649696
: [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
650-
, [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
651697
: "memory"
652698
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
653699
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
@@ -929,7 +975,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
929975
const int8_t * restrict q8 = y[i].qs;
930976

931977
int8_t * scale = (int8_t *)utmp;
932-
int tmp;
978+
int tmp, t1, t2, t3, t4, t5, t6, t7;
933979
__asm__ __volatile__(
934980
"vsetivli zero, 12, e8, m1\n\t"
935981
"vle8.v v0, (%[s6b])\n\t"
@@ -967,19 +1013,23 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
9671013
int isum = 0;
9681014
for (int j = 0; j < QK_K; j += 128) {
9691015
__asm__ __volatile__(
1016+
"lb zero, 31(%[q3])\n\t"
9701017
"vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
9711018
"vle8.v v8, (%[q3])\n\t"
9721019
"vsrl.vi v10, v8, 2\n\t"
9731020
"vsrl.vi v12, v8, 4\n\t"
9741021
"vsrl.vi v14, v8, 6\n\t"
1022+
"lb zero, 64(%[q8])\n\t"
9751023
"vand.vi v8, v8, 3\n\t"
9761024
"vand.vi v10, v10, 3\n\t"
9771025
"vand.vi v12, v12, 3\n\t"
9781026
"vle8.v v2, (%[qh])\n\t"
1027+
"lb zero, 127(%[q8])\n\t"
9791028
"vand.vx v4, v2, %[m]\n\t"
9801029
"slli %[m], %[m], 1\n\t"
9811030
"vmseq.vx v0, v4, zero\n\t"
9821031
"vadd.vi v8, v8, -4, v0.t\n\t"
1032+
"lb zero, 0(%[q8])\n\t"
9831033
"vand.vx v4, v2, %[m]\n\t"
9841034
"slli %[m], %[m], 1\n\t"
9851035
"vmseq.vx v0, v4, zero\n\t"
@@ -994,34 +1044,43 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
9941044
"vadd.vi v14, v14, -4, v0.t\n\t"
9951045
"vsetvli zero, %[vl128], e8, m8\n\t"
9961046
"vle8.v v0, (%[q8])\n\t"
1047+
"lb %[tmp], 0(%[scale])\n\t"
1048+
"lb %[t1], 1(%[scale])\n\t"
1049+
"lb %[t2], 2(%[scale])\n\t"
1050+
"lb %[t3], 3(%[scale])\n\t"
9971051
"vsetvli zero, %[vl64], e8, m4\n\t"
9981052
"vwmul.vv v16, v0, v8\n\t"
9991053
"vwmul.vv v24, v4, v12\n\t"
10001054
"vsetivli zero, 16, e16, m2\n\t"
10011055
"vmv.v.x v0, zero\n\t"
1002-
"vwredsum.vs v10, v16, v0\n\t"
1056+
"vwredsum.vs v8, v16, v0\n\t"
1057+
"lb %[t4], 4(%[scale])\n\t"
1058+
"lb %[t5], 5(%[scale])\n\t"
10031059
"vwredsum.vs v9, v18, v0\n\t"
1004-
"vwredsum.vs v8, v20, v0\n\t"
1005-
"vwredsum.vs v7, v22, v0\n\t"
1006-
"vwredsum.vs v11, v24, v0\n\t"
1007-
"vwredsum.vs v12, v26, v0\n\t"
1008-
"vwredsum.vs v13, v28, v0\n\t"
1009-
"vwredsum.vs v14, v30, v0\n\t"
1060+
"vwredsum.vs v10, v20, v0\n\t"
1061+
"vwredsum.vs v11, v22, v0\n\t"
1062+
"vwredsum.vs v12, v24, v0\n\t"
1063+
"lb %[t6], 6(%[scale])\n\t"
1064+
"lb %[t7], 7(%[scale])\n\t"
1065+
"vwredsum.vs v13, v26, v0\n\t"
1066+
"vwredsum.vs v14, v28, v0\n\t"
1067+
"vwredsum.vs v15, v30, v0\n\t"
10101068
"vsetivli zero, 4, e32, m1\n\t"
1011-
"vslideup.vi v10, v9, 1\n\t"
1012-
"vslideup.vi v8, v7, 1\n\t"
1013-
"vslideup.vi v11, v12, 1\n\t"
1014-
"vslideup.vi v13, v14, 1\n\t"
1015-
"vslideup.vi v10, v8, 2\n\t"
1016-
"vslideup.vi v11, v13, 2\n\t"
1017-
"vsetivli zero, 8, e32, m2\n\t"
1018-
"vle8.v v15, (%[scale])\n\t"
1019-
"vsext.vf4 v12, v15\n\t"
1020-
"vmul.vv v10, v10, v12\n\t"
1021-
"vredsum.vs v0, v10, v0\n\t"
1069+
"vmul.vx v0, v8, %[tmp]\n\t"
1070+
"vmul.vx v1, v9, %[t1]\n\t"
1071+
"vmacc.vx v0, %[t2], v10\n\t"
1072+
"vmacc.vx v1, %[t3], v11\n\t"
1073+
"vmacc.vx v0, %[t4], v12\n\t"
1074+
"vmacc.vx v1, %[t5], v13\n\t"
1075+
"vmacc.vx v0, %[t6], v14\n\t"
1076+
"vmacc.vx v1, %[t7], v15\n\t"
10221077
"vmv.x.s %[tmp], v0\n\t"
1023-
"add %[isum], %[isum], %[tmp]"
1024-
: [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
1078+
"vmv.x.s %[t1], v1\n\t"
1079+
"add %[isum], %[isum], %[tmp]\n\t"
1080+
"add %[isum], %[isum], %[t1]"
1081+
: [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
1082+
, [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
1083+
, [m] "+&r" (m), [isum] "+&r" (isum)
10251084
: [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
10261085
, [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
10271086
: "memory"

0 commit comments

Comments
 (0)