Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 50 additions & 28 deletions src/cpu/aarch64/cpu_isa_traits.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*******************************************************************************
* Copyright 2018 Intel Corporation
* Copyright 2020-2024 FUJITSU LIMITED
* Copyright 2023, 2025 Arm Ltd. and affiliates
* Copyright 2023, 2025, 2026 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -53,27 +53,25 @@ enum {
dnnl_cpu_isa_sve_128 = 0x3,
/// AARCH64 SVE 256 bits
dnnl_cpu_isa_sve_256 = 0x7,
/// AARCH64 SVE 384 bits
dnnl_cpu_isa_sve_384 = 0xf,
/// AARCH64 SVE 512 bits
dnnl_cpu_isa_sve_512 = 0x27,
};

enum cpu_isa_bit_t : unsigned {
asimd_bit = 1u << 0,
sve_128_bit = 1u << 1,
sve_256_bit = 1u << 2,
sve_384_bit = 1u << 3,
sve_bit = 1u << 1,
sve_128_bit = 1u << 2,
sve_256_bit = 1u << 3,
sve_512_bit = 1u << 4,
};

enum cpu_isa_t : unsigned {
isa_undef = 0u,
asimd = asimd_bit,
sve_128 = sve_128_bit | asimd,
sve = sve_bit | asimd,
sve_128 = sve_128_bit | sve,
sve_256 = sve_256_bit | sve_128,
sve_384 = sve_384_bit | sve_256,
sve_512 = sve_512_bit | sve_384,
sve_512 = sve_512_bit | sve_256,
isa_all = ~0u,
};

Expand Down Expand Up @@ -102,9 +100,9 @@ cpu_isa_t DNNL_API get_max_cpu_isa_mask(bool soft = false);
status_t set_max_cpu_isa(dnnl_cpu_isa_t isa);
dnnl_cpu_isa_t get_effective_cpu_isa();

// If isa is a superset of sve_128, return sve_128, else return isa
// Reduce any vector length specific cpu_isa_t to vector length agnostic (VLA) sve
constexpr cpu_isa_t to_vla_sve(const cpu_isa_t isa) {
return (cpu_isa_t)(isa & sve_128);
return (cpu_isa_t)(isa & sve);
}

