Skip to content

Commit 7a7a3de

Browse files
authored
PACK/UNPACK update (#3175)
@tensorflow/micro Update PACK with INT16 support. Add INT16 unit tests. This is a copy of PR #2737 bug=fixes #2736
1 parent 55a3cbe commit 7a7a3de

File tree

4 files changed

+67
-57
lines changed

4 files changed

+67
-57
lines changed

tensorflow/lite/micro/kernels/pack.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -85,6 +85,10 @@ TfLiteStatus PackEval(TfLiteContext* context, TfLiteNode* node) {
8585
return PackImpl<int8_t>(context, node, output, data->values_count,
8686
data->axis);
8787
}
88+
case kTfLiteInt16: {
89+
return PackImpl<int16_t>(context, node, output, data->values_count,
90+
data->axis);
91+
}
8892
case kTfLiteInt32: {
8993
return PackImpl<int32_t>(context, node, output, data->values_count,
9094
data->axis);

tensorflow/lite/micro/kernels/pack_test.cc

Lines changed: 26 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -110,44 +110,11 @@ void TestPackThreeInputsFloat(int* input1_dims_data, const float* input1_data,
110110
1e-5f, output_data);
111111
}
112112

113-
void TestPackTwoInputsQuantized(
114-
int* input1_dims_data, const int8_t* input1_data, int* input2_dims_data,
115-
const int8_t* input2_data, int axis, int* output_dims_data,
116-
const int8_t* expected_output_data, int8_t* output_data) {
117-
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
118-
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
119-
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
120-
const int output_dims_count = ElementCount(*output_dims);
121-
122-
constexpr int input_size = 2;
123-
constexpr int output_size = 1;
124-
constexpr int tensors_size = input_size + output_size;
125-
TfLiteTensor tensors[tensors_size] = {
126-
// CreateQuantizedTensor needs scale/zero_point values as input, but these
127-
// values don't matter as to the functionality of PACK, so just set as 1.0
128-
// and 128.
129-
CreateQuantizedTensor(input1_data, input1_dims, 1.0, 128),
130-
CreateQuantizedTensor(input2_data, input2_dims, 1.0, 128),
131-
CreateQuantizedTensor(output_data, output_dims, 1.0, 128)};
132-
133-
TfLitePackParams builtin_data = {
134-
.values_count = 2,
135-
.axis = axis,
136-
};
137-
int inputs_array_data[] = {2, 0, 1};
138-
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
139-
int outputs_array_data[] = {1, 2};
140-
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
141-
142-
ValidatePackGoldens(tensors, tensors_size, builtin_data, inputs_array,
143-
outputs_array, expected_output_data, output_dims_count,
144-
1e-5f, output_data);
145-
}
146-
147-
void TestPackTwoInputsQuantized32(
148-
int* input1_dims_data, const int32_t* input1_data, int* input2_dims_data,
149-
const int32_t* input2_data, int axis, int* output_dims_data,
150-
const int32_t* expected_output_data, int32_t* output_data) {
113+
template <typename T>
114+
void TestPackTwoInputs(int* input1_dims_data, const T* input1_data,
115+
int* input2_dims_data, const T* input2_data, int axis,
116+
int* output_dims_data, const T* expected_output_data,
117+
T* output_data) {
151118
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
152119
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
153120
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
@@ -227,7 +194,7 @@ TF_LITE_MICRO_TEST(PackFloatThreeInputsNegativeAxis) {
227194
input3_values, axis, output_shape, golden, output_data);
228195
}
229196

230-
TF_LITE_MICRO_TEST(PackFloatMultilDimensions) {
197+
TF_LITE_MICRO_TEST(PackFloatMultiDimensions) {
231198
int input_shape[] = {2, 2, 3};
232199
int output_shape[] = {3, 2, 2, 3};
233200
const float input1_values[] = {1, 2, 3, 4, 5, 6};
@@ -242,7 +209,7 @@ TF_LITE_MICRO_TEST(PackFloatMultilDimensions) {
242209
output_shape, golden, output_data);
243210
}
244211

245-
TF_LITE_MICRO_TEST(PackQuantizedMultilDimensions) {
212+
TF_LITE_MICRO_TEST(PackInt8MultiDimensions) {
246213
int input_shape[] = {2, 2, 3};
247214
int output_shape[] = {3, 2, 2, 3};
248215
const int8_t input1_values[] = {1, 2, 3, 4, 5, 6};
@@ -252,12 +219,27 @@ TF_LITE_MICRO_TEST(PackQuantizedMultilDimensions) {
252219
constexpr int output_dims_count = 12;
253220
int8_t output_data[output_dims_count];
254221

255-
tflite::testing::TestPackTwoInputsQuantized(
222+
tflite::testing::TestPackTwoInputs<int8_t>(input_shape, input1_values,
223+
input_shape, input2_values, axis,
224+
output_shape, golden, output_data);
225+
}
226+
227+
TF_LITE_MICRO_TEST(PackInt16MultiDimensions) {
228+
int input_shape[] = {2, 2, 3};
229+
int output_shape[] = {3, 2, 2, 3};
230+
const int16_t input1_values[] = {1, 2, 3, 4, 5, 6};
231+
const int16_t input2_values[] = {7, 8, 9, 10, 11, 12};
232+
const int16_t golden[] = {1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12};
233+
const int axis = 1;
234+
constexpr int output_dims_count = 12;
235+
int16_t output_data[output_dims_count];
236+
237+
tflite::testing::TestPackTwoInputs<int16_t>(
256238
input_shape, input1_values, input_shape, input2_values, axis,
257239
output_shape, golden, output_data);
258240
}
259241

260-
TF_LITE_MICRO_TEST(PackQuantized32MultilDimensions) {
242+
TF_LITE_MICRO_TEST(PackInt32MultiDimensions) {
261243
int input_shape[] = {2, 2, 3};
262244
int output_shape[] = {3, 2, 2, 3};
263245
const int32_t input1_values[] = {1, 2, 3, 4, 5, 6};
@@ -267,7 +249,7 @@ TF_LITE_MICRO_TEST(PackQuantized32MultilDimensions) {
267249
constexpr int output_dims_count = 12;
268250
int32_t output_data[output_dims_count];
269251

270-
tflite::testing::TestPackTwoInputsQuantized32(
252+
tflite::testing::TestPackTwoInputs<int32_t>(
271253
input_shape, input1_values, input_shape, input2_values, axis,
272254
output_shape, golden, output_data);
273255
}

tensorflow/lite/micro/kernels/unpack.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ TfLiteStatus UnpackEval(TfLiteContext* context, TfLiteNode* node) {
8686
case kTfLiteInt32: {
8787
return UnpackImpl<int32_t>(context, node, input, data->num, data->axis);
8888
}
89-
case kTfLiteInt8: {
90-
return UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
91-
}
9289
case kTfLiteInt16: {
9390
return UnpackImpl<int16_t>(context, node, input, data->num, data->axis);
9491
}
92+
case kTfLiteInt8: {
93+
return UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
94+
}
9595
default: {
9696
MicroPrintf("Type '%s' is not supported by unpack.",
9797
TfLiteTypeGetName(input->type));

tensorflow/lite/micro/kernels/unpack_test.cc

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -132,12 +132,15 @@ void TestUnpackOneOutputFloat(int* input_dims_data, const float* input_data,
132132
}
133133
}
134134

135-
void TestUnpackThreeOutputsQuantized32(
136-
int* input_dims_data, const int32_t* input_data, int axis,
137-
int* output1_dims_data, const int32_t* expected_output1_data,
138-
int* output2_dims_data, const int32_t* expected_output2_data,
139-
int* output3_dims_data, const int32_t* expected_output3_data,
140-
int32_t* output1_data, int32_t* output2_data, int32_t* output3_data) {
135+
template <typename T>
136+
void TestUnpackThreeOutputs(int* input_dims_data, const T* input_data, int axis,
137+
int* output1_dims_data,
138+
const T* expected_output1_data,
139+
int* output2_dims_data,
140+
const T* expected_output2_data,
141+
int* output3_dims_data,
142+
const T* expected_output3_data, T* output1_data,
143+
T* output2_data, T* output3_data) {
141144
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
142145
TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
143146
TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
@@ -257,7 +260,28 @@ TF_LITE_MICRO_TEST(UnpackFloatOneOutput) {
257260
output_shape, golden, output_data);
258261
}
259262

260-
TF_LITE_MICRO_TEST(UnpackQuantized32ThreeOutputs) {
263+
TF_LITE_MICRO_TEST(UnpackInt16ThreeOutputs) {
264+
int input_shape[] = {2, 3, 2};
265+
const int16_t input_values[] = {1, 2, 3, 4, 5, 6};
266+
int output1_shape[] = {1, 2};
267+
const int16_t output1_golden[] = {1, 2};
268+
int output2_shape[] = {1, 2};
269+
const int16_t output2_golden[] = {3, 4};
270+
int output3_shape[] = {1, 2};
271+
const int16_t output3_golden[] = {5, 6};
272+
constexpr int output1_dims_count = 2;
273+
constexpr int output2_dims_count = 2;
274+
constexpr int output3_dims_count = 2;
275+
int16_t output1_data[output1_dims_count];
276+
int16_t output2_data[output2_dims_count];
277+
int16_t output3_data[output3_dims_count];
278+
tflite::testing::TestUnpackThreeOutputs<int16_t>(
279+
input_shape, input_values, 0, output1_shape, output1_golden,
280+
output2_shape, output2_golden, output3_shape, output3_golden,
281+
output1_data, output2_data, output3_data);
282+
}
283+
284+
TF_LITE_MICRO_TEST(UnpackInt32ThreeOutputs) {
261285
int input_shape[] = {2, 3, 2};
262286
const int32_t input_values[] = {1, 2, 3, 4, 5, 6};
263287
int output1_shape[] = {1, 2};
@@ -272,7 +296,7 @@ TF_LITE_MICRO_TEST(UnpackQuantized32ThreeOutputs) {
272296
int32_t output1_data[output1_dims_count];
273297
int32_t output2_data[output2_dims_count];
274298
int32_t output3_data[output3_dims_count];
275-
tflite::testing::TestUnpackThreeOutputsQuantized32(
299+
tflite::testing::TestUnpackThreeOutputs<int32_t>(
276300
input_shape, input_values, 0, output1_shape, output1_golden,
277301
output2_shape, output2_golden, output3_shape, output3_golden,
278302
output1_data, output2_data, output3_data);

0 commit comments

Comments
 (0)