Skip to content

Commit 5ff7ead

Browse files
committed
cpu: aarch64: clean up SIMD getters in eltwise
- Add cpu_isa_t::sve to simd_elems, where the vector length is calculated at runtime (kernel generation time) - Add simd_bytes (in addition to simd_elems) - Add get_sve_elements(data_type) - Clean up use of SIMD width/bytes in jit_uni_eltwise (including one example where we were referring to cacheline as SIMD width, which is not necessarily equivalent)
1 parent 4231561 commit 5ff7ead

File tree

2 files changed

+36
-23
lines changed

2 files changed

+36
-23
lines changed

src/cpu/aarch64/cpu_isa_traits.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,13 @@ inline uint64_t get_sve_length() {
236236
return cpu().getSveLen();
237237
}
238238

239+
// SVE length in element type
240+
inline uint64_t get_sve_length(data_type_t data_type) {
241+
const size_t dt_size = types::data_type_size(data_type);
242+
assert(dt_size > 0);
243+
return get_sve_length() / dt_size;
244+
}
245+
239246
inline bool mayiuse_atomic() {
240247
using namespace Xbyak_aarch64::util;
241248
return cpu().isAtomicSupported();
@@ -292,13 +299,15 @@ inline size_t data_type_vnni_simd_elems(data_type_t data_type) {
292299
}
293300

294301
// Maximum number of elements of a given type in a SIMD (SVE/Neon) vector for a
295-
// given ISA
302+
// given ISA. Note that if cpu_isa_t is just sve (not sve_vl) then the value is
303+
// determined at runtime (unlike the others, which can be determiend at compile time)
296304
inline size_t simd_elems(data_type_t dt, cpu_isa_t cpu_isa) {
297305
switch (cpu_isa) {
298306
case sve_512: return data_type_vnni_simd_elems<sve_512>(dt);
299307
case sve_256: return data_type_vnni_simd_elems<sve_256>(dt);
300308
case sve_128:
301309
case asimd: return data_type_vnni_simd_elems<sve_128>(dt);
310+
case sve: return get_sve_length(dt);
302311
default: {
303312
// If this ISA does implement SIMD, then you need to add support for
304313
// it in this function. If not, then you need to check earlier in
@@ -309,6 +318,10 @@ inline size_t simd_elems(data_type_t dt, cpu_isa_t cpu_isa) {
309318
}
310319
}
311320

321+
inline size_t simd_bytes(cpu_isa_t isa) {
322+
return simd_elems(data_type::s8, isa);
323+
}
324+
312325
} // namespace aarch64
313326
} // namespace cpu
314327
} // namespace impl