static inline bool compare_isa(
Expand Down Expand Up @@ -153,17 +151,21 @@ struct cpu_isa_traits<asimd> {
static constexpr const char *user_option_env = "advanced_simd";
};

template <>
struct cpu_isa_traits<sve> {
using TReg = Xbyak_aarch64::ZReg;
using TRegB = Xbyak_aarch64::ZRegB;
using TRegH = Xbyak_aarch64::ZRegH;
using TRegS = Xbyak_aarch64::ZRegS;
using TRegD = Xbyak_aarch64::ZRegD;
static constexpr int n_vregs = 32;
};

#define CPU_ISA_SVE(bits, shift) \
template <> \
struct cpu_isa_traits<sve_##bits> { \
typedef Xbyak_aarch64::ZReg TReg; \
typedef Xbyak_aarch64::ZRegB TRegB; \
typedef Xbyak_aarch64::ZRegH TRegH; \
typedef Xbyak_aarch64::ZRegS TRegS; \
typedef Xbyak_aarch64::ZRegD TRegD; \
struct cpu_isa_traits<sve_##bits> : public cpu_isa_traits<sve> { \
static constexpr int vlen_shift = shift; \
static constexpr int vlen = (bits) / 8; \
static constexpr int n_vregs = 32; \
static constexpr dnnl_cpu_isa_t user_option_val \
= static_cast<dnnl_cpu_isa_t>(dnnl_cpu_isa_sve_##bits); \
static constexpr const char *user_option_env = "sve_ ## bits"; \
Expand All @@ -189,15 +191,13 @@ inline bool mayiuse(const cpu_isa_t cpu_isa, bool soft = false) {

switch (cpu_isa) {
case asimd: return cpu().has(XBYAK_AARCH64_HWCAP_ADVSIMD);
case sve: return cpu().has(XBYAK_AARCH64_HWCAP_SVE);
case sve_128:
return cpu().has(XBYAK_AARCH64_HWCAP_SVE)
&& cpu().getSveLen() >= SVE_128;
case sve_256:
return cpu().has(XBYAK_AARCH64_HWCAP_SVE)
&& cpu().getSveLen() >= SVE_256;
case sve_384:
return cpu().has(XBYAK_AARCH64_HWCAP_SVE)
&& cpu().getSveLen() >= SVE_384;
case sve_512:
return cpu().has(XBYAK_AARCH64_HWCAP_SVE)
&& cpu().getSveLen() >= SVE_512;
Expand All @@ -207,22 +207,33 @@ inline bool mayiuse(const cpu_isa_t cpu_isa, bool soft = false) {
return false;
}

// SVE length in bytes
inline uint64_t get_sve_length() {
return cpu().getSveLen();
}

// SVE length in element type
inline uint64_t get_sve_length(data_type_t data_type) {
const size_t dt_size = types::data_type_size(data_type);
assert(dt_size > 0);
return get_sve_length() / dt_size;
}

inline int isa_max_vlen(cpu_isa_t isa) {
if (isa == sve_512)
return cpu_isa_traits<sve_512>::vlen;
else if (isa == sve_256)
return cpu_isa_traits<sve_256>::vlen;
else if (isa == sve_128)
return cpu_isa_traits<sve_128>::vlen;
else if (isa == sve)
return get_sve_length();
else if (isa == asimd)
return cpu_isa_traits<asimd>::vlen;
else
return 0;
}

// SVE length in bytes
inline uint64_t get_sve_length() {
return cpu().getSveLen();
}

inline bool mayiuse_atomic() {
using namespace Xbyak_aarch64::util;
return cpu().isAtomicSupported();
Expand All @@ -244,6 +255,10 @@ inline int isa_num_vregs(cpu_isa_t isa) {
return cpu_isa_traits<sve_256>::n_vregs;
else if (isa == sve_128)
return cpu_isa_traits<sve_128>::n_vregs;
else if (isa == sve)
return cpu_isa_traits<sve>::n_vregs;
else if (isa == asimd)
return cpu_isa_traits<asimd>::n_vregs;
else
return 0;
}
Expand All @@ -256,10 +271,11 @@ inline int isa_num_vregs(cpu_isa_t isa) {
#define JIT_IMPL_NAME_HELPER(prefix, isa, suffix_if_any) \
((isa) == isa_undef ? prefix STRINGIFY(any) : \
((isa) == asimd ? prefix STRINGIFY(asimd) : \
((isa) == sve ? prefix STRINGIFY(sve) : \
((isa) == sve_128 ? prefix STRINGIFY(sve_128) : \
((isa) == sve_256 ? prefix STRINGIFY(sve_256) : \
((isa) == sve_512 ? prefix STRINGIFY(sve_512) : \
prefix suffix_if_any)))))
prefix suffix_if_any))))))
/* clang-format on */

inline size_t data_type_vnni_granularity(data_type_t data_type) {
Expand All @@ -285,13 +301,15 @@ inline size_t data_type_vnni_simd_elems(data_type_t data_type) {
}

// Maximum number of elements of a given type in a SIMD (SVE/Neon) vector for a
// given ISA
// given ISA. Note that if cpu_isa_t is just sve (not sve_vl) then the value is
// determined at runtime (unlike the others, which can be determiend at compile time)
inline size_t simd_elems(data_type_t dt, cpu_isa_t cpu_isa) {
switch (cpu_isa) {
case sve_512: return data_type_vnni_simd_elems<sve_512>(dt);
case sve_256: return data_type_vnni_simd_elems<sve_256>(dt);
case sve_128:
case asimd: return data_type_vnni_simd_elems<sve_128>(dt);
case sve: return get_sve_length(dt);
default: {
// If this ISA does implement SIMD, then you need to add support for
// it in this function. If not, then you need to check earlier in
Expand All @@ -302,6 +320,10 @@ inline size_t simd_elems(data_type_t dt, cpu_isa_t cpu_isa) {
}
}

inline size_t simd_bytes(cpu_isa_t isa) {
return simd_elems(data_type::s8, isa);
}

} // namespace aarch64
} // namespace cpu
} // namespace impl
Expand Down
11 changes: 5 additions & 6 deletions src/cpu/aarch64/injectors/jit_uni_eltwise_injector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace aarch64 {
namespace eltwise_injector {

bool is_isa_supported(cpu_isa_t isa) {
return isa == asimd || isa == sve_128;
return isa == asimd || isa == sve;
}

bool is_alg_supported(alg_kind_t alg) {
Expand Down Expand Up @@ -2485,7 +2485,7 @@ size_t jit_uni_eltwise_injector_t<asimd>::get_vec_len() {
}

template <>
size_t jit_uni_eltwise_injector_t<sve_128>::get_vec_len() {
size_t jit_uni_eltwise_injector_t<sve>::get_vec_len() {
return get_sve_length();
}

Expand Down Expand Up @@ -2719,7 +2719,7 @@ void jit_uni_eltwise_injector_t<asimd>::load_1word_replicate(
h->ld1r(TRegS(vmm_src.getIdx()), ptr(x_addr));
}
template <>
void jit_uni_eltwise_injector_t<sve_128>::load_1word_replicate(
void jit_uni_eltwise_injector_t<sve>::load_1word_replicate(
const TRegS &vmm_src, Xbyak_aarch64::XReg x_addr) {
h->ld1rw(TRegS(vmm_src.getIdx()), p_all, ptr(x_addr));
}
Expand All @@ -2730,7 +2730,7 @@ void jit_uni_eltwise_injector_t<asimd>::load_vector(
h->ldr(QReg(vmm_src.getIdx()), ptr(x_addr));
}
template <>
void jit_uni_eltwise_injector_t<sve_128>::load_vector(
void jit_uni_eltwise_injector_t<sve>::load_vector(
const TRegS &vmm_src, Xbyak_aarch64::XReg x_addr) {
h->ldr(TReg(vmm_src.getIdx()), ptr(x_addr));
}
Expand Down Expand Up @@ -2778,8 +2778,7 @@ template <>
void jit_uni_eltwise_injector_t<asimd>::blend_with_mask(
const TRegS &vmm_dst, const TRegS &src) {};

// We only need sve_128 as the injector is fully vector length agnostic.
template struct jit_uni_eltwise_injector_t<sve_128>;
template struct jit_uni_eltwise_injector_t<sve>;
template struct jit_uni_eltwise_injector_t<asimd>;

} // namespace aarch64
Expand Down
Loading
Loading