Skip to content

Commit aa6e625

Browse files
authored
PRelu Int16x8 Ref C support (#3191)
@tensorflow/micro Add PReLu int16x8 support in reference C bug=fixes #3168
1 parent 93ebaa9 commit aa6e625

File tree

3 files changed

+75
-28
lines changed

3 files changed

+75
-28
lines changed

tensorflow/lite/micro/kernels/prelu.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,18 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
6161
tflite::micro::GetTensorData<int8_t>(output));
6262
return kTfLiteOk;
6363
} break;
64+
case kTfLiteInt16: {
65+
reference_ops::BroadcastPrelu4DSlow(
66+
params, tflite::micro::GetTensorShape(input),
67+
tflite::micro::GetTensorData<int16_t>(input),
68+
tflite::micro::GetTensorShape(alpha),
69+
tflite::micro::GetTensorData<int8_t>(alpha),
70+
tflite::micro::GetTensorShape(output),
71+
tflite::micro::GetTensorData<int16_t>(output));
72+
return kTfLiteOk;
73+
} break;
6474
default:
65-
MicroPrintf("Only float32 and uint8_t are supported currently, got %d.",
75+
MicroPrintf("Input type '%s' is not supported.",
6676
TfLiteTypeGetName(input->type));
6777
return kTfLiteError;
6878
}

tensorflow/lite/micro/kernels/prelu_common.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
9696
TF_LITE_ENSURE_OK(context,
9797
CalculatePreluParams(input, alpha, output, params));
9898

99+
if (output->type == kTfLiteInt16) {
100+
// Make sure alpha type is Int8 when Output is Int16
101+
TF_LITE_ENSURE(context, alpha->type == kTfLiteInt8);
102+
}
103+
99104
micro_context->DeallocateTempTfLiteTensor(input);
100105
micro_context->DeallocateTempTfLiteTensor(alpha);
101106
micro_context->DeallocateTempTfLiteTensor(output);

