Skip to content
38 changes: 28 additions & 10 deletions tensorflow/lite/micro/kernels/decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ TfLiteStatus SetOutputTensorData(TfLiteContext* context, const TfLiteNode* node,
return kTfLiteOk;
}

DecodeState* GetDecodeStateFromCustomRegistration(const TfLiteContext* context,
uint8_t type) {
const MicroContext* mc = GetMicroContext(context);
auto registrations = mc->GetCustomDecodeRegistrations();
if (registrations == nullptr) {
return nullptr;
}

for (auto& reg : *registrations) {
if (reg.type == type && reg.func != nullptr) {
return reg.func(context, mc->GetAlternateProfiler());
}
}

return nullptr;
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const size_t num_inputs = NumInputs(node);
const size_t num_outputs = NumOutputs(node);
Expand Down Expand Up @@ -113,21 +130,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
dsp = DecodeState::CreateDecodeStateHuffman(
context, micro_context->GetAlternateProfiler());
break;
case DecodeState::kDcmTypeCustom:
MicroPrintf("Custom decode type not yet supported");
break;
default:
MicroPrintf("unsupported decode type %u",
DecodeState::Type(*ancillary));
uint32_t type = DecodeState::Type(*ancillary);
if (type >= DecodeState::kDcmTypeCustomFirst &&
type <= DecodeState::kDcmTypeCustomLast) {
dsp = GetDecodeStateFromCustomRegistration(context, type);
} else {
MicroPrintf("unsupported decode type %u", type);
}
break;
}

status = SetOutputTensorData(context, node, i / 2, output);
if (status != kTfLiteOk) {
break;
}

if (dsp != nullptr) {
status = SetOutputTensorData(context, node, i / 2, output);
if (status != kTfLiteOk) {
break;
}
status = dsp->Setup(*input, *ancillary, *output);
if (status != kTfLiteOk) {
break;
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/lite/micro/kernels/decode_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class DecodeState {
static constexpr uint8_t kDcmTypeLUT = 0;
static constexpr uint8_t kDcmTypeHuffman = 1;
static constexpr uint8_t kDcmTypePrune = 2;
static constexpr uint8_t kDcmTypeCustom = 127;
static constexpr uint8_t kDcmTypeCustomFirst = 128;
static constexpr uint8_t kDcmTypeCustomLast = 255;

static constexpr size_t kDcmSizeInBytes = 16;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ TF_LITE_MICRO_TEST(DecodeHuffmanTable16BitsInt16Fail) {
tflite::testing::TestDecode<encodes.size() + ancillaries.size(),
outputs.size()>(
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(),
nullptr, kTfLiteError);
nullptr, nullptr, kTfLiteError);
}

TF_LITE_MICRO_TEST(DecodeHuffmanTable32BitsInt8) {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/lite/micro/kernels/decode_state_prune_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ TF_LITE_MICRO_TEST(DecodePruneQuantizedInvalidZeroPointInt16) {
tflite::testing::TestDecode<kEncodes.size() + kAncillaries.size(),
kOutputs.size()>(
kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE(),
nullptr, kTfLiteError);
nullptr, nullptr, kTfLiteError);
}

TF_LITE_MICRO_TESTS_END
129 changes: 129 additions & 0 deletions tensorflow/lite/micro/kernels/decode_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,76 @@ constexpr int kEncodedShapeLUT[] = {1, sizeof(kEncodedLUT)};
constexpr int8_t kExpectLUT0[] = {1, 2, 3, 4, 4, 3, 2, 1};
constexpr int16_t kExpectLUT1[] = {5, 6, 7, 8, 8, 7, 6, 5};

//
// Custom DECODE test data
//
constexpr int kDecodeTypeCustom = 200;

constexpr int8_t kAncillaryDataCustom[] = {0x42};

constexpr uint8_t kDcmCustom[tflite::DecodeState::kDcmSizeInBytes] = {
kDecodeTypeCustom, // type: custom
1, // DCM version: 1
};

// Align the tensor data the same as a Buffer in the TfLite schema
alignas(16) const uint8_t kEncodedCustom[] = {0x42, 0x43, 0x40, 0x46,
0x4A, 0x52, 0x62, 0x02};

// Tensor shapes as TfLiteIntArray
constexpr int kOutputShapeCustom[] = {1, 8};
constexpr int kEncodedShapeCustom[] = {1, sizeof(kEncodedCustom)};

constexpr int8_t kExpectCustom[] = {0x00, 0x01, 0x02, 0x04,
0x08, 0x10, 0x20, 0x40};

class DecodeStateCustom : public tflite::DecodeState {
public:
DecodeStateCustom() = delete;

DecodeStateCustom(const TfLiteContext* context,
tflite::MicroProfilerInterface* profiler)
: DecodeState(context, profiler) {}

virtual TfLiteStatus Setup(const TfLiteTensor& input,
const TfLiteTensor& ancillary,
const TfLiteTensor& output) override {
return kTfLiteOk;
}

virtual TfLiteStatus Decode(const TfLiteEvalTensor& input,
const TfLiteEvalTensor& ancillary,
const TfLiteEvalTensor& output) override {
const uint8_t* inp = tflite::micro::GetTensorData<uint8_t>(&input);
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), inp != nullptr);
uint8_t* outp = tflite::micro::GetTensorData<uint8_t>(
const_cast<TfLiteEvalTensor*>(&output));
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), outp != nullptr);
const uint8_t* vp = tflite::micro::GetTensorData<uint8_t>(&ancillary);
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), vp != nullptr);
vp += kDcmSizeInBytes;

