Skip to content

Commit 1846ffd

Browse files
committed
Accelerate Arm CPU Attention GEMMs with NEON
PR #27954 added cpu_attention_with_kv_cache which supports chucked prefill, prefix caching, SWA, alibi, softcap and sinks. However, it's currently disabled for prefill on Arm CPUs because it's slower than torch.sdpa for relatively long prefills. Hence chunked prefill, prefix caching, sinks, etc remained unsupported on Arm. This PR accelerates cpu_attention_with_kv_cache on Arm CPUs by introducing NEON accelerated GEMMs (enabled with ISA::NEON) for QK and PV. With the new GEMMs, performance of cpu_attention_with_kv_cache is similar to torch.sdpa for long prefills, which allows us to enable cpu_attention_with_kv_cache for prefill path on Arm and thus enable chunked prefill, prefix caching, sinks, alibi, softcap, etc. Performance: Uplift with ISA::NEON vs ISA::VEC: For batch size = 64, query tokens = kv tokens = 512, q heads = 32, kv heads - 8, head size = 128, block size = 128: using ISA::NEON for cpu_attention_with_kv_cache accelerates prefill attention by 2x compared to the current state with ISA::VEC For the throughput benchmark below on Arm Neoverse-V2, using cpu_attention_with_kv_cache for prefills and decodes: ISA::NEON yields ~ %13 higher throughput than ISA::VEC and similar throughput to using torch.sdpa for prefill. ``` export VLLM_CPU_OMP_THREADS_BIND=0-63 export LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libtcmalloc_minimal.so.4:/usr/lib/aarch64-linux-gnu/libgomp.so.1" export VLLM_TARGET_DEVICE=cpu export VLLM_CPU_KVCACHE_SPACE=64 vllm bench throughput \ --num-prompts 128 \ --seed 0 \ --dataset-name sharegpt \ --input-len 1024 \ --output-len 128 \ --max-model-len 2048 \ --max-num-batched-tokens 8192 \ --model meta-llama/Llama-3.1-8B-Instruct \ --load-format dummy ``` Future PRs will accelerate attention further by introducing faster/vectorized exp implementations and leveraging bfmmla/bfdot for QK, PV on Arm CPUs with bf16. Signed-off-by: Fadi Arafeh <[email protected]>
1 parent b7f1f49 commit 1846ffd

File tree

5 files changed

+409
-5
lines changed

5 files changed

+409
-5
lines changed

csrc/cpu/cpu_attn.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,18 @@
1313
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
1414
#endif
1515

16+
#ifdef __aarch64__
17+
#include "cpu_attn_neon.hpp"
18+
#define NEON_DISPATCH(...) \
19+
case cpu_attention::ISA::NEON: { \
20+
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \
21+
scalar_t, head_dim>; \
22+
return __VA_ARGS__(); \
23+
}
24+
#else
25+
#define NEON_DISPATCH(...) case cpu_attention::ISA::NEON:
26+
#endif // #ifdef __aarch64__
27+
1628
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
1729
case HEAD_DIM: { \
1830
constexpr size_t head_dim = HEAD_DIM; \
@@ -41,6 +53,7 @@
4153
[&] { \
4254
switch (ISA_TYPE) { \
4355
AMX_DISPATCH(__VA_ARGS__) \
56+
NEON_DISPATCH(__VA_ARGS__) \
4457
case cpu_attention::ISA::VEC: { \
4558
using attn_impl = \
4659
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
@@ -73,6 +86,8 @@ torch::Tensor get_scheduler_metadata(
7386
isa = cpu_attention::ISA::VEC;
7487
} else if (isa_hint == "vec16") {
7588
isa = cpu_attention::ISA::VEC16;
89+
} else if (isa_hint == "neon") {
90+
isa = cpu_attention::ISA::NEON;
7691
} else {
7792
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
7893
}
@@ -158,6 +173,8 @@ void cpu_attn_reshape_and_cache(
158173
return cpu_attention::ISA::VEC;
159174
} else if (isa == "vec16") {
160175
return cpu_attention::ISA::VEC16;
176+
} else if (isa == "neon") {
177+
return cpu_attention::ISA::NEON;
161178
} else {
162179
TORCH_CHECK(false, "Invalid ISA type: " + isa);
163180
}

csrc/cpu/cpu_attn_impl.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "utils.hpp"
1515

1616
namespace cpu_attention {
17-
enum class ISA { AMX, VEC, VEC16 };
17+
enum class ISA { AMX, VEC, VEC16, NEON };
1818

1919
template <ISA isa, typename scalar_t, int64_t head_dim>
2020
class AttentionImpl {};
@@ -143,6 +143,12 @@ struct AttentionMetadata {
143143
case ISA::VEC:
144144
ss << "VEC, ";
145145
break;
146+
case ISA::VEC16:
147+
ss << "VEC16, ";
148+
break;
149+
case ISA::NEON:
150+
ss << "NEON, ";
151+
break;
146152
}
147153
ss << "workitem_group_num: " << workitem_group_num
148154
<< ", reduction_item_num: " << reduction_item_num

0 commit comments

Comments
 (0)