Skip to content

Commit 740cef3

Browse files
authored
CMSIS-NN Min Max int8 support (#2753)
* Moves common functions to new maximum_minimum.h * Creates cmsis-nn/maximum_minimum.cc BUG=#2752 Change-Id: Ifbb3fedf53043b2f8d4c48d73c2ca44c7f0f87ca
1 parent 9b79b9f commit 740cef3

File tree

5 files changed

+367
-58
lines changed

5 files changed

+367
-58
lines changed

tensorflow/lite/micro/kernels/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ tflm_kernel_cc_library(
333333
"logistic.h",
334334
"lstm_eval.h",
335335
"lstm_shared.h",
336+
"maximum_minimum.h",
336337
"micro_ops.h",
337338
"mul.h",
338339
"pad.h",
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/lite/micro/kernels/maximum_minimum.h"
17+
18+
#include "Include/arm_nnfunctions.h"
19+
#include "tensorflow/lite/c/builtin_op_data.h"
20+
#include "tensorflow/lite/c/common.h"
21+
#include "tensorflow/lite/kernels/internal/common.h"
22+
#include "tensorflow/lite/kernels/internal/quantization_util.h"
23+
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24+
#include "tensorflow/lite/kernels/kernel_util.h"
25+
#include "tensorflow/lite/kernels/op_macros.h"
26+
#include "tensorflow/lite/micro/kernels/kernel_util.h"
27+
#include "tensorflow/lite/micro/micro_log.h"
28+
29+
namespace tflite {
30+
31+
namespace {
32+
33+
cmsis_nn_dims FillVariableShape(int32_t rank, int32_t* tensor_dims) {
34+
if (rank == 4) {
35+
return {tensor_dims[0], tensor_dims[1], tensor_dims[2], tensor_dims[3]};
36+
} else if (rank == 3) {
37+
return {1, tensor_dims[0], tensor_dims[1], tensor_dims[2]};
38+
} else if (rank == 2) {
39+
return {1, 1, tensor_dims[0], tensor_dims[1]};
40+
} else {
41+
return {1, 1, 1, 1};
42+
}
43+
}
44+
45+
TfLiteStatus EvalMaximum(TfLiteContext* context, TfLiteNode* node) {
46+
OpContext op_context(context, node);
47+
const TfLiteEvalTensor* input1 =
48+
tflite::micro::GetEvalInput(context, node, kInputTensor1);
49+
const TfLiteEvalTensor* input2 =
50+
tflite::micro::GetEvalInput(context, node, kInputTensor2);
51+
TfLiteEvalTensor* output =
52+
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
53+
54+
RuntimeShape input_1_shape = tflite::micro::GetTensorShape(input1);
55+
RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2);
56+
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
57+
58+
cmsis_nn_dims input_1_dims = FillVariableShape(
59+
input_1_shape.DimensionsCount(), input_1_shape.DimsData());
60+
cmsis_nn_dims input_2_dims = FillVariableShape(
61+
input_2_shape.DimensionsCount(), input_2_shape.DimsData());
62+
cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(),
63+
output_shape.DimsData());
64+
65+
switch (op_context.output->type) {
66+
case kTfLiteInt8:
67+
cmsis_nn_context ctx;
68+
ctx.buf = nullptr;
69+
ctx.size = 0;
70+
71+
arm_maximum_s8(
72+
&ctx, tflite::micro::GetTensorData<int8_t>(input1), &input_1_dims,
73+
tflite::micro::GetTensorData<int8_t>(input2), &input_2_dims,
74+
tflite::micro::GetTensorData<int8_t>(output), &output_dims);
75+
break;
76+
case kTfLiteFloat32:
77+
TFLiteOperation<float, MaximumOp>(context, node, op_context);
78+
break;
79+
case kTfLiteInt16:
80+
TFLiteOperation<int16_t, MaximumOp>(context, node, op_context);
81+
break;
82+
case kTfLiteInt32:
83+
TFLiteOperation<int32_t, MaximumOp>(context, node, op_context);
84+
break;
85+
case kTfLiteInt64:
86+
TFLiteOperation<int64_t, MaximumOp>(context, node, op_context);
87+
break;
88+
default:
89+
MicroPrintf("Type %s (%d) is not supported by Maximum/Minimum.",
90+
TfLiteTypeGetName(op_context.output->type),
91+
op_context.output->type);
92+
return kTfLiteError;
93+
}
94+
return kTfLiteOk;
95+
}
96+
97+
TfLiteStatus EvalMaximumInt8(TfLiteContext* context, TfLiteNode* node) {
98+
OpContext op_context(context, node);
99+
const TfLiteEvalTensor* input1 =
100+
tflite::micro::GetEvalInput(context, node, kInputTensor1);
101+
const TfLiteEvalTensor* input2 =
102+
tflite::micro::GetEvalInput(context, node, kInputTensor2);
103+
TfLiteEvalTensor* output =
104+
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
105+
106+
RuntimeShape input_1_shape = tflite::micro::GetTensorShape(input1);
107+
RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2);
108+
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
109+
110+
cmsis_nn_dims input_1_dims = FillVariableShape(
111+
input_1_shape.DimensionsCount(), input_1_shape.DimsData());
112+
cmsis_nn_dims input_2_dims = FillVariableShape(
113+
input_2_shape.DimensionsCount(), input_2_shape.DimsData());
114+
cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(),
115+
output_shape.DimsData());
116+
117+
switch (op_context.output->type) {
118+
case kTfLiteInt8:
119+
cmsis_nn_context ctx;
120+
ctx.buf = nullptr;
121+
ctx.size = 0;
122+
123+
arm_maximum_s8(
124+
&ctx, tflite::micro::GetTensorData<int8_t>(input1), &input_1_dims,
125+
tflite::micro::GetTensorData<int8_t>(input2), &input_2_dims,
126+
tflite::micro::GetTensorData<int8_t>(output), &output_dims);
127+
break;
128+
default:
129+
MicroPrintf("Type %s (%d) is not supported by Maximum Int8 Registration.",
130+
TfLiteTypeGetName(op_context.output->type),
131+
op_context.output->type);
132+
return kTfLiteError;
133+
}
134+
return kTfLiteOk;
135+
}
136+
137+
TfLiteStatus EvalMinimum(TfLiteContext* context, TfLiteNode* node) {
138+
OpContext op_context(context, node);
139+
const TfLiteEvalTensor* input1 =
140+
tflite::micro::GetEvalInput(context, node, kInputTensor1);
141+
const TfLiteEvalTensor* input2 =
142+
tflite::micro::GetEvalInput(context, node, kInputTensor2);
143+
TfLiteEvalTensor* output =
144+
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
145+
146+
RuntimeShape input_1_shape = tflite::micro::GetTensorShape(input1);
147+
RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2);
148+
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
149+
150+
cmsis_nn_dims input_1_dims = FillVariableShape(
151+
input_1_shape.DimensionsCount(), input_1_shape.DimsData());
152+
cmsis_nn_dims input_2_dims = FillVariableShape(
153+
input_2_shape.DimensionsCount(), input_2_shape.DimsData());
154+
cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(),
155+
output_shape.DimsData());
156+
157+
switch (op_context.output->type) {
158+
case kTfLiteInt8:
159+
cmsis_nn_context ctx;
160+
ctx.buf = nullptr;
161+
ctx.size = 0;
162+
163+
arm_minimum_s8(
164+
&ctx, tflite::micro::GetTensorData<int8_t>(input1), &input_1_dims,
165+
tflite::micro::GetTensorData<int8_t>(input2), &input_2_dims,
166+
tflite::micro::GetTensorData<int8_t>(output), &output_dims);
167+
break;
168+
case kTfLiteFloat32:
169+
TFLiteOperation<float, MinimumOp>(context, node, op_context);
170+
break;
171+
case kTfLiteInt16:
172+
TFLiteOperation<int16_t, MinimumOp>(context, node, op_context);
173+
break;
174+
case kTfLiteInt32:
175+
TFLiteOperation<int32_t, MinimumOp>(context, node, op_context);
176+
break;
177+
case kTfLiteInt64:
178+
TFLiteOperation<int64_t, MinimumOp>(context, node, op_context);
179+
break;
180+
default:
181+
MicroPrintf("Type %s (%d) is not supported by Maximum/Minimum.",
182+
TfLiteTypeGetName(op_context.output->type),
183+
op_context.output->type);
184+
return kTfLiteError;
185+
}
186+
return kTfLiteOk;
187+
}
188+
189+
TfLiteStatus EvalMinimumInt8(TfLiteContext* context, TfLiteNode* node) {
190+
OpContext op_context(context, node);
191+
const TfLiteEvalTensor* input1 =
192+
tflite::micro::GetEvalInput(context, node, kInputTensor1);
193+
const TfLiteEvalTensor* input2 =
194+
tflite::micro::GetEvalInput(context, node, kInputTensor2);
195+
TfLiteEvalTensor* output =
196+
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
197+
198+
RuntimeShape input_1_shape = tflite::micro::GetTensorShape(input1);
199+
RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2);
200+
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
201+
202+
cmsis_nn_dims input_1_dims = FillVariableShape(
203+
input_1_shape.DimensionsCount(), input_1_shape.DimsData());
204+
cmsis_nn_dims input_2_dims = FillVariableShape(
205+
input_2_shape.DimensionsCount(), input_2_shape.DimsData());
206+
cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(),
207+
output_shape.DimsData());
208+
209+
switch (op_context.output->type) {
210+
case kTfLiteInt8:
211+
cmsis_nn_context ctx;
212+
ctx.buf = nullptr;
213+
ctx.size = 0;
214+
215+
arm_minimum_s8(
216+
&ctx, tflite::micro::GetTensorData<int8_t>(input1), &input_1_dims,
217+
tflite::micro::GetTensorData<int8_t>(input2), &input_2_dims,
218+
tflite::micro::GetTensorData<int8_t>(output), &output_dims);
219+
break;
220+
default:
221+
MicroPrintf("Type %s (%d) is not supported by Minimum Int8 registration.",
222+
TfLiteTypeGetName(op_context.output->type),
223+
op_context.output->type);
224+
return kTfLiteError;
225+
}
226+
return kTfLiteOk;
227+
}
228+
229+
} // namespace
230+
231+
TFLMRegistration Register_MAXIMUM() {
232+
return tflite::micro::RegisterOp(nullptr, nullptr, EvalMaximum);
233+
}
234+
235+
TFLMRegistration Register_MINIMUM() {
236+
return tflite::micro::RegisterOp(nullptr, nullptr, EvalMinimum);
237+
}
238+
239+
TFLMRegistration Register_MAXIMUM_INT8() {
240+
return tflite::micro::RegisterOp(nullptr, nullptr, EvalMaximumInt8);
241+
}
242+
243+
TFLMRegistration Register_MINIMUM_INT8() {
244+
return tflite::micro::RegisterOp(nullptr, nullptr, EvalMinimumInt8);
245+
}
246+
247+
} // namespace tflite

