@@ -101,6 +101,8 @@ struct jit_uni_kernel_t : public jit_uni_eltwise_kernel_t {
101101
102102 void generate () override {
103103 const bool is_fwd = pd_->is_fwd ();
104+ // Note: load type may not the same as compute type
105+ const auto simd_elems_per_load = simd_elems (data_type (), isa);
104106
105107 preamble ();
106108 XReg param = param1;
@@ -116,7 +118,7 @@ struct jit_uni_kernel_t : public jit_uni_eltwise_kernel_t {
116118 ldr (reg_work_amount, ptr (X_TMP_0));
117119 eltwise_injector_->load_table_addr ();
118120 Label vectorized_loop_start, remainder_loop_start, remainder_loop_end;
119- cmp (reg_work_amount, simd_w () );
121+ cmp (reg_work_amount, simd_elems_per_load );
120122 b (LT, remainder_loop_start);
121123 L (vectorized_loop_start);
122124
@@ -163,17 +165,17 @@ struct jit_uni_kernel_t : public jit_uni_eltwise_kernel_t {
163165 }
164166 }
165167
166- const auto shift = vlen ();
167168 store_vector (reg_dst, vmm_src.s );
168169 // Update pointers for the next iteration
169170 // Note: we use X_TMP_0 as a temporary register to avoid conflicts with
170171 // other registers.
171- add_imm (reg_src, reg_src, shift, X_TMP_0);
172- add_imm (reg_dst, reg_dst, shift, X_TMP_0);
173- if (!is_fwd) add_imm (reg_diff_dst, reg_diff_dst, shift, X_TMP_0);
172+ add_imm (reg_src, reg_src, simd_bytes (isa), X_TMP_0);
173+ add_imm (reg_dst, reg_dst, simd_bytes (isa), X_TMP_0);
174+ if (!is_fwd)
175+ add_imm (reg_diff_dst, reg_diff_dst, simd_bytes (isa), X_TMP_0);
174176
175- sub_imm (reg_work_amount, reg_work_amount, simd_w () , X_TMP_0);
176- cmp (reg_work_amount, simd_w () );
177+ sub_imm (reg_work_amount, reg_work_amount, simd_elems_per_load , X_TMP_0);
178+ cmp (reg_work_amount, simd_elems_per_load );
177179 b (GE, vectorized_loop_start);
178180
179181 // tail processing
@@ -214,7 +216,7 @@ struct jit_uni_kernel_t : public jit_uni_eltwise_kernel_t {
214216
215217 add_imm (reg_src, reg_src, dtype_size (), X_TMP_0);
216218 add_imm (reg_dst, reg_dst, dtype_size (), X_TMP_0);
217- add_imm (reg_diff_dst, reg_diff_dst, dtype_size (), X_TMP_0);
219+ if (!is_fwd) add_imm (reg_diff_dst, reg_diff_dst, dtype_size (), X_TMP_0);
218220 subs (reg_work_amount, reg_work_amount, 1 );
219221
220222 b (remainder_loop_start);
@@ -229,12 +231,6 @@ struct jit_uni_kernel_t : public jit_uni_eltwise_kernel_t {
229231private:
230232 using TReg = typename cpu_isa_traits<isa>::TReg;
231233 using TRegS = typename cpu_isa_traits<isa>::TRegS;
232- int vlen () {
233- // TODO: If we do decide to add a different enum for
234- // VLA SVE, we should handle this in cpu_isa_traits
235- return isa == asimd ? cpu_isa_traits<asimd>::vlen : get_sve_length ();
236- }
237- int simd_w () { return vlen () / dtype_size (); }
238234
239235 XReg reg_src = x11;
240236 XReg reg_dst = x8;
@@ -472,7 +468,8 @@ status_t jit_uni_eltwise_fwd_t<isa>::execute(const exec_ctx_t &ctx) const {
472468
473469 const memory_desc_wrapper data_d (pd ()->src_md ());
474470 const auto nelems = data_d.nelems (true );
475- const int simd_w = 64 / data_d.data_type_size ();
471+ // Number of elements in a cacheline. We don't want threads to share
472+ const int cacheline_elems = 64 / data_d.data_type_size ();
476473
477474 const data_type_t src_dt = pd ()->src_md ()->data_type ;
478475 const auto offset_bytes
@@ -484,9 +481,10 @@ status_t jit_uni_eltwise_fwd_t<isa>::execute(const exec_ctx_t &ctx) const {
484481 parallel (0 , [&](const int ithr, const int nthr) {
485482 dim_t start {0 }, end {0 };
486483
487- balance211 (utils::div_up (nelems, simd_w), nthr, ithr, start, end);
488- start = nstl::min (nelems, start * simd_w);
489- end = nstl::min (nelems, end * simd_w);
484+ balance211 (
485+ utils::div_up (nelems, cacheline_elems), nthr, ithr, start, end);
486+ start = nstl::min (nelems, start * cacheline_elems);
487+ end = nstl::min (nelems, end * cacheline_elems);
490488 if (start == end) return ;
491489
492490 jit_args_t args;
@@ -563,7 +561,8 @@ status_t jit_uni_eltwise_bwd_t<isa>::execute(const exec_ctx_t &ctx) const {
563561 const memory_desc_wrapper data_d (pd ()->data_md ());
564562 const memory_desc_wrapper diff_data_d (pd ()->diff_src_md ());
565563 const auto nelems = data_d.nelems (true );
566- const int simd_w = 64 / data_d.data_type_size ();
564+ // Number of elements in a cacheline. We don't want threads to share
565+ const int cacheline_elems = 64 / data_d.data_type_size ();
567566
568567 const data_type_t data_dt = pd ()->use_dst () ? pd ()->dst_md ()->data_type
569568 : pd ()->src_md ()->data_type ;
@@ -579,9 +578,10 @@ status_t jit_uni_eltwise_bwd_t<isa>::execute(const exec_ctx_t &ctx) const {
579578 parallel (0 , [&](const int ithr, const int nthr) {
580579 dim_t start {0 }, end {0 };
581580
582- balance211 (utils::div_up (nelems, simd_w), nthr, ithr, start, end);
583- start = nstl::min (nelems, start * simd_w);
584- end = nstl::min (nelems, end * simd_w);
581+ balance211 (
582+ utils::div_up (nelems, cacheline_elems), nthr, ithr, start, end);
583+ start = nstl::min (nelems, start * cacheline_elems);
584+ end = nstl::min (nelems, end * cacheline_elems);
585585 if (start == end) return ;
586586
587587 jit_args_t args;
0 commit comments