diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp index 50f17c758c14..92f8bee5a47a 100644 --- a/csrc/cpu/cpu_attn.cpp +++ b/csrc/cpu/cpu_attn.cpp @@ -13,6 +13,18 @@ #define AMX_DISPATCH(...) case cpu_attention::ISA::AMX: #endif +#ifdef __aarch64__ + #include "cpu_attn_neon.hpp" + #define NEON_DISPATCH(...) \ + case cpu_attention::ISA::NEON: { \ + using attn_impl = cpu_attention::AttentionImpl; \ + return __VA_ARGS__(); \ + } +#else + #define NEON_DISPATCH(...) case cpu_attention::ISA::NEON: +#endif // #ifdef __aarch64__ + #define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \ case HEAD_DIM: { \ constexpr size_t head_dim = HEAD_DIM; \ @@ -41,6 +53,7 @@ [&] { \ switch (ISA_TYPE) { \ AMX_DISPATCH(__VA_ARGS__) \ + NEON_DISPATCH(__VA_ARGS__) \ case cpu_attention::ISA::VEC: { \ using attn_impl = \ cpu_attention::AttentionImpl class AttentionImpl {}; @@ -143,6 +143,12 @@ struct AttentionMetadata { case ISA::VEC: ss << "VEC, "; break; + case ISA::VEC16: + ss << "VEC16, "; + break; + case ISA::NEON: + ss << "NEON, "; + break; } ss << "workitem_group_num: " << workitem_group_num << ", reduction_item_num: " << reduction_item_num diff --git a/csrc/cpu/cpu_attn_neon.hpp b/csrc/cpu/cpu_attn_neon.hpp new file mode 100644 index 000000000000..f901dd259c80 --- /dev/null +++ b/csrc/cpu/cpu_attn_neon.hpp @@ -0,0 +1,379 @@ +#ifndef CPU_ATTN_NEON_HPP +#define CPU_ATTN_NEON_HPP + +#include "cpu_attn_impl.hpp" +#include +#include +namespace cpu_attention { + +namespace { + +#define BLOCK_SIZE_ALIGNMENT 32 +#define HEAD_SIZE_ALIGNMENT 32 +#define MAX_Q_HEAD_NUM_PER_ITER 16 + +// These do not use vectorized class for loading / converting +// because csrc/cpu/cpu_types_arm.hpp does not have fallback options +// for vec_op::BF16Vec* / vec_op::BF16Vec* on Arm HW that +// doesn't support BF16. +// We don't use vec_op::FP32Vec* or vec_op::FP16Vec* for consistency. +template +FORCE_INLINE void load_row8_B_as_f32(const kv_cache_t* p, float32x4_t& b0, + float32x4_t& b1); + +template <> +FORCE_INLINE void load_row8_B_as_f32(const float* p, float32x4_t& b0, + float32x4_t& b1) { + b0 = vld1q_f32(p + 0); + b1 = vld1q_f32(p + 4); +} + +template <> +FORCE_INLINE void load_row8_B_as_f32(const c10::Half* p, + float32x4_t& b0, + float32x4_t& b1) { + const float16_t* h = reinterpret_cast(p); + float16x8_t v = vld1q_f16(h); + b0 = vcvt_f32_f16(vget_low_f16(v)); + b1 = vcvt_f32_f16(vget_high_f16(v)); +} + +template <> +FORCE_INLINE void load_row8_B_as_f32(const c10::BFloat16* p, + float32x4_t& b0, + float32x4_t& b1) { + const uint16_t* u = reinterpret_cast(p); +#ifdef ARM_BF16_SUPPORT + uint16x8_t u0 = vld1q_u16(u); + bfloat16x8_t bf0 = vreinterpretq_bf16_u16(u0); + b0 = vcvtq_low_f32_bf16(bf0); + b1 = vcvtq_high_f32_bf16(bf0); +#else + uint16x8_t x0 = vld1q_u16(u); + uint32x4_t lo = vshlq_n_u32(vmovl_u16(vget_low_u16(x0)), 16); + uint32x4_t hi = vshlq_n_u32(vmovl_u16(vget_high_u16(x0)), 16); + b0 = vreinterpretq_f32_u32(lo); + b1 = vreinterpretq_f32_u32(hi); +#endif +} + +#define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7) +#define IF_M(i) if constexpr (M > (i)) + +// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with NEON FMLAs +// #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2) +// #FMLAs = (K // 4) * (4 * 2 * M) +// We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads +template +FORCE_INLINE void gemm_micro_neon_fmla_Mx8_Ku4( + const float* __restrict A, // [M x K], + const kv_cache_t* __restrict B, // [K x 8], + float* __restrict C, // [M x 8], + int64_t lda, int64_t ldb, int64_t ldc, int32_t K, bool accumulate) { + // kernel supports max M of 8, as it'd spill for larger M + static_assert(1 <= M && M <= 8, "M must be in [1,8]"); + + // A row base pointers +#define DECL_A(i) const float* a##i = A + (i) * lda; + ROWS_APPLY(DECL_A) +#undef DECL_A + + // declare 2 accumulators per row of M +#define DECL_ACC(i) float32x4_t acc##i##_0, acc##i##_1; + ROWS_APPLY(DECL_ACC) +#undef DECL_ACC + + // initialize accumulators +#define INIT_ACC(i) \ + IF_M(i) { \ + if (accumulate) { \ + acc##i##_0 = vld1q_f32(C + (i) * ldc + 0); \ + acc##i##_1 = vld1q_f32(C + (i) * ldc + 4); \ + } else { \ + acc##i##_0 = vdupq_n_f32(0.f); \ + acc##i##_1 = vdupq_n_f32(0.f); \ + } \ + } + ROWS_APPLY(INIT_ACC) +#undef INIT_ACC + + int32_t k = 0; + + // K unrolled by 4 + for (; k + 3 < K; k += 4) { + // load A[k..k+3] for each active row (M) +#define LOAD_A4(i) \ + float32x4_t a##i##v; \ + IF_M(i) a##i##v = vld1q_f32(a##i + k); + ROWS_APPLY(LOAD_A4) +#undef LOAD_A4 + + // helper: FMA lane L from aiv +#define FMAS_LANE(i, aiv, L) \ + IF_M(i) { \ + acc##i##_0 = vfmaq_laneq_f32(acc##i##_0, b0, aiv, L); \ + acc##i##_1 = vfmaq_laneq_f32(acc##i##_1, b1, aiv, L); \ + } + + // k + 0 + { + float32x4_t b0, b1; + load_row8_B_as_f32(B + (int64_t)(k + 0) * ldb, b0, b1); +#define STEP_K0(i) FMAS_LANE(i, a##i##v, 0) + ROWS_APPLY(STEP_K0) +#undef STEP_K0 + } + // k + 1 + { + float32x4_t b0, b1; + load_row8_B_as_f32(B + (int64_t)(k + 1) * ldb, b0, b1); +#define STEP_K1(i) FMAS_LANE(i, a##i##v, 1) + ROWS_APPLY(STEP_K1) +#undef STEP_K1 + } + // k + 2 + { + float32x4_t b0, b1; + load_row8_B_as_f32(B + (int64_t)(k + 2) * ldb, b0, b1); +#define STEP_K2(i) FMAS_LANE(i, a##i##v, 2) + ROWS_APPLY(STEP_K2) +#undef STEP_K2 + } + // k + 3 + { + float32x4_t b0, b1; + load_row8_B_as_f32(B + (int64_t)(k + 3) * ldb, b0, b1); +#define STEP_K3(i) FMAS_LANE(i, a##i##v, 3) + ROWS_APPLY(STEP_K3) +#undef STEP_K3 + } +#undef FMAS_LANE + } + + // K tail + for (; k < K; ++k) { + float32x4_t b0, b1; + load_row8_B_as_f32(B + (int64_t)k * ldb, b0, b1); +#define TAIL_ROW(i) \ + IF_M(i) { \ + float32x4_t ai = vdupq_n_f32(*(a##i + k)); \ + acc##i##_0 = vfmaq_f32(acc##i##_0, b0, ai); \ + acc##i##_1 = vfmaq_f32(acc##i##_1, b1, ai); \ + } + ROWS_APPLY(TAIL_ROW) +#undef TAIL_ROW + } + + // store accumulators to C +#define STORE_ROW(i) \ + IF_M(i) { \ + vst1q_f32(C + (i) * ldc + 0, acc##i##_0); \ + vst1q_f32(C + (i) * ldc + 4, acc##i##_1); \ + } + ROWS_APPLY(STORE_ROW) +#undef STORE_ROW +} + +template +FORCE_INLINE void gemm_macro_neon_fmla_Mx8_Ku4(const float* __restrict A, + const kv_cache_t* __restrict B, + float* __restrict C, int64_t M, + int32_t K, int64_t lda, + int64_t ldb, int64_t ldc, + bool accumulate) { + for (int64_t m = 0; m < M;) { + int mb = (M - m >= 8) ? 8 : (M - m >= 4) ? 4 : (M - m >= 2) ? 2 : 1; + const float* Ab = A + m * lda; + float* Cb = C + m * ldc; + + for (int64_t n = 0; n < N; n += 8) { + const kv_cache_t* Bn = B + n; + float* Cn = Cb + n; + switch (mb) { + case 8: + gemm_micro_neon_fmla_Mx8_Ku4<8, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, + K, accumulate); + break; + case 4: + gemm_micro_neon_fmla_Mx8_Ku4<4, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, + K, accumulate); + break; + case 2: + gemm_micro_neon_fmla_Mx8_Ku4<2, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, + K, accumulate); + break; + default: + gemm_micro_neon_fmla_Mx8_Ku4<1, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, + K, accumulate); + break; + } + } + m += mb; + } +} + +template +class TileGemmNeonFMLA { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, + float* __restrict__ a_tile, + kv_cache_t* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + const int64_t ldb, const int64_t ldc, + const int32_t block_size, + const int32_t dynamic_k_size, + const bool accum_c) { + if constexpr (phase == AttentionGemmPhase::QK) { + gemm_macro_neon_fmla_Mx8_Ku4( + a_tile, b_tile, c_tile, m_size, k_size, lda, ldb, ldc, accum_c); + } else { + gemm_macro_neon_fmla_Mx8_Ku4( + a_tile, b_tile, c_tile, m_size, dynamic_k_size, lda, ldb, ldc, + accum_c); + } + } +}; + +} // namespace + +// this is similar to "ISA::VEC" at the moment +template +class AttentionImpl { + public: + using query_t = scalar_t; + using q_buffer_t = float; + using kv_cache_t = scalar_t; + using logits_buffer_t = float; + using partial_output_buffer_t = float; + using prob_buffer_t = float; + + constexpr static int64_t BlockSizeAlignment = + BLOCK_SIZE_ALIGNMENT; // KV token num unit of QK and PV phases + constexpr static int64_t HeadDimAlignment = + HEAD_SIZE_ALIGNMENT; // headdim num unit of PV phase + constexpr static int64_t MaxQHeadNumPerIteration = MAX_Q_HEAD_NUM_PER_ITER; + constexpr static int64_t HeadDim = head_dim; + constexpr static ISA ISAType = ISA::NEON; + constexpr static bool scale_on_logits = false; // apply scale on q_buffer + + static_assert(HeadDim % HeadDimAlignment == 0); + // the gemm micro kernel is Mx16 + static_assert(HeadDimAlignment % 16 == 0); + static_assert(BlockSizeAlignment % 16 == 0); + + public: + template