Skip to content

Commit d96fe66

Browse files
ddavis-2015veblush
andauthored
Support for DECODE operator (#3132)
* Support for DECODE operator @tensorflow/micro Add initial support for DECODE operator. Add reference implementation. Add LUT decompression support. Update op resolvers. Update Makefiles and Bazel BUILD files. Add kernel unit test. bug=fixes #3131 * update copyright * Don't use constructors with global objects (bluepill will not call them). Cleanup unit test. * return error if DecodeState cannot be created. * Address review issues. --------- Co-authored-by: Esun Kim <[email protected]>
1 parent 37fb41f commit d96fe66

File tree

13 files changed

+1367
-0
lines changed

13 files changed

+1367
-0
lines changed

python/tflite_micro/python_ops_resolver.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ PythonOpsResolver::PythonOpsResolver() {
4040
AddConv2D();
4141
AddCos();
4242
AddCumSum();
43+
AddDecode();
4344
AddDelay();
4445
AddDepthToSpace();
4546
AddDepthwiseConv2D();

tensorflow/lite/micro/kernels/BUILD

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,9 @@ tflm_kernel_cc_library(
236236
"conv.cc",
237237
"conv_common.cc",
238238
"cumsum.cc",
239+
"decode.cc",
240+
"decode_state.cc",
241+
"decode_state_lut.cc",
239242
"depth_to_space.cc",
240243
"depthwise_conv.cc",
241244
"depthwise_conv_common.cc",
@@ -327,6 +330,8 @@ tflm_kernel_cc_library(
327330
"batch_matmul.h",
328331
"circular_buffer.h",
329332
"conv.h",
333+
"decode_state.h",
334+
"decode_state_lut.h",
330335
"depthwise_conv.h",
331336
"dequantize.h",
332337
"ethosu.h",
@@ -643,6 +648,21 @@ tflm_cc_test(
643648
],
644649
)
645650

651+
tflm_cc_test(
652+
name = "decode_test",
653+
srcs = [
654+
"decode_test.cc",
655+
],
656+
deps = [
657+
":kernel_runner",
658+
"//tensorflow/lite/c:common",
659+
"//tensorflow/lite/micro:debug_log",
660+
"//tensorflow/lite/micro:op_resolvers",
661+
"//tensorflow/lite/micro:test_helpers",
662+
"//tensorflow/lite/micro/testing:micro_test",
663+
],
664+
)
665+
646666
tflm_cc_test(
647667
name = "decompress_test",
648668
srcs = [

tensorflow/lite/micro/kernels/Makefile.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/ceil_test.cc \
123123
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/comparisons_test.cc \
124124
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation_test.cc \
125125
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum_test.cc \
126+
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_test.cc \
126127
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space_test.cc \
127128
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_test.cc \
128129
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_test.cc \
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/lite/c/common.h"
17+
#include "tensorflow/lite/kernels/internal/compatibility.h"
18+
#include "tensorflow/lite/kernels/kernel_util.h"
19+
#include "tensorflow/lite/micro/kernels/decode_state.h"
20+
#include "tensorflow/lite/micro/kernels/kernel_util.h"
21+
#include "tensorflow/lite/micro/micro_context.h"
22+
#include "tensorflow/lite/micro/micro_log.h"
23+
24+
namespace tflite {
25+
namespace {
26+
27+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
28+
const size_t num_inputs = NumInputs(node);
29+
const size_t num_outputs = NumOutputs(node);
30+
TF_LITE_ENSURE(context, num_outputs > 0);
31+
TF_LITE_ENSURE_EQ(context, num_inputs, num_outputs * 2);
32+
33+
MicroContext* const micro_context = GetMicroContext(context);
34+
35+
node->user_data = micro_context->AllocatePersistentBuffer(
36+
num_outputs * sizeof(DecodeState*));
37+
TF_LITE_ENSURE(context, node->user_data != nullptr);
38+
DecodeState** const dsp_arr =
39+
reinterpret_cast<DecodeState**>(node->user_data);
40+
41+
TfLiteTensor* input = nullptr;
42+
TfLiteTensor* ancillary = nullptr;
43+
TfLiteTensor* output = nullptr;
44+
TfLiteStatus status = kTfLiteOk;
45+
46+
for (size_t i = 0; i < num_inputs; i += 2) {
47+
input = micro_context->AllocateTempInputTensor(node, i);
48+
if (input == nullptr) {
49+
MicroPrintf("failed to allocate input tensor %u", i);
50+
status = kTfLiteError;
51+
break;
52+
}
53+
ancillary = micro_context->AllocateTempInputTensor(node, i + 1);
54+
if (ancillary == nullptr) {
55+
MicroPrintf("failed to allocate ancillary tensor %u", i + 1);
56+
status = kTfLiteError;
57+
break;
58+
}
59+
output = micro_context->AllocateTempOutputTensor(node, i / 2);
60+
if (output == nullptr) {
61+
MicroPrintf("failed to allocate output tensor %u", i / 2);
62+
status = kTfLiteError;
63+
break;
64+
}
65+
66+
if (DecodeState::Version(*ancillary) != 1) {
67+
MicroPrintf("version %u != 1", DecodeState::Version(*ancillary));
68+
status = kTfLiteError;
69+
break;
70+
}
71+
72+
DecodeState* dsp = nullptr;
73+
switch (DecodeState::Type(*ancillary)) {
74+
case DecodeState::kDcmTypeLUT:
75+
dsp = DecodeState::CreateDecodeStateLUT(
76+
context, micro_context->GetAlternateProfiler());
77+
break;
78+
case DecodeState::kDcmTypeCustom:
79+
MicroPrintf("Custom decode type not yet supported");
80+
break;
81+
default:
82+
MicroPrintf("unsupported decode type %u",
83+
DecodeState::Type(*ancillary));
84+
break;
85+
}
86+
87+
if (dsp != nullptr) {
88+
status = dsp->Setup(*input, *ancillary, *output);
89+
if (status != kTfLiteOk) {
90+
break;
91+
}
92+
dsp_arr[i / 2] = dsp;
93+
} else {
94+
MicroPrintf("failed to allocate DecodeState[%u]", i / 2);
95+
status = kTfLiteError;
96+
break;
97+
}
98+
99+
micro_context->DeallocateTempTfLiteTensor(input);
100+
micro_context->DeallocateTempTfLiteTensor(ancillary);
101+
micro_context->DeallocateTempTfLiteTensor(output);
102+
input = nullptr;
103+
ancillary = nullptr;
104+
output = nullptr;
105+
}
106+
107+
if (input != nullptr) {
108+
micro_context->DeallocateTempTfLiteTensor(input);
109+
}
110+
if (ancillary != nullptr) {
111+
micro_context->DeallocateTempTfLiteTensor(ancillary);
112+
}
113+
if (output != nullptr) {
114+
micro_context->DeallocateTempTfLiteTensor(output);
115+
}
116+
117+
return status;
118+
}
119+
120+
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
121+
const size_t num_inputs = NumInputs(node);
122+
DecodeState** const dsp_arr =
123+
reinterpret_cast<DecodeState**>(node->user_data);
124+
125+
for (size_t i = 0; i < num_inputs; i += 2) {
126+
const TfLiteEvalTensor* input =
127+
tflite::micro::GetEvalInput(context, node, i);
128+
TF_LITE_ENSURE(context, input != nullptr);
129+
const TfLiteEvalTensor* ancillary =
130+
tflite::micro::GetEvalInput(context, node, i + 1);
131+
TF_LITE_ENSURE(context, ancillary != nullptr);
132+
const TfLiteEvalTensor* output =
133+
tflite::micro::GetEvalOutput(context, node, i / 2);
134+
TF_LITE_ENSURE(context, output != nullptr);
135+
136+
TfLiteStatus status = dsp_arr[i / 2]->Decode(*input, *ancillary, *output);
137+
TF_LITE_ENSURE(context, status == kTfLiteOk);
138+
}
139+
140+
return kTfLiteOk;
141+
}
142+
143+
} // namespace
144+
145+
TFLMRegistration Register_DECODE() {
146+
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
147+
}
148+
149+
} // namespace tflite
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/lite/micro/kernels/decode_state.h"
17+
18+
#include "tensorflow/lite/micro/kernels/decode_state_lut.h"
19+
#include "tensorflow/lite/micro/micro_context.h"
20+
21+
namespace tflite {
22+
23+
DecodeState* DecodeState::CreateDecodeStateLUT(
24+
const TfLiteContext* context, MicroProfilerInterface* profiler) {
25+
MicroContext* const micro_context = GetMicroContext(context);
26+
void* buffer =
27+
micro_context->AllocatePersistentBuffer(sizeof(DecodeStateLut));
28+
if (buffer == nullptr) {
29+
return nullptr;
30+
}
31+
DecodeState* dsp = new (buffer) DecodeStateLut(context, profiler);
32+
33+
return dsp;
34+
}
35+
36+
} // namespace tflite
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_H_
17+
#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_H_
18+
19+
#include <cstdint>
20+
21+
#include "tensorflow/lite/c/common.h"
22+
#include "tensorflow/lite/core/c/c_api_types.h"
23+
#include "tensorflow/lite/kernels/kernel_util.h"
24+
#include "tensorflow/lite/micro/compatibility.h"
25+
#include "tensorflow/lite/micro/kernels/kernel_util.h"
26+
#include "tensorflow/lite/micro/micro_profiler_interface.h"
27+
28+
namespace tflite {
29+
30+
class DecodeState {
31+
public:
32+
DecodeState() = delete;
33+
34+
DecodeState(const TfLiteContext* context, MicroProfilerInterface* profiler)
35+
: context_(context), micro_profiler_(profiler) {}
36+
37+
virtual TfLiteStatus Setup(const TfLiteTensor& input,
38+
const TfLiteTensor& ancillary,
39+
const TfLiteTensor& output) = 0;
40+
virtual TfLiteStatus Decode(const TfLiteEvalTensor& input,
41+
const TfLiteEvalTensor& ancillary,
42+
const TfLiteEvalTensor& output) = 0;
43+
44+
static DecodeState* CreateDecodeStateLUT(const TfLiteContext* context,
45+
MicroProfilerInterface* profiler);
46+
47+
static uint8_t Type(const TfLiteTensor& ancillary) {
48+
return GetTensorData<uint8_t>(&ancillary)[kDcmDecodeTypeOffset];
49+
}
50+
51+
static uint8_t Type(const TfLiteEvalTensor& ancillary) {
52+
return micro::GetTensorData<uint8_t>(&ancillary)[kDcmDecodeTypeOffset];
53+
}
54+
55+
static uint8_t Version(const TfLiteTensor& ancillary) {
56+
return GetTensorData<uint8_t>(&ancillary)[kDcmVersionOffset];
57+
}
58+
59+
static uint8_t Version(const TfLiteEvalTensor& ancillary) {
60+
return micro::GetTensorData<uint8_t>(&ancillary)[kDcmVersionOffset];
61+
}
62+
63+
protected:
64+
virtual ~DecodeState() = default;
65+
66+
// Decode Common Metadata constants
67+
public:
68+
static constexpr uint8_t kDcmTypeLUT = 0;
69+
static constexpr uint8_t kDcmTypeCustom = 127;
70+
71+
static constexpr size_t kDcmSizeInBytes = 16;
72+
73+
private:
74+
static constexpr size_t kDcmDecodeTypeOffset = 0;
75+
static constexpr size_t kDcmVersionOffset = 1;
76+
77+
// DecodeState vars
78+
protected:
79+
const TfLiteContext* context_;
80+
MicroProfilerInterface* micro_profiler_;
81+
82+
private:
83+
TF_LITE_REMOVE_VIRTUAL_DELETE
84+
};
85+
86+
} // namespace tflite
87+
88+
#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_H_

0 commit comments

Comments
 (0)