@@ -3813,7 +3813,44 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
38133813 return;
38143814 }
38153815#endif
3816- #if defined(__ARM_NEON)
3816+ #if defined(__ARM_FEATURE_SVE)
3817+ const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
3818+ const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
3819+
3820+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
3821+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
3822+
3823+ assert(nb % 2 == 0); // TODO: handle odd nb
3824+
3825+ for (int i = 0; i < nb; i += 2) {
3826+ const block_q4_0 * restrict x0 = &x[i + 0];
3827+ const block_q4_0 * restrict x1 = &x[i + 1];
3828+ const block_q8_0 * restrict y0 = &y[i + 0];
3829+ const block_q8_0 * restrict y1 = &y[i + 1];
3830+
3831+ // load x
3832+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
3833+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
3834+
3835+ // 4-bit -> 8-bit
3836+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
3837+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
3838+
3839+ // sub 8
3840+ const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
3841+ const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
3842+
3843+ // load y
3844+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
3845+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
3846+
3847+ // dot product
3848+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3849+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3850+ }
3851+
3852+ *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
3853+ #elif defined(__ARM_NEON)
38173854 float32x4_t sumv0 = vdupq_n_f32(0.0f);
38183855 float32x4_t sumv1 = vdupq_n_f32(0.0f);
38193856
@@ -5384,7 +5421,32 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
53845421 return;
53855422 }
53865423#endif
5387- #if defined(__ARM_NEON)
5424+ #if defined(__ARM_FEATURE_SVE)
5425+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
5426+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
5427+
5428+ assert(nb % 2 == 0); // TODO: handle odd nb
5429+
5430+ for (int i = 0; i < nb; i += 2) {
5431+ const block_q8_0 * restrict x0 = &x[i + 0];
5432+ const block_q8_0 * restrict x1 = &x[i + 1];
5433+ const block_q8_0 * restrict y0 = &y[i + 0];
5434+ const block_q8_0 * restrict y1 = &y[i + 1];
5435+
5436+ // load x
5437+ const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5438+ const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5439+
5440+ // load y
5441+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5442+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5443+
5444+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5445+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5446+ }
5447+
5448+ *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5449+ #elif defined(__ARM_NEON)
53885450 float32x4_t sumv0 = vdupq_n_f32(0.0f);
53895451 float32x4_t sumv1 = vdupq_n_f32(0.0f);
53905452
0 commit comments