src/cpu/aarch64/jit_uni_eltwise.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ struct jit_uni_kernel_t : public jit_uni_eltwise_kernel_t {
101101

102102
void generate() override {
103103
const bool is_fwd = pd_->is_fwd();
104+
// Note: load type may not the same as compute type
105+
const auto simd_elems_per_load = simd_elems(data_type(), isa);
104106

105107
preamble();
106108
XReg param = param1;
@@ -116,7 +118,7 @@ struct jit_uni_kernel_t : public jit_uni_eltwise_kernel_t {
116118
ldr(reg_work_amount, ptr(X_TMP_0));
117119
eltwise_injector_->load_table_addr();
118120
Label vectorized_loop_start, remainder_loop_start, remainder_loop_end;
119-
cmp(reg_work_amount, simd_w());
121+
cmp(reg_work_amount, simd_elems_per_load);
120122
b(LT, remainder_loop_start);
121123
L(vectorized_loop_start);
122124

@@ -163,17 +165,17 @@ struct jit_uni_kernel_t : public jit_uni_eltwise_kernel_t {
163165
}
164166
}
165167

166-
const auto shift = vlen();
167168
store_vector(reg_dst, vmm_src.s);
168169
// Update pointers for the next iteration
169170
// Note: we use X_TMP_0 as a temporary register to avoid conflicts with
170171
// other registers.
171-
add_imm(reg_src, reg_src, shift, X_TMP_0);
172-
add_imm(reg_dst, reg_dst, shift, X_TMP_0);
173-
if (!is_fwd) add_imm(reg_diff_dst, reg_diff_dst, shift, X_TMP_0);
172+
add_imm(reg_src, reg_src, simd_bytes(isa), X_TMP_0);
173+
add_imm(reg_dst, reg_dst, simd_bytes(isa), X_TMP_0);
174+
if (!is_fwd)
175+
add_imm(reg_diff_dst, reg_diff_dst, simd_bytes(isa), X_TMP_0);
174176

175-
sub_imm(reg_work_amount, reg_work_amount, simd_w(), X_TMP_0);
176-
cmp(reg_work_amount, simd_w());
177+
sub_imm(reg_work_amount, reg_work_amount, simd_elems_per_load, X_TMP_0);
178+
cmp(reg_work_amount, simd_elems_per_load);
177179
b(GE, vectorized_loop_start);
178180

179181
// tail processing
@@ -214,7 +216,7 @@ struct jit_uni_kernel_t : public jit_uni_eltwise_kernel_t {
214216

215217
add_imm(reg_src, reg_src, dtype_size(), X_TMP_0);
216218
add_imm(reg_dst, reg_dst, dtype_size(), X_TMP_0);
217-
add_imm(reg_diff_dst, reg_diff_dst, dtype_size(), X_TMP_0);
219+
if (!is_fwd) add_imm(reg_diff_dst, reg_diff_dst, dtype_size(), X_TMP_0);
218220
subs(reg_work_amount, reg_work_amount, 1);
219221

220222
b(remainder_loop_start);
@@ -229,12 +231,6 @@ struct jit_uni_kernel_t : public jit_uni_eltwise_kernel_t {
229231
private:
230232
using TReg = typename cpu_isa_traits<isa>::TReg;
231233
using TRegS = typename cpu_isa_traits<isa>::TRegS;
232-
int vlen() {
233-
// TODO: If we do decide to add a different enum for
234-
// VLA SVE, we should handle this in cpu_isa_traits
235-
return isa == asimd ? cpu_isa_traits<asimd>::vlen : get_sve_length();
236-
}
237-
int simd_w() { return vlen() / dtype_size(); }
238234

239235
XReg reg_src = x11;
240236
XReg reg_dst = x8;
@@ -472,7 +468,8 @@ status_t jit_uni_eltwise_fwd_t<isa>::execute(const exec_ctx_t &ctx) const {
472468

473469
const memory_desc_wrapper data_d(pd()->src_md());
474470
const auto nelems = data_d.nelems(true);
475-
const int simd_w = 64 / data_d.data_type_size();
471+
// Number of elements in a cacheline. We don't want threads to share
472+
const int cacheline_elems = 64 / data_d.data_type_size();
476473

477474
const data_type_t src_dt = pd()->src_md()->data_type;
478475
const auto offset_bytes
@@ -484,9 +481,10 @@ status_t jit_uni_eltwise_fwd_t<isa>::execute(const exec_ctx_t &ctx) const {
484481
parallel(0, [&](const int ithr, const int nthr) {
485482
dim_t start {0}, end {0};
486483

487-
balance211(utils::div_up(nelems, simd_w), nthr, ithr, start, end);
488-
start = nstl::min(nelems, start * simd_w);
489-
end = nstl::min(nelems, end * simd_w);
484+
balance211(
485+
utils::div_up(nelems, cacheline_elems), nthr, ithr, start, end);
486+
start = nstl::min(nelems, start * cacheline_elems);
487+
end = nstl::min(nelems, end * cacheline_elems);
490488
if (start == end) return;
491489

492490
jit_args_t args;
@@ -563,7 +561,8 @@ status_t jit_uni_eltwise_bwd_t<isa>::execute(const exec_ctx_t &ctx) const {
563561
const memory_desc_wrapper data_d(pd()->data_md());
564562
const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
565563
const auto nelems = data_d.nelems(true);
566-
const int simd_w = 64 / data_d.data_type_size();
564+
// Number of elements in a cacheline. We don't want threads to share
565+
const int cacheline_elems = 64 / data_d.data_type_size();
567566

568567
const data_type_t data_dt = pd()->use_dst() ? pd()->dst_md()->data_type
569568
: pd()->src_md()->data_type;
@@ -579,9 +578,10 @@ status_t jit_uni_eltwise_bwd_t<isa>::execute(const exec_ctx_t &ctx) const {
579578
parallel(0, [&](const int ithr, const int nthr) {
580579
dim_t start {0}, end {0};
581580

582-
balance211(utils::div_up(nelems, simd_w), nthr, ithr, start, end);
583-
start = nstl::min(nelems, start * simd_w);
584-
end = nstl::min(nelems, end * simd_w);
581+
balance211(
582+
utils::div_up(nelems, cacheline_elems), nthr, ithr, start, end);
583+
start = nstl::min(nelems, start * cacheline_elems);
584+
end = nstl::min(nelems, end * cacheline_elems);
585585
if (start == end) return;
586586

587587
jit_args_t args;

0 commit comments

Comments
 (0)