tensorflow/lite/micro/kernels/prelu_test.cc

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,20 @@ namespace tflite {
2323
namespace testing {
2424
namespace {
2525

26-
template <typename T>
27-
void ValidatePreluGoldens(TfLiteTensor* tensors, int tensors_size,
28-
const T* golden, const int output_length,
29-
T* output_data) {
26+
const float kQuantizedTolerance = 2 * (1. / 256);
27+
28+
void ExecutePReluTest(const int tensors_count, TfLiteTensor* tensors) {
3029
int inputs_array_data[] = {2, 0, 1};
3130
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
3231
int outputs_array_data[] = {1, 2};
3332
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
3433

3534
const TFLMRegistration registration = tflite::Register_PRELU();
36-
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
37-
outputs_array,
38-
/*builtin_data=*/nullptr);
35+
micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array,
36+
outputs_array, /*builtin_data=*/nullptr);
3937

4038
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
4139
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
42-
43-
for (int i = 0; i < output_length; ++i) {
44-
TF_LITE_MICRO_EXPECT_NEAR(golden[i], output_data[i], 1e-5f);
45-
}
4640
}
4741

4842
void TestPreluFloat(int* input_dims_data, const float* input_data,
@@ -62,19 +56,22 @@ void TestPreluFloat(int* input_dims_data, const float* input_data,
6256
CreateTensor(output_data, output_dims),
6357
};
6458

65-
ValidatePreluGoldens(tensors, tensors_size, expected_output_data,
66-
output_dims_count, output_data);
59+
ExecutePReluTest(tensors_size, tensors);
60+
61+
for (int i = 0; i < output_dims_count; i++) {
62+
TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
63+
}
6764
}
6865

69-
template <typename T>
66+
template <typename T, typename Slope>
7067
void TestPreluQuantized(int* input_dims_data, const float* input_data,
7168
T* input_quantized, const float input_scale,
7269
const int input_zero_point, int* alpha_dims_data,
73-
const float* alpha_data, T* alpha_quantized,
70+
const float* alpha_data, Slope* alpha_quantized,
7471
const float alpha_scale, const int alpha_zero_point,
75-
const float* golden, T* golden_quantized,
76-
const float output_scale, const int output_zero_point,
77-
int* output_dims_data, T* output_data) {
72+
const float* golden, const float output_scale,
73+
const int output_zero_point, int* output_dims_data,
74+
T* output_quantized, float* output_data) {
7875
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
7976
TfLiteIntArray* alpha_dims = IntArrayFromInts(alpha_dims_data);
8077
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
@@ -87,15 +84,18 @@ void TestPreluQuantized(int* input_dims_data, const float* input_data,
8784
input_scale, input_zero_point),
8885
CreateQuantizedTensor(alpha_data, alpha_quantized, alpha_dims,
8986
alpha_scale, alpha_zero_point),
90-
CreateQuantizedTensor(output_data, output_dims, output_scale,
87+
CreateQuantizedTensor(output_quantized, output_dims, output_scale,
9188
output_zero_point),
9289
};
9390

94-
Quantize(golden, golden_quantized, output_dims_count, output_scale,
95-
output_zero_point);
91+
ExecutePReluTest(tensors_size, tensors);
92+
93+
Dequantize(output_quantized, output_dims_count, output_scale,
94+
output_zero_point, output_data);
9695

97-
ValidatePreluGoldens(tensors, tensors_size, golden_quantized,
98-
output_dims_count, output_data);
96+
for (int i = 0; i < output_dims_count; i++) {
97+
TF_LITE_MICRO_EXPECT_NEAR(golden[i], output_data[i], kQuantizedTolerance);
98+
}
9999
}
100100
} // namespace
101101
} // namespace testing
@@ -147,13 +147,45 @@ TF_LITE_MICRO_TEST(QuantizedInt8PreluActivationsOpTest) {
147147
const int dims_count = 12;
148148
int8_t input_quantized[dims_count];
149149
int8_t alpha_quantized[3];
150-
int8_t golden_quantized[dims_count];
151150
float scale = 2.0 / 255.0;
152151
int zero_point = 0;
153-
int8_t output_data[dims_count];
154-
tflite::testing::TestPreluQuantized(
152+
int8_t output_data_q[dims_count];
153+
float output_data_f[dims_count];
154+
tflite::testing::TestPreluQuantized<int8_t, int8_t>(
155155
input_shape, input_values, input_quantized, scale, zero_point,
156156
alpha_shape, alpha_values, alpha_quantized, scale, zero_point, golden,
157-
golden_quantized, scale, zero_point, output_shape, output_data);
157+
scale, zero_point, output_shape, output_data_q, output_data_f);
158+
}
159+
160+
TF_LITE_MICRO_TEST(QuantizedInt16PreluActivationsOpTest) {
161+
int input_shape[] = {3, 2, 2, 3};
162+
const float input_values[] = {
163+
0.0f, 0.0f, 0.0f, // Row 1, Column 1
164+
0.5f, 0.5f, 0.5f, // Row 1, Column 2
165+
-1.0f, -1.0f, -1.0f, // Row 2, Column 1
166+
-0.25f, -0.25f, -0.25f, // Row 1, Column 2
167+
};
168+
int alpha_shape[] = {3, 1, 1, 3};
169+
const float alpha_values[] = {0.0f, 0.5f, -0.5f};
170+
int output_shape[] = {3, 2, 2, 3};
171+
const float golden[] = {
172+
0.0f, 0.0f, 0.0f, // Row 1, Column 1
173+
0.5f, 0.5f, 0.5f, // Row 1, Column 2
174+
0.0f, -0.5f, 0.5f, // Row 2, Column 1
175+
0.0f, -0.125f, 0.125f, // Row 1, Column 2
176+
};
177+
const int dims_count = 12;
178+
int16_t input_quantized[dims_count];
179+
int8_t alpha_quantized[3];
180+
float scale_input_output = 2.0 / 65535.0;
181+
float scale_alpha = 2.0 / 255.0;
182+
int zero_point = 0;
183+
int16_t output_data_q[dims_count];
184+
float output_data_f[dims_count];
185+
tflite::testing::TestPreluQuantized<int16_t, int8_t>(
186+
input_shape, input_values, input_quantized, scale_input_output,
187+
zero_point, alpha_shape, alpha_values, alpha_quantized, scale_alpha,
188+
zero_point, golden, scale_input_output, zero_point, output_shape,
189+
output_data_q, output_data_f);
158190
}
159191
TF_LITE_MICRO_TESTS_END

0 commit comments

Comments
 (0)