Skip to content

Commit 392b78f

Browse files
Add support for CMSIS-NN int8 transpose and padding operators (#2757)
BUG=New operator support
1 parent b15beb9 commit 392b78f

File tree

11 files changed

+526
-171
lines changed

11 files changed

+526
-171
lines changed

tensorflow/lite/micro/kernels/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ tflm_kernel_cc_library(
278278
"neg.cc",
279279
"pack.cc",
280280
"pad.cc",
281+
"pad_common.cc",
281282
"pooling.cc",
282283
"pooling_common.cc",
283284
"prelu.cc",
@@ -311,6 +312,7 @@ tflm_kernel_cc_library(
311312
"svdf_common.cc",
312313
"tanh.cc",
313314
"transpose.cc",
315+
"transpose_common.cc",
314316
"transpose_conv.cc",
315317
"unidirectional_sequence_lstm.cc",
316318
"unpack.cc",
@@ -347,6 +349,7 @@ tflm_kernel_cc_library(
347349
"strided_slice.h",
348350
"sub.h",
349351
"svdf.h",
352+
"transpose.h",
350353
"transpose_conv.h",
351354
"unidirectional_sequence_lstm.h",
352355
] + select({
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/lite/kernels/internal/reference/pad.h"
16+
17+
#include <limits>
18+
19+
#include "Include/arm_nn_types.h"
20+
#include "Include/arm_nnfunctions.h"
21+
#include "tensorflow/lite/c/common.h"
22+
#include "tensorflow/lite/kernels/kernel_util.h"
23+
#include "tensorflow/lite/micro/kernels/kernel_util.h"
24+
#include "tensorflow/lite/micro/kernels/pad.h"
25+
26+
namespace tflite {
27+
namespace {
28+
29+
TfLiteStatus PadEvalInt8(TfLiteContext* context, TfLiteNode* node) {
30+
TFLITE_DCHECK(node->user_data != nullptr);
31+
const OpData* data = static_cast<const OpData*>(node->user_data);
32+
33+
const TfLiteEvalTensor* input =
34+
tflite::micro::GetEvalInput(context, node, /*index=*/0);
35+
const TfLiteEvalTensor* constant_values =
36+
NumInputs(node) == 3
37+
? tflite::micro::GetEvalInput(context, node, /*index=*/2)
38+
: nullptr;
39+
TfLiteEvalTensor* output =
40+
tflite::micro::GetEvalOutput(context, node, /*index=*/0);
41+
42+
int8_t pad_value;
43+
if (constant_values == nullptr) {
44+
pad_value = static_cast<uint8_t>(data->output_zero_point);
45+
} else {
46+
pad_value = *tflite::micro::GetTensorData<int8_t>(constant_values);
47+
}
48+
const int8_t* input_ptr = tflite::micro::GetTensorData<int8_t>(input);
49+
int8_t* output_ptr = tflite::micro::GetTensorData<int8_t>(output);
50+
51+
const RuntimeShape d = tflite::micro::GetTensorShape(input);
52+
const cmsis_nn_dims input_size = {d.Dims(0), d.Dims(1), d.Dims(2), d.Dims(3)};
53+
54+
const PadParams p = data->params;
55+
const cmsis_nn_dims pre_pad = {p.left_padding[0], p.left_padding[1],
56+
p.left_padding[2], p.left_padding[3]};
57+
const cmsis_nn_dims post_pad = {p.right_padding[0], p.right_padding[1],
58+
p.right_padding[2], p.right_padding[3]};
59+
60+
arm_pad_s8(input_ptr, output_ptr, pad_value, &input_size, &pre_pad,
61+
&post_pad);
62+
63+
return kTfLiteOk;
64+
}
65+
66+
TfLiteStatus PadEval(TfLiteContext* context, TfLiteNode* node) {
67+
TFLITE_DCHECK(node->user_data != nullptr);
68+
const OpData* data = static_cast<const OpData*>(node->user_data);
69+
70+
const TfLiteEvalTensor* input =
71+
tflite::micro::GetEvalInput(context, node, /*index=*/0);
72+
const TfLiteEvalTensor* constant_values =
73+
NumInputs(node) == 3
74+
? tflite::micro::GetEvalInput(context, node, /*index=*/2)
75+
: nullptr;
76+
TfLiteEvalTensor* output =
77+
tflite::micro::GetEvalOutput(context, node, /*index=*/0);
78+
79+
switch (input->type) {
80+
case kTfLiteFloat32: {
81+
float pad_value =
82+
constant_values == nullptr
83+
? 0.f
84+
: *tflite::micro::GetTensorData<float>(constant_values);
85+
if (data->params.resizing_category == ResizingCategory::kImageStyle) {
86+
reference_ops::PadImageStyle(
87+
data->params, tflite::micro::GetTensorShape(input),
88+
tflite::micro::GetTensorData<float>(input), &pad_value,
89+
tflite::micro::GetTensorShape(output),
90+
tflite::micro::GetTensorData<float>(output));
91+
} else {
92+
reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
93+
tflite::micro::GetTensorData<float>(input),
94+
&pad_value, tflite::micro::GetTensorShape(output),
95+
tflite::micro::GetTensorData<float>(output));
96+
}
97+
} break;
98+
case kTfLiteInt8: {
99+
PadEvalInt8(context, node);
100+
} break;
101+
case kTfLiteInt16: {
102+
int16_t pad_value =
103+
constant_values == nullptr
104+
? 0
105+
: *tflite::micro::GetTensorData<int16_t>(constant_values);
106+
reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
107+
tflite::micro::GetTensorData<int16_t>(input),
108+
&pad_value, tflite::micro::GetTensorShape(output),
109+
tflite::micro::GetTensorData<int16_t>(output));
110+
} break;
111+
case kTfLiteInt32: {
112+
int32_t pad_value =
113+
constant_values == nullptr
114+
? 0
115+
: *tflite::micro::GetTensorData<int32_t>(constant_values);
116+
reference_ops::Pad(data->params, tflite::micro::GetTensorShape(input),
117+
tflite::micro::GetTensorData<int32_t>(input),
118+
&pad_value, tflite::micro::GetTensorShape(output),
119+
tflite::micro::GetTensorData<int32_t>(output));
120+
} break;
121+
default:
122+
123+
MicroPrintf("Type %s not currently supported by Pad.",
124+
TfLiteTypeGetName(input->type));
125+
return kTfLiteError;
126+
}
127+
return kTfLiteOk;
128+
}
129+
130+
} // namespace
131+
132+
TFLMRegistration Register_PAD() {
133+
return tflite::micro::RegisterOp(PadInit, PadPrepare, PadEval);
134+
}
135+
136+
// Also register Pad as PadV2.
137+
TFLMRegistration Register_PADV2() {
138+
return tflite::micro::RegisterOp(PadInit, PadPrepare, PadEval);
139+
}
140+
141+
TFLMRegistration Register_PAD_INT8() {
142+
return tflite::micro::RegisterOp(PadInit, PadPrepare, PadEvalInt8);
143+
}
144+
145+
} // namespace tflite
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/lite/kernels/internal/reference/transpose.h"
16+
17+
#include "Include/arm_nnfunctions.h"
18+
#include "tensorflow/lite/c/common.h"
19+
#include "tensorflow/lite/kernels/kernel_util.h"
20+
#include "tensorflow/lite/micro/kernels/kernel_util.h"
21+
#include "tensorflow/lite/micro/kernels/transpose.h"
22+
23+
namespace tflite {
24+
namespace {
25+
26+
TfLiteStatus TransposeEvalInt8(TfLiteContext* context, TfLiteNode* node) {
27+
const TfLiteEvalTensor* perm_tensor =
28+
tflite::micro::GetEvalInput(context, node, kTransposePermTensor);
29+
const int size = perm_tensor->dims->data[0];
30+
TF_LITE_ENSURE(context, size <= 4);
31+
const TfLiteEvalTensor* input =
32+
tflite::micro::GetEvalInput(context, node, kTransposeInputTensor);
33+
TfLiteEvalTensor* output =
34+
tflite::micro::GetEvalOutput(context, node, kTransposeOutputTensor);
35+
const cmsis_nn_transpose_params transpose_params = {
36+
size, reinterpret_cast<const uint32_t*>(perm_tensor->data.i32)};
37+
cmsis_nn_dims input_dims = {
38+
tflite::micro::GetTensorShape(input).DimsData()[0],
39+
tflite::micro::GetTensorShape(input).DimsData()[1],
40+
tflite::micro::GetTensorShape(input).DimsData()[2],
41+
tflite::micro::GetTensorShape(input).DimsData()[3]};
42+
cmsis_nn_dims output_dims = {
43+
tflite::micro::GetTensorShape(output).DimsData()[0],
44+
tflite::micro::GetTensorShape(output).DimsData()[1],
45+
tflite::micro::GetTensorShape(output).DimsData()[2],
46+
tflite::micro::GetTensorShape(output).DimsData()[3]};
47+
48+
TFLITE_DCHECK_EQ(
49+
arm_transpose_s8(tflite::micro::GetTensorData<int8_t>(input),
50+
tflite::micro::GetTensorData<int8_t>(output),
51+
&input_dims, &output_dims, &transpose_params),
52+
ARM_CMSIS_NN_SUCCESS);
53+
54+
return kTfLiteOk;
55+
}
56+
57+
TfLiteStatus TransposeEval(TfLiteContext* context, TfLiteNode* node) {
58+
const TfLiteEvalTensor* perm_tensor =
59+
tflite::micro::GetEvalInput(context, node, kTransposePermTensor);
60+
const int32_t* perm_data = perm_tensor->data.i32;
61+
const int size = perm_tensor->dims->data[0];
62+
TransposeParams params;
63+
params.perm_count = size;
64+
for (int i = 0; i < size; ++i) {
65+
params.perm[i] = perm_data[i];
66+
}
67+
68+
// Transpose kernel only does rearranging values not numeric evaluations
69+
// on each cell. It's safe to implement per size of scalar type and this
70+
// trick keeps the total code size in a reasonable range.
71+
const TfLiteEvalTensor* input =
72+
tflite::micro::GetEvalInput(context, node, kTransposeInputTensor);
73+
TfLiteEvalTensor* output =
74+
tflite::micro::GetEvalOutput(context, node, kTransposeOutputTensor);
75+
switch (input->type) {
76+
case kTfLiteFloat32:
77+
reference_ops::Transpose(params, tflite::micro::GetTensorShape(input),
78+
tflite::micro::GetTensorData<float>(input),
79+
tflite::micro::GetTensorShape(output),
80+
tflite::micro::GetTensorData<float>(output));
81+
break;
82+
case kTfLiteInt8: {
83+
TransposeEvalInt8(context, node);
84+
} break;
85+
case kTfLiteInt16:
86+
reference_ops::Transpose(params, tflite::micro::GetTensorShape(input),
87+
tflite::micro::GetTensorData<int16_t>(input),
88+
tflite::micro::GetTensorShape(output),
89+
tflite::micro::GetTensorData<int16_t>(output));
90+
break;
91+
default:
92+
MicroPrintf(
93+
"Type %s is currently not supported by Transpose. "
94+
"Only float32, int8 and int16 is supported",
95+
TfLiteTypeGetName(input->type));
96+
return kTfLiteError;
97+
}
98+
99+
return kTfLiteOk;
100+
}
101+
102+
} // namespace
103+
104+
TFLMRegistration Register_TRANSPOSE() {
105+
return tflite::micro::RegisterOp(nullptr, TransposePrepare, TransposeEval);
106+
}
107+
TFLMRegistration Register_TRANSPOSE_INT8() {
108+
return tflite::micro::RegisterOp(nullptr, TransposePrepare,
109+
TransposeEvalInt8);
110+
}
111+
112+
} // namespace tflite

0 commit comments

Comments
 (0)