diff --git a/src/cpu/rv64/rvv_softmax.cpp b/src/cpu/rv64/rvv_softmax.cpp index 75e3c86166a..d6c2179cf95 100644 --- a/src/cpu/rv64/rvv_softmax.cpp +++ b/src/cpu/rv64/rvv_softmax.cpp @@ -17,6 +17,9 @@ #include #include +#include "common/float16.hpp" +#include "common/nstl.hpp" + #include "cpu/rv64/rvv_softmax.hpp" namespace dnnl { @@ -67,6 +70,61 @@ void compute_softmax_f32_rvv( } } +#if defined(DNNL_RISCV_USE_ZVFH_INTRINSICS) +// f16 compute kernel +void compute_softmax_f16_rvv(const dnnl::impl::float16_t *src, + dnnl::impl::float16_t *dst, dim_t len, bool is_logsoftmax) { + + float max_val + = (float)nstl::numeric_limits::lowest(); + for (dim_t i = 0; i < len; ++i) { + float val = (float)src[i]; + if (val > max_val) max_val = val; + } + + if (is_logsoftmax) { + float sum_exp = 0.f; + for (dim_t i = 0; i < len; ++i) { + sum_exp += expf((float)src[i] - max_val); + } + const float log_sum = logf(sum_exp); + + for (dim_t i = 0; i < len;) { + size_t vl = __riscv_vsetvl_e16m1((size_t)(len - i)); + vfloat16m1_t v_src + = __riscv_vle16_v_f16m1((const _Float16 *)(src + i), vl); + vfloat32m2_t v_f32 = __riscv_vfwcvt_f_f_v_f32m2(v_src, vl); + vfloat32m2_t v_res = __riscv_vfsub_vf_f32m2(v_f32, max_val, vl); + v_res = __riscv_vfsub_vf_f32m2(v_res, log_sum, vl); + vfloat16m1_t v_out = __riscv_vfncvt_f_f_w_f16m1(v_res, vl); + __riscv_vse16_v_f16m1((_Float16 *)(dst + i), v_out, vl); + i += (dim_t)vl; + } + } else { + float *tmp_dst = new float[len]; + float sum_exp = 0.f; + for (dim_t i = 0; i < len; ++i) { + float e = expf((float)src[i] - max_val); + tmp_dst[i] = e; + sum_exp += e; + } + const float inv_sum = 1.0f / sum_exp; + + for (dim_t i = 0; i < len;) { + size_t vl = __riscv_vsetvl_e16m1((size_t)(len - i)); + + vfloat32m2_t v_f32 = __riscv_vle32_v_f32m2(tmp_dst + i, vl); + vfloat32m2_t v_res = __riscv_vfmul_vf_f32m2(v_f32, inv_sum, vl); + vfloat16m1_t v_out = __riscv_vfncvt_f_f_w_f16m1(v_res, vl); + __riscv_vse16_v_f16m1((_Float16 *)(dst + i), v_out, vl); + + i += (dim_t)vl; + } + delete[] tmp_dst; + } +} +#endif + } // namespace status_t rvv_softmax_fwd_t::execute_forward(const exec_ctx_t &ctx) const { @@ -121,6 +179,70 @@ status_t rvv_softmax_fwd_t::execute_forward(const exec_ctx_t &ctx) const { }); } } break; +#if defined(DNNL_RISCV_USE_ZVFH_INTRINSICS) + case data_type::f16: { + const auto *src_f16 + = static_cast(src); + auto *dst_f16 = static_cast(dst); + + const dim_t outer_stride = pd()->axis_size(true) * rsp.inner_size; + const int nthr = pd()->nthr_; + + if (rsp.inner_size == 1) { + parallel_nd(rsp.outer_size, [&](dim_t outer) { + const dim_t base = outer * outer_stride; + compute_softmax_f16_rvv(src_f16 + base, dst_f16 + base, + rsp.axis_size, rsp.is_logsoftmax); + }); + } else { + auto scratch = ctx.get_scratchpad_grantor().template get( + memory_tracking::names::key_softmax_interim_store); + + parallel(nthr, [&](int ithr, int nthr) { + auto *tmp + = reinterpret_cast(scratch) + + static_cast(ithr) + * static_cast(rsp.axis_size); + + const dim_t work_amount = rsp.outer_size * rsp.inner_size; + dim_t start {0}, end {0}; + balance211(work_amount, nthr, ithr, start, end); + + size_t stride_bytes = rsp.inner_size * sizeof(_Float16); + + for (dim_t idx = start; idx < end; ++idx) { + const dim_t outer = idx / rsp.inner_size; + const dim_t i = idx % rsp.inner_size; + const dim_t base = outer * outer_stride + i; + + for (dim_t a = 0; a < rsp.axis_size;) { + size_t vl = __riscv_vsetvl_e16m1(rsp.axis_size - a); + vfloat16m1_t v = __riscv_vlse16_v_f16m1( + (const _Float16 *)(src_f16 + base + + a * rsp.inner_size), + stride_bytes, vl); + __riscv_vse16_v_f16m1((_Float16 *)(tmp + a), v, vl); + a += (dim_t)vl; + } + + compute_softmax_f16_rvv( + tmp, tmp, rsp.axis_size, rsp.is_logsoftmax); + + for (dim_t a = 0; a < rsp.axis_size;) { + size_t vl = __riscv_vsetvl_e16m1(rsp.axis_size - a); + vfloat16m1_t v = __riscv_vle16_v_f16m1( + (const _Float16 *)(tmp + a), vl); + __riscv_vsse16_v_f16m1( + (_Float16 *)(dst_f16 + base + + a * rsp.inner_size), + stride_bytes, v, vl); + a += (dim_t)vl; + } + } + }); + } + } break; +#endif default: return status::unimplemented; } diff --git a/src/cpu/rv64/rvv_softmax.hpp b/src/cpu/rv64/rvv_softmax.hpp index cefe85b75f3..effd277d722 100644 --- a/src/cpu/rv64/rvv_softmax.hpp +++ b/src/cpu/rv64/rvv_softmax.hpp @@ -21,6 +21,7 @@ #include "common/memory_tracking.hpp" #include "common/primitive.hpp" #include "cpu/cpu_softmax_pd.hpp" +#include "cpu/rv64/cpu_isa_traits.hpp" namespace dnnl { namespace impl { @@ -61,8 +62,17 @@ struct rvv_softmax_fwd_t : public primitive_t { rsp_.outer_size = src_d.nelems(true) / (rsp_.inner_size * axis_size(true)); - VDISPATCH_SOFTMAX(rsp_.data_type == data_type::f32 - && dst_md()->data_type == rsp_.data_type, + const bool is_f16 = src_md()->data_type == data_type::f16; + VDISPATCH_SOFTMAX(utils::one_of(src_md()->data_type, data_type::f32, + data_type::f16), + VERBOSE_UNSUPPORTED_DT); + VDISPATCH_SOFTMAX(src_md()->data_type == dst_md()->data_type, + VERBOSE_UNSUPPORTED_DT); + if (is_f16) { + VDISPATCH_SOFTMAX(mayiuse(zvfh), VERBOSE_UNSUPPORTED_ISA); + } + VDISPATCH_SOFTMAX( + platform::has_data_type_support(src_md()->data_type), VERBOSE_UNSUPPORTED_DT); VDISPATCH_SOFTMAX( check_layouts(src_d, dst_d), VERBOSE_UNSUPPORTED_TAG); @@ -83,10 +93,11 @@ struct rvv_softmax_fwd_t : public primitive_t { void init_scratchpad() { auto scratchpad = scratchpad_registry().registrar(); nthr_ = rsp_.inner_size > 1 ? dnnl_get_max_threads() : 1; + const size_t dt_size = types::data_type_size(rsp_.data_type); if (rsp_.inner_size > 1) { scratchpad.template book( memory_tracking::names::key_softmax_interim_store, - static_cast(axis_size(true)) * sizeof(float) + static_cast(axis_size(true)) * dt_size * static_cast(nthr_)); } }