diff --git a/.gitignore b/.gitignore index 1d1c4dec664..c8c180931da 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ *audio_frontend* *google* *__pycache__* +.venv venv gen diff --git a/tensorflow/lite/micro/examples/micro_speech/Makefile.inc b/tensorflow/lite/micro/examples/micro_speech/Makefile.inc index a1b5b565cf5..67ce420a6a1 100644 --- a/tensorflow/lite/micro/examples/micro_speech/Makefile.inc +++ b/tensorflow/lite/micro/examples/micro_speech/Makefile.inc @@ -61,4 +61,4 @@ list_micro_speech_example_sources: @echo $(MICRO_SPEECH_SRCS) list_micro_speech_example_headers: - @echo $(MICRO_SPEECH_HDRS) + @echo $(MICRO_SPEECH_HDRS) \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/conv.cc b/tensorflow/lite/micro/kernels/riscv_vector/conv.cc new file mode 100644 index 00000000000..ac732310e16 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/conv.cc @@ -0,0 +1,199 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/conv.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" +#include "tensorflow/lite/kernels/internal/reference/conv.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_log.h" + +#include "tensorflow/lite/micro/kernels/riscv_vector/conv_rvv.h" + +namespace tflite { +namespace { + +TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kConvInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kConvWeightsTensor); + const TfLiteEvalTensor* bias = + (NumInputs(node) == 3) + ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor) + : nullptr; + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kConvOutputTensor); + + TFLITE_DCHECK(node->builtin_data != nullptr); + const auto& params = + *(reinterpret_cast(node->builtin_data)); + TFLITE_DCHECK(node->user_data != nullptr); + const auto& data = *(static_cast(node->user_data)); + +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: { + tflite::reference_ops::Conv( + ConvParamsFloat(params, data), tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), + tflite::micro::GetTensorShape(nullptr), nullptr); + break; + } + case kTfLiteInt16: { + if (bias == nullptr || bias->type == kTfLiteInt32) { + reference_integer_ops::ConvPerChannel( + ConvParamsQuantized(params, data), + data.per_channel_output_multiplier, data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } else if (bias->type == kTfLiteInt64) { + reference_integer_ops::ConvPerChannel( + ConvParamsQuantized(params, data), + data.per_channel_output_multiplier, data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } else { + MicroPrintf("Bias type %s (%d) not supported.", + TfLiteTypeGetName(bias->type), bias->type); + return kTfLiteError; + } + break; + } + case kTfLiteInt8: { + switch (filter->type) { + case kTfLiteInt4: { + int8_t* unpacked_filter_data = static_cast( + context->GetScratchBuffer(context, data.filter_buffer_index)); + tflite::tensor_utils::UnpackDenseInt4IntoInt8( + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(filter).FlatSize(), + unpacked_filter_data); + reference_integer_ops::ConvPerChannel( + ConvParamsQuantized(params, data), + data.per_channel_output_multiplier, data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), unpacked_filter_data, + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + case kTfLiteInt8: { + ConvPerChannelRVV( + ConvParamsQuantized(params, data), + data.per_channel_output_multiplier, data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + default: + MicroPrintf("Weight type %s (%d) not supported.", + TfLiteTypeGetName(filter->type), filter->type); + return kTfLiteError; + } + break; + } + default: + MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type), + input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace + +TFLMRegistration Register_CONV_2D() { + return tflite::micro::RegisterOp(ConvInit, ConvPrepare, ConvEval); +} + +} // namespace tflite \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/conv_rvv.cc b/tensorflow/lite/micro/kernels/riscv_vector/conv_rvv.cc new file mode 100644 index 00000000000..e0cf889ecf8 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/conv_rvv.cc @@ -0,0 +1,362 @@ +#include + +#include +#include +#include +#include +#include + +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/micro/micro_log.h" + +#include "requantize_rvv.h" + +using namespace tflite; + +void ConvPerChannelRVV(const ConvParams& params, + const int32_t* output_multiplier, + const int32_t* output_shift, + const RuntimeShape& input_shape, + const int8_t* input_data, + const RuntimeShape& filter_shape, + const int8_t* filter_data, + const RuntimeShape& bias_shape, + const int32_t* bias_data, + const RuntimeShape& output_shape, + int8_t* output_data) +{ + // Extract convolution parameters + const int32_t input_offset = params.input_offset; + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + + // Extract shape dimensions + const int input_batches = input_shape.Dims(0); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = input_shape.Dims(3); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int filter_input_depth = filter_shape.Dims(3); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int output_depth = output_shape.Dims(3); + + // Calculate grouping parameters + const int groups = input_depth / filter_input_depth; + const int filters_per_group = output_depth / groups; + + // Calculate tensor strides + const int input_ch_stride = 1; + const int input_w_stride = input_depth; + const int input_h_stride = input_width * input_w_stride; + const int input_b_stride = input_height * input_h_stride; + const int filter_ch_stride = 1; + const int filter_w_stride = filter_input_depth; + const int filter_h_stride = filter_width * filter_w_stride; + const int filter_o_stride = filter_height * filter_h_stride; + const int output_ch_stride = 1; + const int output_w_stride = output_depth; + const int output_h_stride = output_width * output_w_stride; + const int output_b_stride = output_height * output_h_stride; + + // Prepare scalar constants + const int16_t s_input_offset_s16 = static_cast(input_offset); + const int32_t s_output_offset_s32 = output_offset; + const int32_t s_output_activation_min_s32 = output_activation_min; + const int32_t s_output_activation_max_s32 = output_activation_max; + + // Loop over batches + for (int batch = 0; batch < input_batches; ++batch) + { + const int8_t* input_batch_base = input_data + batch * input_b_stride; + int8_t* output_batch_base = output_data + batch * output_b_stride; + + // Loop over output height + for (int out_y = 0; out_y < output_height; ++out_y) + { + const int in_y_origin = (out_y * stride_height) - pad_height; + int8_t* output_row_base = output_batch_base + out_y * output_h_stride; + + // Loop over output channels + for (int out_channel = 0; out_channel < output_depth; ++out_channel) + { + // Calculate group and filter parameters for this output channel + const int group = out_channel / filters_per_group; + const int group_start_input_channel = group * filter_input_depth; + const int8_t* filter_oc_base = filter_data + out_channel * filter_o_stride; + + // Get per-channel requantization parameters + const int32_t scalar_multiplier = output_multiplier[out_channel]; + const int32_t scalar_shift = output_shift[out_channel]; + const int effective_right_shift = 31 - scalar_shift; + + // Get bias value for this output channel + const int32_t bias_val = bias_data ? bias_data[out_channel] : 0; + + // Calculate output pointer and stride for this channel row + int8_t* output_channel_base = output_row_base + out_channel * output_ch_stride; + const ptrdiff_t output_x_stride_bytes = output_w_stride * sizeof(int8_t); + + // Process output width in vector chunks + size_t current_out_x = 0; + while (current_out_x < static_cast(output_width)) + { + // Set vector length for this iteration (LMUL=2 optimization) + size_t vl = __riscv_vsetvl_e32m2(output_width - current_out_x); + + // Initialize accumulator vector with bias + vint32m2_t v_acc_s32 = bias_data ? __riscv_vmv_v_x_i32m2(bias_val, vl) + : __riscv_vmv_v_x_i32m2(0, vl); + + // Calculate base input x coordinates for the vector lanes + vuint32m2_t v_idx = __riscv_vid_v_u32m2(vl); + vint32m2_t v_out_x = __riscv_vreinterpret_v_u32m2_i32m2(__riscv_vadd_vx_u32m2(v_idx, static_cast(current_out_x), vl)); + vint32m2_t v_in_x_origin_base = __riscv_vsub_vx_i32m2(__riscv_vmul_vx_i32m2(v_out_x, stride_width, vl), pad_width, vl); + + // Loop over filter height + for (int filter_y = 0; filter_y < filter_height; ++filter_y) + { + const int in_y = in_y_origin + dilation_height_factor * filter_y; + if (in_y < 0 || in_y >= input_height) continue; + + const int8_t* filter_y_base = filter_oc_base + (filter_y * filter_h_stride); + + // Loop over filter width + for (int filter_x = 0; filter_x < filter_width; ++filter_x) + { + const int in_x_offset = dilation_width_factor * filter_x; + const int8_t* filter_patch_base = filter_y_base + (filter_x * filter_w_stride); + vint32m2_t v_in_x = __riscv_vadd_vx_i32m2(v_in_x_origin_base, in_x_offset, vl); + + // Create mask for valid input coordinates + vbool16_t v_mask_ge_zero = __riscv_vmsge_vx_i32m2_b16(v_in_x, 0, vl); + vbool16_t v_mask_lt_width = __riscv_vmslt_vx_i32m2_b16(v_in_x, input_width, vl); + vbool16_t v_active_lane_mask = __riscv_vmand_mm_b16(v_mask_ge_zero, v_mask_lt_width, vl); + + // Calculate base input pointer and stride for vector load + int32_t base_in_x_for_vector0 = static_cast(current_out_x) * stride_width - pad_width + in_x_offset; + const int8_t* input_base_for_y_x_patch = input_batch_base + (in_y * input_h_stride) + (base_in_x_for_vector0 * input_w_stride) + + (group_start_input_channel * input_ch_stride); + ptrdiff_t input_x_stride_bytes = static_cast(stride_width) * input_w_stride * sizeof(int8_t); + + // Loop over input channels for this filter tap + for (int ic = 0; ic < filter_input_depth; ++ic) + { + int8_t s_filter_val_s8 = filter_patch_base[ic * filter_ch_stride]; + int16_t s_filter_val_s16 = static_cast(s_filter_val_s8); + const int8_t* input_ic_ptr = input_base_for_y_x_patch + (ic * input_ch_stride); + + // Load inputs: Use mf2 to match m2 element count (32bit vs 8bit ratio is 4) + vint8mf2_t v_input_s8 = __riscv_vlse8_v_i8mf2_m(v_active_lane_mask, input_ic_ptr, input_x_stride_bytes, vl); + + // Widen to 16-bit (m1) + vint16m1_t v_input_s16 = __riscv_vsext_vf2_i16m1_m(v_active_lane_mask, v_input_s8, vl); + vint16m1_t v_input_plus_offset_s16 = __riscv_vadd_vx_i16m1_m(v_active_lane_mask, v_input_s16, s_input_offset_s16, vl); + + // Widen accumulate into 32-bit (m2) + v_acc_s32 = __riscv_vwmacc_vx_i32m2_m(v_active_lane_mask, v_acc_s32, s_filter_val_s16, v_input_plus_offset_s16, vl); + } + } + } + + // Requantize the accumulated values (vint32m2_t) + vint32m2_t v_res32 = RequantizeVectorPerTensorS32( + v_acc_s32, + scalar_multiplier, + effective_right_shift, + s_output_offset_s32, + s_output_activation_min_s32, + s_output_activation_max_s32, + vl); + + // Narrow result to int16 (m1) and then int8 (mf2) with saturation + vint16m1_t v_res16 = __riscv_vnclip_wx_i16m1(v_res32, 0, __RISCV_VXRM_RNU, vl); + vint8mf2_t v_out_s8 = __riscv_vnclip_wx_i8mf2(v_res16, 0, __RISCV_VXRM_RNU, vl); + + // Store results vector (strided) + int8_t* output_strip_base_ptr = output_channel_base + current_out_x * output_w_stride; + __riscv_vsse8_v_i8mf2(output_strip_base_ptr, output_x_stride_bytes, v_out_s8, vl); + + // Advance output x pointer + current_out_x += vl; + } + } + } + } +} + +void DepthwiseConvPerChannelRVV(const DepthwiseParams& params, + const int32_t* output_multiplier, + const int32_t* output_shift, + const RuntimeShape& input_shape, + const int8_t* input_data, + const RuntimeShape& filter_shape, + const int8_t* filter_data, + const RuntimeShape& bias_shape, + const int32_t* bias_data, + const RuntimeShape& output_shape, + int8_t* output_data) +{ + // Extract depthwise convolution parameters + const int32_t input_offset = params.input_offset; + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + const int depth_multiplier = params.depth_multiplier; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + + // Extract shape dimensions + const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3); + const int input_batches = input_shape.Dims(0); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = input_shape.Dims(3); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + // Calculate tensor strides + const int input_ch_stride = 1; + const int input_w_stride = input_depth; + const int input_h_stride = input_width * input_w_stride; + const int input_b_stride = input_height * input_h_stride; + const int filter_ch_stride = 1; + const int filter_w_stride = output_depth; + const int filter_h_stride = filter_width * filter_w_stride; + const int output_ch_stride = 1; + const int output_w_stride = output_depth; + const int output_h_stride = output_width * output_w_stride; + const int output_b_stride = output_height * output_h_stride; + + // Prepare scalar constants + const int16_t s_input_offset_s16 = static_cast(input_offset); + const int32_t s_output_offset_s32 = output_offset; + const int32_t s_output_activation_min_s32 = output_activation_min; + const int32_t s_output_activation_max_s32 = output_activation_max; + + // Loop over batches + for (int batch = 0; batch < input_batches; ++batch) + { + const int8_t* input_batch_base = input_data + batch * input_b_stride; + int8_t* output_batch_base = output_data + batch * output_b_stride; + + // Loop over output height + for (int out_y = 0; out_y < output_height; ++out_y) + { + const int in_y_origin = (out_y * stride_height) - pad_height; + + // Loop over input channels (depthwise) + for (int in_channel = 0; in_channel < input_depth; ++in_channel) + { + // Loop over depth multiplier + for (int m = 0; m < depth_multiplier; ++m) + { + // Calculate the current output channel + const int output_channel = m + in_channel * depth_multiplier; + + // Get per-channel requantization parameters + const int32_t scalar_multiplier = output_multiplier[output_channel]; + const int32_t scalar_shift = output_shift[output_channel]; + const int effective_right_shift = 31 - scalar_shift; + + // Get bias value for this output channel + const int32_t bias_val = bias_data ? bias_data[output_channel] : 0; + + // Calculate output pointer and stride for this channel row + int8_t* output_channel_row_base = output_batch_base + out_y * output_h_stride + output_channel * output_ch_stride; + const ptrdiff_t output_x_stride_bytes = output_w_stride * sizeof(int8_t); + + // Process output width in vector chunks + size_t current_out_x = 0; + while (current_out_x < static_cast(output_width)) + { + // Set vector length for this iteration (LMUL=2) + size_t vl = __riscv_vsetvl_e32m2(output_width - current_out_x); + + // Initialize accumulator vector with bias + vint32m2_t v_acc_s32 = bias_data ? __riscv_vmv_v_x_i32m2(bias_val, vl) + : __riscv_vmv_v_x_i32m2(0, vl); + + // Calculate base input x coordinates for the vector lanes + vuint32m2_t v_idx = __riscv_vid_v_u32m2(vl); + vint32m2_t v_out_x = __riscv_vreinterpret_v_u32m2_i32m2(__riscv_vadd_vx_u32m2(v_idx, static_cast(current_out_x), vl)); + vint32m2_t v_in_x_origin_base = __riscv_vsub_vx_i32m2(__riscv_vmul_vx_i32m2(v_out_x, stride_width, vl), pad_width, vl); + + // Loop over filter height + for (int filter_y = 0; filter_y < filter_height; ++filter_y) + { + const int in_y = in_y_origin + dilation_height_factor * filter_y; + if (in_y < 0 || in_y >= input_height) continue; + + const int8_t* filter_y_base = filter_data + filter_y * filter_h_stride; + + // Loop over filter width + for (int filter_x = 0; filter_x < filter_width; ++filter_x) + { + const int in_x_offset = dilation_width_factor * filter_x; + vint32m2_t v_in_x = __riscv_vadd_vx_i32m2(v_in_x_origin_base, in_x_offset, vl); + + // Create mask for valid input coordinates + vbool16_t v_mask_ge_zero = __riscv_vmsge_vx_i32m2_b16(v_in_x, 0, vl); + vbool16_t v_mask_lt_width = __riscv_vmslt_vx_i32m2_b16(v_in_x, input_width, vl); + vbool16_t v_active_lane_mask = __riscv_vmand_mm_b16(v_mask_ge_zero, v_mask_lt_width, vl); + + // Optimization: skip MAC if all lanes are masked off + if (__riscv_vfirst_m_b16(v_active_lane_mask, vl) == -1) continue; + + const int8_t* filter_ptr = filter_y_base + filter_x * filter_w_stride + output_channel * filter_ch_stride; + int16_t s_filter_val_s16 = static_cast(*filter_ptr); + + int32_t base_in_x_for_vector0 = static_cast(current_out_x) * stride_width - pad_width + in_x_offset; + const int8_t* input_base_ptr = + input_batch_base + in_y * input_h_stride + base_in_x_for_vector0 * input_w_stride + in_channel * input_ch_stride; + ptrdiff_t input_x_stride_bytes = static_cast(stride_width) * input_w_stride * sizeof(int8_t); + + // Load input: mf2 -> m1 -> m2 accumulate + vint8mf2_t v_input_s8 = __riscv_vlse8_v_i8mf2_m(v_active_lane_mask, input_base_ptr, input_x_stride_bytes, vl); + vint16m1_t v_input_s16 = __riscv_vsext_vf2_i16m1_m(v_active_lane_mask, v_input_s8, vl); + vint16m1_t v_input_plus_offset_s16 = __riscv_vadd_vx_i16m1_m(v_active_lane_mask, v_input_s16, s_input_offset_s16, vl); + v_acc_s32 = __riscv_vwmacc_vx_i32m2_m(v_active_lane_mask, v_acc_s32, s_filter_val_s16, v_input_plus_offset_s16, vl); + } + } + + // Requantize the accumulated values in a single function call. + vint32m2_t v_res32 = RequantizeVectorPerTensorS32( + v_acc_s32, + scalar_multiplier, + effective_right_shift, + s_output_offset_s32, + s_output_activation_min_s32, + s_output_activation_max_s32, + vl); + + // Narrow result to int16 and then int8 with saturation + vint16m1_t v_res16 = __riscv_vnclip_wx_i16m1(v_res32, 0, __RISCV_VXRM_RNU, vl); + vint8mf2_t v_out_s8 = __riscv_vnclip_wx_i8mf2(v_res16, 0, __RISCV_VXRM_RNU, vl); + + // Store results vector (strided) + int8_t* output_strip_base_ptr = output_channel_row_base + current_out_x * output_w_stride; + __riscv_vsse8_v_i8mf2(output_strip_base_ptr, output_x_stride_bytes, v_out_s8, vl); + + // Advance output x pointer + current_out_x += vl; + } + } + } + } + } +} \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/conv_rvv.h b/tensorflow/lite/micro/kernels/riscv_vector/conv_rvv.h new file mode 100644 index 00000000000..0dfdebddf3e --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/conv_rvv.h @@ -0,0 +1,31 @@ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_CONV_RVV_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_CONV_RVV_H_ + +#include +#include + +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" + +using namespace tflite; + +void ConvPerChannelRVV( + const ConvParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int8_t* output_data); + +void DepthwiseConvPerChannelRVV(const DepthwiseParams& params, + const int32_t* output_multiplier, + const int32_t* output_shift, + const RuntimeShape& input_shape, + const int8_t* input_data, + const RuntimeShape& filter_shape, + const int8_t* filter_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, + const RuntimeShape& output_shape, int8_t* output_data); + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_CONV_RVV_H_ \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/depthwise_conv.cc b/tensorflow/lite/micro/kernels/riscv_vector/depthwise_conv.cc new file mode 100644 index 00000000000..f2a1ce9ec9a --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/depthwise_conv.cc @@ -0,0 +1,192 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/depthwise_conv.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_log.h" + +#include "tensorflow/lite/micro/kernels/riscv_vector/conv_rvv.h" + +namespace tflite { +namespace { + +void* DepthwiseConvInit(TfLiteContext* context, const char* buffer, + size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpDataConv)); +} + +TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(node->builtin_data != nullptr); + + auto& params = + *(reinterpret_cast(node->builtin_data)); + const OpDataConv& data = *(static_cast(node->user_data)); + + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kDepthwiseConvOutputTensor); + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kDepthwiseConvInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kDepthwiseConvWeightsTensor); + const TfLiteEvalTensor* bias = + (NumInputs(node) == 3) + ? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor) + : nullptr; + +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, + kDepthwiseConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: { + tflite::reference_ops::DepthwiseConv( + DepthwiseConvParamsFloat(params, data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + case kTfLiteInt8: { + switch (filter->type) { + case kTfLiteInt4: { + int8_t* unpacked_filter_data = static_cast( + context->GetScratchBuffer(context, data.filter_buffer_index)); + tflite::tensor_utils::UnpackDenseInt4IntoInt8( + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(filter).FlatSize(), + unpacked_filter_data); + reference_integer_ops::DepthwiseConvPerChannel( + DepthwiseConvParamsQuantized(params, data), + data.per_channel_output_multiplier, data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), unpacked_filter_data, + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + case kTfLiteInt8: { + DepthwiseConvPerChannelRVV( + DepthwiseConvParamsQuantized(params, data), + data.per_channel_output_multiplier, data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + default: + MicroPrintf("Filter type %s (%d) for input type %s not supported.", + TfLiteTypeGetName(filter->type), filter->type, + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + break; + } + case kTfLiteInt16: { + switch (filter->type) { + case kTfLiteInt8: { + reference_integer_ops::DepthwiseConvPerChannel( + DepthwiseConvParamsQuantized(params, data), + data.per_channel_output_multiplier, data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + filter_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + default: + MicroPrintf("Filter type %s (%d) for input type %s not supported.", + TfLiteTypeGetName(filter->type), filter->type, + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + break; + } + default: + MicroPrintf("Input type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace + +TFLMRegistration Register_DEPTHWISE_CONV_2D() { + return tflite::micro::RegisterOp(DepthwiseConvInit, DepthwiseConvPrepare, + DepthwiseConvEval); +} + +} // namespace tflite \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/fully_connected.cc b/tensorflow/lite/micro/kernels/riscv_vector/fully_connected.cc new file mode 100644 index 00000000000..32fb6b9fb44 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/fully_connected.cc @@ -0,0 +1,365 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/fully_connected.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_log.h" + +#include "tensorflow/lite/micro/kernels/riscv_vector/fully_connected_rvv.h" + +namespace tflite { +namespace { + +void* FullyConnectedInit(TfLiteContext* context, const char* buffer, + size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, + sizeof(OpDataFullyConnected)); +} + +TfLiteStatus FullyConnectedPrepare(TfLiteContext* context, TfLiteNode* node) { + MicroContext* micro_context = GetMicroContext(context); + + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(node->builtin_data != nullptr); + + auto* data = static_cast(node->user_data); + const auto params = + static_cast(node->builtin_data); + + TfLiteTensor* input = + micro_context->AllocateTempInputTensor(node, kFullyConnectedInputTensor); + TF_LITE_ENSURE(context, input != nullptr); + TfLiteTensor* filter = micro_context->AllocateTempInputTensor( + node, kFullyConnectedWeightsTensor); + TF_LITE_ENSURE(context, filter != nullptr); + TfLiteTensor* bias = + micro_context->AllocateTempInputTensor(node, kFullyConnectedBiasTensor); + TfLiteTensor* output = micro_context->AllocateTempOutputTensor( + node, kFullyConnectedOutputTensor); + TF_LITE_ENSURE(context, output != nullptr); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); + + if ((input->type == kTfLiteFloat32 && filter->type != kTfLiteFloat32) || + (input->type == kTfLiteInt8 && + (filter->type != kTfLiteInt8 && filter->type != kTfLiteInt4)) || + (input->type == kTfLiteInt16 && filter->type != kTfLiteInt8)) { + MicroPrintf("Input type: %s with filter type: %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(filter->type)); + return kTfLiteError; + } + + if (filter->type == kTfLiteInt4) { + int filter_size = + RuntimeShape(filter->dims->size, + reinterpret_cast(filter->dims->data)) + .FlatSize(); + context->RequestScratchBufferInArena(context, filter_size, + &data->filter_buffer_index); + } + + TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected( + context, params->activation, input->type, + input, filter, bias, output, data)); + +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kFullyConnectedWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedWeightsTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + + micro_context->DeallocateTempTfLiteTensor(input); + micro_context->DeallocateTempTfLiteTensor(filter); + if (bias != nullptr) { + micro_context->DeallocateTempTfLiteTensor(bias); + } + micro_context->DeallocateTempTfLiteTensor(output); + return kTfLiteOk; +} + +TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->builtin_data != nullptr); + const auto* params = + static_cast(node->builtin_data); + + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor); + const TfLiteEvalTensor* bias = + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); + +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + + TFLITE_DCHECK(node->user_data != nullptr); + const auto& data = + *(static_cast(node->user_data)); + + // Checks in Prepare ensure input, output and filter types are all the same. + switch (input->type) { + case kTfLiteFloat32: { + tflite::reference_ops::FullyConnected( + FullyConnectedParamsFloat(params->activation), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + + case kTfLiteInt8: { + switch (filter->type) { + case kTfLiteInt4: { + int8_t* unpacked_filter_data = static_cast( + context->GetScratchBuffer(context, data.filter_buffer_index)); + tflite::tensor_utils::UnpackDenseInt4IntoInt8( + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(filter).FlatSize(), + unpacked_filter_data); + tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), unpacked_filter_data, + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + case kTfLiteInt8: { + data.is_per_channel + ? FullyConnectedPerChannelRVV( + FullyConnectedParamsQuantized(data), + data.per_channel_output_multiplier, + reinterpret_cast(data.per_channel_output_shift), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)) + : FullyConnectedRVV( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + default: { + MicroPrintf("Filter type %s (%d) not supported.", + TfLiteTypeGetName(filter->type), input->type); + return kTfLiteError; + } + } + break; + } + + case kTfLiteInt16: { + switch (filter->type) { + case kTfLiteInt8: { + if (bias == nullptr || bias->type == kTfLiteInt32) { + data.is_per_channel + ? tflite::reference_integer_ops::FullyConnectedPerChannel( + FullyConnectedParamsQuantized(data), + data.per_channel_output_multiplier, + reinterpret_cast( + data.per_channel_output_shift), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)) + : tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } else if (bias->type == kTfLiteInt64) { + data.is_per_channel + ? tflite::reference_integer_ops::FullyConnectedPerChannel( + FullyConnectedParamsQuantized(data), + data.per_channel_output_multiplier, + reinterpret_cast( + data.per_channel_output_shift), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)) + : tflite::reference_integer_ops::FullyConnected( + FullyConnectedParamsQuantized(data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } + break; + } + default: { + MicroPrintf("Filter type %s (%d) not supported.", + TfLiteTypeGetName(filter->type), input->type); + return kTfLiteError; + } + } + break; + } + + default: { + MicroPrintf("Input type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +} // namespace + +TFLMRegistration Register_FULLY_CONNECTED() { + return tflite::micro::RegisterOp(FullyConnectedInit, FullyConnectedPrepare, + FullyConnectedEval); +} + +TFLMInferenceRegistration RegisterInference_FULLY_CONNECTED() { + return tflite::micro::RegisterOp(FullyConnectedEval); +} + +} // namespace tflite \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/fully_connected_rvv.cc b/tensorflow/lite/micro/kernels/riscv_vector/fully_connected_rvv.cc new file mode 100644 index 00000000000..cfdfc12b893 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/fully_connected_rvv.cc @@ -0,0 +1,175 @@ +#include + +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/micro/micro_log.h" + +#include "requantize_rvv.h" + +using namespace tflite; + +void FullyConnectedPerChannelRVV(const FullyConnectedParams& params, + const int32_t* output_multiplier, + const int* output_shift, + const RuntimeShape& input_shape, + const int8_t* input_data, + const RuntimeShape& filter_shape, + const int8_t* filter_data, + const RuntimeShape& bias_shape, + const int32_t* bias_data, + const RuntimeShape& output_shape, + int8_t* output_data) +{ + // Extract quantization parameters + const int32_t input_offset = params.input_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + + // Extract shape dimensions + const int batches = FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1); + const int output_depth = output_shape.Dims(output_shape.DimensionsCount() - 1); + const int accum_depth = filter_shape.Dims(filter_shape.DimensionsCount() - 1); + + // Prepare scalar constants + const int16_t s_input_offset_s16 = static_cast(input_offset); + + // Loop over batches + for (int b = 0; b < batches; ++b) + { + const int8_t* input_batch_ptr = input_data + b * accum_depth; + int8_t* output_batch_ptr = output_data + b * output_depth; + + // Vectorized loop over output channels + size_t current_out_c = 0; + while (current_out_c < static_cast(output_depth)) + { + // Set vector length for this iteration (LMUL=2) + size_t vl = __riscv_vsetvl_e32m2(output_depth - current_out_c); + + // Initialize accumulator vector with biases + vint32m2_t v_acc_s32 = bias_data + ? __riscv_vle32_v_i32m2(bias_data + current_out_c, vl) + : __riscv_vmv_v_x_i32m2(0, vl); + + // Main MAC loop to compute dot products + for (int d = 0; d < accum_depth; ++d) + { + int16_t s_input_val_s16 = static_cast(input_batch_ptr[d]) + s_input_offset_s16; + const int8_t* filter_col_ptr = filter_data + d + current_out_c * accum_depth; + ptrdiff_t filter_stride = accum_depth * sizeof(int8_t); + + // Load filter: mf2 (matches element count of m2 32-bit) + vint8mf2_t v_filter_s8 = __riscv_vlse8_v_i8mf2(filter_col_ptr, filter_stride, vl); + + // Widen to m1 + vint16m1_t v_filter_s16 = __riscv_vsext_vf2_i16m1(v_filter_s8, vl); + + // Widen accumulate to m2 + v_acc_s32 = __riscv_vwmacc_vx_i32m2(v_acc_s32, s_input_val_s16, v_filter_s16, vl); + } + + // Load per-channel requantization parameters into vectors + vint32m2_t v_multiplier = __riscv_vle32_v_i32m2(output_multiplier + current_out_c, vl); + vint32m2_t v_shift = __riscv_vle32_v_i32m2( + reinterpret_cast(output_shift) + current_out_c, vl); + + // Requantize + vint32m2_t v_res32 = RequantizeVectorPerChannelS32( + v_acc_s32, v_multiplier, v_shift, + output_offset, output_activation_min, output_activation_max, vl); + + // Narrow result + vint16m1_t v_res16 = __riscv_vnclip_wx_i16m1(v_res32, 0, __RISCV_VXRM_RNU, vl); + vint8mf2_t v_out_s8 = __riscv_vnclip_wx_i8mf2(v_res16, 0, __RISCV_VXRM_RNU, vl); + + // Store result + __riscv_vse8_v_i8mf2(output_batch_ptr + current_out_c, v_out_s8, vl); + + // Advance to the next block of output channels + current_out_c += vl; + } + } +} + +void FullyConnectedRVV(const FullyConnectedParams& params, + const RuntimeShape& input_shape, + const int8_t* input_data, + const RuntimeShape& filter_shape, + const int8_t* filter_data, + const RuntimeShape& bias_shape, + const int32_t* bias_data, + const RuntimeShape& output_shape, + int8_t* output_data) +{ + // Extract quantization parameters + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + + // Extract shape dimensions + const int filter_dim_count = filter_shape.DimensionsCount(); + const int output_dim_count = output_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = output_shape.Dims(output_dim_count - 1); + const int accum_depth = filter_shape.Dims(filter_dim_count - 1); + + // Prepare scalar constants for vector operations + const int16_t s_input_offset_s16 = static_cast(input_offset); + const int16_t s_filter_offset_s16 = static_cast(filter_offset); + + // Loop over batches + for (int b = 0; b < batches; ++b) + { + const int8_t* input_batch_ptr = input_data + b * accum_depth; + int8_t* output_batch_ptr = output_data + b * output_depth; + + // Vectorized loop over output channels + size_t current_out_c = 0; + while (current_out_c < static_cast(output_depth)) + { + // Set vector length for processing multiple output channels (LMUL=2) + size_t vl = __riscv_vsetvl_e32m2(output_depth - current_out_c); + + // Initialize accumulator vector with biases + vint32m2_t v_acc_s32 = bias_data + ? __riscv_vle32_v_i32m2(bias_data + current_out_c, vl) + : __riscv_vmv_v_x_i32m2(0, vl); + + // Loop over accumulation depth to compute dot products in parallel + for (int d = 0; d < accum_depth; ++d) + { + int16_t s_input_val_s16 = static_cast(input_batch_ptr[d]) + s_input_offset_s16; + const int8_t* filter_col_ptr = filter_data + current_out_c * accum_depth + d; + ptrdiff_t filter_stride = accum_depth * sizeof(int8_t); + + // Load: mf2 -> m1 -> m1 (+offset) -> m2 (accumulate) + vint8mf2_t v_filter_s8 = __riscv_vlse8_v_i8mf2(filter_col_ptr, filter_stride, vl); + vint16m1_t v_filter_s16 = __riscv_vsext_vf2_i16m1(v_filter_s8, vl); + vint16m1_t v_filter_plus_offset_s16 = __riscv_vadd_vx_i16m1(v_filter_s16, s_filter_offset_s16, vl); + v_acc_s32 = __riscv_vwmacc_vx_i32m2(v_acc_s32, s_input_val_s16, v_filter_plus_offset_s16, vl); + } + + const int effective_right_shift = 31 - output_shift; + vint32m2_t v_res32 = RequantizeVectorPerTensorS32( + v_acc_s32, + output_multiplier, + effective_right_shift, + output_offset, + output_activation_min, + output_activation_max, + vl); + + // Narrow result + vint16m1_t v_res16 = __riscv_vnclip_wx_i16m1(v_res32, 0, __RISCV_VXRM_RNU, vl); + vint8mf2_t v_out_s8 = __riscv_vnclip_wx_i8mf2(v_res16, 0, __RISCV_VXRM_RNU, vl); + __riscv_vse8_v_i8mf2(output_batch_ptr + current_out_c, v_out_s8, vl); + + // Advance to the next block of output channels + current_out_c += vl; + } + } +} \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/fully_connected_rvv.h b/tensorflow/lite/micro/kernels/riscv_vector/fully_connected_rvv.h new file mode 100644 index 00000000000..48d00580b7a --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/fully_connected_rvv.h @@ -0,0 +1,33 @@ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_FULLY_CONNECTED_RVV_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_FULLY_CONNECTED_RVV_H_ + +#include "tensorflow/lite/micro/kernels/fully_connected.h" +#include "tensorflow/lite/c/common.h" + +using namespace tflite; + +void FullyConnectedPerChannelRVV( + const FullyConnectedParams& params, + const int32_t* output_multiplier, + const int* output_shift, + const RuntimeShape& input_shape, + const int8_t* input_data, + const RuntimeShape& filter_shape, + const int8_t* filter_data, + const RuntimeShape& bias_shape, + const int32_t* bias_data, + const RuntimeShape& output_shape, + int8_t* output_data); + +void FullyConnectedRVV( + const FullyConnectedParams& params, + const RuntimeShape& input_shape, + const int8_t* input_data, + const RuntimeShape& filter_shape, + const int8_t* filter_data, + const RuntimeShape& bias_shape, + const int32_t* bias_data, + const RuntimeShape& output_shape, + int8_t* output_data); + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_FULLY_CONNECTED_RVV_H_ \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/pooling.cc b/tensorflow/lite/micro/kernels/riscv_vector/pooling.cc new file mode 100644 index 00000000000..934f526d82f --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/pooling.cc @@ -0,0 +1,141 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/pooling.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/pooling.h" +#include "tensorflow/lite/micro/micro_log.h" + +#include "pooling_rvv.h" + +namespace tflite { + +namespace { + +TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->builtin_data != nullptr); + auto* params = reinterpret_cast(node->builtin_data); + + TFLITE_DCHECK(node->user_data != nullptr); + const OpDataPooling* data = + static_cast(node->user_data); + + const TfLiteEvalTensor* input = + micro::GetEvalInput(context, node, kPoolingInputTensor); + TfLiteEvalTensor* output = + micro::GetEvalOutput(context, node, kPoolingOutputTensor); + + // Inputs and outputs share the same type, guaranteed by the converter. + switch (input->type) { + case kTfLiteFloat32: + AveragePoolingEvalFloat(context, node, params, data, input, output); + break; + case kTfLiteInt8: + AveragePoolingEvalQuantized(context, node, params, data, input, + output); + break; + case kTfLiteInt16: + AveragePoolingEvalQuantized(context, node, params, data, input, + output); + break; + default: + MicroPrintf("Input type %s is not currently supported", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->builtin_data != nullptr); + auto* params = reinterpret_cast(node->builtin_data); + + TFLITE_DCHECK(node->user_data != nullptr); + const OpDataPooling* data = + static_cast(node->user_data); + + const TfLiteEvalTensor* input = + micro::GetEvalInput(context, node, kPoolingInputTensor); + TfLiteEvalTensor* output = + micro::GetEvalOutput(context, node, kPoolingOutputTensor); + + switch (input->type) { + case kTfLiteFloat32: + MaxPoolingEvalFloat(context, node, params, data, input, output); + break; + case kTfLiteInt8: + { + tflite::PoolParams op_params; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.filter_height = params->filter_height; + op_params.filter_width = params->filter_width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = data->activation_min; + op_params.quantized_activation_max = data->activation_max; + + MaxPool8BitRVV(op_params, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } + break; + case kTfLiteInt16: + { + tflite::PoolParams op_params; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.filter_height = params->filter_height; + op_params.filter_width = params->filter_width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = data->activation_min; + op_params.quantized_activation_max = data->activation_max; + + MaxPool16BitRVV(op_params, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } + break; + default: + MicroPrintf("Type %s not currently supported.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +void* PoolInit(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpDataPooling)); +} + +} // namespace + +TFLMRegistration Register_AVERAGE_POOL_2D() { + return tflite::micro::RegisterOp(PoolInit, PoolingPrepare, AverageEval); +} + +TFLMRegistration Register_MAX_POOL_2D() { + return tflite::micro::RegisterOp(PoolInit, PoolingPrepare, MaxEval); +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/riscv_vector/pooling_rvv.cc b/tensorflow/lite/micro/kernels/riscv_vector/pooling_rvv.cc new file mode 100644 index 00000000000..2986fdb95ac --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/pooling_rvv.cc @@ -0,0 +1,202 @@ +#include + +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/micro/micro_log.h" + +using namespace tflite; + +void MaxPool8BitRVV(const PoolParams& params, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& output_shape, + int8_t* output_data) +{ + // Extract pooling parameters + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; + const int filter_height = params.filter_height; + const int filter_width = params.filter_width; + const int pad_height = params.padding_values.height; + const int pad_width = params.padding_values.width; + const int8_t output_activation_min = params.quantized_activation_min; + const int8_t output_activation_max = params.quantized_activation_max; + + // Extract shape dimensions + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + // Calculate tensor strides for direct pointer arithmetic + const int input_y_stride = input_width * depth; + const int input_b_stride = input_height * input_y_stride; + const int output_y_stride = output_width * depth; + const int output_b_stride = output_height * output_y_stride; + + // Loop over batches + for (int batch = 0; batch < batches; ++batch) { + const int8_t* input_batch_base = input_data + batch * input_b_stride; + int8_t* output_batch_base = output_data + batch * output_b_stride; + + // Loop over output spatial dimensions (y, x) + for (int out_y = 0; out_y < output_height; ++out_y) + { + for (int out_x = 0; out_x < output_width; ++out_x) + { + // Vectorized loop over channels (depth) + size_t current_channel = 0; + while (current_channel < static_cast(depth)) + { + // Set vector length. For `zvl128b`, VLEN=128. With SEW=8 (int8_t), + // VLMAX is 16 * LMUL. Using LMUL=4 provides a good balance, allowing + // up to 64 channels to be processed per iteration. + size_t vl = __riscv_vsetvl_e8m4(depth - current_channel); + + // Initialize the accumulator vector with the smallest possible int8_t value. + vint8m4_t v_max_s8 = __riscv_vmv_v_x_i8m4(std::numeric_limits::lowest(), vl); + + // Loop over the filter window dimensions (y, x) + for (int f_y = 0; f_y < filter_height; ++f_y) + { + for (int f_x = 0; f_x < filter_width; ++f_x) + { + // Calculate corresponding input coordinates for this filter tap + const int in_y = (out_y * stride_height) + f_y - pad_height; + const int in_x = (out_x * stride_width) + f_x - pad_width; + + // Handle padding by checking if the input coordinates are valid + if (in_y >= 0 && in_y < input_height && in_x >= 0 && in_x < input_width) + { + // If valid, calculate the pointer to the input vector + const int8_t* input_ptr = input_batch_base + + (in_y * input_y_stride) + + (in_x * depth) + + current_channel; + + // Load a vector of input values (unit-stride access) + vint8m4_t v_input_s8 = __riscv_vle8_v_i8m4(input_ptr, vl); + + // Perform the vector max operation + v_max_s8 = __riscv_vmax_vv_i8m4(v_max_s8, v_input_s8, vl); + } + } + } + + // After iterating through the filter window, apply activation clamping + v_max_s8 = __riscv_vmax_vx_i8m4(v_max_s8, output_activation_min, vl); + v_max_s8 = __riscv_vmin_vx_i8m4(v_max_s8, output_activation_max, vl); + + // Calculate the output pointer + int8_t* output_ptr = output_batch_base + + (out_y * output_y_stride) + + (out_x * depth) + + current_channel; + + // Store the final vector of maximum values (unit-stride access) + __riscv_vse8_v_i8m4(output_ptr, v_max_s8, vl); + + // Advance to the next block of channels + current_channel += vl; + } + } + } + } +} + +void MaxPool16BitRVV(const PoolParams& params, const RuntimeShape& input_shape, + const int16_t* input_data, const RuntimeShape& output_shape, + int16_t* output_data) +{ + // Extract pooling parameters + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; + const int filter_height = params.filter_height; + const int filter_width = params.filter_width; + const int pad_height = params.padding_values.height; + const int pad_width = params.padding_values.width; + const int16_t output_activation_min = params.quantized_activation_min; + const int16_t output_activation_max = params.quantized_activation_max; + + // Extract shape dimensions + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + // Calculate tensor strides for direct pointer arithmetic + const int input_y_stride = input_width * depth; + const int input_b_stride = input_height * input_y_stride; + const int output_y_stride = output_width * depth; + const int output_b_stride = output_height * output_y_stride; + + // Loop over batches + for (int batch = 0; batch < batches; ++batch) + { + const int16_t* input_batch_base = input_data + batch * input_b_stride; + int16_t* output_batch_base = output_data + batch * output_b_stride; + + // Loop over output spatial dimensions (y, x) + for (int out_y = 0; out_y < output_height; ++out_y) + { + for (int out_x = 0; out_x < output_width; ++out_x) + { + // Vectorized loop over channels (depth) + size_t current_channel = 0; + while (current_channel < static_cast(depth)) + { + // Set vector length. SEW is now 16 bits. With VLEN=128, VLMAX is 8 * LMUL. + // LMUL=4 still provides a good balance, processing up to 32 channels. + size_t vl = __riscv_vsetvl_e16m4(depth - current_channel); + + // Initialize the accumulator vector with the smallest possible int16_t value. + vint16m4_t v_max_s16 = __riscv_vmv_v_x_i16m4(std::numeric_limits::lowest(), vl); + + // Loop over the filter window dimensions (y, x) + for (int f_y = 0; f_y < filter_height; ++f_y) + { + for (int f_x = 0; f_x < filter_width; ++f_x) + { + // Calculate corresponding input coordinates for this filter tap + const int in_y = (out_y * stride_height) + f_y - pad_height; + const int in_x = (out_x * stride_width) + f_x - pad_width; + + // Handle padding by checking if the input coordinates are valid + if (in_y >= 0 && in_y < input_height && in_x >= 0 && in_x < input_width) + { + // If valid, calculate the pointer to the input vector + const int16_t* input_ptr = input_batch_base + + (in_y * input_y_stride) + + (in_x * depth) + + current_channel; + + // Load a vector of input values (unit-stride access) + vint16m4_t v_input_s16 = __riscv_vle16_v_i16m4(input_ptr, vl); + + // Perform the vector max operation + v_max_s16 = __riscv_vmax_vv_i16m4(v_max_s16, v_input_s16, vl); + } + } + } + + // After iterating through the filter window, apply activation clamping + v_max_s16 = __riscv_vmax_vx_i16m4(v_max_s16, output_activation_min, vl); + v_max_s16 = __riscv_vmin_vx_i16m4(v_max_s16, output_activation_max, vl); + + // Calculate the output pointer + int16_t* output_ptr = output_batch_base + + (out_y * output_y_stride) + + (out_x * depth) + + current_channel; + + // Store the final vector of maximum values (unit-stride access) + __riscv_vse16_v_i16m4(output_ptr, v_max_s16, vl); + + // Advance to the next block of channels + current_channel += vl; + } + } + } + } +} \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/pooling_rvv.h b/tensorflow/lite/micro/kernels/riscv_vector/pooling_rvv.h new file mode 100644 index 00000000000..69c05065106 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/pooling_rvv.h @@ -0,0 +1,17 @@ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_POOLING_RVV_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_POOLING_RVV_H_ + +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/micro/micro_log.h" + +using namespace tflite; + +void MaxPool8BitRVV(const PoolParams& params, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& output_shape, + int8_t* output_data); + +void MaxPool16BitRVV(const PoolParams& params, const RuntimeShape& input_shape, + const int16_t* input_data, const RuntimeShape& output_shape, + int16_t* output_data); + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_POOLING_RVV_H_ \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/requantize_rvv.h b/tensorflow/lite/micro/kernels/riscv_vector/requantize_rvv.h new file mode 100644 index 00000000000..d1812fcd85e --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/requantize_rvv.h @@ -0,0 +1,173 @@ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_REQUANTIZE_RVV_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_REQUANTIZE_RVV_H_ + +inline vint32m2_t RequantizeVectorPerTensorS32( + vint32m2_t v_acc, const int32_t multiplier, const int effective_right_shift, + const int32_t output_offset, const int32_t activation_min, + const int32_t activation_max, const size_t vl) +{ + // Calculate rounding constants for the 64-bit shift + const int64_t rounding_val = + (effective_right_shift > 0) + ? (INT64_C(1) << (effective_right_shift - 1)) + : 0; + const int32_t rounding_lo = static_cast(rounding_val); + const int32_t rounding_hi = static_cast((rounding_val >> 32)); + + // Multiply accumulator by scalar multiplier (results in 64b intermediate) + // Uses m2 intrinsics + vint32m2_t v_prod_lo = __riscv_vmul_vx_i32m2(v_acc, multiplier, vl); + vint32m2_t v_prod_hi = __riscv_vmulh_vx_i32m2(v_acc, multiplier, vl); + + // Add 64b rounding value using 32b operations with carry + vuint32m2_t v_prod_lo_u = __riscv_vreinterpret_v_i32m2_u32m2(v_prod_lo); + vuint32m2_t v_sum_lo_u = __riscv_vadd_vx_u32m2(v_prod_lo_u, rounding_lo, vl); + vbool16_t v_carry = __riscv_vmsltu_vx_u32m2_b16(v_sum_lo_u, rounding_lo, vl); + vint32m2_t v_rounded_hi = __riscv_vadd_vx_i32m2(v_prod_hi, rounding_hi, vl); + v_rounded_hi = __riscv_vadd_vx_i32m2_m(v_carry, v_rounded_hi, 1, vl); + vint32m2_t v_rounded_lo = __riscv_vreinterpret_v_u32m2_i32m2(v_sum_lo_u); + + // Perform 64b arithmetic right shift using 32b vector shifts + vint32m2_t v_res32; + if (effective_right_shift == 0) + { + v_res32 = v_rounded_lo; + } + else if (effective_right_shift > 0 && effective_right_shift < 32) + { + vuint32m2_t v_lo_usrl = __riscv_vsrl_vx_u32m2( + __riscv_vreinterpret_v_i32m2_u32m2(v_rounded_lo), + effective_right_shift, vl); + vint32m2_t v_hi_sll = __riscv_vsll_vx_i32m2( + v_rounded_hi, 32 - effective_right_shift, vl); + v_res32 = __riscv_vreinterpret_v_u32m2_i32m2(__riscv_vor_vv_u32m2( + v_lo_usrl, __riscv_vreinterpret_v_i32m2_u32m2(v_hi_sll), vl)); + } + else + { + const int shift_hi = std::min(31, effective_right_shift - 32); + v_res32 = __riscv_vsra_vx_i32m2(v_rounded_hi, shift_hi, vl); + } + + // Add output offset + v_res32 = __riscv_vadd_vx_i32m2(v_res32, output_offset, vl); + + // Clamp to activation bounds + v_res32 = __riscv_vmax_vx_i32m2(v_res32, activation_min, vl); + v_res32 = __riscv_vmin_vx_i32m2(v_res32, activation_max, vl); + + return v_res32; +} + +inline vint32m2_t RequantizeVectorPerChannelS32( + vint32m2_t v_acc, vint32m2_t v_multiplier, vint32m2_t v_shift, + const int32_t output_offset, const int32_t activation_min, + const int32_t activation_max, const size_t vl) +{ + // Perform 32x32 -> 64-bit multiplication + vint32m2_t v_prod_hi = __riscv_vmulh_vv_i32m2(v_acc, v_multiplier, vl); + vint32m2_t v_prod_lo = __riscv_vmul_vv_i32m2(v_acc, v_multiplier, vl); + + // Calculate effective right shift + vint32m2_t v_effective_shift = __riscv_vrsub_vx_i32m2(v_shift, 31, vl); + + // Create masks + vbool16_t v_mask_right_shift = + __riscv_vmsgt_vx_i32m2_b16(v_effective_shift, 0, vl); + vbool16_t v_mask_left_shift = __riscv_vmnot_m_b16(v_mask_right_shift, vl); + + // Path 1: Right Shift + // Initialize to 0 to avoid "maybe-uninitialized" warnings + vint32m2_t v_res_right = __riscv_vmv_v_x_i32m2(0, vl); + + // Optimization: check if any lane needs right shift + if (__riscv_vfirst_m_b16(v_mask_right_shift, vl) >= 0) + { + vint32m2_t v_shift_minus_1 = __riscv_vsub_vx_i32m2_m( + v_mask_right_shift, v_effective_shift, 1, vl); + vuint32m2_t v_shift_minus_1_u = + __riscv_vreinterpret_v_i32m2_u32m2(v_shift_minus_1); + vbool16_t v_mask_round_lt_32 = __riscv_vmsltu_vx_u32m2_b16_m( + v_mask_right_shift, v_shift_minus_1_u, 32, vl); + vbool16_t v_mask_round_ge_32 = __riscv_vmandn_mm_b16( + v_mask_right_shift, v_mask_round_lt_32, vl); + vuint32m2_t v_one_u = __riscv_vmv_v_x_u32m2(1, vl); + vuint32m2_t v_zero_u = __riscv_vmv_v_x_u32m2(0, vl); + vuint32m2_t v_rounding_lo_u = __riscv_vmerge_vvm_u32m2( + v_zero_u, + __riscv_vsll_vv_u32m2_m(v_mask_round_lt_32, v_one_u, + v_shift_minus_1_u, vl), + v_mask_round_lt_32, vl); + vuint32m2_t v_rounding_hi_u = __riscv_vmerge_vvm_u32m2( + v_zero_u, + __riscv_vsll_vv_u32m2_m( + v_mask_round_ge_32, v_one_u, + __riscv_vsub_vx_u32m2_m(v_mask_round_ge_32, v_shift_minus_1_u, + 32, vl), + vl), + v_mask_round_ge_32, vl); + + vuint32m2_t v_prod_lo_u = __riscv_vreinterpret_v_i32m2_u32m2(v_prod_lo); + vuint32m2_t v_sum_lo_u = __riscv_vadd_vv_u32m2_m( + v_mask_right_shift, v_prod_lo_u, v_rounding_lo_u, vl); + vbool16_t v_carry = __riscv_vmsltu_vv_u32m2_b16_m( + v_mask_right_shift, v_sum_lo_u, v_prod_lo_u, vl); + vint32m2_t v_rounded_hi = __riscv_vadd_vv_i32m2_m( + v_mask_right_shift, v_prod_hi, + __riscv_vreinterpret_v_u32m2_i32m2(v_rounding_hi_u), vl); + v_rounded_hi = __riscv_vadd_vx_i32m2_m(v_carry, v_rounded_hi, 1, vl); + + vbool16_t v_mask_shift_lt_32 = __riscv_vmslt_vx_i32m2_b16_m( + v_mask_right_shift, v_effective_shift, 32, vl); + vbool16_t v_mask_shift_ge_32 = __riscv_vmandn_mm_b16( + v_mask_right_shift, v_mask_shift_lt_32, vl); + vuint32m2_t v_shift_u = + __riscv_vreinterpret_v_i32m2_u32m2(v_effective_shift); + vuint32m2_t v_lo_part = __riscv_vsrl_vv_u32m2_m( + v_mask_shift_lt_32, v_sum_lo_u, v_shift_u, vl); + vuint32m2_t v_hi_part = __riscv_vsll_vv_u32m2_m( + v_mask_shift_lt_32, + __riscv_vreinterpret_v_i32m2_u32m2(v_rounded_hi), + __riscv_vrsub_vx_u32m2_m(v_mask_shift_lt_32, v_shift_u, 32, vl), + vl); + vint32m2_t v_res_lt_32 = __riscv_vreinterpret_v_u32m2_i32m2( + __riscv_vor_vv_u32m2_m(v_mask_shift_lt_32, v_lo_part, v_hi_part, vl)); + vint32m2_t v_res_ge_32 = __riscv_vsra_vv_i32m2_m( + v_mask_shift_ge_32, v_rounded_hi, + __riscv_vreinterpret_v_i32m2_u32m2(__riscv_vsub_vx_i32m2_m( + v_mask_shift_ge_32, v_effective_shift, 32, vl)), + vl); + v_res_right = __riscv_vmerge_vvm_i32m2(v_res_ge_32, v_res_lt_32, + v_mask_shift_lt_32, vl); + } + + // Path 2: Left Shift + // Initialize to 0 to avoid "maybe-uninitialized" warnings + vint32m2_t v_res_left = __riscv_vmv_v_x_i32m2(0, vl); + + if (__riscv_vfirst_m_b16(v_mask_left_shift, vl) >= 0) + { + vint32m2_t v_left_shift_amount = + __riscv_vneg_v_i32m2_m(v_mask_left_shift, v_effective_shift, vl); + + v_res_left = __riscv_vsll_vv_i32m2_m( + v_mask_left_shift, v_prod_lo, + __riscv_vreinterpret_v_i32m2_u32m2(v_left_shift_amount), vl); + } + + // Merge results + // Lanes with mask_right=1 take v_res_right, mask_right=0 (left) take v_res_left + vint32m2_t v_res32 = + __riscv_vmerge_vvm_i32m2(v_res_left, v_res_right, v_mask_right_shift, vl); + + // Add output offset + v_res32 = __riscv_vadd_vx_i32m2(v_res32, output_offset, vl); + + // Clamp to activation bounds + v_res32 = __riscv_vmax_vx_i32m2(v_res32, activation_min, vl); + v_res32 = __riscv_vmin_vx_i32m2(v_res32, activation_max, vl); + + return v_res32; +} + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_REQUANTIZE_RVV_H_ \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank.cc b/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank.cc new file mode 100644 index 00000000000..150373b7456 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank.cc @@ -0,0 +1,178 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_rvv.h" + +#include + +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/flatbuffer_utils.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/memory_helpers.h" +#include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/micro_utils.h" + +namespace tflite { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kWeightTensor = 1; +constexpr int kUnweightTensor = 2; +constexpr int kChFreqStartsTensor = 3; +constexpr int kChWeightStartsTensor = 4; +constexpr int kChannelWidthsTensor = 5; +constexpr int kOutputTensor = 0; + +// Indices into the init flexbuffer's vector. +// The parameter's name is in the comment that follows. +// Elements in the vectors are ordered alphabetically by parameter name. +constexpr int kNumChannelsIndex = 0; // 'num_channels' + +struct TFLMSignalFilterBankParams { + FilterbankConfig config; + uint64_t* work_area; +}; + +void* FilterBankInit(TfLiteContext* context, const char* buffer, + size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + + auto* params = static_cast( + context->AllocatePersistentBuffer(context, + sizeof(TFLMSignalFilterBankParams))); + if (params == nullptr) { + return nullptr; + } + + tflite::FlexbufferWrapper fbw(reinterpret_cast(buffer), + length); + params->config.num_channels = fbw.ElementAsInt32(kNumChannelsIndex); + + params->work_area = static_cast(context->AllocatePersistentBuffer( + context, (params->config.num_channels + 1) * sizeof(uint64_t))); + + if (params->work_area == nullptr) { + return nullptr; + } + + return params; +} + +TfLiteStatus FilterBankPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 6); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + MicroContext* micro_context = GetMicroContext(context); + TfLiteTensor* input = + micro_context->AllocateTempInputTensor(node, kInputTensor); + TF_LITE_ENSURE(context, input != nullptr); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt32); + micro_context->DeallocateTempTfLiteTensor(input); + + input = micro_context->AllocateTempInputTensor(node, kWeightTensor); + TF_LITE_ENSURE(context, input != nullptr); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16); + micro_context->DeallocateTempTfLiteTensor(input); + + input = micro_context->AllocateTempInputTensor(node, kUnweightTensor); + TF_LITE_ENSURE(context, input != nullptr); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16); + micro_context->DeallocateTempTfLiteTensor(input); + + input = micro_context->AllocateTempInputTensor(node, kChFreqStartsTensor); + TF_LITE_ENSURE(context, input != nullptr); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16); + micro_context->DeallocateTempTfLiteTensor(input); + + input = micro_context->AllocateTempInputTensor(node, kChWeightStartsTensor); + TF_LITE_ENSURE(context, input != nullptr); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16); + micro_context->DeallocateTempTfLiteTensor(input); + + input = micro_context->AllocateTempInputTensor(node, kChannelWidthsTensor); + TF_LITE_ENSURE(context, input != nullptr); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16); + micro_context->DeallocateTempTfLiteTensor(input); + + TfLiteTensor* output = + micro_context->AllocateTempOutputTensor(node, kOutputTensor); + TF_LITE_ENSURE(context, output != nullptr); + TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1); + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt64); + micro_context->DeallocateTempTfLiteTensor(output); + + return kTfLiteOk; +} + +TfLiteStatus FilterBankEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->user_data); + + const TfLiteEvalTensor* input0 = + tflite::micro::GetEvalInput(context, node, kInputTensor); + const TfLiteEvalTensor* input1 = + tflite::micro::GetEvalInput(context, node, kWeightTensor); + const TfLiteEvalTensor* input2 = + tflite::micro::GetEvalInput(context, node, kUnweightTensor); + const TfLiteEvalTensor* input3 = + tflite::micro::GetEvalInput(context, node, kChFreqStartsTensor); + const TfLiteEvalTensor* input4 = + tflite::micro::GetEvalInput(context, node, kChWeightStartsTensor); + const TfLiteEvalTensor* input5 = + tflite::micro::GetEvalInput(context, node, kChannelWidthsTensor); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + params->config.weights = tflite::micro::GetTensorData(input1); + params->config.unweights = tflite::micro::GetTensorData(input2); + params->config.channel_frequency_starts = + tflite::micro::GetTensorData(input3); + params->config.channel_weight_starts = + tflite::micro::GetTensorData(input4); + params->config.channel_widths = tflite::micro::GetTensorData(input5); + + const uint32_t* input_data = tflite::micro::GetTensorData(input0); + uint64_t* output_data = tflite::micro::GetTensorData(output); + + FilterbankAccumulateChannelsRVV(¶ms->config, input_data, + params->work_area); + + size_t output_size; + TfLiteTypeSizeOf(output->type, &output_size); + output_size *= ElementCount(*output->dims); + // Discard channel 0, which is just scratch + memcpy(output_data, params->work_area + 1, output_size); + return kTfLiteOk; +} + +} // namespace + +namespace tflm_signal { + +TFLMRegistration* Register_FILTER_BANK() { + static TFLMRegistration r = tflite::micro::RegisterOp( + FilterBankInit, FilterBankPrepare, FilterBankEval); + return &r; +} + +} // namespace tflm_signal + +} // namespace tflite \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_log.cc b/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_log.cc new file mode 100644 index 00000000000..eeee7cc2797 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_log.cc @@ -0,0 +1,114 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/flatbuffer_utils.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/memory_helpers.h" +#include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/micro_utils.h" + +#include "tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_log_rvv.h" + +namespace tflite { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +// Indices into the init flexbuffer's vector. +// The parameter's name is in the comment that follows. +// Elements in the vectors are ordered alphabetically by parameter name. +constexpr int kInputCorrectionBitsIndex = 0; // 'input_correction_bits' +constexpr int kOutputScaleIndex = 1; // 'output_scale' + +struct TFLMSignalLogParams { + int input_correction_bits; + int output_scale; +}; + +void* FilterBankLogInit(TfLiteContext* context, const char* buffer, + size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + + auto* params = static_cast( + context->AllocatePersistentBuffer(context, sizeof(TFLMSignalLogParams))); + + if (params == nullptr) { + return nullptr; + } + tflite::FlexbufferWrapper fbw(reinterpret_cast(buffer), + length); + + params->input_correction_bits = fbw.ElementAsInt32(kInputCorrectionBitsIndex); + params->output_scale = fbw.ElementAsInt32(kOutputScaleIndex); + return params; +} + +TfLiteStatus FilterBankLogPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + MicroContext* micro_context = GetMicroContext(context); + TfLiteTensor* input = + micro_context->AllocateTempInputTensor(node, kInputTensor); + TfLiteTensor* output = + micro_context->AllocateTempOutputTensor(node, kOutputTensor); + TF_LITE_ENSURE(context, input != nullptr); + TF_LITE_ENSURE(context, output != nullptr); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1); + + TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt32); + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16); + + micro_context->DeallocateTempTfLiteTensor(input); + micro_context->DeallocateTempTfLiteTensor(output); + return kTfLiteOk; +} + +TfLiteStatus FilterBankLogEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->user_data); + + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + const uint32_t* input_data = tflite::micro::GetTensorData(input); + int16_t* output_data = tflite::micro::GetTensorData(output); + int num_channels = input->dims->data[0]; + FilterbankLogRVV(input_data, num_channels, params->output_scale, + params->input_correction_bits, output_data); + return kTfLiteOk; +} + +} // namespace + +namespace tflm_signal { + +TFLMRegistration* Register_FILTER_BANK_LOG() { + static TFLMRegistration r = tflite::micro::RegisterOp( + FilterBankLogInit, FilterBankLogPrepare, FilterBankLogEval); + return &r; +} + +} // namespace tflm_signal + +} // namespace tflite \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_log_rvv.cc b/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_log_rvv.cc new file mode 100644 index 00000000000..235513e46b7 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_log_rvv.cc @@ -0,0 +1,160 @@ +#include + +#include "tensorflow/lite/kernels/internal/common.h" + +constexpr uint16_t kLogCoeff = 45426; + +const uint16_t kLogLut[] = +{ + 0, 224, 442, 654, 861, 1063, 1259, 1450, 1636, 1817, 1992, 2163, + 2329, 2490, 2646, 2797, 2944, 3087, 3224, 3358, 3487, 3611, 3732, 3848, + 3960, 4068, 4172, 4272, 4368, 4460, 4549, 4633, 4714, 4791, 4864, 4934, + 5001, 5063, 5123, 5178, 5231, 5280, 5326, 5368, 5408, 5444, 5477, 5507, + 5533, 5557, 5578, 5595, 5610, 5622, 5631, 5637, 5640, 5641, 5638, 5633, + 5626, 5615, 5602, 5586, 5568, 5547, 5524, 5498, 5470, 5439, 5406, 5370, + 5332, 5291, 5249, 5203, 5156, 5106, 5054, 5000, 4944, 4885, 4825, 4762, + 4697, 4630, 4561, 4490, 4416, 4341, 4264, 4184, 4103, 4020, 3935, 3848, + 3759, 3668, 3575, 3481, 3384, 3286, 3186, 3084, 2981, 2875, 2768, 2659, + 2549, 2437, 2323, 2207, 2090, 1971, 1851, 1729, 1605, 1480, 1353, 1224, + 1094, 963, 830, 695, 559, 421, 282, 142, 0, 0 +}; + +// Calculate Integer Log2 using binary search (SIMD compatible). +// This manual implementation is required because the target architecture +// (rv32imc_zve32x_zvl128b) does not support the 'zvbb' extension +// which provides the hardware '__riscv_vclz' instruction. +inline vuint32m4_t VectorLog2Int_Zve32x(vuint32m4_t v_in, size_t vl) +{ + // Initialize variables + vuint32m4_t v_result = __riscv_vmv_v_x_u32m4(0, vl); + vuint32m4_t v_tmp; + vbool8_t v_mask; + + // Check bit 16 and update result and input + v_tmp = __riscv_vsrl_vx_u32m4(v_in, 16, vl); + v_mask = __riscv_vmsne_vx_u32m4_b8(v_tmp, 0, vl); + v_result = __riscv_vadd_vx_u32m4_mu(v_mask, v_result, v_result, 16, vl); + v_in = __riscv_vmerge_vvm_u32m4(v_in, v_tmp, v_mask, vl); + + // Check bit 8 and update result and input + v_tmp = __riscv_vsrl_vx_u32m4(v_in, 8, vl); + v_mask = __riscv_vmsne_vx_u32m4_b8(v_tmp, 0, vl); + v_result = __riscv_vadd_vx_u32m4_mu(v_mask, v_result, v_result, 8, vl); + v_in = __riscv_vmerge_vvm_u32m4(v_in, v_tmp, v_mask, vl); + + // Check bit 4 and update result and input + v_tmp = __riscv_vsrl_vx_u32m4(v_in, 4, vl); + v_mask = __riscv_vmsne_vx_u32m4_b8(v_tmp, 0, vl); + v_result = __riscv_vadd_vx_u32m4_mu(v_mask, v_result, v_result, 4, vl); + v_in = __riscv_vmerge_vvm_u32m4(v_in, v_tmp, v_mask, vl); + + // Check bit 2 and update result and input + v_tmp = __riscv_vsrl_vx_u32m4(v_in, 2, vl); + v_mask = __riscv_vmsne_vx_u32m4_b8(v_tmp, 0, vl); + v_result = __riscv_vadd_vx_u32m4_mu(v_mask, v_result, v_result, 2, vl); + v_in = __riscv_vmerge_vvm_u32m4(v_in, v_tmp, v_mask, vl); + + // Check bit 1 and update result + v_tmp = __riscv_vsrl_vx_u32m4(v_in, 1, vl); + v_mask = __riscv_vmsne_vx_u32m4_b8(v_tmp, 0, vl); + v_result = __riscv_vadd_vx_u32m4_mu(v_mask, v_result, v_result, 1, vl); + + return v_result; +} + +void FilterbankLogRVV(const uint32_t* input, int num_channels, + int32_t output_scale, uint32_t correction_bits, + int16_t* output) +{ + const uint32_t* p_src = input; + int16_t* p_dst = output; + int remaining = num_channels; + + while (remaining > 0) + { + // Set vector length and load input + size_t vl = __riscv_vsetvl_e32m4(remaining); + vuint32m4_t v_input = __riscv_vle32_v_u32m4(p_src, vl); + vuint32m4_t v_scaled = __riscv_vsll_vx_u32m4(v_input, correction_bits, vl); + vbool8_t v_active = __riscv_vmsgtu_vx_u32m4_b8(v_scaled, 1, vl); + + // Calculate integer part of log2 + vuint32m4_t v_integer = VectorLog2Int_Zve32x(v_scaled, vl); + + // Normalize mantissa to [1.0, 2.0) in Q16 + vuint32m4_t v_shift_norm = __riscv_vrsub_vx_u32m4(v_integer, 31, vl); + vuint32m4_t v_norm = __riscv_vsll_vv_u32m4(v_scaled, v_shift_norm, vl); + vuint32m4_t v_frac = __riscv_vsrl_vx_u32m4(v_norm, 15, vl); + v_frac = __riscv_vand_vx_u32m4(v_frac, 0xFFFF, vl); + + // Calculate base segment index and offsets for LUT access + vuint32m4_t v_base_seg = __riscv_vsrl_vx_u32m4(v_frac, 9, vl); + vuint16m2_t v_base_seg_u16 = __riscv_vncvt_x_x_w_u16m2(v_base_seg, vl); + vuint16m2_t v_offset = __riscv_vsll_vx_u16m2(v_base_seg_u16, 1, vl); + + // Gather LUT coefficients using 16-bit element width + size_t vl_u16 = __riscv_vsetvl_e16m2(vl); + vuint16m2_t v_c0_u16 = __riscv_vluxei16_v_u16m2(kLogLut, v_offset, vl_u16); + v_offset = __riscv_vadd_vx_u16m2(v_offset, 2, vl); + vuint16m2_t v_c1_u16 = __riscv_vluxei16_v_u16m2(kLogLut, v_offset, vl_u16); + + // Calculate interpolation distance and difference + vint16m2_t v_diff = __riscv_vsub_vv_i16m2( + __riscv_vreinterpret_v_u16m2_i16m2(v_c1_u16), + __riscv_vreinterpret_v_u16m2_i16m2(v_c0_u16), vl_u16); + vuint16m2_t v_frac_u16 = __riscv_vncvt_x_x_w_u16m2(v_frac, vl); + vuint16m2_t v_seg_base = __riscv_vand_vx_u16m2(v_frac_u16, 0xFE00, vl_u16); + vuint16m2_t v_dist = __riscv_vsub_vv_u16m2(v_frac_u16, v_seg_base, vl_u16); + + // Restore vector length and widen for interpolation + vl = __riscv_vsetvl_e32m4(vl); + vint32m4_t v_rel_pos = __riscv_vwmul_vv_i32m4( + v_diff, __riscv_vreinterpret_v_u16m2_i16m2(v_dist), vl); + v_rel_pos = __riscv_vsra_vx_i32m4(v_rel_pos, 16, vl); + + // Combine interpolated result with base coefficient and fraction + vint32m4_t v_tmp = __riscv_vwadd_wv_i32m4( + v_rel_pos, __riscv_vreinterpret_v_u16m2_i16m2(v_c0_u16), vl); + vint32m4_t v_final_frac_part = __riscv_vadd_vv_i32m4( + v_tmp, __riscv_vreinterpret_v_u32m4_i32m4(v_frac), vl); + + // Convert Log2 to LogE using fixed point multiplication + vuint32m4_t v_term1 = __riscv_vmul_vx_u32m4(v_integer, kLogCoeff, vl); + vuint32m4_t v_frac_u32 = __riscv_vreinterpret_v_i32m4_u32m4(v_final_frac_part); + vuint32m4_t v_term2_u = __riscv_vmul_vx_u32m4(v_frac_u32, kLogCoeff, vl); + v_term2_u = __riscv_vadd_vx_u32m4(v_term2_u, 32768, vl); + v_term2_u = __riscv_vsrl_vx_u32m4(v_term2_u, 16, vl); + vuint32m4_t v_loge = __riscv_vadd_vv_u32m4(v_term1, v_term2_u, vl); + + // Apply output scaling using signed arithmetic + vint32m4_t v_loge_i = __riscv_vreinterpret_v_u32m4_i32m4(v_loge); + vint32m4_t v_lo = __riscv_vmul_vx_i32m4(v_loge_i, output_scale, vl); + vint32m4_t v_hi = __riscv_vmulh_vx_i32m4(v_loge_i, output_scale, vl); + + // Add rounding constant and propagate carry + vint32m4_t v_lo_rounded = __riscv_vadd_vx_i32m4(v_lo, 32768, vl); + vbool8_t v_carry = __riscv_vmsltu_vx_u32m4_b8( + __riscv_vreinterpret_v_i32m4_u32m4(v_lo_rounded), 32768, vl); + v_hi = __riscv_vadd_vx_i32m4_mu(v_carry, v_hi, v_hi, 1, vl); + + // Combine high shifted left and low shifted right + vint32m4_t v_res = __riscv_vor_vv_i32m4( + __riscv_vsll_vx_i32m4(v_hi, 16, vl), + __riscv_vreinterpret_v_u32m4_i32m4( + __riscv_vsrl_vx_u32m4( + __riscv_vreinterpret_v_i32m4_u32m4(v_lo_rounded), 16, vl)), + vl); + + // Saturate result to 16-bit range + vint16m2_t v_res_i16 = __riscv_vnclip_wx_i16m2(v_res, 0, __RISCV_VXRM_RNU, vl); + + // Zero out inactive elements and store result + vint16m2_t v_zero = __riscv_vmv_v_x_i16m2(0, vl); + vint16m2_t v_final = __riscv_vmerge_vvm_i16m2(v_zero, v_res_i16, v_active, vl); + __riscv_vse16_v_i16m2(p_dst, v_final, vl); + + p_src += vl; + p_dst += vl; + remaining -= vl; + } +} \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_log_rvv.h b/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_log_rvv.h new file mode 100644 index 00000000000..cff55f8c932 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_log_rvv.h @@ -0,0 +1,10 @@ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_SIGNAL_FILTER_BANK_LOG_RVV_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_SIGNAL_FILTER_BANK_LOG_RVV_H_ + +#include "tensorflow/lite/kernels/internal/common.h" + +void FilterbankLogRVV(const uint32_t* input, int num_channels, + int32_t output_scale, uint32_t correction_bits, + int16_t* output); + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_SIGNAL_FILTER_BANK_LOG_RVV_H_ \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_rvv.cc b/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_rvv.cc new file mode 100644 index 00000000000..a314361cd27 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_rvv.cc @@ -0,0 +1,120 @@ +#include + +#include "tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_rvv.h" +#include "tensorflow/lite/micro/micro_log.h" + +void FilterbankAccumulateChannelsRVV(const FilterbankConfig* config, + const uint32_t* input, uint64_t* output) +{ + // Initialize unweighted accumulator for the first channel + uint64_t unweight_accumulator = 0; + + // Loop over each channel + for (int i = 0; i < config->num_channels + 1; i++) + { + // Get parameters for the current channel + const int16_t freq_start = config->channel_frequency_starts[i]; + const int16_t weight_start = config->channel_weight_starts[i]; + const int16_t channel_width = config->channel_widths[i]; + + // Initialize scalar accumulators for this channel + uint64_t channel_w_acc = unweight_accumulator; + uint64_t channel_uw_acc = 0; + + // Process channel only if it has non-zero width + if (channel_width > 0) + { + // Optimization: Use LMUL=2 to fit all variables in registers and avoid spilling + size_t vl_max = __riscv_vsetvl_e32m2(channel_width); + + // Initialize vector accumulators for 64-bit sums + vuint32m2_t v_acc_w_low = __riscv_vmv_v_x_u32m2(0, vl_max); + vuint32m2_t v_acc_w_high = __riscv_vmv_v_x_u32m2(0, vl_max); + vuint32m2_t v_acc_uw_low = __riscv_vmv_v_x_u32m2(0, vl_max); + vuint32m2_t v_acc_uw_high = __riscv_vmv_v_x_u32m2(0, vl_max); + + // Process the channel width in vector-sized chunks + int j = 0; + while (j < channel_width) + { + // Set vector length for the current strip + size_t vl = __riscv_vsetvl_e32m2(channel_width - j); + + // Load vector of input data + vuint32m2_t v_input = __riscv_vle32_v_u32m2(&input[freq_start + j], vl); + + // Load Weights and Unweights + vint16m1_t v_weights16 = __riscv_vle16_v_i16m1( + reinterpret_cast(&config->weights[weight_start + j]), vl); + vint16m1_t v_unweights16 = __riscv_vle16_v_i16m1( + reinterpret_cast(&config->unweights[weight_start + j]), vl); + + // Sign-extend weights to 32-bit + vint32m2_t v_weights32 = __riscv_vsext_vf2_i32m2(v_weights16, vl); + vint32m2_t v_unweights32 = __riscv_vsext_vf2_i32m2(v_unweights16, vl); + + // Reinterpret weights as unsigned bits for vmul + vuint32m2_t v_weights32_u = __riscv_vreinterpret_v_i32m2_u32m2(v_weights32); + vuint32m2_t v_unweights32_u = __riscv_vreinterpret_v_i32m2_u32m2(v_unweights32); + + // Low part multiply + vuint32m2_t v_prod_w_low = __riscv_vmul_vv_u32m2(v_input, v_weights32_u, vl); + vuint32m2_t v_prod_uw_low = __riscv_vmul_vv_u32m2(v_input, v_unweights32_u, vl); + + // High part multiply + vint32m2_t v_prod_w_high_i = __riscv_vmulhsu_vv_i32m2(v_weights32, v_input, vl); + vint32m2_t v_prod_uw_high_i = __riscv_vmulhsu_vv_i32m2(v_unweights32, v_input, vl); + vuint32m2_t v_prod_w_high = __riscv_vreinterpret_v_i32m2_u32m2(v_prod_w_high_i); + vuint32m2_t v_prod_uw_high = __riscv_vreinterpret_v_i32m2_u32m2(v_prod_uw_high_i); + + // Accumulate Low part + vuint32m2_t v_next_acc_w_low = __riscv_vadd_vv_u32m2(v_acc_w_low, v_prod_w_low, vl); + vuint32m2_t v_next_acc_uw_low = __riscv_vadd_vv_u32m2(v_acc_uw_low, v_prod_uw_low, vl); + + // Detect Carries (if result < accumulator, we wrapped) + vbool16_t v_carry_w = __riscv_vmsltu_vv_u32m2_b16(v_next_acc_w_low, v_acc_w_low, vl); + vbool16_t v_carry_uw = __riscv_vmsltu_vv_u32m2_b16(v_next_acc_uw_low, v_acc_uw_low, vl); + + // Accumulate High part + v_acc_w_high = __riscv_vadd_vv_u32m2(v_acc_w_high, v_prod_w_high, vl); + v_acc_uw_high = __riscv_vadd_vv_u32m2(v_acc_uw_high, v_prod_uw_high, vl); + + // Apply Carry: Add 1 to high accumulator where carry is set + v_acc_w_high = __riscv_vadd_vx_u32m2_mu(v_carry_w, v_acc_w_high, v_acc_w_high, 1, vl); + v_acc_uw_high = __riscv_vadd_vx_u32m2_mu(v_carry_uw, v_acc_uw_high, v_acc_uw_high, 1, vl); + + // Update low accumulator + v_acc_w_low = v_next_acc_w_low; + v_acc_uw_low = v_next_acc_uw_low; + + // Advance stripmining index + j += vl; + } + + // Initialize a zero vector for reduction + vuint32m1_t v_zero = __riscv_vmv_v_x_u32m1(0, vl_max); + + // Reduce the 32-bit vector accumulators to scalar sums + vuint32m1_t v_sum_w_low = __riscv_vredsum_vs_u32m2_u32m1(v_acc_w_low, v_zero, vl_max); + vuint32m1_t v_sum_w_high = __riscv_vredsum_vs_u32m2_u32m1(v_acc_w_high, v_zero, vl_max); + vuint32m1_t v_sum_uw_low = __riscv_vredsum_vs_u32m2_u32m1(v_acc_uw_low, v_zero, vl_max); + vuint32m1_t v_sum_uw_high = __riscv_vredsum_vs_u32m2_u32m1(v_acc_uw_high, v_zero, vl_max); + + // Extract scalar results + uint32_t final_w_low = __riscv_vmv_x_s_u32m1_u32(v_sum_w_low); + uint32_t final_w_high = __riscv_vmv_x_s_u32m1_u32(v_sum_w_high); + uint32_t final_uw_low = __riscv_vmv_x_s_u32m1_u32(v_sum_uw_low); + uint32_t final_uw_high = __riscv_vmv_x_s_u32m1_u32(v_sum_uw_high); + + // Reconstruct the final 64-bit sum and add to channel accumulator + channel_w_acc += ((uint64_t)final_w_high << 32) | final_w_low; + channel_uw_acc += ((uint64_t)final_uw_high << 32) | final_uw_low; + } + + // Store the final weighted result for this channel + output[i] = channel_w_acc; + + // The unweighted sum from this channel becomes the starting accumulator for the next + unweight_accumulator = channel_uw_acc; + } +} \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_rvv.h b/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_rvv.h new file mode 100644 index 00000000000..c513e24dbea --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_rvv.h @@ -0,0 +1,23 @@ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_SIGNAL_FILTER_BANK_RVV_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_SIGNAL_FILTER_BANK_RVV_H_ + +#include + +#include "tensorflow/lite/kernels/internal/common.h" + +struct FilterbankConfig { + int32_t num_channels; + const int16_t* channel_frequency_starts; + const int16_t* channel_weight_starts; + const int16_t* channel_widths; + const int16_t* weights; + const int16_t* unweights; + int32_t output_scale; + + int32_t input_correction_bits; +}; + +void FilterbankAccumulateChannelsRVV(const FilterbankConfig* config, + const uint32_t* input, uint64_t* output); + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_SIGNAL_FILTER_BANK_RVV_H_ \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/signal/rfft.cc b/tensorflow/lite/micro/kernels/riscv_vector/signal/rfft.cc new file mode 100644 index 00000000000..b50f082faa5 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/signal/rfft.cc @@ -0,0 +1,241 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "signal/src/rfft.h" + +#include +#include +#include + +#include "signal/micro/kernels/rfft.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/flatbuffer_utils.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/portable_type_to_tflitetype.h" + +#include "rfft_int16_rvv.h" + +namespace tflite { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +// Indices into the init flexbuffer's vector. +// The parameter's name is in the comment that follows. +// Elements in the vectors are ordered alphabetically by parameter name. +// 'T' is added implicitly by the TensorFlow framework when the type is resolved +// during graph construction. +// constexpr int kTypeIndex = 0; // 'T' (unused) +constexpr int kFftLengthIndex = 1; // 'fft_length' + +template +struct TfLiteAudioFrontendRfftParams { + int32_t fft_length; + int32_t input_size; + int32_t input_length; + int32_t output_length; + TfLiteType fft_type; + T* work_area; + int scratch_buffer_index; + int8_t* state; +}; + +template +void* RfftInit(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + + const uint8_t* buffer_t = reinterpret_cast(buffer); + auto* params = static_cast*>( + context->AllocatePersistentBuffer( + context, sizeof(TfLiteAudioFrontendRfftParams))); + + tflite::FlexbufferWrapper fbw(buffer_t, length); + params->fft_length = fbw.ElementAsInt32(kFftLengthIndex); + params->fft_type = typeToTfLiteType(); + + size_t state_size = (*get_needed_memory_func)(params->fft_length); + params->state = static_cast( + context->AllocatePersistentBuffer(context, state_size * sizeof(int8_t))); + (*init_func)(params->fft_length, params->state, state_size); + return params; +} + +template +TfLiteStatus RfftPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + MicroContext* micro_context = GetMicroContext(context); + TfLiteTensor* input = + micro_context->AllocateTempInputTensor(node, kInputTensor); + TF_LITE_ENSURE(context, input != nullptr); + TfLiteTensor* output = + micro_context->AllocateTempOutputTensor(node, kOutputTensor); + TF_LITE_ENSURE(context, output != nullptr); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), NumDimensions(output)); + + TF_LITE_ENSURE_TYPES_EQ(context, input->type, TfLiteTypeEnum); + TF_LITE_ENSURE_TYPES_EQ(context, output->type, TfLiteTypeEnum); + + auto* params = + reinterpret_cast*>(node->user_data); + RuntimeShape input_shape = GetTensorShape(input); + RuntimeShape output_shape = GetTensorShape(output); + params->input_length = input_shape.Dims(input_shape.DimensionsCount() - 1); + params->input_size = input_shape.FlatSize(); + // Divide by 2 because output is complex. + params->output_length = + output_shape.Dims(output_shape.DimensionsCount() - 1) / 2; + + context->RequestScratchBufferInArena(context, params->fft_length * sizeof(T), + ¶ms->scratch_buffer_index); + micro_context->DeallocateTempTfLiteTensor(input); + micro_context->DeallocateTempTfLiteTensor(output); + return kTfLiteOk; +} + +template *)> +TfLiteStatus RfftEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast*>(node->user_data); + + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + + const T* input_data = tflite::micro::GetTensorData(input); + + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + Complex* output_data = tflite::micro::GetTensorData>(output); + + T* work_area = static_cast( + context->GetScratchBuffer(context, params->scratch_buffer_index)); + + for (int input_idx = 0, output_idx = 0; input_idx < params->input_size; + input_idx += params->input_length, output_idx += params->output_length) { + memcpy(work_area, &input_data[input_idx], sizeof(T) * params->input_length); + // Zero pad input to FFT length + memset(&work_area[params->input_length], 0, + sizeof(T) * (params->fft_length - params->input_length)); + + (*apply_func)(params->state, work_area, &output_data[output_idx]); + } + return kTfLiteOk; +} + +void* RfftInitAll(TfLiteContext* context, const char* buffer, size_t length) { + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + auto tensor_type = static_cast(m["T"].AsInt32()); + + switch (tensor_type) { + case TensorType_INT16: { + return RfftInit(context, buffer, length); + } + case TensorType_INT32: { + return RfftInit(context, buffer, length); + } + case TensorType_FLOAT32: { + return RfftInit(context, buffer, length); + } + default: + return nullptr; + } +} + +TfLiteStatus RfftPrepareAll(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast*>(node->user_data); + + switch (params->fft_type) { + case kTfLiteInt16: { + return RfftPrepare(context, node); + } + case kTfLiteInt32: { + return RfftPrepare(context, node); + } + case kTfLiteFloat32: { + return RfftPrepare(context, node); + } + default: + return kTfLiteError; + } +} + +TfLiteStatus RfftEvalAll(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast*>(node->user_data); + + switch (params->fft_type) { + case kTfLiteInt16: { + return RfftEval(context, node); + } + case kTfLiteInt32: { + return RfftEval(context, node); + } + case kTfLiteFloat32: { + return RfftEval(context, node); + } + default: + return kTfLiteError; + } +} +} // namespace + +// TODO(b/286250473): remove namespace once de-duped libraries +namespace tflm_signal { + +TFLMRegistration* Register_RFFT() { + static TFLMRegistration r = + tflite::micro::RegisterOp(RfftInitAll, RfftPrepareAll, RfftEvalAll); + return &r; +} + +TFLMRegistration* Register_RFFT_FLOAT() { + static TFLMRegistration r = tflite::micro::RegisterOp( + RfftInit, + RfftPrepare, + RfftEval); + return &r; +} + +TFLMRegistration* Register_RFFT_INT16() { + static TFLMRegistration r = tflite::micro::RegisterOp( + RfftInit, + RfftPrepare, + RfftEval); + return &r; +} + +TFLMRegistration* Register_RFFT_INT32() { + static TFLMRegistration r = tflite::micro::RegisterOp( + RfftInit, + RfftPrepare, + RfftEval); + return &r; +} + +} // namespace tflm_signal +} // namespace tflite \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/signal/rfft_int16_rvv.cc b/tensorflow/lite/micro/kernels/riscv_vector/signal/rfft_int16_rvv.cc new file mode 100644 index 00000000000..740730a02a8 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/signal/rfft_int16_rvv.cc @@ -0,0 +1,869 @@ +#include + +#include "signal/src/complex.h" +#include "signal/src/kiss_fft_wrappers/kiss_fft_int16.h" +#include "signal/src/rfft.h" +#include "signal/src/kiss_fft_wrappers/kiss_fft_common.h" + +#define FIXED_POINT 16 + +#include "kiss_fft.h" +#include "tools/kiss_fftr.h" + +namespace kiss_fft_fixed16 { +#include "_kiss_fft_guts.h" +struct kiss_fftr_state{ + kiss_fft_cfg substate; + kiss_fft_cpx * tmpbuf; + kiss_fft_cpx * super_twiddles; +#ifdef USE_SIMD + void * pad; +#endif +}; +} + +static void kf_bfly2_rvv(kiss_fft_fixed16::kiss_fft_cpx* Fout, + const size_t fstride, + const kiss_fft_fixed16::kiss_fft_cfg st, size_t m) +{ + // Initialize pointers and constants + kiss_fft_fixed16::kiss_fft_cpx* Fout2 = Fout + m; + const int16_t* tw1_base = (const int16_t*)st->twiddles; + int16_t* Fout_base = (int16_t*)Fout; + int16_t* Fout2_base = (int16_t*)Fout2; + ptrdiff_t cpx_stride = sizeof(kiss_fft_fixed16::kiss_fft_cpx); + ptrdiff_t tw_stride = fstride * cpx_stride; + const int16_t scale = 16383; + const int32_t round_const = 16384; + + // Main processing loop + size_t k = 0; + while (k < m) + { + // Set the vector length for this iteration (LMUL=2) + size_t vl = __riscv_vsetvl_e16m2(m - k); + + // Load input data vectors + vint16m2_t v_fout_r = + __riscv_vlse16_v_i16m2(Fout_base + 2 * k, cpx_stride, vl); + vint16m2_t v_fout_i = + __riscv_vlse16_v_i16m2(Fout_base + 2 * k + 1, cpx_stride, vl); + vint16m2_t v_fout2_r = + __riscv_vlse16_v_i16m2(Fout2_base + 2 * k, cpx_stride, vl); + vint16m2_t v_fout2_i = + __riscv_vlse16_v_i16m2(Fout2_base + 2 * k + 1, cpx_stride, vl); + + // Load twiddle factor vectors + vint16m2_t v_tw_r = + __riscv_vlse16_v_i16m2(tw1_base + (k * fstride * 2), tw_stride, vl); + vint16m2_t v_tw_i = + __riscv_vlse16_v_i16m2(tw1_base + (k * fstride * 2) + 1, tw_stride, vl); + + // Perform rounding division by 2 on input data + vint32m4_t v_fout_r_32 = __riscv_vsra_vx_i32m4( + __riscv_vadd_vx_i32m4(__riscv_vwmul_vx_i32m4(v_fout_r, scale, vl), + round_const, vl), + 15, vl); + vint32m4_t v_fout_i_32 = __riscv_vsra_vx_i32m4( + __riscv_vadd_vx_i32m4(__riscv_vwmul_vx_i32m4(v_fout_i, scale, vl), + round_const, vl), + 15, vl); + vint16m2_t v_fout_r_div2 = + __riscv_vnclip_wx_i16m2(v_fout_r_32, 0, __RISCV_VXRM_RNU, vl); + vint16m2_t v_fout_i_div2 = + __riscv_vnclip_wx_i16m2(v_fout_i_32, 0, __RISCV_VXRM_RNU, vl); + vint32m4_t v_fout2_r_32 = __riscv_vsra_vx_i32m4( + __riscv_vadd_vx_i32m4(__riscv_vwmul_vx_i32m4(v_fout2_r, scale, vl), + round_const, vl), + 15, vl); + vint32m4_t v_fout2_i_32 = __riscv_vsra_vx_i32m4( + __riscv_vadd_vx_i32m4(__riscv_vwmul_vx_i32m4(v_fout2_i, scale, vl), + round_const, vl), + 15, vl); + vint16m2_t v_fout2_r_div2 = + __riscv_vnclip_wx_i16m2(v_fout2_r_32, 0, __RISCV_VXRM_RNU, vl); + vint16m2_t v_fout2_i_div2 = + __riscv_vnclip_wx_i16m2(v_fout2_i_32, 0, __RISCV_VXRM_RNU, vl); + + // Perform complex multiplication: t = Fout2 * tw + vint32m4_t v_ac = __riscv_vwmul_vv_i32m4(v_fout2_r_div2, v_tw_r, vl); + vint32m4_t v_bd = __riscv_vwmul_vv_i32m4(v_fout2_i_div2, v_tw_i, vl); + vint32m4_t v_ad = __riscv_vwmul_vv_i32m4(v_fout2_r_div2, v_tw_i, vl); + vint32m4_t v_bc = __riscv_vwmul_vv_i32m4(v_fout2_i_div2, v_tw_r, vl); + vint32m4_t v_t_r_32 = __riscv_vssra_vx_i32m4( + __riscv_vsub_vv_i32m4(v_ac, v_bd, vl), 15, __RISCV_VXRM_RNU, vl); + vint32m4_t v_t_i_32 = __riscv_vssra_vx_i32m4( + __riscv_vadd_vv_i32m4(v_ad, v_bc, vl), 15, __RISCV_VXRM_RNU, vl); + vint16m2_t v_t_r = __riscv_vnclip_wx_i16m2(v_t_r_32, 0, __RISCV_VXRM_RNU, vl); + vint16m2_t v_t_i = __riscv_vnclip_wx_i16m2(v_t_i_32, 0, __RISCV_VXRM_RNU, vl); + + // Calculate butterfly outputs: Fout = Fout + t and Fout2 = Fout - t + vint16m2_t v_res_fout2_r = __riscv_vsub_vv_i16m2(v_fout_r_div2, v_t_r, vl); + vint16m2_t v_res_fout2_i = __riscv_vsub_vv_i16m2(v_fout_i_div2, v_t_i, vl); + vint16m2_t v_res_fout_r = __riscv_vadd_vv_i16m2(v_fout_r_div2, v_t_r, vl); + vint16m2_t v_res_fout_i = __riscv_vadd_vv_i16m2(v_fout_i_div2, v_t_i, vl); + + // Store results + __riscv_vsse16_v_i16m2(Fout_base + 2 * k, cpx_stride, v_res_fout_r, vl); + __riscv_vsse16_v_i16m2(Fout_base + 2 * k + 1, cpx_stride, v_res_fout_i, vl); + __riscv_vsse16_v_i16m2(Fout2_base + 2 * k, cpx_stride, v_res_fout2_r, vl); + __riscv_vsse16_v_i16m2(Fout2_base + 2 * k + 1, cpx_stride, v_res_fout2_i, vl); + + // Advance loop counter + k += vl; + } +} + +static void kf_bfly4_rvv(kiss_fft_fixed16::kiss_fft_cpx* Fout, + const size_t fstride, + const kiss_fft_fixed16::kiss_fft_cfg st, + const size_t m) +{ + // Initialize pointers and constants + const size_t m2 = 2 * m; + const size_t m3 = 3 * m; + + int16_t* Fout0_base = (int16_t*)(Fout); + int16_t* Fout1_base = (int16_t*)(Fout + m); + int16_t* Fout2_base = (int16_t*)(Fout + m2); + int16_t* Fout3_base = (int16_t*)(Fout + m3); + const int16_t* tw_base = (const int16_t*)st->twiddles; + + ptrdiff_t cpx_stride = sizeof(kiss_fft_fixed16::kiss_fft_cpx); + ptrdiff_t tw1_stride = fstride * cpx_stride; + ptrdiff_t tw2_stride = fstride * 2 * cpx_stride; + ptrdiff_t tw3_stride = fstride * 3 * cpx_stride; + + const int16_t scale = 8191; + const int32_t round_const = 16384; + + // Main processing loop + size_t k = 0; + while (k < m) + { + // Set the vector length for this iteration (LMUL=1) + size_t vl = __riscv_vsetvl_e16m1(m - k); + + // Load input data vectors + vint16m1_t v_f0_r = + __riscv_vlse16_v_i16m1(Fout0_base + 2 * k, cpx_stride, vl); + vint16m1_t v_f0_i = + __riscv_vlse16_v_i16m1(Fout0_base + 2 * k + 1, cpx_stride, vl); + vint16m1_t v_f1_r = + __riscv_vlse16_v_i16m1(Fout1_base + 2 * k, cpx_stride, vl); + vint16m1_t v_f1_i = + __riscv_vlse16_v_i16m1(Fout1_base + 2 * k + 1, cpx_stride, vl); + vint16m1_t v_f2_r = + __riscv_vlse16_v_i16m1(Fout2_base + 2 * k, cpx_stride, vl); + vint16m1_t v_f2_i = + __riscv_vlse16_v_i16m1(Fout2_base + 2 * k + 1, cpx_stride, vl); + vint16m1_t v_f3_r = + __riscv_vlse16_v_i16m1(Fout3_base + 2 * k, cpx_stride, vl); + vint16m1_t v_f3_i = + __riscv_vlse16_v_i16m1(Fout3_base + 2 * k + 1, cpx_stride, vl); + + // Perform rounding division by 4 on input data + vint16m1_t v_f0d_r = __riscv_vnclip_wx_i16m1( + __riscv_vsra_vx_i32m2( + __riscv_vadd_vx_i32m2( + __riscv_vwmul_vx_i32m2(v_f0_r, scale, vl), round_const, vl), + 15, vl), + 0, __RISCV_VXRM_RNU, vl); + vint16m1_t v_f0d_i = __riscv_vnclip_wx_i16m1( + __riscv_vsra_vx_i32m2( + __riscv_vadd_vx_i32m2( + __riscv_vwmul_vx_i32m2(v_f0_i, scale, vl), round_const, vl), + 15, vl), + 0, __RISCV_VXRM_RNU, vl); + vint16m1_t v_f1d_r = __riscv_vnclip_wx_i16m1( + __riscv_vsra_vx_i32m2( + __riscv_vadd_vx_i32m2( + __riscv_vwmul_vx_i32m2(v_f1_r, scale, vl), round_const, vl), + 15, vl), + 0, __RISCV_VXRM_RNU, vl); + vint16m1_t v_f1d_i = __riscv_vnclip_wx_i16m1( + __riscv_vsra_vx_i32m2( + __riscv_vadd_vx_i32m2( + __riscv_vwmul_vx_i32m2(v_f1_i, scale, vl), round_const, vl), + 15, vl), + 0, __RISCV_VXRM_RNU, vl); + vint16m1_t v_f2d_r = __riscv_vnclip_wx_i16m1( + __riscv_vsra_vx_i32m2( + __riscv_vadd_vx_i32m2( + __riscv_vwmul_vx_i32m2(v_f2_r, scale, vl), round_const, vl), + 15, vl), + 0, __RISCV_VXRM_RNU, vl); + vint16m1_t v_f2d_i = __riscv_vnclip_wx_i16m1( + __riscv_vsra_vx_i32m2( + __riscv_vadd_vx_i32m2( + __riscv_vwmul_vx_i32m2(v_f2_i, scale, vl), round_const, vl), + 15, vl), + 0, __RISCV_VXRM_RNU, vl); + vint16m1_t v_f3d_r = __riscv_vnclip_wx_i16m1( + __riscv_vsra_vx_i32m2( + __riscv_vadd_vx_i32m2( + __riscv_vwmul_vx_i32m2(v_f3_r, scale, vl), round_const, vl), + 15, vl), + 0, __RISCV_VXRM_RNU, vl); + vint16m1_t v_f3d_i = __riscv_vnclip_wx_i16m1( + __riscv_vsra_vx_i32m2( + __riscv_vadd_vx_i32m2( + __riscv_vwmul_vx_i32m2(v_f3_i, scale, vl), round_const, vl), + 15, vl), + 0, __RISCV_VXRM_RNU, vl); + + // Load twiddle factor vectors + vint16m1_t v_tw1_r = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 2), tw1_stride, vl); + vint16m1_t v_tw1_i = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 2) + 1, tw1_stride, vl); + vint16m1_t v_tw2_r = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 4), tw2_stride, vl); + vint16m1_t v_tw2_i = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 4) + 1, tw2_stride, vl); + vint16m1_t v_tw3_r = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 6), tw3_stride, vl); + vint16m1_t v_tw3_i = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 6) + 1, tw3_stride, vl); + + // Perform complex multiplications + vint16m1_t v_s0_r, v_s0_i, v_s1_r, v_s1_i, v_s2_r, v_s2_i; + do + { + vint32m2_t ac = __riscv_vwmul_vv_i32m2(v_f1d_r, v_tw1_r, vl); + vint32m2_t bd = __riscv_vwmul_vv_i32m2(v_f1d_i, v_tw1_i, vl); + vint32m2_t ad = __riscv_vwmul_vv_i32m2(v_f1d_r, v_tw1_i, vl); + vint32m2_t bc = __riscv_vwmul_vv_i32m2(v_f1d_i, v_tw1_r, vl); + v_s0_r = __riscv_vnclip_wx_i16m1(__riscv_vssra_vx_i32m2( + __riscv_vsub_vv_i32m2(ac, bd, vl), 15, __RISCV_VXRM_RNU, vl), + 0, __RISCV_VXRM_RNU, vl); + v_s0_i = __riscv_vnclip_wx_i16m1(__riscv_vssra_vx_i32m2( + __riscv_vadd_vv_i32m2(ad, bc, vl), 15, __RISCV_VXRM_RNU, vl), + 0, __RISCV_VXRM_RNU, vl); + } while (0); + + do + { + vint32m2_t ac = __riscv_vwmul_vv_i32m2(v_f2d_r, v_tw2_r, vl); + vint32m2_t bd = __riscv_vwmul_vv_i32m2(v_f2d_i, v_tw2_i, vl); + vint32m2_t ad = __riscv_vwmul_vv_i32m2(v_f2d_r, v_tw2_i, vl); + vint32m2_t bc = __riscv_vwmul_vv_i32m2(v_f2d_i, v_tw2_r, vl); + v_s1_r = __riscv_vnclip_wx_i16m1(__riscv_vssra_vx_i32m2( + __riscv_vsub_vv_i32m2(ac, bd, vl), 15, __RISCV_VXRM_RNU, vl), + 0, __RISCV_VXRM_RNU, vl); + v_s1_i = __riscv_vnclip_wx_i16m1(__riscv_vssra_vx_i32m2( + __riscv_vadd_vv_i32m2(ad, bc, vl), 15, __RISCV_VXRM_RNU, vl), + 0, __RISCV_VXRM_RNU, vl); + } while (0); + + do + { + vint32m2_t ac = __riscv_vwmul_vv_i32m2(v_f3d_r, v_tw3_r, vl); + vint32m2_t bd = __riscv_vwmul_vv_i32m2(v_f3d_i, v_tw3_i, vl); + vint32m2_t ad = __riscv_vwmul_vv_i32m2(v_f3d_r, v_tw3_i, vl); + vint32m2_t bc = __riscv_vwmul_vv_i32m2(v_f3d_i, v_tw3_r, vl); + v_s2_r = __riscv_vnclip_wx_i16m1(__riscv_vssra_vx_i32m2( + __riscv_vsub_vv_i32m2(ac, bd, vl), 15, __RISCV_VXRM_RNU, vl), + 0, __RISCV_VXRM_RNU, vl); + v_s2_i = __riscv_vnclip_wx_i16m1(__riscv_vssra_vx_i32m2( + __riscv_vadd_vv_i32m2(ad, bc, vl), 15, __RISCV_VXRM_RNU, vl), + 0, __RISCV_VXRM_RNU, vl); + } while (0); + + // Calculate intermediate butterfly values + vint16m1_t v_s5_r = __riscv_vsub_vv_i16m1(v_f0d_r, v_s1_r, vl); + vint16m1_t v_s5_i = __riscv_vsub_vv_i16m1(v_f0d_i, v_s1_i, vl); + vint16m1_t v_f0d_plus_s1_r = __riscv_vadd_vv_i16m1(v_f0d_r, v_s1_r, vl); + vint16m1_t v_f0d_plus_s1_i = __riscv_vadd_vv_i16m1(v_f0d_i, v_s1_i, vl); + vint16m1_t v_s3_r = __riscv_vadd_vv_i16m1(v_s0_r, v_s2_r, vl); + vint16m1_t v_s3_i = __riscv_vadd_vv_i16m1(v_s0_i, v_s2_i, vl); + vint16m1_t v_s4_r = __riscv_vsub_vv_i16m1(v_s0_r, v_s2_r, vl); + vint16m1_t v_s4_i = __riscv_vsub_vv_i16m1(v_s0_i, v_s2_i, vl); + vint16m1_t v_res_f0_r = __riscv_vadd_vv_i16m1(v_f0d_plus_s1_r, v_s3_r, vl); + vint16m1_t v_res_f0_i = __riscv_vadd_vv_i16m1(v_f0d_plus_s1_i, v_s3_i, vl); + vint16m1_t v_res_f2_r = __riscv_vsub_vv_i16m1(v_f0d_plus_s1_r, v_s3_r, vl); + vint16m1_t v_res_f2_i = __riscv_vsub_vv_i16m1(v_f0d_plus_s1_i, v_s3_i, vl); + + // Calculate final results, handling inverse case + vint16m1_t v_res_f1_r, v_res_f1_i, v_res_f3_r, v_res_f3_i; + if (st->inverse) + { + v_res_f1_r = __riscv_vsub_vv_i16m1(v_s5_r, v_s4_i, vl); + v_res_f1_i = __riscv_vadd_vv_i16m1(v_s5_i, v_s4_r, vl); + v_res_f3_r = __riscv_vadd_vv_i16m1(v_s5_r, v_s4_i, vl); + v_res_f3_i = __riscv_vsub_vv_i16m1(v_s5_i, v_s4_r, vl); + } + else + { + v_res_f1_r = __riscv_vadd_vv_i16m1(v_s5_r, v_s4_i, vl); + v_res_f1_i = __riscv_vsub_vv_i16m1(v_s5_i, v_s4_r, vl); + v_res_f3_r = __riscv_vsub_vv_i16m1(v_s5_r, v_s4_i, vl); + v_res_f3_i = __riscv_vadd_vv_i16m1(v_s5_i, v_s4_r, vl); + } + + // Store final results + __riscv_vsse16_v_i16m1(Fout0_base + 2 * k, cpx_stride, v_res_f0_r, vl); + __riscv_vsse16_v_i16m1(Fout0_base + 2 * k + 1, cpx_stride, v_res_f0_i, vl); + __riscv_vsse16_v_i16m1(Fout1_base + 2 * k, cpx_stride, v_res_f1_r, vl); + __riscv_vsse16_v_i16m1(Fout1_base + 2 * k + 1, cpx_stride, v_res_f1_i, vl); + __riscv_vsse16_v_i16m1(Fout2_base + 2 * k, cpx_stride, v_res_f2_r, vl); + __riscv_vsse16_v_i16m1(Fout2_base + 2 * k + 1, cpx_stride, v_res_f2_i, vl); + __riscv_vsse16_v_i16m1(Fout3_base + 2 * k, cpx_stride, v_res_f3_r, vl); + __riscv_vsse16_v_i16m1(Fout3_base + 2 * k + 1, cpx_stride, v_res_f3_i, vl); + + // Advance loop counter + k += vl; + } +} + +static void kf_bfly3_rvv(kiss_fft_fixed16::kiss_fft_cpx* Fout, + const size_t fstride, + const kiss_fft_fixed16::kiss_fft_cfg st, size_t m) +{ + // Initialize pointers and constants + kiss_fft_fixed16::kiss_fft_cpx* Fout1 = Fout + m; + kiss_fft_fixed16::kiss_fft_cpx* Fout2 = Fout + m * 2; + const int16_t* tw1_base = (const int16_t*)st->twiddles; + const int16_t* tw2_base = tw1_base; + const int16_t tw3i = -28378; // Q15 value for sin(-2*pi/3) + int16_t* Fout0_base = (int16_t*)Fout; + int16_t* Fout1_base = (int16_t*)Fout1; + int16_t* Fout2_base = (int16_t*)Fout2; + ptrdiff_t cpx_stride = sizeof(kiss_fft_fixed16::kiss_fft_cpx); + ptrdiff_t tw1_stride = fstride * cpx_stride; + ptrdiff_t tw2_stride = fstride * 2 * cpx_stride; + + // Main processing loop + size_t k = 0; + while (k < m) + { + // Set the vector length for this iteration (LMUL=1) + size_t vl = __riscv_vsetvl_e16m1(m - k); + + // Load input data vectors + vint16m1_t v_f0_r = + __riscv_vlse16_v_i16m1(Fout0_base + 2 * k, cpx_stride, vl); + vint16m1_t v_f0_i = + __riscv_vlse16_v_i16m1(Fout0_base + 2 * k + 1, cpx_stride, vl); + vint16m1_t v_f1_r = + __riscv_vlse16_v_i16m1(Fout1_base + 2 * k, cpx_stride, vl); + vint16m1_t v_f1_i = + __riscv_vlse16_v_i16m1(Fout1_base + 2 * k + 1, cpx_stride, vl); + vint16m1_t v_f2_r = + __riscv_vlse16_v_i16m1(Fout2_base + 2 * k, cpx_stride, vl); + vint16m1_t v_f2_i = + __riscv_vlse16_v_i16m1(Fout2_base + 2 * k + 1, cpx_stride, vl); + + // Load twiddle factor vectors + vint16m1_t v_tw1_r = + __riscv_vlse16_v_i16m1(tw1_base + (k * fstride * 2), tw1_stride, vl); + vint16m1_t v_tw1_i = + __riscv_vlse16_v_i16m1(tw1_base + (k * fstride * 2) + 1, tw1_stride, vl); + vint16m1_t v_tw2_r = + __riscv_vlse16_v_i16m1(tw2_base + (k * fstride * 4), tw2_stride, vl); + vint16m1_t v_tw2_i = + __riscv_vlse16_v_i16m1(tw2_base + (k * fstride * 4) + 1, tw2_stride, vl); + + // Perform complex multiplications: v_s0 = v_f1 * v_tw1 + vint32m2_t v_ac0 = __riscv_vwmul_vv_i32m2(v_f1_r, v_tw1_r, vl); + vint32m2_t v_bd0 = __riscv_vwmul_vv_i32m2(v_f1_i, v_tw1_i, vl); + vint32m2_t v_ad0 = __riscv_vwmul_vv_i32m2(v_f1_r, v_tw1_i, vl); + vint32m2_t v_bc0 = __riscv_vwmul_vv_i32m2(v_f1_i, v_tw1_r, vl); + vint16m1_t v_s0_r = __riscv_vnclip_wx_i16m1( + __riscv_vssra_vx_i32m2(__riscv_vsub_vv_i32m2(v_ac0, v_bd0, vl), 15, + __RISCV_VXRM_RNU, vl), + 0, __RISCV_VXRM_RNU, vl); + vint16m1_t v_s0_i = __riscv_vnclip_wx_i16m1( + __riscv_vssra_vx_i32m2(__riscv_vadd_vv_i32m2(v_ad0, v_bc0, vl), 15, + __RISCV_VXRM_RNU, vl), + 0, __RISCV_VXRM_RNU, vl); + + // Perform complex multiplications + vint32m2_t v_ac1 = __riscv_vwmul_vv_i32m2(v_f2_r, v_tw2_r, vl); + vint32m2_t v_bd1 = __riscv_vwmul_vv_i32m2(v_f2_i, v_tw2_i, vl); + vint32m2_t v_ad1 = __riscv_vwmul_vv_i32m2(v_f2_r, v_tw2_i, vl); + vint32m2_t v_bc1 = __riscv_vwmul_vv_i32m2(v_f2_i, v_tw2_r, vl); + vint16m1_t v_s1_r = __riscv_vnclip_wx_i16m1( + __riscv_vssra_vx_i32m2(__riscv_vsub_vv_i32m2(v_ac1, v_bd1, vl), 15, + __RISCV_VXRM_RNU, vl), + 0, __RISCV_VXRM_RNU, vl); + vint16m1_t v_s1_i = __riscv_vnclip_wx_i16m1( + __riscv_vssra_vx_i32m2(__riscv_vadd_vv_i32m2(v_ad1, v_bc1, vl), 15, + __RISCV_VXRM_RNU, vl), + 0, __RISCV_VXRM_RNU, vl); + + // Calculate intermediate butterfly values + vint16m1_t v_s_add_r = __riscv_vadd_vv_i16m1(v_s0_r, v_s1_r, vl); + vint16m1_t v_s_add_i = __riscv_vadd_vv_i16m1(v_s0_i, v_s1_i, vl); + vint16m1_t v_s_sub_r = __riscv_vsub_vv_i16m1(v_s0_r, v_s1_r, vl); + vint16m1_t v_s_sub_i = __riscv_vsub_vv_i16m1(v_s0_i, v_s1_i, vl); + + // Calculate Fout0 = Fout0 + s_add + vint16m1_t v_res_f0_r = __riscv_vadd_vv_i16m1(v_f0_r, v_s_add_r, vl); + vint16m1_t v_res_f0_i = __riscv_vadd_vv_i16m1(v_f0_i, v_s_add_i, vl); + + // Calculate remaining outputs using rotations + vint16m1_t v_s_add_r_neg_half = + __riscv_vneg_v_i16m1(__riscv_vsra_vx_i16m1(v_s_add_r, 1, vl), vl); + vint16m1_t v_s_add_i_neg_half = + __riscv_vneg_v_i16m1(__riscv_vsra_vx_i16m1(v_s_add_i, 1, vl), vl); + vint32m2_t v_s_sub_i_mul_tw3i = __riscv_vwmul_vx_i32m2(v_s_sub_i, tw3i, vl); + vint32m2_t v_s_sub_r_mul_tw3i = __riscv_vwmul_vx_i32m2(v_s_sub_r, tw3i, vl); + vint16m1_t v_s_sub_i_scaled = __riscv_vnclip_wx_i16m1( + __riscv_vssra_vx_i32m2(v_s_sub_i_mul_tw3i, 15, __RISCV_VXRM_RNU, vl), 0, + __RISCV_VXRM_RNU, vl); + vint16m1_t v_s_sub_r_scaled = __riscv_vnclip_wx_i16m1( + __riscv_vssra_vx_i32m2(v_s_sub_r_mul_tw3i, 15, __RISCV_VXRM_RNU, vl), 0, + __RISCV_VXRM_RNU, vl); + vint16m1_t v_tmp_r1 = __riscv_vadd_vv_i16m1(v_f0_r, v_s_add_r_neg_half, vl); + vint16m1_t v_res_f1_r = __riscv_vsub_vv_i16m1(v_tmp_r1, v_s_sub_i_scaled, vl); + vint16m1_t v_tmp_i1 = __riscv_vadd_vv_i16m1(v_f0_i, v_s_add_i_neg_half, vl); + vint16m1_t v_res_f1_i = __riscv_vadd_vv_i16m1(v_tmp_i1, v_s_sub_r_scaled, vl); + vint16m1_t v_res_f2_r = __riscv_vadd_vv_i16m1(v_tmp_r1, v_s_sub_i_scaled, vl); + vint16m1_t v_res_f2_i = __riscv_vsub_vv_i16m1(v_tmp_i1, v_s_sub_r_scaled, vl); + + // Store results + __riscv_vsse16_v_i16m1(Fout0_base + 2 * k, cpx_stride, v_res_f0_r, vl); + __riscv_vsse16_v_i16m1(Fout0_base + 2 * k + 1, cpx_stride, v_res_f0_i, vl); + __riscv_vsse16_v_i16m1(Fout1_base + 2 * k, cpx_stride, v_res_f1_r, vl); + __riscv_vsse16_v_i16m1(Fout1_base + 2 * k + 1, cpx_stride, v_res_f1_i, vl); + __riscv_vsse16_v_i16m1(Fout2_base + 2 * k, cpx_stride, v_res_f2_r, vl); + __riscv_vsse16_v_i16m1(Fout2_base + 2 * k + 1, cpx_stride, v_res_f2_i, vl); + + // Advance loop counter + k += vl; + } +} + +static void kf_bfly5_rvv(kiss_fft_fixed16::kiss_fft_cpx* Fout, + const size_t fstride, + const kiss_fft_fixed16::kiss_fft_cfg st, size_t m) +{ + // Initialize pointers and constants + kiss_fft_fixed16::kiss_fft_cpx *Fout0, *Fout1, *Fout2, *Fout3, *Fout4; + const int16_t* tw_base = (const int16_t*)st->twiddles; + const int16_t ya1 = 19021; // Q15 value for cos(2*pi/5) + const int16_t yb1 = 31164; // Q15 value for sin(2*pi/5) + const int16_t ya2 = -30777; // Q15 value for cos(4*pi/5) + const int16_t yb2 = 19021; // Q15 value for sin(4*pi/5) + + Fout0 = Fout; + Fout1 = Fout + m; + Fout2 = Fout + 2 * m; + Fout3 = Fout + 3 * m; + Fout4 = Fout + 4 * m; + + int16_t* Fout0_base = (int16_t*)Fout0; + int16_t* Fout1_base = (int16_t*)Fout1; + int16_t* Fout2_base = (int16_t*)Fout2; + int16_t* Fout3_base = (int16_t*)Fout3; + int16_t* Fout4_base = (int16_t*)Fout4; + + ptrdiff_t cpx_stride = sizeof(kiss_fft_fixed16::kiss_fft_cpx); + ptrdiff_t tw1_stride = fstride * cpx_stride; + ptrdiff_t tw2_stride = 2 * tw1_stride; + ptrdiff_t tw3_stride = 3 * tw1_stride; + ptrdiff_t tw4_stride = 4 * tw1_stride; + + // Main processing loop + size_t k = 0; + while (k < m) + { + // Set the vector length for this iteration + size_t vl = __riscv_vsetvl_e16m1(m - k); + + // Load input data vectors + vint16m1_t v_f0_r = + __riscv_vlse16_v_i16m1(Fout0_base + 2 * k, cpx_stride, vl); + vint16m1_t v_f0_i = + __riscv_vlse16_v_i16m1(Fout0_base + 2 * k + 1, cpx_stride, vl); + vint16m1_t v_f1_r = + __riscv_vlse16_v_i16m1(Fout1_base + 2 * k, cpx_stride, vl); + vint16m1_t v_f1_i = + __riscv_vlse16_v_i16m1(Fout1_base + 2 * k + 1, cpx_stride, vl); + vint16m1_t v_f2_r = + __riscv_vlse16_v_i16m1(Fout2_base + 2 * k, cpx_stride, vl); + vint16m1_t v_f2_i = + __riscv_vlse16_v_i16m1(Fout2_base + 2 * k + 1, cpx_stride, vl); + vint16m1_t v_f3_r = + __riscv_vlse16_v_i16m1(Fout3_base + 2 * k, cpx_stride, vl); + vint16m1_t v_f3_i = + __riscv_vlse16_v_i16m1(Fout3_base + 2 * k + 1, cpx_stride, vl); + vint16m1_t v_f4_r = + __riscv_vlse16_v_i16m1(Fout4_base + 2 * k, cpx_stride, vl); + vint16m1_t v_f4_i = + __riscv_vlse16_v_i16m1(Fout4_base + 2 * k + 1, cpx_stride, vl); + + // Load twiddle factor vectors + vint16m1_t v_tw1_r = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 2), tw1_stride, vl); + vint16m1_t v_tw1_i = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 2) + 1, tw1_stride, vl); + vint16m1_t v_tw2_r = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 4), tw2_stride, vl); + vint16m1_t v_tw2_i = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 4) + 1, tw2_stride, vl); + vint16m1_t v_tw3_r = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 6), tw3_stride, vl); + vint16m1_t v_tw3_i = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 6) + 1, tw3_stride, vl); + vint16m1_t v_tw4_r = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 8), tw4_stride, vl); + vint16m1_t v_tw4_i = + __riscv_vlse16_v_i16m1(tw_base + (k * fstride * 8) + 1, tw4_stride, vl); + +// Macro for complex multiplication, wrapped in do-while(0) to prevent scope issues +#define C_MUL_VEC(res_r, res_i, f_r, f_i, tw_r, tw_i) \ + do \ + { \ + vint32m2_t ac = __riscv_vwmul_vv_i32m2(f_r, tw_r, vl); \ + vint32m2_t bd = __riscv_vwmul_vv_i32m2(f_i, tw_i, vl); \ + vint32m2_t ad = __riscv_vwmul_vv_i32m2(f_r, tw_i, vl); \ + vint32m2_t bc = __riscv_vwmul_vv_i32m2(f_i, tw_r, vl); \ + res_r = __riscv_vnclip_wx_i16m1( \ + __riscv_vssra_vx_i32m2(__riscv_vsub_vv_i32m2(ac, bd, vl), 15, \ + __RISCV_VXRM_RNU, vl), \ + 0, __RISCV_VXRM_RNU, vl); \ + res_i = __riscv_vnclip_wx_i16m1( \ + __riscv_vssra_vx_i32m2(__riscv_vadd_vv_i32m2(ad, bc, vl), 15, \ + __RISCV_VXRM_RNU, vl), \ + 0, __RISCV_VXRM_RNU, vl); \ + } while (0) + + // Perform complex multiplications + vint16m1_t v_s0_r, v_s0_i, v_s1_r, v_s1_i, v_s2_r, v_s2_i, v_s3_r, v_s3_i; + C_MUL_VEC(v_s0_r, v_s0_i, v_f1_r, v_f1_i, v_tw1_r, v_tw1_i); + C_MUL_VEC(v_s1_r, v_s1_i, v_f2_r, v_f2_i, v_tw2_r, v_tw2_i); + C_MUL_VEC(v_s2_r, v_s2_i, v_f3_r, v_f3_i, v_tw3_r, v_tw3_i); + C_MUL_VEC(v_s3_r, v_s3_i, v_f4_r, v_f4_i, v_tw4_r, v_tw4_i); +#undef C_MUL_VEC + + // Calculate intermediate butterfly values + vint16m1_t v_s03_add_r = __riscv_vadd_vv_i16m1(v_s0_r, v_s3_r, vl); + vint16m1_t v_s03_add_i = __riscv_vadd_vv_i16m1(v_s0_i, v_s3_i, vl); + vint16m1_t v_s03_sub_r = __riscv_vsub_vv_i16m1(v_s0_r, v_s3_r, vl); + vint16m1_t v_s03_sub_i = __riscv_vsub_vv_i16m1(v_s0_i, v_s3_i, vl); + vint16m1_t v_s12_add_r = __riscv_vadd_vv_i16m1(v_s1_r, v_s2_r, vl); + vint16m1_t v_s12_add_i = __riscv_vadd_vv_i16m1(v_s1_i, v_s2_i, vl); + vint16m1_t v_s12_sub_r = __riscv_vsub_vv_i16m1(v_s1_r, v_s2_r, vl); + vint16m1_t v_s12_sub_i = __riscv_vsub_vv_i16m1(v_s1_i, v_s2_i, vl); + + // Calculate Fout0 = f0 + s03_add + s12_add + vint16m1_t v_res_f0_r = __riscv_vadd_vv_i16m1( + v_f0_r, __riscv_vadd_vv_i16m1(v_s03_add_r, v_s12_add_r, vl), vl); + vint16m1_t v_res_f0_i = __riscv_vadd_vv_i16m1( + v_f0_i, __riscv_vadd_vv_i16m1(v_s03_add_i, v_s12_add_i, vl), vl); + +// Macro for scalar multiplication, wrapped in do-while(0) to prevent scope issues +#define S_MUL_VX(res, val, const_val) \ + do \ + { \ + vint32m2_t tmp_mul = __riscv_vwmul_vx_i32m2(val, const_val, vl); \ + res = __riscv_vnclip_wx_i16m1( \ + __riscv_vssra_vx_i32m2(tmp_mul, 15, __RISCV_VXRM_RNU, vl), 0, \ + __RISCV_VXRM_RNU, vl); \ + } while (0) + + // Perform final rotations + vint16m1_t v_tmp1_r, v_tmp1_i, v_tmp2_r, v_tmp2_i; + S_MUL_VX(v_tmp1_r, v_s03_add_r, ya1); + S_MUL_VX(v_tmp1_i, v_s03_add_i, ya1); + S_MUL_VX(v_tmp2_r, v_s12_add_r, ya2); + S_MUL_VX(v_tmp2_i, v_s12_add_i, ya2); + vint16m1_t v_r_part1 = __riscv_vadd_vv_i16m1( + v_f0_r, __riscv_vadd_vv_i16m1(v_tmp1_r, v_tmp2_r, vl), vl); + vint16m1_t v_i_part1 = __riscv_vadd_vv_i16m1( + v_f0_i, __riscv_vadd_vv_i16m1(v_tmp1_i, v_tmp2_i, vl), vl); + S_MUL_VX(v_tmp1_r, v_s03_sub_i, yb1); + S_MUL_VX(v_tmp1_i, v_s03_sub_r, yb1); + S_MUL_VX(v_tmp2_r, v_s12_sub_i, yb2); + S_MUL_VX(v_tmp2_i, v_s12_sub_r, yb2); + vint16m1_t v_r_part2 = __riscv_vsub_vv_i16m1(v_tmp1_r, v_tmp2_r, vl); + vint16m1_t v_i_part2 = __riscv_vadd_vv_i16m1(v_tmp1_i, v_tmp2_i, vl); + + // Calculate final butterfly outputs + vint16m1_t v_res_f1_r = __riscv_vadd_vv_i16m1(v_r_part1, v_r_part2, vl); + vint16m1_t v_res_f1_i = __riscv_vadd_vv_i16m1(v_i_part1, v_i_part2, vl); + vint16m1_t v_res_f4_r = __riscv_vsub_vv_i16m1(v_r_part1, v_r_part2, vl); + vint16m1_t v_res_f4_i = __riscv_vsub_vv_i16m1(v_i_part1, v_i_part2, vl); + v_r_part2 = __riscv_vadd_vv_i16m1(v_tmp1_r, v_tmp2_r, vl); + v_i_part2 = __riscv_vsub_vv_i16m1(v_tmp1_i, v_tmp2_i, vl); + vint16m1_t v_res_f2_r = __riscv_vsub_vv_i16m1(v_r_part1, v_r_part2, vl); + vint16m1_t v_res_f2_i = __riscv_vadd_vv_i16m1(v_i_part1, v_i_part2, vl); + vint16m1_t v_res_f3_r = __riscv_vadd_vv_i16m1(v_r_part1, v_r_part2, vl); + vint16m1_t v_res_f3_i = __riscv_vsub_vv_i16m1(v_i_part1, v_i_part2, vl); +#undef S_MUL_VX + + // Store results + __riscv_vsse16_v_i16m1(Fout0_base + 2 * k, cpx_stride, v_res_f0_r, vl); + __riscv_vsse16_v_i16m1(Fout0_base + 2 * k + 1, cpx_stride, v_res_f0_i, vl); + __riscv_vsse16_v_i16m1(Fout1_base + 2 * k, cpx_stride, v_res_f1_r, vl); + __riscv_vsse16_v_i16m1(Fout1_base + 2 * k + 1, cpx_stride, v_res_f1_i, vl); + __riscv_vsse16_v_i16m1(Fout2_base + 2 * k, cpx_stride, v_res_f2_r, vl); + __riscv_vsse16_v_i16m1(Fout2_base + 2 * k + 1, cpx_stride, v_res_f2_i, vl); + __riscv_vsse16_v_i16m1(Fout3_base + 2 * k, cpx_stride, v_res_f3_r, vl); + __riscv_vsse16_v_i16m1(Fout3_base + 2 * k + 1, cpx_stride, v_res_f3_i, vl); + __riscv_vsse16_v_i16m1(Fout4_base + 2 * k, cpx_stride, v_res_f4_r, vl); + __riscv_vsse16_v_i16m1(Fout4_base + 2 * k + 1, cpx_stride, v_res_f4_i, vl); + + // Advance loop counter + k += vl; + } +} + +// Generic radix implementation copy/pasted from kissfft (kiss_fft.c) +static void kf_bfly_generic( + kiss_fft_fixed16::kiss_fft_cpx * Fout, + const size_t fstride, + const kiss_fft_fixed16::kiss_fft_cfg st, + int m, + int p + ) +{ + int u,k,q1,q; + kiss_fft_fixed16::kiss_fft_cpx * twiddles = st->twiddles; + kiss_fft_fixed16::kiss_fft_cpx t; + int Norig = st->nfft; + + kiss_fft_fixed16::kiss_fft_cpx * scratch = (kiss_fft_fixed16::kiss_fft_cpx*)KISS_FFT_TMP_ALLOC(sizeof(kiss_fft_fixed16::kiss_fft_cpx)*p); + if (scratch == NULL){ + return; + } + + for ( u=0; u=Norig) twidx-=Norig; + C_MUL(t,scratch[q] , twiddles[twidx] ); + C_ADDTO( Fout[ k ] ,t); + } + k += m; + } + } + KISS_FFT_TMP_FREE(scratch); +} + +static void kf_work_rvv(kiss_fft_fixed16::kiss_fft_cpx* Fout, + const kiss_fft_fixed16::kiss_fft_cpx* f, + const size_t fstride, int in_stride, int* factors, + const kiss_fft_fixed16::kiss_fft_cfg st) +{ + // Decompose the problem into factors p and m + const int p = *factors++; + const int m = *factors++; + kiss_fft_fixed16::kiss_fft_cpx* Fout_beg = Fout; + const kiss_fft_fixed16::kiss_fft_cpx* Fout_end = Fout + p * m; + + // Perform recursion for the m-point DFTs + if (m == 1) + { + do + { + *Fout = *f; + f += fstride * in_stride; + } while (++Fout != Fout_end); + } + else + { + do + { + kf_work_rvv(Fout, f, fstride * p, in_stride, factors, st); + f += fstride * in_stride; + } while ((Fout += m) != Fout_end); + } + + // Perform the p-point butterfly operations + Fout = Fout_beg; + switch (p) + { + case 2: + kf_bfly2_rvv(Fout, fstride, st, m); + break; + case 3: + kf_bfly3_rvv(Fout, fstride, st, m); + break; + case 4: + kf_bfly4_rvv(Fout, fstride, st, m); + break; + case 5: + kf_bfly5_rvv(Fout, fstride, st, m); + break; + default: kf_bfly_generic(Fout, fstride, st, m, p); break; + } +} + +void kiss_fft_stride_rvv(kiss_fft_fixed16::kiss_fft_cfg st, const kiss_fft_fixed16::kiss_fft_cpx* fin, + kiss_fft_fixed16::kiss_fft_cpx* fout, int in_stride) +{ + // Handle in-place transform + if (fin == fout) + { + if (fout == NULL) + { + return; + } + + kiss_fft_fixed16::kiss_fft_cpx* tmpbuf = + (kiss_fft_fixed16::kiss_fft_cpx*)KISS_FFT_TMP_ALLOC( + sizeof(kiss_fft_fixed16::kiss_fft_cpx) * st->nfft); + + if (tmpbuf == NULL) + { + return; + } + + kf_work_rvv(tmpbuf, fin, 1, in_stride, st->factors, st); + + memcpy(fout, tmpbuf, sizeof(kiss_fft_fixed16::kiss_fft_cpx) * st->nfft); + + KISS_FFT_TMP_FREE(tmpbuf); + } + else + { + // Handle out-of-place transform + kf_work_rvv(fout, fin, 1, in_stride, st->factors, st); + } +} + +void kiss_fft_rvv(kiss_fft_fixed16::kiss_fft_cfg cfg, const kiss_fft_fixed16::kiss_fft_cpx* fin, kiss_fft_fixed16::kiss_fft_cpx* fout) +{ + kiss_fft_stride_rvv(cfg, fin, fout, 1); +} + +void kiss_fftr_rvv(kiss_fft_fixed16::kiss_fftr_cfg st, const kiss_fft_scalar* timedata, + kiss_fft_fixed16::kiss_fft_cpx* freqdata) +{ + // Handle inverse FFT case and perform the initial complex FFT + if (st->substate->inverse) + { + return; + } + kiss_fft_rvv(st->substate, (const kiss_fft_fixed16::kiss_fft_cpx*)timedata, st->tmpbuf); + + // Process DC and Nyquist bins separately (scalar operations) + const int ncfft = st->substate->nfft; + kiss_fft_fixed16::kiss_fft_cpx tdc; + tdc.r = st->tmpbuf[0].r; + tdc.i = st->tmpbuf[0].i; + C_FIXDIV(tdc, 2); + freqdata[0].r = tdc.r + tdc.i; + freqdata[ncfft].r = tdc.r - tdc.i; + freqdata[0].i = 0; + freqdata[ncfft].i = 0; + + // Initialize pointers and loop variables + size_t k = 1; + const size_t loop_end = ncfft / 2; + const int16_t* tmpbuf_base_ptr = (const int16_t*)st->tmpbuf; + const int16_t* twiddles_base_ptr = (const int16_t*)st->super_twiddles; + int16_t* freqdata_base_ptr = (int16_t*)freqdata; + + // Stride for complex numbers (R, I) is 4 bytes (2 * int16) + ptrdiff_t stride = sizeof(kiss_fft_fixed16::kiss_fft_cpx); + ptrdiff_t neg_stride = -stride; + + // Main loop to process FFT bins in vector chunks + while (k <= loop_end) + { + // Set the vector length (vl) for the current iteration + // Optimization: Reduced to m2 to prevent register spilling + size_t vl = __riscv_vsetvl_e16m2(loop_end - k + 1); + + // fpk indices: k, k+1, ... + vint16m2_t v_fpk_r = __riscv_vlse16_v_i16m2(&tmpbuf_base_ptr[2 * k], stride, vl); + vint16m2_t v_fpk_i = __riscv_vlse16_v_i16m2(&tmpbuf_base_ptr[2 * k + 1], stride, vl); + + // fpnk indices: N-k, N-(k+1), ... + const int16_t* fpnk_ptr = &tmpbuf_base_ptr[2 * (ncfft - k)]; + vint16m2_t v_fpnk_r_raw = __riscv_vlse16_v_i16m2(fpnk_ptr, neg_stride, vl); + vint16m2_t v_fpnk_i_raw = __riscv_vlse16_v_i16m2(fpnk_ptr + 1, neg_stride, vl); + + // Twiddle indices: k-1, k, ... + // Must use strided load to extract only Reals or only Imags from the interleaved array + const int16_t* tw_ptr = &twiddles_base_ptr[2 * (k - 1)]; + vint16m2_t v_tw_r = __riscv_vlse16_v_i16m2(tw_ptr, stride, vl); + vint16m2_t v_tw_i = __riscv_vlse16_v_i16m2(tw_ptr + 1, stride, vl); + + // Perform high-precision rounding division on fpk + const int16_t scale = 16383; + const int32_t round_const = 16384; + vint32m4_t v_fpk_r_32 = __riscv_vsra_vx_i32m4( + __riscv_vadd_vx_i32m4(__riscv_vwmul_vx_i32m4(v_fpk_r, scale, vl), round_const, vl), 15, vl); + vint32m4_t v_fpk_i_32 = __riscv_vsra_vx_i32m4( + __riscv_vadd_vx_i32m4(__riscv_vwmul_vx_i32m4(v_fpk_i, scale, vl), round_const, vl), 15, vl); + vint16m2_t v_fpk_r_div2 = __riscv_vnclip_wx_i16m2(v_fpk_r_32, 0, __RISCV_VXRM_RNU, vl); + vint16m2_t v_fpk_i_div2 = __riscv_vnclip_wx_i16m2(v_fpk_i_32, 0, __RISCV_VXRM_RNU, vl); + + // Perform high-precision rounding division on fpnk (with negated imaginary part) + vint16m2_t v_fpnk_i_neg = __riscv_vneg_v_i16m2(v_fpnk_i_raw, vl); + vint32m4_t v_fpnk_r_32 = __riscv_vsra_vx_i32m4( + __riscv_vadd_vx_i32m4(__riscv_vwmul_vx_i32m4(v_fpnk_r_raw, scale, vl), round_const, vl), 15, vl); + vint32m4_t v_fpnk_i_32 = __riscv_vsra_vx_i32m4( + __riscv_vadd_vx_i32m4(__riscv_vwmul_vx_i32m4(v_fpnk_i_neg, scale, vl), round_const, vl), 15, vl); + vint16m2_t v_fpnk_r_div2 = __riscv_vnclip_wx_i16m2(v_fpnk_r_32, 0, __RISCV_VXRM_RNU, vl); + vint16m2_t v_fpnk_i_div2 = __riscv_vnclip_wx_i16m2(v_fpnk_i_32, 0, __RISCV_VXRM_RNU, vl); + + // Calculate intermediate values f1k (add) and f2k (subtract) + vint16m2_t v_f1k_r = __riscv_vadd_vv_i16m2(v_fpk_r_div2, v_fpnk_r_div2, vl); + vint16m2_t v_f1k_i = __riscv_vadd_vv_i16m2(v_fpk_i_div2, v_fpnk_i_div2, vl); + vint16m2_t v_f2k_r = __riscv_vsub_vv_i16m2(v_fpk_r_div2, v_fpnk_r_div2, vl); + vint16m2_t v_f2k_i = __riscv_vsub_vv_i16m2(v_fpk_i_div2, v_fpnk_i_div2, vl); + + // Perform complex multiplication + vint32m4_t v_ac = __riscv_vwmul_vv_i32m4(v_f2k_r, v_tw_r, vl); + vint32m4_t v_bd = __riscv_vwmul_vv_i32m4(v_f2k_i, v_tw_i, vl); + vint32m4_t v_ad = __riscv_vwmul_vv_i32m4(v_f2k_r, v_tw_i, vl); + vint32m4_t v_bc = __riscv_vwmul_vv_i32m4(v_f2k_i, v_tw_r, vl); + vint32m4_t v_tw_res_r_32 = __riscv_vssra_vx_i32m4(__riscv_vsub_vv_i32m4(v_ac, v_bd, vl), 15, __RISCV_VXRM_RNU, vl); + vint32m4_t v_tw_res_i_32 = __riscv_vssra_vx_i32m4(__riscv_vadd_vv_i32m4(v_ad, v_bc, vl), 15, __RISCV_VXRM_RNU, vl); + vint16m2_t v_tw_res_r = __riscv_vnclip_wx_i16m2(v_tw_res_r_32, 0, __RISCV_VXRM_RNU, vl); + vint16m2_t v_tw_res_i = __riscv_vnclip_wx_i16m2(v_tw_res_i_32, 0, __RISCV_VXRM_RNU, vl); + + // Calculate final output vectors + vint16m2_t v_out_k_r = __riscv_vsra_vx_i16m2(__riscv_vadd_vv_i16m2(v_f1k_r, v_tw_res_r, vl), 1, vl); + vint16m2_t v_out_k_i = __riscv_vsra_vx_i16m2(__riscv_vadd_vv_i16m2(v_f1k_i, v_tw_res_i, vl), 1, vl); + vint16m2_t v_out_nk_r = __riscv_vsra_vx_i16m2(__riscv_vsub_vv_i16m2(v_f1k_r, v_tw_res_r, vl), 1, vl); + vint16m2_t v_out_nk_i = __riscv_vsra_vx_i16m2(__riscv_vsub_vv_i16m2(v_tw_res_i, v_f1k_i, vl), 1, vl); + + // Store the results using a strided store (Forward) + __riscv_vsse16_v_i16m2(&freqdata_base_ptr[2 * k], stride, v_out_k_r, vl); + __riscv_vsse16_v_i16m2(&freqdata_base_ptr[2 * k + 1], stride, v_out_k_i, vl); + + // Store the results using a strided store (Reverse) + int16_t* out_nk_ptr = &freqdata_base_ptr[2 * (ncfft - k)]; + __riscv_vsse16_v_i16m2(out_nk_ptr, neg_stride, v_out_nk_r, vl); + __riscv_vsse16_v_i16m2(out_nk_ptr + 1, neg_stride, v_out_nk_i, vl); + + // Advance to the next vector chunk + k += vl; + } +} + +size_t RfftInt16GetNeededMemory(int32_t fft_length) { + size_t state_size = 0; + kiss_fft_fixed16::kiss_fftr_alloc(fft_length, 0, nullptr, &state_size); + return state_size; +} + +void* RfftInt16Init(int32_t fft_length, void* state, size_t state_size) { + return kiss_fft_fixed16::kiss_fftr_alloc(fft_length, 0, state, &state_size); +} + +void RfftInt16ApplyRVV(void* state, const int16_t* input, + Complex* output) { + kiss_fftr_rvv( + static_cast(state), + reinterpret_cast(input), + reinterpret_cast(output)); +} \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/signal/rfft_int16_rvv.h b/tensorflow/lite/micro/kernels/riscv_vector/signal/rfft_int16_rvv.h new file mode 100644 index 00000000000..dc9bef662e9 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/signal/rfft_int16_rvv.h @@ -0,0 +1,13 @@ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_SIGNAL_RFFT_RVV_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_SIGNAL_RFFT_RVV_H_ + +#include "tensorflow/lite/kernels/internal/common.h" + +size_t RfftInt16GetNeededMemory(int32_t fft_length); + +void* RfftInt16Init(int32_t fft_length, void* state, size_t state_size); + +void RfftInt16ApplyRVV(void* state, const int16_t* input, + Complex* output); + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_SIGNAL_RFFT_RVV_H_ \ No newline at end of file diff --git a/tensorflow/lite/micro/kernels/riscv_vector/softmax.cc b/tensorflow/lite/micro/kernels/riscv_vector/softmax.cc new file mode 100644 index 00000000000..c66afe58a65 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/softmax.cc @@ -0,0 +1,93 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/softmax.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/softmax.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_log.h" + +#include "tensorflow/lite/micro/kernels/riscv_vector/softmax_rvv.h" + +namespace tflite { + +namespace { + +void SoftmaxQuantized(const TfLiteEvalTensor* input, TfLiteEvalTensor* output, + const SoftmaxParams& op_data) { + if (input->type == kTfLiteInt8) { + if (output->type == kTfLiteInt16) { + SoftmaxRVV( + op_data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } else { + SoftmaxRVV( + op_data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } + } else { + tflite::reference_ops::SoftmaxInt16( + op_data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } +} + +TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); + TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); + + TFLITE_DCHECK(node->user_data != nullptr); + SoftmaxParams op_data = *static_cast(node->user_data); + + switch (input->type) { + case kTfLiteFloat32: { + tflite::reference_ops::Softmax( + op_data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + return kTfLiteOk; + } + case kTfLiteInt8: + case kTfLiteInt16: { + SoftmaxQuantized(input, output, op_data); + return kTfLiteOk; + } + default: + MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type), + input->type); + return kTfLiteError; + } +} +} // namespace + +TFLMRegistration Register_SOFTMAX() { + return tflite::micro::RegisterOp(SoftmaxInit, SoftmaxPrepare, SoftmaxEval); +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/riscv_vector/softmax_rvv.h b/tensorflow/lite/micro/kernels/riscv_vector/softmax_rvv.h new file mode 100644 index 00000000000..28f8fed7500 --- /dev/null +++ b/tensorflow/lite/micro/kernels/riscv_vector/softmax_rvv.h @@ -0,0 +1,365 @@ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_SOFTMAX_RVV_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_SOFTMAX_RVV_H_ + +#include + +#include +#include +#include + +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/cppmath.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/micro/kernels/softmax.h" +#include "tensorflow/lite/micro/micro_log.h" + +inline vint32m2_t SaturatingLeftShift_vx_i32m2(vint32m2_t v_in, int shift, size_t vl) +{ + // Return early if shift is zero or negative + if (shift <= 0) return v_in; + + // Handle extreme shifts that always saturate + if (shift >= 31) + { + vbool16_t v_neg = __riscv_vmslt_vx_i32m2_b16(v_in, 0, vl); + vint32m2_t v_max = __riscv_vmv_v_x_i32m2(INT32_MAX, vl); + return __riscv_vmerge_vxm_i32m2(v_max, INT32_MIN, v_neg, vl); + } + + // Perform the logical left shift + vint32m2_t v_shifted = __riscv_vsll_vx_i32m2(v_in, shift, vl); + + // Verify overflow by shifting back and comparing + vint32m2_t v_unshifted = __riscv_vsra_vx_i32m2(v_shifted, shift, vl); + vbool16_t v_no_overflow = __riscv_vmseq_vv_i32m2_b16(v_in, v_unshifted, vl); + + // Select saturating constants based on sign + vbool16_t v_neg = __riscv_vmslt_vx_i32m2_b16(v_in, 0, vl); + vint32m2_t v_sat = __riscv_vmerge_vxm_i32m2( + __riscv_vmv_v_x_i32m2(INT32_MAX, vl), INT32_MIN, v_neg, vl); + + // Merge valid results with saturated results + return __riscv_vmerge_vvm_i32m2(v_sat, v_shifted, v_no_overflow, vl); +} + +inline vint32m2_t MultiplyByQuantizedMultiplierGreaterThanOne_32bit_vx_i32m2( + vint32m2_t v_x, int32_t multiplier, int left_shift, size_t vl) +{ + // Calculate low 32 bits of product + vint32m2_t v_lo = __riscv_vmul_vx_i32m2(v_x, multiplier, vl); + + // Calculate high 32 bits of product + vint32m2_t v_hi = __riscv_vmulh_vx_i32m2(v_x, multiplier, vl); + + // Determine effective right shift amount + int total_right_shift = 31 - left_shift; + + // Calculate rounding nudge + int32_t nudge = 1 << (total_right_shift - 1); + + // Add nudge to low part treating as unsigned + vuint32m2_t v_lo_u = __riscv_vreinterpret_v_i32m2_u32m2(v_lo); + vuint32m2_t v_lo_plus_nudge = __riscv_vadd_vx_u32m2(v_lo_u, nudge, vl); + + // Detect carry from low part addition + vbool16_t v_carry = __riscv_vmsltu_vx_u32m2_b16(v_lo_plus_nudge, nudge, vl); + + // Apply carry to high part + vint32m2_t v_hi_rounded = __riscv_vadd_vx_i32m2_m(v_carry, v_hi, 1, vl); + + // Calculate shift amounts for recombination + int shift_hi = left_shift + 1; + int shift_lo = total_right_shift; + + // Shift high part (handling mod 32 behavior) + vint32m2_t v_res_from_hi; + if (shift_hi < 32) + { + v_res_from_hi = __riscv_vsll_vx_i32m2(v_hi_rounded, shift_hi, vl); + } + else + { + v_res_from_hi = __riscv_vmv_v_x_i32m2(0, vl); + } + + // Shift low part + vuint32m2_t v_res_from_lo = __riscv_vsrl_vx_u32m2(v_lo_plus_nudge, shift_lo, vl); + + // Combine results + return __riscv_vor_vv_i32m2( + v_res_from_hi, __riscv_vreinterpret_v_u32m2_i32m2(v_res_from_lo), vl); +} + +inline vint32m2_t SRMPOT_vx_i32m2(vint32m2_t v_vec, int shift, size_t vl) +{ + // Return early if shift is zero + if (shift == 0) return v_vec; + + // Handle positive shifts using saturating left shift + if (shift > 0) + { + return SaturatingLeftShift_vx_i32m2(v_vec, shift, vl); + } + else + { + // Perform rounding arithmetic right shift + return __riscv_vssra_vx_i32m2(v_vec, -shift, __RISCV_VXRM_RNU, vl); + } +} + +vint32m2_t vectorized_exp_on_negative_values(vint32m2_t v_a_q5_26, size_t vl) +{ + // Define fixed-point constants + const int kInputFractionalBits = 26; + const int kOutputFractionalBits = 31; + const int32_t s_kOneQuarter_q5_26 = INT32_C(1) << (kInputFractionalBits - 2); + const int32_t s_mask_val = s_kOneQuarter_q5_26 - 1; + + // Define Taylor Series Constants (Q0.31) + const int32_t s_result_one_q0_31 = INT32_MAX; + const int32_t s_exp_neg_1_8_q0_31 = 1895147668; + const int32_t s_one_third_q0_31 = 715827883; + const int32_t s_one_24th_q0_31 = 89478485; + const int32_t s_one_eighth_q0_31 = INT32_C(1) << (kOutputFractionalBits - 3); + + // Perform range reduction masking + vint32m2_t v_a_masked = __riscv_vand_vx_i32m2(v_a_q5_26, s_mask_val, vl); + + // Subtract quarter constant + vint32m2_t v_a_mod_q_m_q_q5_26 = __riscv_vsub_vx_i32m2(v_a_masked, s_kOneQuarter_q5_26, vl); + + // Rescale from Q5.26 to Q0.31 + const int rescale_shift = kOutputFractionalBits - kInputFractionalBits; + vint32m2_t v_a_input_taylor_q0_31 = SRMPOT_vx_i32m2(v_a_mod_q_m_q_q5_26, rescale_shift, vl); + + // Center input around -1/8 + vint32m2_t v_y = __riscv_vadd_vx_i32m2(v_a_input_taylor_q0_31, s_one_eighth_q0_31, vl); + + // Calculate polynomial terms using 32-bit saturating multiply + vint32m2_t v_y2 = __riscv_vsmul_vv_i32m2(v_y, v_y, __RISCV_VXRM_RNU, vl); + vint32m2_t v_y3 = __riscv_vsmul_vv_i32m2(v_y2, v_y, __RISCV_VXRM_RNU, vl); + vint32m2_t v_y4 = __riscv_vsmul_vv_i32m2(v_y2, v_y2, __RISCV_VXRM_RNU, vl); + + // Calculate coefficients + vint32m2_t v_term_y2_over_2 = SRMPOT_vx_i32m2(v_y2, -1, vl); + vint32m2_t v_term_y3_over_3 = __riscv_vsmul_vx_i32m2(v_y3, s_one_third_q0_31, __RISCV_VXRM_RNU, vl); + vint32m2_t v_term_y3_over_6 = SRMPOT_vx_i32m2(v_term_y3_over_3, -1, vl); + vint32m2_t v_term_y4_over_24 = __riscv_vsmul_vx_i32m2(v_y4, s_one_24th_q0_31, __RISCV_VXRM_RNU, vl); + + // Sum polynomial terms + vint32m2_t v_poly_sum = __riscv_vadd_vv_i32m2(v_y, v_term_y2_over_2, vl); + v_poly_sum = __riscv_vadd_vv_i32m2(v_poly_sum, v_term_y3_over_6, vl); + v_poly_sum = __riscv_vadd_vv_i32m2(v_poly_sum, v_term_y4_over_24, vl); + + // Apply constant term + vint32m2_t v_mul_term = __riscv_vsmul_vx_i32m2(v_poly_sum, s_exp_neg_1_8_q0_31, __RISCV_VXRM_RNU, vl); + vint32m2_t v_current_result = __riscv_vadd_vx_i32m2(v_mul_term, s_exp_neg_1_8_q0_31, vl); + + // Calculate remainder for barrel shifter + vint32m2_t v_remainder_q5_26 = __riscv_vsub_vv_i32m2(v_a_mod_q_m_q_q5_26, v_a_q5_26, vl); + + // Multipliers for reconstruction + const int32_t multipliers[] = {1672461947, 1302514674, 790015084, 290630308, 39332535, 720401, 242}; + + // Apply barrel shifter using unrolled loop + for (int i = 0; i < 7; ++i) + { + int exponent = i - 2; + int shift_amount = 26 + exponent; + if (shift_amount >= 0 && shift_amount < 32) + { + int32_t mask = 1 << shift_amount; + int32_t mult = multipliers[i]; + + vint32m2_t v_rem_masked = __riscv_vand_vx_i32m2(v_remainder_q5_26, mask, vl); + vbool16_t v_apply = __riscv_vmsne_vx_i32m2_b16(v_rem_masked, 0, vl); + + vint32m2_t v_multiplied = __riscv_vsmul_vx_i32m2(v_current_result, mult, __RISCV_VXRM_RNU, vl); + v_current_result = __riscv_vmerge_vvm_i32m2(v_current_result, v_multiplied, v_apply, vl); + } + } + + // Handle zero input case + vbool16_t v_zero_mask = __riscv_vmseq_vx_i32m2_b16(v_a_q5_26, 0, vl); + return __riscv_vmerge_vxm_i32m2(v_current_result, s_result_one_q0_31, v_zero_mask, vl); +} + +template +void SoftmaxRVV(const tflite::SoftmaxParams& params, + const tflite::RuntimeShape& input_shape, + const InputT* input_data, + const tflite::RuntimeShape& output_shape, OutputT* output_data) +{ + // Extract quantization parameters + const int32_t input_beta_multiplier = params.input_multiplier; + const int32_t input_beta_left_shift = params.input_left_shift; + const int diff_min = params.diff_min; + + // Define fixed-point constants + static const int kAccumulationIntegerBits = 12; + static const int kAccumulationFractionalBits = 32 - 1 - kAccumulationIntegerBits; + static const int kExpOutputFractionalBits = 31; + + // Extract shape dimensions + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = tflite::MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = tflite::MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + const size_t depth_sz = static_cast(depth); + + // Loop over outer dimensions + for (int i = 0; i < outer_size; ++i) + { + const InputT* current_input_data = input_data + i * depth; + OutputT* current_output_data = output_data + i * depth; + + // Find maximum value in the row + InputT max_in_row = std::numeric_limits::min(); + const InputT* ptr_max = current_input_data; + size_t n = depth_sz; + while (n > 0) + { + // Keep m1 for Max finding (low register pressure) + size_t vl = __riscv_vsetvl_e8m1(n); + if constexpr (std::is_signed_v) + { + vint8m1_t v_in = __riscv_vle8_v_i8m1(reinterpret_cast(ptr_max), vl); + vint8m1_t v_red = __riscv_vredmax_vs_i8m1_i8m1(v_in, __riscv_vmv_v_x_i8m1(max_in_row, vl), vl); + max_in_row = std::max(max_in_row, __riscv_vmv_x_s_i8m1_i8(v_red)); + } + else + { + vuint8m1_t v_in = __riscv_vle8_v_u8m1(reinterpret_cast(ptr_max), vl); + vuint8m1_t v_red = __riscv_vredmaxu_vs_u8m1_u8m1(v_in, __riscv_vmv_v_x_u8m1(max_in_row, vl), vl); + max_in_row = std::max(max_in_row, (InputT)__riscv_vmv_x_s_u8m1_u8(v_red)); + } + ptr_max += vl; + n -= vl; + } + const int32_t max_in_row_s32 = static_cast(max_in_row); + + // Accumulate sum of exponentials + size_t current_c = 0; + vint32m1_t v_sum_acc = __riscv_vmv_v_x_i32m1(0, 1); + + while (current_c < depth_sz) + { + // OPT: Use m2 to reduce register pressure in the exp() call + size_t vl = __riscv_vsetvl_e32m2(depth_sz - current_c); + + // Load and widen input without 64-bit instructions + vint32m2_t v_input_s32; + if constexpr (std::is_signed_v) + { + // Load mf2 (8-bit) matches m2 (32-bit) element count + vint8mf2_t v_in = __riscv_vle8_v_i8mf2(reinterpret_cast(current_input_data + current_c), vl); + vint16m1_t v_in_16 = __riscv_vsext_vf2_i16m1(v_in, vl); + v_input_s32 = __riscv_vsext_vf2_i32m2(v_in_16, vl); + } + else + { + vuint8mf2_t v_in = __riscv_vle8_v_u8mf2(reinterpret_cast(current_input_data + current_c), vl); + vuint16m1_t v_in_16 = __riscv_vzext_vf2_u16m1(v_in, vl); + vuint32m2_t v_in_32 = __riscv_vzext_vf2_u32m2(v_in_16, vl); + v_input_s32 = __riscv_vreinterpret_v_u32m2_i32m2(v_in_32); + } + + // Calculate difference from max + vint32m2_t v_diff = __riscv_vsub_vx_i32m2(v_input_s32, max_in_row_s32, vl); + vbool16_t v_mask = __riscv_vmsge_vx_i32m2_b16(v_diff, diff_min, vl); + + // Scale difference using custom 32-bit implementation + vint32m2_t v_diff_scaled = MultiplyByQuantizedMultiplierGreaterThanOne_32bit_vx_i32m2( + v_diff, input_beta_multiplier, input_beta_left_shift, vl); + + // Calculate exponential + vint32m2_t v_exp = vectorized_exp_on_negative_values(v_diff_scaled, vl); + + // Rescale result + vint32m2_t v_exp_rescaled = __riscv_vssra_vx_i32m2(v_exp, kExpOutputFractionalBits - kAccumulationFractionalBits, __RISCV_VXRM_RNU, vl); + + // Merge and accumulate + vint32m2_t v_add_val = __riscv_vmerge_vvm_i32m2(__riscv_vmv_v_x_i32m2(0, vl), v_exp_rescaled, v_mask, vl); + + // Reduce m2 vector to scalar + v_sum_acc = __riscv_vredsum_vs_i32m2_i32m1(v_add_val, v_sum_acc, vl); + + current_c += vl; + } + int32_t sum_of_exps = __riscv_vmv_x_s_i32m1_i32(v_sum_acc); + + // Calculate reciprocal + int num_bits_over_unit; + int32_t reciprocal = tflite::GetReciprocal(sum_of_exps, kAccumulationIntegerBits, &num_bits_over_unit); + const int exponent = num_bits_over_unit + 31 - (sizeof(OutputT) * 8); + const int32_t output_min = static_cast(std::numeric_limits::min()); + const int32_t output_max = static_cast(std::numeric_limits::max()); + + // Compute final output + current_c = 0; + while (current_c < depth_sz) + { + // OPT: m2 + size_t vl = __riscv_vsetvl_e32m2(depth_sz - current_c); + + // Reload and widen input + vint32m2_t v_input_s32; + if constexpr (std::is_signed_v) + { + vint8mf2_t v_in = __riscv_vle8_v_i8mf2(reinterpret_cast(current_input_data + current_c), vl); + v_input_s32 = __riscv_vsext_vf2_i32m2(__riscv_vsext_vf2_i16m1(v_in, vl), vl); + } + else + { + vuint8mf2_t v_in = __riscv_vle8_v_u8mf2(reinterpret_cast(current_input_data + current_c), vl); + v_input_s32 = __riscv_vreinterpret_v_u32m2_i32m2(__riscv_vzext_vf2_u32m2(__riscv_vzext_vf2_u16m1(v_in, vl), vl)); + } + + // Recompute difference and mask + vint32m2_t v_diff = __riscv_vsub_vx_i32m2(v_input_s32, max_in_row_s32, vl); + vbool16_t v_mask = __riscv_vmsge_vx_i32m2_b16(v_diff, diff_min, vl); + + // Scale and exponentiate + vint32m2_t v_diff_scaled = MultiplyByQuantizedMultiplierGreaterThanOne_32bit_vx_i32m2( + v_diff, input_beta_multiplier, input_beta_left_shift, vl); + vint32m2_t v_exp = vectorized_exp_on_negative_values(v_diff_scaled, vl); + + // Multiply by reciprocal using 32-bit saturating multiply + vint32m2_t v_prod = __riscv_vsmul_vx_i32m2(v_exp, reciprocal, __RISCV_VXRM_RNU, vl); + + // Perform final shift and add offset + vint32m2_t v_out_shifted = __riscv_vssra_vx_i32m2(v_prod, exponent, __RISCV_VXRM_RNU, vl); + vint32m2_t v_out_final = __riscv_vadd_vx_i32m2(v_out_shifted, output_min, vl); + + // Clamp result + v_out_final = __riscv_vmax_vx_i32m2(v_out_final, output_min, vl); + v_out_final = __riscv_vmin_vx_i32m2(v_out_final, output_max, vl); + + // Apply mask using vector merge + v_out_final = __riscv_vmerge_vvm_i32m2(__riscv_vmv_v_x_i32m2(output_min, vl), v_out_final, v_mask, vl); + + // Narrow and store result + if constexpr (sizeof(OutputT) == 1) + { + if constexpr (std::is_signed_v) + { + // Narrow: m2 -> m1 -> mf2 + vint8mf2_t v_store = __riscv_vncvt_x_x_w_i8mf2(__riscv_vncvt_x_x_w_i16m1(v_out_final, vl), vl); + __riscv_vse8_v_i8mf2(reinterpret_cast(current_output_data + current_c), v_store, vl); + } + else + { + vuint8mf2_t v_store = __riscv_vncvt_x_x_w_u8mf2(__riscv_vncvt_x_x_w_u16m1(__riscv_vreinterpret_v_i32m2_u32m2(v_out_final), vl), vl); + __riscv_vse8_v_u8mf2(reinterpret_cast(current_output_data + current_c), v_store, vl); + } + } + else + { + vint16m1_t v_store = __riscv_vncvt_x_x_w_i16m1(v_out_final, vl); + __riscv_vse16_v_i16m1(reinterpret_cast(current_output_data + current_c), v_store, vl); + } + current_c += vl; + } + } +} + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_RISCV_VECTOR_SOFTMAX_RVV_H_ \ No newline at end of file diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 21f21a1ce05..60b8181087c 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -620,8 +620,6 @@ include $(MAKEFILE_DIR)/additional_kernels.inc MICROLITE_CC_SRCS := $(filter-out $(MICROLITE_TEST_SRCS), $(MICROLITE_CC_BASE_SRCS)) MICROLITE_CC_SRCS := $(filter-out $(MICROLITE_BENCHMARK_SRCS), $(MICROLITE_CC_SRCS)) - - # The download scripts require that the downloads directory already exist for # improved error checking. To accomodate that, we first create a downloads # directory. diff --git a/tensorflow/lite/micro/tools/make/ext_libs/riscv_vector.inc b/tensorflow/lite/micro/tools/make/ext_libs/riscv_vector.inc new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorflow/lite/micro/tools/make/targets/riscv32_vector_makefile.inc b/tensorflow/lite/micro/tools/make/targets/riscv32_vector_makefile.inc new file mode 100644 index 00000000000..ba3e4a2cc33 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/riscv32_vector_makefile.inc @@ -0,0 +1,92 @@ +# Settings for RISCV 32-bit toolchain. +TARGET_ARCH := riscv32 +TARGET_TOOLCHAIN_PREFIX := riscv32-unknown-elf- + +RISCV_ARCH := rv32imc_zve32x_zvl128b +RISCV_ABI := ilp32 +RISCV_CODE_MODEL := medany + +# Allow additional flags on the command line for debugging. +RISCV_EXTRA_CFLAGS := + +TARGET_DEFAULT_TOOLCHAIN_ROOT := $(DOWNLOADS_DIR)/riscv_toolchain/bin/ +TARGET_TOOLCHAIN_ROOT := $(TARGET_DEFAULT_TOOLCHAIN_ROOT) +ifeq ($(TARGET_TOOLCHAIN_ROOT), $(TARGET_DEFAULT_TOOLCHAIN_ROOT)) + $(eval $(call add_third_party_download,$(RISCV_TOOLCHAIN_URL),$(RISCV_TOOLCHAIN_MD5),riscv_toolchain,)) +endif + +export PATH := $(TARGET_TOOLCHAIN_ROOT):$(PATH) + +PLATFORM_FLAGS = \ + -march=$(RISCV_ARCH) \ + -mabi=$(RISCV_ABI) \ + -mcmodel=$(RISCV_CODE_MODEL) \ + -mexplicit-relocs \ + -fno-builtin-printf \ + -DTF_LITE_MCU_DEBUG_LOG \ + -DTF_LITE_USE_GLOBAL_CMATH_FUNCTIONS \ + -funsigned-char \ + -fno-delete-null-pointer-checks \ + -fomit-frame-pointer \ + -DTFLM_USE_RISCV_VECTOR + +CXXFLAGS += $(PLATFORM_FLAGS) \ + -fpermissive \ + -fno-use-cxa-atexit \ + -DTF_LITE_USE_GLOBAL_MIN \ + -DTF_LITE_USE_GLOBAL_MAX + +CCFLAGS += $(PLATFORM_FLAGS) + +BUILD_TYPE := micro + +LDFLAGS += --specs=nano.specs + +# See http://b/15851472 for why memory arena threshold test is disabled. +EXCLUDED_TESTS := \ + $(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_arena_threshold_test.cc + +MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) + +CCFLAGS += $(RISCV_EXTRA_CFLAGS) +CXXFLAGS += $(RISCV_EXTRA_CFLAGS) + +# This disables the "linker relaxation" optimization, which produced incorrect code. +# TODO(b/279805615): Check whether this is fixed in newer versions of the toolchain. +LDFLAGS += -mno-relax +TEST_SCRIPT := $(TENSORFLOW_ROOT)tensorflow/lite/micro/testing/test_with_spike.sh +SIZE_SCRIPT := ${TENSORFLOW_ROOT}tensorflow/lite/micro/testing/size_riscv32_binary.sh + +include $(MAKEFILE_DIR)/ext_libs/eyalroz_printf.inc + +MICROLITE_CC_SRCS += \ + tensorflow/lite/micro/kernels/riscv_vector/conv_rvv.cc \ + tensorflow/lite/micro/kernels/riscv_vector/fully_connected_rvv.cc \ + tensorflow/lite/micro/kernels/riscv_vector/conv.cc \ + tensorflow/lite/micro/kernels/riscv_vector/depthwise_conv.cc \ + tensorflow/lite/micro/kernels/riscv_vector/fully_connected.cc \ + tensorflow/lite/micro/kernels/riscv_vector/pooling.cc \ + tensorflow/lite/micro/kernels/riscv_vector/pooling_rvv.cc \ + tensorflow/lite/micro/kernels/riscv_vector/signal/rfft.cc\ + tensorflow/lite/micro/kernels/riscv_vector/signal/rfft_int16_rvv.cc \ + tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank.cc \ + tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_rvv.cc \ + tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_log.cc \ + tensorflow/lite/micro/kernels/riscv_vector/signal/filter_bank_log_rvv.cc \ + tensorflow/lite/micro/kernels/riscv_vector/softmax.cc \ + +EXCLUDED_SRCS := \ + tensorflow/lite/micro/kernels/conv.cc \ + tensorflow/lite/micro/kernels/depthwise_conv.cc \ + tensorflow/lite/micro/kernels/fully_connected.cc \ + tensorflow/lite/micro/kernels/pooling.cc\ + signal/micro/kernels/rfft.cc \ + signal/src/rfft_int16.cc \ + signal/src/kiss_fft_wrappers/kiss_fft_int16.cc \ + signal/micro/kernels/filter_bank.cc \ + signal/src/filter_bank.cc \ + signal/src/filter_bank_log.cc \ + signal/micro/kernels/filter_bank_log.cc \ + tensorflow/lite/micro/kernels/softmax.cc \ + +