2222 * Title: muriscv_nn_support_functions.h
2323 * Description: Public header file of support functions for MURISCV NN Library
2424 *
25- * $Date: 27 May 2024
26- * $Revision: V.22.1 .0
25+ * $Date: 08 October 2024
26+ * $Revision: V.22.4 .0
2727 *
2828 * Target : Arm(R) M-Profile Architecture
2929 * -------------------------------------------------------------------- */
@@ -849,7 +849,48 @@ muriscv_nn_status muriscv_nn_vec_mat_mult_t_s8(const int8_t *lhs,
849849 const int32_t rhs_offset );
850850
851851/**
852- * @brief s16 Vector by Matrix (transposed) multiplication
852+ * @brief s8 Vector by Matrix (transposed) multiplication using per channel quantization for output
853+ *
854+ * @param[in] lhs Input left-hand side vector
855+ * @param[in] rhs Input right-hand side matrix (transposed)
856+ * @param[in] kernel_sum Kernel sums of the kernels (rhs). See muriscv_nn_vector_sum_s8 for more info.
857+ * @param[in] bias Input bias
858+ * @param[out] dst Output vector
859+ * @param[in] lhs_offset Offset to be added to the input values of the left-hand side vector.
860+ * Range: -127 to 128
861+ * @param[in] dst_offset Offset to be added to the output values. Range: -127 to 128
862+ * @param[in] dst_multiplier Output multipliers
863+ * @param[in] dst_shift Output shifts
864+ * @param[in] rhs_cols Number of columns in the right-hand side input matrix
865+ * @param[in] rhs_rows Number of rows in the right-hand side input matrix
866+ * @param[in] activation_min Minimum value to clamp the output to. Range: int8
867+ * @param[in] activation_max Maximum value to clamp the output to. Range: int8
868+ * @param[in] address_offset Memory position offset for dst. First output is stored at 'dst', the
869+ * second at 'dst + address_offset' and so on. Default value is typically 1.
870+ * @param[in] rhs_offset Offset to be added to the input values of the right-hand side vector.
871+ * Range: -127 to 128
872+ *
873+ * @return The function returns <code>MURISCV_NN_SUCCESS</code>
874+ *
875+ */
876+ muriscv_nn_status muriscv_nn_vec_mat_mult_t_per_ch_s8 (const int8_t * lhs ,
877+ const int8_t * rhs ,
878+ const int32_t * kernel_sum ,
879+ const int32_t * bias ,
880+ int8_t * dst ,
881+ const int32_t lhs_offset ,
882+ const int32_t dst_offset ,
883+ const int32_t * dst_multiplier ,
884+ const int32_t * dst_shift ,
885+ const int32_t rhs_cols ,
886+ const int32_t rhs_rows ,
887+ const int32_t activation_min ,
888+ const int32_t activation_max ,
889+ const int32_t address_offset ,
890+ const int32_t rhs_offset );
891+
892+ /**
893+ * @brief s16 Vector by s8 Matrix (transposed) multiplication
853894 *
854895 * @param[in] lhs Input left-hand side vector
855896 * @param[in] rhs Input right-hand side matrix (transposed)
@@ -876,6 +917,34 @@ muriscv_nn_status muriscv_nn_vec_mat_mult_t_s16(const int16_t *lhs,
876917 const int32_t activation_min ,
877918 const int32_t activation_max );
878919
920+ /**
921+ * @brief s16 Vector by s16 Matrix (transposed) multiplication
922+ *
923+ * @param[in] lhs Input left-hand side vector
924+ * @param[in] rhs Input right-hand side matrix (transposed)
925+ * @param[in] bias Input bias
926+ * @param[out] dst Output vector
927+ * @param[in] dst_multiplier Output multiplier
928+ * @param[in] dst_shift Output shift
929+ * @param[in] rhs_cols Number of columns in the right-hand side input matrix
930+ * @param[in] rhs_rows Number of rows in the right-hand side input matrix
931+ * @param[in] activation_min Minimum value to clamp the output to. Range: int16
932+ * @param[in] activation_max Maximum value to clamp the output to. Range: int16
933+ *
934+ * @return The function returns <code>MURISCV_NN_SUCCESS</code>
935+ *
936+ */
937+ muriscv_nn_status muriscv_nn_vec_mat_mult_t_s16_s16 (const int16_t * lhs ,
938+ const int16_t * rhs ,
939+ const int64_t * bias ,
940+ int16_t * dst ,
941+ const int32_t dst_multiplier ,
942+ const int32_t dst_shift ,
943+ const int32_t rhs_cols ,
944+ const int32_t rhs_rows ,
945+ const int32_t activation_min ,
946+ const int32_t activation_max );
947+
879948/**
880949 * @brief s8 Vector by Matrix (transposed) multiplication with s16 output
881950 *
@@ -2488,6 +2557,26 @@ muriscv_nn_status muriscv_nn_elementwise_mul_acc_s16(const int16_t *input_1_vect
24882557 const int32_t out_activation_max ,
24892558 const int32_t block_size );
24902559
2560+ /**
2561+ * @brief Check if a broadcast is required between 2 muriscv_nn_dims.
2562+ * @param[in] shape_1 pointer to input tensor 1
2563+ * @param[in] shape_2 pointer to input tensor 2
2564+ * @return The function returns 1 if a broadcast is required, or 0 if not.
2565+ *
2566+ * @details Compares each dimension and returns 1 if any dimension does not match.
2567+ * This function does not check that broadcast rules are met.
2568+ */
2569+ __STATIC_FORCEINLINE int32_t muriscv_nn_check_broadcast_required (const muriscv_nn_dims * shape_1 , const muriscv_nn_dims * shape_2 )
2570+ {
2571+ if ((shape_1 -> n != shape_2 -> n ) || (shape_1 -> h != shape_2 -> h ) || (shape_1 -> w != shape_2 -> w ) ||
2572+ (shape_1 -> c != shape_2 -> c ))
2573+ {
2574+ return 1 ;
2575+ }
2576+
2577+ return 0 ;
2578+ }
2579+
24912580#ifdef __cplusplus
24922581}
24932582#endif
0 commit comments