tensorflow/lite/micro/kernels/maximum_minimum.cc

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -23,59 +23,13 @@ limitations under the License.
2323
#include "tensorflow/lite/kernels/kernel_util.h"
2424
#include "tensorflow/lite/kernels/op_macros.h"
2525
#include "tensorflow/lite/micro/kernels/kernel_util.h"
26+
#include "tensorflow/lite/micro/kernels/maximum_minimum.h"
2627
#include "tensorflow/lite/micro/micro_log.h"
2728

2829
namespace tflite {
2930

3031
namespace {
3132

32-
// This file has a reference implementation of TFMaximum/TFMinimum.
33-
enum KernelType {
34-
kReference,
35-
};
36-
37-
constexpr int kInputTensor1 = 0;
38-
constexpr int kInputTensor2 = 1;
39-
constexpr int kOutputTensor = 0;
40-
41-
struct OpContext {
42-
OpContext(TfLiteContext* context, TfLiteNode* node) {
43-
input1 = tflite::micro::GetEvalInput(context, node, kInputTensor1);
44-
input2 = tflite::micro::GetEvalInput(context, node, kInputTensor2);
45-
output = tflite::micro::GetEvalOutput(context, node, kOutputTensor);
46-
}
47-
const TfLiteEvalTensor* input1;
48-
const TfLiteEvalTensor* input2;
49-
TfLiteEvalTensor* output;
50-
};
51-
52-
struct MaximumOp {
53-
template <typename data_type>
54-
static data_type op(data_type el1, data_type el2) {
55-
return el1 > el2 ? el1 : el2;
56-
}
57-
};
58-
59-
struct MinimumOp {
60-
template <typename data_type>
61-
static data_type op(data_type el1, data_type el2) {
62-
return el1 < el2 ? el1 : el2;
63-
}
64-
};
65-
66-
template <typename data_type, typename op_type>
67-
void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
68-
const OpContext& op_context) {
69-
reference_ops::MaximumMinimumBroadcastSlow(
70-
tflite::micro::GetTensorShape(op_context.input1),
71-
tflite::micro::GetTensorData<data_type>(op_context.input1),
72-
tflite::micro::GetTensorShape(op_context.input2),
73-
tflite::micro::GetTensorData<data_type>(op_context.input2),
74-
tflite::micro::GetTensorShape(op_context.output),
75-
tflite::micro::GetTensorData<data_type>(op_context.output),
76-
op_type::template op<data_type>);
77-
}
78-
7933
template <KernelType kernel_type, typename OpType>
8034
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
8135
OpContext op_context(context, node);

0 commit comments

Comments
 (0)