diff --git a/tensorflow/lite/micro/kernels/decode.cc b/tensorflow/lite/micro/kernels/decode.cc index 9f4d34cff15..30d54442e64 100644 --- a/tensorflow/lite/micro/kernels/decode.cc +++ b/tensorflow/lite/micro/kernels/decode.cc @@ -49,6 +49,23 @@ TfLiteStatus SetOutputTensorData(TfLiteContext* context, const TfLiteNode* node, return kTfLiteOk; } +DecodeState* GetDecodeStateFromCustomRegistration(const TfLiteContext* context, + uint8_t type) { + const MicroContext* mc = GetMicroContext(context); + auto registrations = mc->GetCustomDecodeRegistrations(); + if (registrations == nullptr) { + return nullptr; + } + + for (auto& reg : *registrations) { + if (reg.type == type && reg.func != nullptr) { + return reg.func(context, mc->GetAlternateProfiler()); + } + } + + return nullptr; +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const size_t num_inputs = NumInputs(node); const size_t num_outputs = NumOutputs(node); @@ -113,21 +130,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { dsp = DecodeState::CreateDecodeStateHuffman( context, micro_context->GetAlternateProfiler()); break; - case DecodeState::kDcmTypeCustom: - MicroPrintf("Custom decode type not yet supported"); - break; default: - MicroPrintf("unsupported decode type %u", - DecodeState::Type(*ancillary)); + uint32_t type = DecodeState::Type(*ancillary); + if (type >= DecodeState::kDcmTypeCustomFirst && + type <= DecodeState::kDcmTypeCustomLast) { + dsp = GetDecodeStateFromCustomRegistration(context, type); + } else { + MicroPrintf("unsupported decode type %u", type); + } break; } - status = SetOutputTensorData(context, node, i / 2, output); - if (status != kTfLiteOk) { - break; - } - if (dsp != nullptr) { + status = SetOutputTensorData(context, node, i / 2, output); + if (status != kTfLiteOk) { + break; + } status = dsp->Setup(*input, *ancillary, *output); if (status != kTfLiteOk) { break; diff --git a/tensorflow/lite/micro/kernels/decode_state.h b/tensorflow/lite/micro/kernels/decode_state.h index 06f821dbc3c..9be36e32de3 100644 --- a/tensorflow/lite/micro/kernels/decode_state.h +++ b/tensorflow/lite/micro/kernels/decode_state.h @@ -72,7 +72,8 @@ class DecodeState { static constexpr uint8_t kDcmTypeLUT = 0; static constexpr uint8_t kDcmTypeHuffman = 1; static constexpr uint8_t kDcmTypePrune = 2; - static constexpr uint8_t kDcmTypeCustom = 127; + static constexpr uint8_t kDcmTypeCustomFirst = 128; + static constexpr uint8_t kDcmTypeCustomLast = 255; static constexpr size_t kDcmSizeInBytes = 16; diff --git a/tensorflow/lite/micro/kernels/decode_state_huffman_test.cc b/tensorflow/lite/micro/kernels/decode_state_huffman_test.cc index 0030b371d14..269bdd17e11 100644 --- a/tensorflow/lite/micro/kernels/decode_state_huffman_test.cc +++ b/tensorflow/lite/micro/kernels/decode_state_huffman_test.cc @@ -271,7 +271,7 @@ TF_LITE_MICRO_TEST(DecodeHuffmanTable16BitsInt16Fail) { tflite::testing::TestDecode( encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), - nullptr, kTfLiteError); + nullptr, nullptr, kTfLiteError); } TF_LITE_MICRO_TEST(DecodeHuffmanTable32BitsInt8) { diff --git a/tensorflow/lite/micro/kernels/decode_state_prune_test.cc b/tensorflow/lite/micro/kernels/decode_state_prune_test.cc index 636a5d9a746..955c4008157 100644 --- a/tensorflow/lite/micro/kernels/decode_state_prune_test.cc +++ b/tensorflow/lite/micro/kernels/decode_state_prune_test.cc @@ -575,7 +575,7 @@ TF_LITE_MICRO_TEST(DecodePruneQuantizedInvalidZeroPointInt16) { tflite::testing::TestDecode( kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE(), - nullptr, kTfLiteError); + nullptr, nullptr, kTfLiteError); } TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/decode_test.cc b/tensorflow/lite/micro/kernels/decode_test.cc index b07afda1b14..fdf4c1477b8 100644 --- a/tensorflow/lite/micro/kernels/decode_test.cc +++ b/tensorflow/lite/micro/kernels/decode_test.cc @@ -66,6 +66,76 @@ constexpr int kEncodedShapeLUT[] = {1, sizeof(kEncodedLUT)}; constexpr int8_t kExpectLUT0[] = {1, 2, 3, 4, 4, 3, 2, 1}; constexpr int16_t kExpectLUT1[] = {5, 6, 7, 8, 8, 7, 6, 5}; +// +// Custom DECODE test data +// +constexpr int kDecodeTypeCustom = 200; + +constexpr int8_t kAncillaryDataCustom[] = {0x42}; + +constexpr uint8_t kDcmCustom[tflite::DecodeState::kDcmSizeInBytes] = { + kDecodeTypeCustom, // type: custom + 1, // DCM version: 1 +}; + +// Align the tensor data the same as a Buffer in the TfLite schema +alignas(16) const uint8_t kEncodedCustom[] = {0x42, 0x43, 0x40, 0x46, + 0x4A, 0x52, 0x62, 0x02}; + +// Tensor shapes as TfLiteIntArray +constexpr int kOutputShapeCustom[] = {1, 8}; +constexpr int kEncodedShapeCustom[] = {1, sizeof(kEncodedCustom)}; + +constexpr int8_t kExpectCustom[] = {0x00, 0x01, 0x02, 0x04, + 0x08, 0x10, 0x20, 0x40}; + +class DecodeStateCustom : public tflite::DecodeState { + public: + DecodeStateCustom() = delete; + + DecodeStateCustom(const TfLiteContext* context, + tflite::MicroProfilerInterface* profiler) + : DecodeState(context, profiler) {} + + virtual TfLiteStatus Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) override { + return kTfLiteOk; + } + + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override { + const uint8_t* inp = tflite::micro::GetTensorData(&input); + TF_LITE_ENSURE(const_cast(context_), inp != nullptr); + uint8_t* outp = tflite::micro::GetTensorData( + const_cast(&output)); + TF_LITE_ENSURE(const_cast(context_), outp != nullptr); + const uint8_t* vp = tflite::micro::GetTensorData(&ancillary); + TF_LITE_ENSURE(const_cast(context_), vp != nullptr); + vp += kDcmSizeInBytes; + + // simple XOR de-obfuscation + std::transform(inp, inp + input.dims->data[0], outp, + [vp](uint8_t i) { return i ^ *vp; }); + + return kTfLiteOk; + } + + static DecodeState* CreateDecodeStateCustom( + const TfLiteContext* context, tflite::MicroProfilerInterface* profiler) { + alignas(4) static uint8_t buffer[sizeof(DecodeStateCustom)]; + DecodeState* instance = new (buffer) DecodeStateCustom(context, profiler); + return instance; + } + + protected: + virtual ~DecodeStateCustom() = default; + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + } // namespace TF_LITE_MICRO_TESTS_BEGIN @@ -246,4 +316,63 @@ TF_LITE_MICRO_TEST(DecodeWithAltDecompressionMemory) { encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), &amr); } +TF_LITE_MICRO_TEST(DecodeWithCustomRegistration) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectCustom)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmCustom}, {kAncillaryDataCustom}}; + + constexpr int kAncillaryShapeCustom[] = {1, sizeof(kAncillaryData)}; + + const TfLiteIntArray* const encoded_dims = + tflite::testing::IntArrayFromInts(kEncodedShapeCustom); + static const TensorInDatum tid_encode = { + kEncodedCustom, + *encoded_dims, + }; + static constexpr std::initializer_list encodes = { + &tid_encode, + }; + + const TfLiteIntArray* const ancillary_dims = + tflite::testing::IntArrayFromInts(kAncillaryShapeCustom); + static const TensorInDatum tid_ancillary = { + &kAncillaryData, + *ancillary_dims, + }; + static constexpr std::initializer_list ancillaries = { + &tid_ancillary}; + + const TfLiteIntArray* const output_dims = + tflite::testing::IntArrayFromInts(kOutputShapeCustom); + constexpr int kOutputZeroPointsData[] = {0}; + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kOutputZeroPointsData); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum tod = { + output_data, *output_dims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints, + 0, {}, + }; + static constexpr std::initializer_list outputs = { + &tod}; + + const std::initializer_list expected = {kExpectCustom}; + + const std::initializer_list + cdr = { + { + kDecodeTypeCustom, + 0, // reserved + 0, // reserved + 0, // reserved + DecodeStateCustom::CreateDecodeStateCustom, + }, + }; + + tflite::testing::TestDecode( + encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), + nullptr, &cdr); +} + TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/decode_test_helpers.h b/tensorflow/lite/micro/kernels/decode_test_helpers.h index c171dacaa46..1588cf2dc62 100644 --- a/tensorflow/lite/micro/kernels/decode_test_helpers.h +++ b/tensorflow/lite/micro/kernels/decode_test_helpers.h @@ -85,6 +85,8 @@ TfLiteStatus ExecuteDecodeTest( TfLiteTensor* tensors, const TFLMRegistration& registration, const std::initializer_list& expected, const std::initializer_list* amr = + nullptr, + const std::initializer_list* cdr = nullptr) { int kInputArrayData[kNumInputs + 1] = {kNumInputs}; for (size_t i = 0; i < kNumInputs; i++) { @@ -104,6 +106,9 @@ TfLiteStatus ExecuteDecodeTest( if (amr != nullptr) { runner.GetFakeMicroContext()->SetDecompressionMemory(*amr); } + if (cdr != nullptr) { + runner.GetFakeMicroContext()->SetCustomDecodeRegistrations(*cdr); + } if (runner.InitAndPrepare() != kTfLiteOk || runner.Invoke() != kTfLiteOk) { return kTfLiteError; @@ -149,6 +154,8 @@ void TestDecode( const TFLMRegistration& registration, const std::initializer_list* amr = nullptr, + const std::initializer_list* cdr = + nullptr, const TfLiteStatus expected_status = kTfLiteOk) { TfLiteTensor tensors[kNumInputs + kNumOutputs] = {}; @@ -182,7 +189,7 @@ void TestDecode( } TfLiteStatus s = ExecuteDecodeTest( - tensors, registration, expected, amr); + tensors, registration, expected, amr, cdr); TF_LITE_MICRO_EXPECT_EQ(s, expected_status); } diff --git a/tensorflow/lite/micro/kernels/kernel_util.h b/tensorflow/lite/micro/kernels/kernel_util.h index 5cb71af7953..84e541f735d 100644 --- a/tensorflow/lite/micro/kernels/kernel_util.h +++ b/tensorflow/lite/micro/kernels/kernel_util.h @@ -23,7 +23,9 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/micro/micro_common.h" #include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/micro_graph.h" #ifdef USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/micro_context.h b/tensorflow/lite/micro/micro_context.h index f9a70bf5d32..96844ad1a18 100644 --- a/tensorflow/lite/micro/micro_context.h +++ b/tensorflow/lite/micro/micro_context.h @@ -33,6 +33,8 @@ namespace tflite { // TODO(b/149795762): kTfLiteAbort cannot be part of the tflite TfLiteStatus. const TfLiteStatus kTfLiteAbort = static_cast(15); +class DecodeState; // can't use decode_state.h due to circular include + // MicroContext is eventually going to become the API between TFLM and the // kernels, replacing all the functions in TfLiteContext. The end state is code // kernels to have code like: @@ -136,7 +138,7 @@ class MicroContext { }; // Set the alternate decompression memory regions. - // Can only be called during the MicroInterpreter kInit state. + // Can only be called during the kInit state. virtual TfLiteStatus SetDecompressionMemory( const std::initializer_list& regions); @@ -169,12 +171,40 @@ class MicroContext { return nullptr; } + struct CustomDecodeRegistration { + uint8_t type; // custom decode type + uint8_t reserved1; // reserved + uint8_t reserved2; // reserved + uint8_t reserved3; // reserved + tflite::DecodeState* (*func)(const TfLiteContext*, MicroProfilerInterface*); + }; + + // Set the custom DECODE operator registrations. + // Can only be called during the kInit state. + virtual TfLiteStatus SetCustomDecodeRegistrations( + const std::initializer_list& registrations) { + if (custom_decode_registrations_ != nullptr) { + return kTfLiteError; + } + custom_decode_registrations_ = ®istrations; + return kTfLiteOk; + } + + // Get the custom decompression registrations. + virtual const std::initializer_list* + GetCustomDecodeRegistrations() const { + return custom_decode_registrations_; + } + private: const std::initializer_list* decompress_regions_ = nullptr; // array of size_t elements with length equal to decompress_regions_.size() size_t* decompress_regions_allocations_ = nullptr; + const std::initializer_list* + custom_decode_registrations_ = nullptr; + TF_LITE_REMOVE_VIRTUAL_DELETE }; diff --git a/tensorflow/lite/micro/micro_interpreter_context.h b/tensorflow/lite/micro/micro_interpreter_context.h index 5f17c1efac6..b9693cbbd17 100644 --- a/tensorflow/lite/micro/micro_interpreter_context.h +++ b/tensorflow/lite/micro/micro_interpreter_context.h @@ -134,7 +134,7 @@ class MicroInterpreterContext : public MicroContext { #endif // USE_TFLM_COMPRESSION // Set the alternate decompression memory regions. - // Can only be called during the MicroInterpreter kInit state. + // Can only be called during the kInit state. TfLiteStatus SetDecompressionMemory( const std::initializer_list& regions) override; @@ -159,6 +159,17 @@ class MicroInterpreterContext : public MicroContext { // decompression subsystem. MicroProfilerInterface* GetAlternateProfiler() const override; + // Set the custom DECODE operator registrations. + // Can only be called during the kInit state. + virtual TfLiteStatus SetCustomDecodeRegistrations( + const std::initializer_list& registrations) + override { + if (state_ != InterpreterState::kInit) { + return kTfLiteError; + } + return MicroContext::SetCustomDecodeRegistrations(registrations); + } + private: MicroAllocator& allocator_; MicroInterpreterGraph& graph_;