@@ -16,22 +16,12 @@ limitations under the License.
1616#ifndef TENSORFLOW_LITE_MICRO_KERNELS_BATCH_MATMUL_H_
1717#define TENSORFLOW_LITE_MICRO_KERNELS_BATCH_MATMUL_H_
1818
19- #include < cstdint>
20-
2119#include " tensorflow/lite/c/builtin_op_data.h"
22- #include " tensorflow/lite/kernels/internal/reference/transpose.h"
2320#include " tensorflow/lite/kernels/internal/types.h"
24- #include " tensorflow/lite/kernels/kernel_util.h"
25- #include " tensorflow/lite/micro/kernels/kernel_util.h"
2621#include " tensorflow/lite/micro/micro_common.h"
27- #include " tensorflow/lite/micro/micro_log.h"
2822
2923namespace tflite {
3024
31- extern constexpr int kBatchMatmulInputLhsTensor = 0 ;
32- extern constexpr int kBatchMatmulInputRhsTensor = 1 ;
33- extern constexpr int kBatchMatmulOutputTensor = 0 ;
34-
3525struct QuantizationOpDataBatchMatmul {
3626 // The scaling factor from input to output (aka the 'real multiplier') can
3727 // be represented as a fixed point multiplier plus a left shift.
@@ -59,98 +49,29 @@ struct OpDataBatchMatmul {
5949 bool rhs_is_constant_tensor;
6050};
6151
52+ extern const int kBatchMatmulInputLhsTensor ;
53+ extern const int kBatchMatmulInputRhsTensor ;
54+ extern const int kBatchMatmulOutputTensor ;
55+
6256TfLiteStatus ReshapeOutputTensor (TfLiteContext* context, TfLiteNode* node,
6357 const RuntimeShape& extended_lhs_shape,
6458 const RuntimeShape& extended_rhs_shape,
6559 bool adj_x, bool adj_y, int output_rank,
66- TfLiteTensor* output) {
67- int64_t orig_size = NumElements (output);
68-
69- // make sure the new output dims rank does not exceed the original rank
70- TF_LITE_ENSURE (context, output_rank <= NumDimensions (output));
71-
72- // make sure output tensor dims are not in the FlatBuffer
73- TfLiteEvalTensor* output_eval =
74- tflite::micro::GetEvalOutput (context, node, kBatchMatmulOutputTensor );
75- TF_LITE_ENSURE_OK (context, tflite::micro::CreateWritableTensorDimsWithCopy (
76- context, output, output_eval));
77-
78- // Fill in any broadcast dimensions.
79- for (int i = 0 ; i < output_rank - 2 ; ++i) {
80- const int lhs_dim = extended_lhs_shape.Dims (i);
81- const int rhs_dim = extended_rhs_shape.Dims (i);
82- int broadcast_dim = lhs_dim;
83- if ((lhs_dim != rhs_dim) && (lhs_dim == 1 )) {
84- broadcast_dim = rhs_dim;
85- }
86- output->dims ->data [i] = broadcast_dim;
87- }
88- // Fill in the matmul dimensions.
89- int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2 ;
90- int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1 ;
91-
92- output->dims ->data [output_rank - 2 ] = extended_lhs_shape.Dims (lhs_rows_index);
93- output->dims ->data [output_rank - 1 ] = extended_rhs_shape.Dims (rhs_cols_index);
94- output->dims ->size = output_rank;
95-
96- // Check that output tensor has not been resized
97- // since TFLM doesn't support tensor resizing.
98- TF_LITE_ENSURE_EQ (context, orig_size, NumElements (output));
99-
100- return kTfLiteOk ;
101- }
60+ TfLiteTensor* output);
10261
10362template <typename T>
10463void TransposeRowsColumnsImpl (const TfLiteEvalTensor& tensor_in,
105- TfLiteEvalTensor* tensor_out) {
106- const T* input = tflite::micro::GetTensorData<T>(&tensor_in);
107- T* output = tflite::micro::GetTensorData<T>(tensor_out);
108- RuntimeShape transposed_shape (tflite::micro::GetTensorShape (&tensor_in));
109- RuntimeShape shape (transposed_shape);
110- TransposeParams params;
111- const int rank = shape.DimensionsCount ();
112- params.perm_count = rank;
113- for (int i = 0 ; i < rank - 2 ; ++i) {
114- params.perm [i] = i;
115- }
116- // Transpose the last two dimensions.
117- params.perm [rank - 2 ] = rank - 1 ;
118- params.perm [rank - 1 ] = rank - 2 ;
119- transposed_shape.SetDim (rank - 1 , shape.Dims (rank - 2 ));
120- transposed_shape.SetDim (rank - 2 , shape.Dims (rank - 1 ));
121- reference_ops::Transpose (params, shape, input, transposed_shape, output);
122- }
64+ TfLiteEvalTensor* tensor_out);
12365
12466TfLiteStatus TransposeRowsColumns (const TfLiteEvalTensor& tensor_in,
125- TfLiteEvalTensor* tensor_out) {
126- if (tensor_in.type == kTfLiteFloat32 ) {
127- TransposeRowsColumnsImpl<float >(tensor_in, tensor_out);
128- return kTfLiteOk ;
129- } else if (tensor_in.type == kTfLiteInt8 ) {
130- TransposeRowsColumnsImpl<int8_t >(tensor_in, tensor_out);
131- return kTfLiteOk ;
132- } else if (tensor_in.type == kTfLiteInt16 ) {
133- TransposeRowsColumnsImpl<int16_t >(tensor_in, tensor_out);
134- return kTfLiteOk ;
135- } else {
136- MicroPrintf (
137- " BATCH_MATMUL can only transpose tensors with FLOAT32, INT8, INT16 "
138- " type." );
139- }
140- return kTfLiteError ;
141- }
67+ TfLiteEvalTensor* tensor_out);
14268
143- RuntimeShape SwapRowColumnDims (const RuntimeShape& shape) {
144- RuntimeShape swapped_shape (shape);
145- const int32_t dims = shape.DimensionsCount ();
146- swapped_shape.SetDim (dims - 2 , shape.Dims (dims - 1 ));
147- swapped_shape.SetDim (dims - 1 , shape.Dims (dims - 2 ));
148- return swapped_shape;
149- }
69+ RuntimeShape SwapRowColumnDims (const RuntimeShape& shape);
15070
15171TFLMRegistration Register_BATCH_MATMUL ();
15272
15373#if defined(CMSIS_NN)
74+
15475// Returns a TFLMRegistration struct for kernel variant that only supports
15576// int8 matrix multiplication and uses the latency optimized
15677// implementations.
0 commit comments