Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions csrc/cpu/cpu_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
#endif

#ifdef __aarch64__
#include "cpu_attn_neon.hpp"
#define NEON_DISPATCH(...) \
case cpu_attention::ISA::NEON: { \
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \
scalar_t, head_dim>; \
return __VA_ARGS__(); \
}
#else
#define NEON_DISPATCH(...) case cpu_attention::ISA::NEON:
#endif // #ifdef __aarch64__

#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
case HEAD_DIM: { \
constexpr size_t head_dim = HEAD_DIM; \
Expand Down Expand Up @@ -41,6 +53,7 @@
[&] { \
switch (ISA_TYPE) { \
AMX_DISPATCH(__VA_ARGS__) \
NEON_DISPATCH(__VA_ARGS__) \
case cpu_attention::ISA::VEC: { \
using attn_impl = \
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
Expand Down Expand Up @@ -73,6 +86,8 @@ torch::Tensor get_scheduler_metadata(
isa = cpu_attention::ISA::VEC;
} else if (isa_hint == "vec16") {
isa = cpu_attention::ISA::VEC16;
} else if (isa_hint == "neon") {
isa = cpu_attention::ISA::NEON;
} else {
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
}
Expand Down Expand Up @@ -158,6 +173,8 @@ void cpu_attn_reshape_and_cache(
return cpu_attention::ISA::VEC;
} else if (isa == "vec16") {
return cpu_attention::ISA::VEC16;
} else if (isa == "neon") {
return cpu_attention::ISA::NEON;
} else {
TORCH_CHECK(false, "Invalid ISA type: " + isa);
}
Expand Down
8 changes: 7 additions & 1 deletion csrc/cpu/cpu_attn_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "utils.hpp"

namespace cpu_attention {
enum class ISA { AMX, VEC, VEC16 };
enum class ISA { AMX, VEC, VEC16, NEON };

template <ISA isa, typename scalar_t, int64_t head_dim>
class AttentionImpl {};
Expand Down Expand Up @@ -143,6 +143,12 @@ struct AttentionMetadata {
case ISA::VEC:
ss << "VEC, ";
break;
case ISA::VEC16:
ss << "VEC16, ";
break;
case ISA::NEON:
ss << "NEON, ";
break;
}
ss << "workitem_group_num: " << workitem_group_num
<< ", reduction_item_num: " << reduction_item_num
Expand Down
Loading