// simple XOR de-obfuscation
std::transform(inp, inp + input.dims->data[0], outp,
[vp](uint8_t i) { return i ^ *vp; });

return kTfLiteOk;
}

static DecodeState* CreateDecodeStateCustom(
const TfLiteContext* context, tflite::MicroProfilerInterface* profiler) {
alignas(4) static uint8_t buffer[sizeof(DecodeStateCustom)];
DecodeState* instance = new (buffer) DecodeStateCustom(context, profiler);
return instance;
}

protected:
virtual ~DecodeStateCustom() = default;

private:
TF_LITE_REMOVE_VIRTUAL_DELETE
};

} // namespace

TF_LITE_MICRO_TESTS_BEGIN
Expand Down Expand Up @@ -246,4 +316,63 @@ TF_LITE_MICRO_TEST(DecodeWithAltDecompressionMemory) {
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), &amr);
}

TF_LITE_MICRO_TEST(DecodeWithCustomRegistration) {
// Align the tensor data the same as a Buffer in the TfLite schema
alignas(16) int8_t output_data[std::size(kExpectCustom)] = {};
alignas(16) const AncillaryData<int8_t, std::size(kAncillaryDataCustom)>
kAncillaryData = {{kDcmCustom}, {kAncillaryDataCustom}};

constexpr int kAncillaryShapeCustom[] = {1, sizeof(kAncillaryData)};

const TfLiteIntArray* const encoded_dims =
tflite::testing::IntArrayFromInts(kEncodedShapeCustom);
static const TensorInDatum tid_encode = {
kEncodedCustom,
*encoded_dims,
};
static constexpr std::initializer_list<const TensorInDatum*> encodes = {
&tid_encode,
};

const TfLiteIntArray* const ancillary_dims =
tflite::testing::IntArrayFromInts(kAncillaryShapeCustom);
static const TensorInDatum tid_ancillary = {
&kAncillaryData,
*ancillary_dims,
};
static constexpr std::initializer_list<const TensorInDatum*> ancillaries = {
&tid_ancillary};

const TfLiteIntArray* const output_dims =
tflite::testing::IntArrayFromInts(kOutputShapeCustom);
constexpr int kOutputZeroPointsData[] = {0};
const TfLiteIntArray* const kOutputZeroPoints =
tflite::testing::IntArrayFromInts(kOutputZeroPointsData);
const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size};
static const TensorOutDatum tod = {
output_data, *output_dims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints,
0, {},
};
static constexpr std::initializer_list<const TensorOutDatum*> outputs = {
&tod};

const std::initializer_list<const void*> expected = {kExpectCustom};

const std::initializer_list<tflite::MicroContext::CustomDecodeRegistration>
cdr = {
{
kDecodeTypeCustom,
0, // reserved
0, // reserved
0, // reserved
DecodeStateCustom::CreateDecodeStateCustom,
},
};

tflite::testing::TestDecode<encodes.size() + ancillaries.size(),
outputs.size()>(
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(),
nullptr, &cdr);
}

