Skip to content

Commit 59589c6

Browse files
committed
Support for DECODE operator
@tensorflow/micro Add support for alternate decompression memory to DECODE operator. Additional unit tests. Update generic benchmark application and Makefile. bug=fixes #3212
1 parent 7ba714a commit 59589c6

35 files changed

+4535
-85
lines changed

python/tflite_micro/python_ops_resolver.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ PythonOpsResolver::PythonOpsResolver() {
4040
AddConv2D();
4141
AddCos();
4242
AddCumSum();
43+
AddDecode();
4344
AddDelay();
4445
AddDepthToSpace();
4546
AddDepthwiseConv2D();

tensorflow/lite/micro/kernels/BUILD

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ tflm_kernel_cc_library(
236236
"conv.cc",
237237
"conv_common.cc",
238238
"cumsum.cc",
239+
"decode.cc",
240+
"decode_state.cc",
241+
"decode_state_huffman.cc",
242+
"decode_state_lut.cc",
243+
"decode_state_prune.cc",
239244
"depth_to_space.cc",
240245
"depthwise_conv.cc",
241246
"depthwise_conv_common.cc",
@@ -327,6 +332,10 @@ tflm_kernel_cc_library(
327332
"batch_matmul.h",
328333
"circular_buffer.h",
329334
"conv.h",
335+
"decode_state.h",
336+
"decode_state_huffman.h",
337+
"decode_state_lut.h",
338+
"decode_state_prune.h",
330339
"depthwise_conv.h",
331340
"dequantize.h",
332341
"ethosu.h",
@@ -643,6 +652,21 @@ tflm_cc_test(
643652
],
644653
)
645654

655+
tflm_cc_test(
656+
name = "decode_test",
657+
srcs = [
658+
"decode_test.cc",
659+
],
660+
deps = [
661+
":kernel_runner",
662+
"//tensorflow/lite/c:common",
663+
"//tensorflow/lite/micro:debug_log",
664+
"//tensorflow/lite/micro:op_resolvers",
665+
"//tensorflow/lite/micro:test_helpers",
666+
"//tensorflow/lite/micro/testing:micro_test",
667+
],
668+
)
669+
646670
tflm_cc_test(
647671
name = "decompress_test",
648672
srcs = [

tensorflow/lite/micro/kernels/Makefile.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/ceil_test.cc \
123123
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/comparisons_test.cc \
124124
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation_test.cc \
125125
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum_test.cc \
126+
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_test.cc \
126127
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space_test.cc \
127128
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_test.cc \
128129
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_test.cc \
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/lite/c/common.h"
17+
#include "tensorflow/lite/kernels/internal/compatibility.h"
18+
#include "tensorflow/lite/kernels/kernel_util.h"
19+
#include "tensorflow/lite/micro/kernels/decode_state.h"
20+
#include "tensorflow/lite/micro/kernels/kernel_util.h"
21+
#include "tensorflow/lite/micro/micro_arena_constants.h"
22+
#include "tensorflow/lite/micro/micro_context.h"
23+
#include "tensorflow/lite/micro/micro_log.h"
24+
25+
namespace tflite {
26+
namespace {
27+
28+
TfLiteStatus SetOutputTensorData(TfLiteContext* context, const TfLiteNode* node,
29+
size_t tensor_output_index,
30+
TfLiteTensor* output) {
31+
// If alternate decompression memory is available, set the tensor data
32+
// pointer now to preclude allocation by the memory planner.
33+
void* alternate_decompress_mem =
34+
GetMicroContext(context)->AllocateDecompressionMemory(
35+
output->bytes, MicroArenaBufferAlignment());
36+
if (alternate_decompress_mem != nullptr) {
37+
TfLiteEvalTensor* output_eval =
38+
tflite::micro::GetEvalOutput(context, node, tensor_output_index);
39+
TF_LITE_ENSURE(context, output_eval != nullptr);
40+
output_eval->data.data = alternate_decompress_mem;
41+
output->data.data = alternate_decompress_mem;
42+
}
43+
44+
return kTfLiteOk;
45+
}
46+
47+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
48+
const size_t num_inputs = NumInputs(node);
49+
const size_t num_outputs = NumOutputs(node);
50+
TF_LITE_ENSURE(context, num_outputs > 0);
51+
TF_LITE_ENSURE_EQ(context, num_inputs, num_outputs * 2);
52+
53+
MicroContext* const micro_context = GetMicroContext(context);
54+
55+
node->user_data = micro_context->AllocatePersistentBuffer(
56+
num_outputs * sizeof(DecodeState*));
57+
TF_LITE_ENSURE(context, node->user_data != nullptr);
58+
DecodeState** const dsp_arr =
59+
reinterpret_cast<DecodeState**>(node->user_data);
60+
61+
TfLiteTensor* input = nullptr;
62+
TfLiteTensor* ancillary = nullptr;
63+
TfLiteTensor* output = nullptr;
64+
TfLiteStatus status = kTfLiteOk;
65+
66+
micro_context->ResetDecompressionMemoryAllocations();
67+
68+
for (size_t i = 0; i < num_inputs; i += 2) {
69+
input = micro_context->AllocateTempInputTensor(node, i);
70+
if (input == nullptr) {
71+
MicroPrintf("failed to allocate input tensor %u", i);
72+
status = kTfLiteError;
73+
break;
74+
}
75+
ancillary = micro_context->AllocateTempInputTensor(node, i + 1);
76+
if (ancillary == nullptr) {
77+
MicroPrintf("failed to allocate ancillary tensor %u", i + 1);
78+
status = kTfLiteError;
79+
break;
80+
}
81+
output = micro_context->AllocateTempOutputTensor(node, i / 2);
82+
if (output == nullptr) {
83+
MicroPrintf("failed to allocate output tensor %u", i / 2);
84+
status = kTfLiteError;
85+
break;
86+
}
87+
88+
TF_LITE_ENSURE(context, IsConstantTensor(input));
89+
TF_LITE_ENSURE(context, IsConstantTensor(ancillary));
90+
91+
if (DecodeState::Version(*ancillary) != 1) {
92+
MicroPrintf("version %u != 1", DecodeState::Version(*ancillary));
93+
status = kTfLiteError;
94+
break;
95+
}
96+
97+
DecodeState* dsp = nullptr;
98+
switch (DecodeState::Type(*ancillary)) {
99+
case DecodeState::kDcmTypeLUT:
100+
dsp = DecodeState::CreateDecodeStateLUT(
101+
context, micro_context->GetAlternateProfiler());
102+
break;
103+
case DecodeState::kDcmTypePrune:
104+
dsp = DecodeState::CreateDecodeStatePrune(
105+
context, micro_context->GetAlternateProfiler());
106+
break;
107+
case DecodeState::kDcmTypeHuffman:
108+
dsp = DecodeState::CreateDecodeStateHuffman(
109+
context, micro_context->GetAlternateProfiler());
110+
break;
111+
case DecodeState::kDcmTypeCustom:
112+
MicroPrintf("Custom decode type not yet supported");
113+
break;
114+
default:
115+
MicroPrintf("unsupported decode type %u",
116+
DecodeState::Type(*ancillary));
117+
break;
118+
}
119+
120+
status = SetOutputTensorData(context, node, i / 2, output);
121+
if (status != kTfLiteOk) {
122+
break;
123+
}
124+
125+
if (dsp != nullptr) {
126+
status = dsp->Setup(*input, *ancillary, *output);
127+
if (status != kTfLiteOk) {
128+
break;
129+
}
130+
dsp_arr[i / 2] = dsp;
131+
} else {
132+
MicroPrintf("failed to allocate DecodeState[%u]", i / 2);
133+
status = kTfLiteError;
134+
break;
135+
}
136+
137+
micro_context->DeallocateTempTfLiteTensor(input);
138+
micro_context->DeallocateTempTfLiteTensor(ancillary);
139+
micro_context->DeallocateTempTfLiteTensor(output);
140+
input = nullptr;
141+
ancillary = nullptr;
142+
output = nullptr;
143+
}
144+
145+
if (input != nullptr) {
146+
micro_context->DeallocateTempTfLiteTensor(input);
147+
}
148+
if (ancillary != nullptr) {
149+
micro_context->DeallocateTempTfLiteTensor(ancillary);
150+
}
151+
if (output != nullptr) {
152+
micro_context->DeallocateTempTfLiteTensor(output);
153+
}
154+
155+
return status;
156+
}
157+
158+
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
159+
const size_t num_inputs = NumInputs(node);
160+
DecodeState** const dsp_arr =
161+
reinterpret_cast<DecodeState**>(node->user_data);
162+
163+
for (size_t i = 0; i < num_inputs; i += 2) {
164+
const TfLiteEvalTensor* input =
165+
tflite::micro::GetEvalInput(context, node, i);
166+
TF_LITE_ENSURE(context, input != nullptr);
167+
const TfLiteEvalTensor* ancillary =
168+
tflite::micro::GetEvalInput(context, node, i + 1);
169+
TF_LITE_ENSURE(context, ancillary != nullptr);
170+
const TfLiteEvalTensor* output =
171+
tflite::micro::GetEvalOutput(context, node, i / 2);
172+
TF_LITE_ENSURE(context, output != nullptr);
173+
174+
TfLiteStatus status = dsp_arr[i / 2]->Decode(*input, *ancillary, *output);
175+
TF_LITE_ENSURE(context, status == kTfLiteOk);
176+
}
177+
178+
return kTfLiteOk;
179+
}
180+
181+
} // namespace
182+
183+
TFLMRegistration Register_DECODE() {
184+
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
185+
}
186+
187+
} // namespace tflite
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/lite/micro/kernels/decode_state.h"
17+
18+
#include "tensorflow/lite/micro/kernels/decode_state_huffman.h"
19+
#include "tensorflow/lite/micro/kernels/decode_state_lut.h"
20+
#include "tensorflow/lite/micro/kernels/decode_state_prune.h"
21+
#include "tensorflow/lite/micro/micro_context.h"
22+
23+
namespace tflite {
24+
25+
DecodeState* DecodeState::CreateDecodeStateLUT(
26+
const TfLiteContext* context, MicroProfilerInterface* profiler) {
27+
MicroContext* const micro_context = GetMicroContext(context);
28+
void* buffer =
29+
micro_context->AllocatePersistentBuffer(sizeof(DecodeStateLUT));
30+
if (buffer == nullptr) {
31+
return nullptr;
32+
}
33+
DecodeState* dsp = new (buffer) DecodeStateLUT(context, profiler);
34+
35+
return dsp;
36+
}
37+
38+
DecodeState* DecodeState::CreateDecodeStatePrune(
39+
const TfLiteContext* context, MicroProfilerInterface* profiler) {
40+
MicroContext* const micro_context = GetMicroContext(context);
41+
void* buffer =
42+
micro_context->AllocatePersistentBuffer(sizeof(DecodeStatePrune));
43+
if (buffer == nullptr) {
44+
return nullptr;
45+
}
46+
DecodeState* dsp = new (buffer) DecodeStatePrune(context, profiler);
47+
48+
return dsp;
49+
}
50+
51+
DecodeState* DecodeState::CreateDecodeStateHuffman(
52+
const TfLiteContext* context, MicroProfilerInterface* profiler) {
53+
MicroContext* const micro_context = GetMicroContext(context);
54+
void* buffer =
55+
micro_context->AllocatePersistentBuffer(sizeof(DecodeStateHuffman));
56+
if (buffer == nullptr) {
57+
return nullptr;
58+
}
59+
DecodeState* dsp = new (buffer) DecodeStateHuffman(context, profiler);
60+
61+
return dsp;
62+
}
63+
64+
} // namespace tflite

0 commit comments

Comments
 (0)