Skip to content

Commit b78bb7e

Browse files
authored
Add int16 support to RELU (#2727)
This PR adds int16 support to RELU bug=#2726
1 parent 392b78f commit b78bb7e

File tree

4 files changed

+103
-29
lines changed

4 files changed

+103
-29
lines changed

tensorflow/lite/micro/kernels/activations.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -54,14 +54,23 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
5454
return kTfLiteOk;
5555
}
5656
case kTfLiteInt8: {
57-
tflite::ReluQuantized(data, tflite::micro::GetTensorShape(input),
58-
tflite::micro::GetTensorShape(output),
59-
tflite::micro::GetTensorData<int8_t>(input),
60-
tflite::micro::GetTensorData<int8_t>(output));
57+
tflite::ReluQuantized<int8_t>(
58+
data, tflite::micro::GetTensorShape(input),
59+
tflite::micro::GetTensorShape(output),
60+
tflite::micro::GetTensorData<int8_t>(input),
61+
tflite::micro::GetTensorData<int8_t>(output));
62+
return kTfLiteOk;
63+
}
64+
case kTfLiteInt16: {
65+
tflite::ReluQuantized<int16_t>(
66+
data, tflite::micro::GetTensorShape(input),
67+
tflite::micro::GetTensorShape(output),
68+
tflite::micro::GetTensorData<int16_t>(input),
69+
tflite::micro::GetTensorData<int16_t>(output));
6170
return kTfLiteOk;
6271
}
6372
default: {
64-
MicroPrintf("Only float32 is supported currently, got %s",
73+
MicroPrintf("Only float32/int8/int16 is supported currently, got %s",
6574
TfLiteTypeGetName(input->type));
6675
return kTfLiteError;
6776
}
@@ -109,7 +118,7 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
109118
return kTfLiteOk;
110119
}
111120
default: {
112-
MicroPrintf("Only float32 is supported currently, got %s",
121+
MicroPrintf("Only float32/int8/int16 is supported currently, got %s",
113122
TfLiteTypeGetName(input->type));
114123
return kTfLiteError;
115124
}

tensorflow/lite/micro/kernels/activations.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
#include "tensorflow/lite/c/builtin_op_data.h"
2222
#include "tensorflow/lite/c/common.h"
23+
#include "tensorflow/lite/kernels/internal/common.h"
2324
#include "tensorflow/lite/kernels/internal/types.h"
2425

2526
namespace tflite {
@@ -36,9 +37,23 @@ struct Relu6OpData {
3637
int32_t zero;
3738
};
3839

40+
template <typename T>
3941
void ReluQuantized(const ReluOpData& data, const RuntimeShape& input_shape,
40-
const RuntimeShape& output_shape, const int8_t* input_data,
41-
int8_t* output_data);
42+
const RuntimeShape& output_shape, const T* input_data,
43+
T* output_data) {
44+
const int flat_size = MatchingFlatSize(input_shape, output_shape);
45+
for (int i = 0; i < flat_size; ++i) {
46+
const int32_t val = static_cast<int32_t>(input_data[i]);
47+
int32_t clamped =
48+
data.params.output_offset +
49+
MultiplyByQuantizedMultiplier(val - data.params.input_offset,
50+
data.params.output_multiplier,
51+
data.params.output_shift);
52+
clamped = std::max(data.params.quantized_activation_min, clamped);
53+
clamped = std::min(data.params.quantized_activation_max, clamped);
54+
output_data[i] = static_cast<T>(clamped);
55+
}
56+
}
4257

4358
template <typename T>
4459
void CalculateReluOpData(const TfLiteTensor* input, TfLiteTensor* output,

tensorflow/lite/micro/kernels/activations_common.cc

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -33,23 +33,6 @@ namespace tflite {
3333
const int kActivationsInputTensor = 0;
3434
const int kActivationsOutputTensor = 0;
3535

36-
void ReluQuantized(const ReluOpData& data, const RuntimeShape& input_shape,
37-
const RuntimeShape& output_shape, const int8_t* input_data,
38-
int8_t* output_data) {
39-
const int flat_size = MatchingFlatSize(input_shape, output_shape);
40-
for (int i = 0; i < flat_size; ++i) {
41-
const int32_t val = static_cast<int32_t>(input_data[i]);
42-
int32_t clamped =
43-
data.params.output_offset +
44-
MultiplyByQuantizedMultiplier(val - data.params.input_offset,
45-
data.params.output_multiplier,
46-
data.params.output_shift);
47-
clamped = std::max(data.params.quantized_activation_min, clamped);
48-
clamped = std::min(data.params.quantized_activation_max, clamped);
49-
output_data[i] = static_cast<int8_t>(clamped);
50-
}
51-
}
52-
5336
template <typename T>
5437
void CalculateReluOpData(const TfLiteTensor* input, TfLiteTensor* output,
5538
ReluOpData* data) {
@@ -116,6 +99,10 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
11699

117100
if (input->type == kTfLiteInt8) {
118101
CalculateReluOpData<int8_t>(input, output, data);
102+
} else if (input->type == kTfLiteInt16) {
103+
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
104+
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
105+
CalculateReluOpData<int16_t>(input, output, data);
119106
}
120107

121108
micro_context->DeallocateTempTfLiteTensor(input);

tensorflow/lite/micro/kernels/activations_test.cc

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -129,6 +129,46 @@ void TestReluInt8(int* input_dims_data, const float* input_data,
129129
}
130130
}
131131

132+
void TestReluInt16(int* input_dims_data, const float* input_data,
133+
int16_t* input_data_quantized, const float input_scale,
134+
const int input_zero_point, const float* golden,
135+
int16_t* golden_quantized, int* output_dims_data,
136+
const float output_scale, const int output_zero_point,
137+
int16_t* output_data) {
138+
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
139+
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
140+
const int output_elements_count = ElementCount(*output_dims);
141+
constexpr int inputs_size = 1;
142+
constexpr int outputs_size = 1;
143+
constexpr int tensors_size = inputs_size + outputs_size;
144+
TfLiteTensor tensors[tensors_size] = {
145+
CreateQuantizedTensor(input_data, input_data_quantized, input_dims,
146+
input_scale, input_zero_point),
147+
CreateQuantizedTensor(output_data, output_dims, output_scale,
148+
output_zero_point),
149+
};
150+
151+
int inputs_array_data[] = {1, 0};
152+
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
153+
int outputs_array_data[] = {1, 1};
154+
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
155+
156+
const TFLMRegistration registration = Register_RELU();
157+
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
158+
outputs_array,
159+
/*builtin_data=*/nullptr);
160+
161+
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
162+
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
163+
164+
Quantize(golden, golden_quantized, output_elements_count, output_scale,
165+
output_zero_point);
166+
167+
for (int i = 0; i < output_elements_count; ++i) {
168+
TF_LITE_MICRO_EXPECT_EQ(golden_quantized[i], output_data[i]);
169+
}
170+
}
171+
132172
void TestRelu6Int8(int* input_dims_data, const float* input_data,
133173
int8_t* input_data_quantized, const float input_scale,
134174
const int input_zero_point, const float* golden,
@@ -265,6 +305,29 @@ TF_LITE_MICRO_TEST(SimpleReluTestInt8) {
265305
output_zero_point, output_data);
266306
}
267307

308+
TF_LITE_MICRO_TEST(SimpleReluTestInt16) {
309+
const int elements_count = 10;
310+
311+
int input_shape[] = {2, 2, 5};
312+
const float input_data[] = {256, 257, 258, 259, 260,
313+
-256, -257, -258, -259, -260};
314+
int16_t input_quantized[elements_count];
315+
int output_shape[] = {2, 2, 5};
316+
const float golden[] = {256, 257, 258, 259, 260, 0, 0, 0, 0, 0};
317+
int16_t golden_quantized[elements_count];
318+
int16_t output_data[elements_count];
319+
320+
const float input_scale = 0.5f;
321+
const int input_zero_point = 0;
322+
const float output_scale = 0.5f;
323+
const int output_zero_point = 0;
324+
325+
tflite::testing::TestReluInt16(input_shape, input_data, input_quantized,
326+
input_scale, input_zero_point, golden,
327+
golden_quantized, output_shape, output_scale,
328+
output_zero_point, output_data);
329+
}
330+
268331
TF_LITE_MICRO_TEST(SimpleRelu6TestInt8) {
269332
const int elements_count = 10;
270333

0 commit comments

Comments
 (0)