Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions src/cpu/rv64/rvv_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#include <math.h>
#include <riscv_vector.h>

#include "common/float16.hpp"
#include "common/nstl.hpp"

#include "cpu/rv64/rvv_softmax.hpp"

namespace dnnl {
Expand Down Expand Up @@ -67,6 +70,61 @@ void compute_softmax_f32_rvv(
}
}

#if defined(DNNL_RISCV_USE_ZVFH_INTRINSICS)
// f16 compute kernel
void compute_softmax_f16_rvv(const dnnl::impl::float16_t *src,
dnnl::impl::float16_t *dst, dim_t len, bool is_logsoftmax) {

float max_val
= (float)nstl::numeric_limits<dnnl::impl::float16_t>::lowest();
for (dim_t i = 0; i < len; ++i) {
float val = (float)src[i];
if (val > max_val) max_val = val;
}

if (is_logsoftmax) {
float sum_exp = 0.f;
for (dim_t i = 0; i < len; ++i) {
sum_exp += expf((float)src[i] - max_val);
}
const float log_sum = logf(sum_exp);

for (dim_t i = 0; i < len;) {
size_t vl = __riscv_vsetvl_e16m1((size_t)(len - i));
vfloat16m1_t v_src
= __riscv_vle16_v_f16m1((const _Float16 *)(src + i), vl);
vfloat32m2_t v_f32 = __riscv_vfwcvt_f_f_v_f32m2(v_src, vl);
vfloat32m2_t v_res = __riscv_vfsub_vf_f32m2(v_f32, max_val, vl);
v_res = __riscv_vfsub_vf_f32m2(v_res, log_sum, vl);
vfloat16m1_t v_out = __riscv_vfncvt_f_f_w_f16m1(v_res, vl);
__riscv_vse16_v_f16m1((_Float16 *)(dst + i), v_out, vl);
i += (dim_t)vl;
}
} else {
float *tmp_dst = new float[len];
float sum_exp = 0.f;
for (dim_t i = 0; i < len; ++i) {
float e = expf((float)src[i] - max_val);
tmp_dst[i] = e;
sum_exp += e;
}
const float inv_sum = 1.0f / sum_exp;

for (dim_t i = 0; i < len;) {
size_t vl = __riscv_vsetvl_e16m1((size_t)(len - i));

vfloat32m2_t v_f32 = __riscv_vle32_v_f32m2(tmp_dst + i, vl);
vfloat32m2_t v_res = __riscv_vfmul_vf_f32m2(v_f32, inv_sum, vl);
vfloat16m1_t v_out = __riscv_vfncvt_f_f_w_f16m1(v_res, vl);
__riscv_vse16_v_f16m1((_Float16 *)(dst + i), v_out, vl);

i += (dim_t)vl;
}
delete[] tmp_dst;
}
}
#endif

} // namespace

status_t rvv_softmax_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
Expand Down Expand Up @@ -121,6 +179,70 @@ status_t rvv_softmax_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
});
}
} break;
#if defined(DNNL_RISCV_USE_ZVFH_INTRINSICS)
case data_type::f16: {
const auto *src_f16
= static_cast<const dnnl::impl::float16_t *>(src);
auto *dst_f16 = static_cast<dnnl::impl::float16_t *>(dst);

const dim_t outer_stride = pd()->axis_size(true) * rsp.inner_size;
const int nthr = pd()->nthr_;

if (rsp.inner_size == 1) {
parallel_nd(rsp.outer_size, [&](dim_t outer) {
const dim_t base = outer * outer_stride;
compute_softmax_f16_rvv(src_f16 + base, dst_f16 + base,
rsp.axis_size, rsp.is_logsoftmax);
});
} else {
auto scratch = ctx.get_scratchpad_grantor().template get<char>(
memory_tracking::names::key_softmax_interim_store);

parallel(nthr, [&](int ithr, int nthr) {
auto *tmp
= reinterpret_cast<dnnl::impl::float16_t *>(scratch)
+ static_cast<size_t>(ithr)
* static_cast<size_t>(rsp.axis_size);

const dim_t work_amount = rsp.outer_size * rsp.inner_size;
dim_t start {0}, end {0};
balance211(work_amount, nthr, ithr, start, end);

size_t stride_bytes = rsp.inner_size * sizeof(_Float16);

for (dim_t idx = start; idx < end; ++idx) {
const dim_t outer = idx / rsp.inner_size;
const dim_t i = idx % rsp.inner_size;
const dim_t base = outer * outer_stride + i;

for (dim_t a = 0; a < rsp.axis_size;) {
size_t vl = __riscv_vsetvl_e16m1(rsp.axis_size - a);
vfloat16m1_t v = __riscv_vlse16_v_f16m1(
(const _Float16 *)(src_f16 + base
+ a * rsp.inner_size),
stride_bytes, vl);
__riscv_vse16_v_f16m1((_Float16 *)(tmp + a), v, vl);
a += (dim_t)vl;
}

compute_softmax_f16_rvv(
tmp, tmp, rsp.axis_size, rsp.is_logsoftmax);

for (dim_t a = 0; a < rsp.axis_size;) {
size_t vl = __riscv_vsetvl_e16m1(rsp.axis_size - a);
vfloat16m1_t v = __riscv_vle16_v_f16m1(
(const _Float16 *)(tmp + a), vl);
__riscv_vsse16_v_f16m1(
(_Float16 *)(dst_f16 + base
+ a * rsp.inner_size),
stride_bytes, v, vl);
a += (dim_t)vl;
}
}
});
}
} break;
#endif
default: return status::unimplemented;
}

Expand Down
17 changes: 14 additions & 3 deletions src/cpu/rv64/rvv_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "common/memory_tracking.hpp"
#include "common/primitive.hpp"
#include "cpu/cpu_softmax_pd.hpp"
#include "cpu/rv64/cpu_isa_traits.hpp"

namespace dnnl {
namespace impl {
Expand Down Expand Up @@ -61,8 +62,17 @@ struct rvv_softmax_fwd_t : public primitive_t {
rsp_.outer_size
= src_d.nelems(true) / (rsp_.inner_size * axis_size(true));

VDISPATCH_SOFTMAX(rsp_.data_type == data_type::f32
&& dst_md()->data_type == rsp_.data_type,
const bool is_f16 = src_md()->data_type == data_type::f16;
VDISPATCH_SOFTMAX(utils::one_of(src_md()->data_type, data_type::f32,
data_type::f16),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_SOFTMAX(src_md()->data_type == dst_md()->data_type,
VERBOSE_UNSUPPORTED_DT);
if (is_f16) {
VDISPATCH_SOFTMAX(mayiuse(zvfh), VERBOSE_UNSUPPORTED_ISA);
}
VDISPATCH_SOFTMAX(
platform::has_data_type_support(src_md()->data_type),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_SOFTMAX(
check_layouts(src_d, dst_d), VERBOSE_UNSUPPORTED_TAG);
Expand All @@ -83,10 +93,11 @@ struct rvv_softmax_fwd_t : public primitive_t {
void init_scratchpad() {
auto scratchpad = scratchpad_registry().registrar();
nthr_ = rsp_.inner_size > 1 ? dnnl_get_max_threads() : 1;
const size_t dt_size = types::data_type_size(rsp_.data_type);
if (rsp_.inner_size > 1) {
scratchpad.template book<char>(
memory_tracking::names::key_softmax_interim_store,
static_cast<size_t>(axis_size(true)) * sizeof(float)
static_cast<size_t>(axis_size(true)) * dt_size
* static_cast<size_t>(nthr_));
}
}
Expand Down
Loading