TF_LITE_MICRO_TESTS_END
9 changes: 8 additions & 1 deletion tensorflow/lite/micro/kernels/decode_test_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ TfLiteStatus ExecuteDecodeTest(
TfLiteTensor* tensors, const TFLMRegistration& registration,
const std::initializer_list<const void*>& expected,
const std::initializer_list<MicroContext::AlternateMemoryRegion>* amr =
nullptr,
const std::initializer_list<MicroContext::CustomDecodeRegistration>* cdr =
nullptr) {
int kInputArrayData[kNumInputs + 1] = {kNumInputs};
for (size_t i = 0; i < kNumInputs; i++) {
Expand All @@ -104,6 +106,9 @@ TfLiteStatus ExecuteDecodeTest(
if (amr != nullptr) {
runner.GetFakeMicroContext()->SetDecompressionMemory(*amr);
}
if (cdr != nullptr) {
runner.GetFakeMicroContext()->SetCustomDecodeRegistrations(*cdr);
}

if (runner.InitAndPrepare() != kTfLiteOk || runner.Invoke() != kTfLiteOk) {
return kTfLiteError;
Expand Down Expand Up @@ -149,6 +154,8 @@ void TestDecode(
const TFLMRegistration& registration,
const std::initializer_list<MicroContext::AlternateMemoryRegion>* amr =
nullptr,
const std::initializer_list<MicroContext::CustomDecodeRegistration>* cdr =
nullptr,
const TfLiteStatus expected_status = kTfLiteOk) {
TfLiteTensor tensors[kNumInputs + kNumOutputs] = {};

Expand Down Expand Up @@ -182,7 +189,7 @@ void TestDecode(
}

TfLiteStatus s = ExecuteDecodeTest<kNumInputs, kNumOutputs>(
tensors, registration, expected, amr);
tensors, registration, expected, amr, cdr);
TF_LITE_MICRO_EXPECT_EQ(s, expected_status);
}

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/lite/micro/kernels/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/micro/micro_common.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_graph.h"

#ifdef USE_TFLM_COMPRESSION

Expand Down
32 changes: 31 additions & 1 deletion tensorflow/lite/micro/micro_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ namespace tflite {
// TODO(b/149795762): kTfLiteAbort cannot be part of the tflite TfLiteStatus.
const TfLiteStatus kTfLiteAbort = static_cast<TfLiteStatus>(15);

class DecodeState; // can't use decode_state.h due to circular include

// MicroContext is eventually going to become the API between TFLM and the
// kernels, replacing all the functions in TfLiteContext. The end state is code
// kernels to have code like:
Expand Down Expand Up @@ -136,7 +138,7 @@ class MicroContext {
};

// Set the alternate decompression memory regions.
// Can only be called during the MicroInterpreter kInit state.
// Can only be called during the kInit state.
virtual TfLiteStatus SetDecompressionMemory(
const std::initializer_list<AlternateMemoryRegion>& regions);

Expand Down Expand Up @@ -169,12 +171,40 @@ class MicroContext {
return nullptr;
}

struct CustomDecodeRegistration {
uint8_t type; // custom decode type
uint8_t reserved1; // reserved
uint8_t reserved2; // reserved
uint8_t reserved3; // reserved
tflite::DecodeState* (*func)(const TfLiteContext*, MicroProfilerInterface*);
};

// Set the custom DECODE operator registrations.
// Can only be called during the kInit state.
virtual TfLiteStatus SetCustomDecodeRegistrations(
const std::initializer_list<CustomDecodeRegistration>& registrations) {
if (custom_decode_registrations_ != nullptr) {
return kTfLiteError;
}
custom_decode_registrations_ = &registrations;
return kTfLiteOk;
}

// Get the custom decompression registrations.
virtual const std::initializer_list<CustomDecodeRegistration>*
GetCustomDecodeRegistrations() const {
return custom_decode_registrations_;
}

private:
const std::initializer_list<AlternateMemoryRegion>* decompress_regions_ =
nullptr;
// array of size_t elements with length equal to decompress_regions_.size()
size_t* decompress_regions_allocations_ = nullptr;

const std::initializer_list<CustomDecodeRegistration>*
custom_decode_registrations_ = nullptr;

TF_LITE_REMOVE_VIRTUAL_DELETE
};

Expand Down
13 changes: 12 additions & 1 deletion tensorflow/lite/micro/micro_interpreter_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class MicroInterpreterContext : public MicroContext {
#endif // USE_TFLM_COMPRESSION

// Set the alternate decompression memory regions.
// Can only be called during the MicroInterpreter kInit state.
// Can only be called during the kInit state.
TfLiteStatus SetDecompressionMemory(
const std::initializer_list<AlternateMemoryRegion>& regions) override;

Expand All @@ -159,6 +159,17 @@ class MicroInterpreterContext : public MicroContext {
// decompression subsystem.
MicroProfilerInterface* GetAlternateProfiler() const override;

// Set the custom DECODE operator registrations.
// Can only be called during the kInit state.
virtual TfLiteStatus SetCustomDecodeRegistrations(
const std::initializer_list<CustomDecodeRegistration>& registrations)
override {
if (state_ != InterpreterState::kInit) {
return kTfLiteError;
}
return MicroContext::SetCustomDecodeRegistrations(registrations);
}

private:
MicroAllocator& allocator_;
MicroInterpreterGraph& graph_;
Expand Down
Loading