diff --git a/python/tflite_micro/python_ops_resolver.cc b/python/tflite_micro/python_ops_resolver.cc index 77ef336d9de..13f160ca53c 100644 --- a/python/tflite_micro/python_ops_resolver.cc +++ b/python/tflite_micro/python_ops_resolver.cc @@ -40,6 +40,7 @@ PythonOpsResolver::PythonOpsResolver() { AddConv2D(); AddCos(); AddCumSum(); + AddDecode(); AddDelay(); AddDepthToSpace(); AddDepthwiseConv2D(); diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index f1d12f04634..f5accf27ca9 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -236,6 +236,11 @@ tflm_kernel_cc_library( "conv.cc", "conv_common.cc", "cumsum.cc", + "decode.cc", + "decode_state.cc", + "decode_state_huffman.cc", + "decode_state_lut.cc", + "decode_state_prune.cc", "depth_to_space.cc", "depthwise_conv.cc", "depthwise_conv_common.cc", @@ -327,6 +332,10 @@ tflm_kernel_cc_library( "batch_matmul.h", "circular_buffer.h", "conv.h", + "decode_state.h", + "decode_state_huffman.h", + "decode_state_lut.h", + "decode_state_prune.h", "depthwise_conv.h", "dequantize.h", "ethosu.h", @@ -643,6 +652,21 @@ tflm_cc_test( ], ) +tflm_cc_test( + name = "decode_test", + srcs = [ + "decode_test.cc", + ], + deps = [ + ":kernel_runner", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:debug_log", + "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + tflm_cc_test( name = "decompress_test", srcs = [ diff --git a/tensorflow/lite/micro/kernels/Makefile.inc b/tensorflow/lite/micro/kernels/Makefile.inc index 11684278801..62e9324995e 100644 --- a/tensorflow/lite/micro/kernels/Makefile.inc +++ b/tensorflow/lite/micro/kernels/Makefile.inc @@ -123,6 +123,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/ceil_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/comparisons_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum_test.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_test.cc \ diff --git a/tensorflow/lite/micro/kernels/decode.cc b/tensorflow/lite/micro/kernels/decode.cc new file mode 100644 index 00000000000..55f40a2f0ca --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode.cc @@ -0,0 +1,187 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_arena_constants.h" +#include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/micro_log.h" + +namespace tflite { +namespace { + +TfLiteStatus SetOutputTensorData(TfLiteContext* context, const TfLiteNode* node, + size_t tensor_output_index, + TfLiteTensor* output) { + // If alternate decompression memory is available, set the tensor data + // pointer now to preclude allocation by the memory planner. + void* alternate_decompress_mem = + GetMicroContext(context)->AllocateDecompressionMemory( + output->bytes, MicroArenaBufferAlignment()); + if (alternate_decompress_mem != nullptr) { + TfLiteEvalTensor* output_eval = + tflite::micro::GetEvalOutput(context, node, tensor_output_index); + TF_LITE_ENSURE(context, output_eval != nullptr); + output_eval->data.data = alternate_decompress_mem; + output->data.data = alternate_decompress_mem; + } + + return kTfLiteOk; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const size_t num_inputs = NumInputs(node); + const size_t num_outputs = NumOutputs(node); + TF_LITE_ENSURE(context, num_outputs > 0); + TF_LITE_ENSURE_EQ(context, num_inputs, num_outputs * 2); + + MicroContext* const micro_context = GetMicroContext(context); + + node->user_data = micro_context->AllocatePersistentBuffer( + num_outputs * sizeof(DecodeState*)); + TF_LITE_ENSURE(context, node->user_data != nullptr); + DecodeState** const dsp_arr = + reinterpret_cast(node->user_data); + + TfLiteTensor* input = nullptr; + TfLiteTensor* ancillary = nullptr; + TfLiteTensor* output = nullptr; + TfLiteStatus status = kTfLiteOk; + + micro_context->ResetDecompressionMemoryAllocations(); + + for (size_t i = 0; i < num_inputs; i += 2) { + input = micro_context->AllocateTempInputTensor(node, i); + if (input == nullptr) { + MicroPrintf("failed to allocate input tensor %u", i); + status = kTfLiteError; + break; + } + ancillary = micro_context->AllocateTempInputTensor(node, i + 1); + if (ancillary == nullptr) { + MicroPrintf("failed to allocate ancillary tensor %u", i + 1); + status = kTfLiteError; + break; + } + output = micro_context->AllocateTempOutputTensor(node, i / 2); + if (output == nullptr) { + MicroPrintf("failed to allocate output tensor %u", i / 2); + status = kTfLiteError; + break; + } + + TF_LITE_ENSURE(context, IsConstantTensor(input)); + TF_LITE_ENSURE(context, IsConstantTensor(ancillary)); + + if (DecodeState::Version(*ancillary) != 1) { + MicroPrintf("version %u != 1", DecodeState::Version(*ancillary)); + status = kTfLiteError; + break; + } + + DecodeState* dsp = nullptr; + switch (DecodeState::Type(*ancillary)) { + case DecodeState::kDcmTypeLUT: + dsp = DecodeState::CreateDecodeStateLUT( + context, micro_context->GetAlternateProfiler()); + break; + case DecodeState::kDcmTypePrune: + dsp = DecodeState::CreateDecodeStatePrune( + context, micro_context->GetAlternateProfiler()); + break; + case DecodeState::kDcmTypeHuffman: + dsp = DecodeState::CreateDecodeStateHuffman( + context, micro_context->GetAlternateProfiler()); + break; + case DecodeState::kDcmTypeCustom: + MicroPrintf("Custom decode type not yet supported"); + break; + default: + MicroPrintf("unsupported decode type %u", + DecodeState::Type(*ancillary)); + break; + } + + status = SetOutputTensorData(context, node, i / 2, output); + if (status != kTfLiteOk) { + break; + } + + if (dsp != nullptr) { + status = dsp->Setup(*input, *ancillary, *output); + if (status != kTfLiteOk) { + break; + } + dsp_arr[i / 2] = dsp; + } else { + MicroPrintf("failed to allocate DecodeState[%u]", i / 2); + status = kTfLiteError; + break; + } + + micro_context->DeallocateTempTfLiteTensor(input); + micro_context->DeallocateTempTfLiteTensor(ancillary); + micro_context->DeallocateTempTfLiteTensor(output); + input = nullptr; + ancillary = nullptr; + output = nullptr; + } + + if (input != nullptr) { + micro_context->DeallocateTempTfLiteTensor(input); + } + if (ancillary != nullptr) { + micro_context->DeallocateTempTfLiteTensor(ancillary); + } + if (output != nullptr) { + micro_context->DeallocateTempTfLiteTensor(output); + } + + return status; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const size_t num_inputs = NumInputs(node); + DecodeState** const dsp_arr = + reinterpret_cast(node->user_data); + + for (size_t i = 0; i < num_inputs; i += 2) { + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, i); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteEvalTensor* ancillary = + tflite::micro::GetEvalInput(context, node, i + 1); + TF_LITE_ENSURE(context, ancillary != nullptr); + const TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, i / 2); + TF_LITE_ENSURE(context, output != nullptr); + + TfLiteStatus status = dsp_arr[i / 2]->Decode(*input, *ancillary, *output); + TF_LITE_ENSURE(context, status == kTfLiteOk); + } + + return kTfLiteOk; +} + +} // namespace + +TFLMRegistration Register_DECODE() { + return tflite::micro::RegisterOp(nullptr, Prepare, Eval); +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state.cc b/tensorflow/lite/micro/kernels/decode_state.cc new file mode 100644 index 00000000000..8895ee5f4a4 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state.cc @@ -0,0 +1,64 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/decode_state.h" + +#include "tensorflow/lite/micro/kernels/decode_state_huffman.h" +#include "tensorflow/lite/micro/kernels/decode_state_lut.h" +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" +#include "tensorflow/lite/micro/micro_context.h" + +namespace tflite { + +DecodeState* DecodeState::CreateDecodeStateLUT( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); + void* buffer = + micro_context->AllocatePersistentBuffer(sizeof(DecodeStateLUT)); + if (buffer == nullptr) { + return nullptr; + } + DecodeState* dsp = new (buffer) DecodeStateLUT(context, profiler); + + return dsp; +} + +DecodeState* DecodeState::CreateDecodeStatePrune( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); + void* buffer = + micro_context->AllocatePersistentBuffer(sizeof(DecodeStatePrune)); + if (buffer == nullptr) { + return nullptr; + } + DecodeState* dsp = new (buffer) DecodeStatePrune(context, profiler); + + return dsp; +} + +DecodeState* DecodeState::CreateDecodeStateHuffman( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); + void* buffer = + micro_context->AllocatePersistentBuffer(sizeof(DecodeStateHuffman)); + if (buffer == nullptr) { + return nullptr; + } + DecodeState* dsp = new (buffer) DecodeStateHuffman(context, profiler); + + return dsp; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state.h b/tensorflow/lite/micro/kernels/decode_state.h new file mode 100644 index 00000000000..d4aa25b4278 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state.h @@ -0,0 +1,93 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_H_ + +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_profiler_interface.h" + +namespace tflite { + +struct DecodeState { + DecodeState() = delete; + + DecodeState(const TfLiteContext* context, MicroProfilerInterface* profiler) + : context_(context), micro_profiler_(profiler) {} + + virtual TfLiteStatus Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) = 0; + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) = 0; + + static DecodeState* CreateDecodeStateLUT(const TfLiteContext* context, + MicroProfilerInterface* profiler); + static DecodeState* CreateDecodeStatePrune(const TfLiteContext* context, + MicroProfilerInterface* profiler); + static DecodeState* CreateDecodeStateHuffman( + const TfLiteContext* context, MicroProfilerInterface* profiler); + + static uint8_t Type(const TfLiteTensor& ancillary) { + return GetTensorData(&ancillary)[kDcmDecodeTypeOffset]; + } + + static uint8_t Type(const TfLiteEvalTensor& ancillary) { + return micro::GetTensorData(&ancillary)[kDcmDecodeTypeOffset]; + } + + static uint8_t Version(const TfLiteTensor& ancillary) { + return GetTensorData(&ancillary)[kDcmVersionOffset]; + } + + static uint8_t Version(const TfLiteEvalTensor& ancillary) { + return micro::GetTensorData(&ancillary)[kDcmVersionOffset]; + } + + protected: + virtual ~DecodeState() = default; + + // Decode Common Metadata constants + public: + static constexpr uint8_t kDcmTypeLUT = 0; + static constexpr uint8_t kDcmTypeHuffman = 1; + static constexpr uint8_t kDcmTypePrune = 2; + static constexpr uint8_t kDcmTypeCustom = 127; + + static constexpr size_t kDcmSizeInBytes = 16; + + private: + static constexpr size_t kDcmDecodeTypeOffset = 0; + static constexpr size_t kDcmVersionOffset = 1; + + // DecodeState vars + protected: + const TfLiteContext* context_; + MicroProfilerInterface* micro_profiler_; + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_H_ diff --git a/tensorflow/lite/micro/kernels/decode_state_huffman.cc b/tensorflow/lite/micro/kernels/decode_state_huffman.cc new file mode 100644 index 00000000000..8a30ac66ad2 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_huffman.cc @@ -0,0 +1,167 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/decode_state_huffman.h" + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +TfLiteStatus DecodeStateHuffman::Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) { + const uint8_t* const ancillary_data = GetTensorData(&ancillary); + if (ancillary_data[kDcmVersionOffset] != 1) { + MicroPrintf("unsupported version %u", ancillary_data[kDcmVersionOffset]); + return kTfLiteError; + } + + compressed_codewords_ = GetTensorData(&input); + count_codewords_ = NumElements(&output); + huffman_tables_ = &ancillary_data[kDcmSizeInBytes]; + use_32bit_table_ = + (ancillary_data[kDcmTableSizeOffset] & kDcmTableSize32BitsMask) != 0; + initial_table_size_ = + (ancillary_data[kDcmTableSizeOffset] & kDcmTableSizeInitialMask) >> + kDcmTableSizeInitialShift; + + if (!use_32bit_table_) { + TF_LITE_ENSURE_TYPES_EQ(const_cast(context_), output.type, + kTfLiteInt8); + } + + return kTfLiteOk; +} + +TfLiteStatus DecodeStateHuffman::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteInt8: + if (use_32bit_table_) { + DecompressToBufferWith32BitTable(static_cast(buffer)); + } else { + DecompressToBufferWith16BitTable(static_cast(buffer)); + } + break; + case kTfLiteInt16: + DecompressToBufferWith32BitTable(static_cast(buffer)); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +void DecodeStateHuffman::DecompressToBufferWith16BitTable(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + size_t remaining = count_codewords_; + const size_t initial_table_size = initial_table_size_ + 1; + const uint16_t* huffman_tables = + static_cast(huffman_tables_); + uint32_t head_offset = 0; // codewords bitstring state + uint32_t head_hold = 0; // codewords bitstring state + const uint32_t* head_next = nullptr; // codewords bitstring state + uint16_t table_value = 0; + + InitNextBits(head_offset, head_hold, head_next); + + while (remaining--) { + size_t last_used_bits = initial_table_size; + uint32_t current_index = + GetNextBits(last_used_bits, head_offset, head_hold, head_next); + size_t table_offset = current_index; + table_value = huffman_tables[table_offset]; + + while (!(table_value & kTable16BitSymbolFoundMask)) { + last_used_bits = + ((table_value & kTable16BitCountMask) >> kTable16BitCountShift) + 1; + current_index = + GetNextBits(last_used_bits, head_offset, head_hold, head_next); + const size_t next_table_offset = table_value & kTable16BitValueMask; + table_offset += next_table_offset + current_index; + table_value = huffman_tables[table_offset]; + } + + *buffer++ = table_value; + + const size_t symbol_residual_bits = + (table_value & kTable16BitCountMask) >> kTable16BitCountShift; + if (last_used_bits > symbol_residual_bits) { + PutBackBits(last_used_bits - symbol_residual_bits, head_offset, head_hold, + head_next); + } + } +} + +template +void DecodeStateHuffman::DecompressToBufferWith32BitTable(T* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + size_t remaining = count_codewords_; + const size_t initial_table_size = initial_table_size_ + 1; + const uint32_t* huffman_tables = + static_cast(huffman_tables_); + uint32_t head_offset = 0; // codewords bitstring state + uint32_t head_hold = 0; // codewords bitstring state + const uint32_t* head_next = nullptr; // codewords bitstring state + uint32_t table_value = 0; + + InitNextBits(head_offset, head_hold, head_next); + + while (remaining--) { + size_t last_used_bits = initial_table_size; + uint32_t current_index = + GetNextBits(last_used_bits, head_offset, head_hold, head_next); + size_t table_offset = current_index; + table_value = huffman_tables[table_offset]; + + while (!(table_value & kTable32BitSymbolFoundMask)) { + last_used_bits = + ((table_value & kTable32BitCountMask) >> kTable32BitCountShift) + 1; + current_index = + GetNextBits(last_used_bits, head_offset, head_hold, head_next); + const size_t next_table_offset = table_value & kTable32BitValueMask; + table_offset += next_table_offset + current_index; + table_value = huffman_tables[table_offset]; + } + + *buffer++ = table_value; + + const size_t symbol_residual_bits = + (table_value & kTable32BitCountMask) >> kTable32BitCountShift; + if (last_used_bits > symbol_residual_bits) { + PutBackBits(last_used_bits - symbol_residual_bits, head_offset, head_hold, + head_next); + } + } +} + +template void DecodeStateHuffman::DecompressToBufferWith32BitTable( + int8_t*); +template void DecodeStateHuffman::DecompressToBufferWith32BitTable( + int16_t*); + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state_huffman.h b/tensorflow/lite/micro/kernels/decode_state_huffman.h new file mode 100644 index 00000000000..087144badae --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_huffman.h @@ -0,0 +1,150 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_HUFFMAN_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_HUFFMAN_H_ + +#include +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" + +namespace tflite { + +struct DecodeStateHuffman : public DecodeState { + DecodeStateHuffman() = delete; + + DecodeStateHuffman(const TfLiteContext* context, + MicroProfilerInterface* profiler) + : DecodeState(context, profiler) {} + + virtual TfLiteStatus Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) override; + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + // Huffman table element constants + static constexpr uint16_t kTable16BitSymbolFoundMask = 0x8000; + static constexpr uint16_t kTable16BitCountMask = 0x7800; + static constexpr size_t kTable16BitCountShift = 11; + static constexpr uint16_t kTable16BitValueMask = 0x07FF; + static constexpr uint32_t kTable32BitSymbolFoundMask = 0x8000'0000; + static constexpr uint32_t kTable32BitCountMask = 0x7800'0000; + static constexpr size_t kTable32BitCountShift = 27; + static constexpr uint32_t kTable32BitValueMask = 0x07FF'FFFF; + + // + // Huffman Decode Common Metadata constants + // + static constexpr size_t kDcmVersionOffset = 4; + static constexpr size_t kDcmTableSizeOffset = 5; + // 32 bit table element if set, 16 bit otherwise + static constexpr uint8_t kDcmTableSize32BitsMask = 0x01; + // Initial table size of N elements where value is log2(N) - 1 + static constexpr uint8_t kDcmTableSizeInitialMask = 0xF0; + static constexpr size_t kDcmTableSizeInitialShift = 4; + + private: + inline bool IsLittleEndian() const { + int i = 1; + return (reinterpret_cast(&i)[0] == 1); + } + + inline uint32_t Swap32(const uint32_t x) const { + return (x << 24) | ((x & 0xFF00) << 8) | ((x >> 8) & 0xFF00) | (x >> 24); + } + + protected: + virtual ~DecodeStateHuffman() = default; + + template + void DecompressToBufferWith32BitTable(T* buffer); + + void DecompressToBufferWith16BitTable(int8_t* buffer); + + void InitNextBits(uint32_t& head_offset, uint32_t& head_hold, + const uint32_t*& head_next) const { + if (count_codewords_ > 0) { + head_offset = 32; + head_next = compressed_codewords_; + head_hold = *head_next++; + if (IsLittleEndian()) { + head_hold = Swap32(head_hold); + } + } + } + + inline uint32_t GetNextBits(size_t count, uint32_t& head_offset, + uint32_t& head_hold, + const uint32_t*& head_next) const { + TFLITE_DCHECK_LE(count, 31); // avoid 64 bit shift for below + TFLITE_DCHECK_GT(count, 0); + + uint32_t output = 0; + + if (count > head_offset) { + // reset head + const uint32_t mask = (1 << head_offset) - 1; + output = (head_hold & mask) << (count - head_offset); + count -= head_offset; + head_hold = *head_next++; + if (IsLittleEndian()) { + head_hold = Swap32(head_hold); + } + head_offset = 32; + } + + const uint32_t mask = (1 << count) - 1; + const size_t shift = head_offset - count; + output |= (head_hold >> shift) & mask; + head_offset -= count; + + return output; + } + + inline void PutBackBits(size_t count, uint32_t& head_offset, + uint32_t& head_hold, + const uint32_t*& head_next) const { + TFLITE_DCHECK_LE(count, 31); + + head_offset += count; + if (head_offset > 32) { + head_offset -= 32; + head_next--; + head_hold = *(head_next - 1); + if (IsLittleEndian()) { + head_hold = Swap32(head_hold); + } + } + } + + protected: + const uint32_t* compressed_codewords_ = nullptr; + size_t count_codewords_ = 0; + const void* huffman_tables_ = nullptr; + bool use_32bit_table_ = false; + uint8_t initial_table_size_ = 0; // log2(N) - 1 where N is + // number of table elements + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_HUFFMAN_H_ diff --git a/tensorflow/lite/micro/kernels/decode_state_lut.cc b/tensorflow/lite/micro/kernels/decode_state_lut.cc new file mode 100644 index 00000000000..477c21d80a7 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_lut.cc @@ -0,0 +1,630 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/decode_state_lut.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +TfLiteStatus DecodeStateLUT::Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) { + const uint8_t* const ancillary_data = GetTensorData(&ancillary); + if (ancillary_data[kDcmVersionOffset] != 1) { + MicroPrintf("unsupported version %u", ancillary_data[kDcmVersionOffset]); + return kTfLiteError; + } + + // resolve num_channels_ and use_alternate_axis_ + if (output.quantization.type == kTfLiteAffineQuantization && + output.quantization.params != nullptr) { + const TfLiteAffineQuantization* quantization = + reinterpret_cast(output.quantization.params); + num_channels_ = quantization->scale->size; + if ((quantization->quantized_dimension == output.dims->size - 1) && + num_channels_ > 1) { + use_alternate_axis_ = true; + } else if (quantization->quantized_dimension != 0) { + MicroPrintf("unsupported quantization axis %u", + quantization->quantized_dimension); + return kTfLiteError; + } + } + + compressed_indices_ = GetTensorData(&input); + count_indices_ = NumElements(&output); + elements_per_channel_ = + use_alternate_axis_ ? 1 : count_indices_ / num_channels_; + value_table_ = &ancillary_data[kDcmSizeInBytes]; + value_table_channel_stride_ = ancillary_data[kDcmValueTableStrideOffset]; + compressed_bit_width_ = + ancillary_data[kDcmParamsOffset] & kDcmParamsBitWidthMask; + + return kTfLiteOk; +} + +TfLiteStatus DecodeStateLUT::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteBool: + DecompressToBuffer(buffer); + break; + case kTfLiteFloat32: + DecompressToBuffer(buffer); + break; + case kTfLiteInt8: + DecompressToBuffer(buffer); + break; + case kTfLiteInt16: + DecompressToBuffer(buffer); + break; + case kTfLiteInt32: + DecompressToBuffer(buffer); + break; + case kTfLiteInt64: + DecompressToBuffer(buffer); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +template +T* DecodeStateLUT::DecompressToBuffer(void* buffer) { + TFLITE_DCHECK(compressed_bit_width_ <= kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + if (std::is_same::value && compressed_bit_width_ == 4 && + !use_alternate_axis_) { + DecompressToBufferWidth4_16(static_cast(buffer)); + } else if (std::is_same::value && compressed_bit_width_ == 3 && + !use_alternate_axis_) { + DecompressToBufferWidth3_32(static_cast(buffer)); + } else if (std::is_same::value && compressed_bit_width_ == 2 && + !use_alternate_axis_) { + DecompressToBufferWidth2_16(static_cast(buffer)); + } else { + DecompressToBufferWidthAny(static_cast(buffer)); + } + + return static_cast(buffer); +} + +template bool* DecodeStateLUT::DecompressToBuffer(void*); +template float* DecodeStateLUT::DecompressToBuffer(void*); +template int8_t* DecodeStateLUT::DecompressToBuffer(void*); +template int16_t* DecodeStateLUT::DecompressToBuffer(void*); +template int32_t* DecodeStateLUT::DecompressToBuffer(void*); +template int64_t* DecodeStateLUT::DecompressToBuffer(void*); + +void DecodeStateLUT::DecompressToBufferWidth4_16(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const size_t stride = value_table_channel_stride_; + const uint8_t* value_table = static_cast(value_table_); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + // process elements at start of channel up to next uint64_t alignment of + // compressed_indices_ + while (count > 0 && (current_offset & 0x0F)) { + const size_t index = GetNextTableIndexWidth4(current_offset++); + *buffer++ = value_table[index]; + count -= 1; + } + + // process elements in current channel in groups of 16 + if (count >= 16) { + const uint64_t* indices = reinterpret_cast( + &compressed_indices_[current_offset >> 1]); + + while (count >= 16) { + count -= 16; + uint64_t index = *indices++; + uint64_t value, value2; + + value = static_cast(value_table[(index >> 4) & 0x0F]); + value |= static_cast(value_table[index & 0x0F]) << 8; + value |= static_cast(value_table[(index >> 12) & 0x0F]) << 16; + value |= static_cast(value_table[(index >> 8) & 0x0F]) << 24; + value |= static_cast(value_table[(index >> 20) & 0x0F]) << 32; + value |= static_cast(value_table[(index >> 16) & 0x0F]) << 40; + value |= static_cast(value_table[(index >> 28) & 0x0F]) << 48; + value |= static_cast(value_table[(index >> 24) & 0x0F]) << 56; + + *reinterpret_cast(buffer) = value; + + value2 = static_cast(value_table[(index >> 36) & 0x0F]); + value2 |= static_cast(value_table[(index >> 32) & 0x0F]) << 8; + value2 |= static_cast(value_table[(index >> 44) & 0x0F]) + << 16; + value2 |= static_cast(value_table[(index >> 40) & 0x0F]) + << 24; + value2 |= static_cast(value_table[(index >> 52) & 0x0F]) + << 32; + value2 |= static_cast(value_table[(index >> 48) & 0x0F]) + << 40; + value2 |= static_cast(value_table[(index >> 60) & 0x0F]) + << 48; + value2 |= static_cast(value_table[(index >> 56) & 0x0F]) + << 56; + + *reinterpret_cast(buffer + 8) = value2; + + buffer += 16; + } + + current_offset = + (reinterpret_cast(indices) - compressed_indices_) + << 1; + } + + // process remaining elements in current channel + while (count > 0) { + count -= 1; + const size_t index = GetNextTableIndexWidth4(current_offset++); + *buffer++ = value_table[index]; + } + + value_table += stride; + } +} + +void DecodeStateLUT::DecompressToBufferWidth2_16(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const size_t stride = value_table_channel_stride_; + const uint8_t* value_table = static_cast(value_table_); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + // process elements at start of channel up to next uint32_t alignment of + // compressed_indices_ + while (count > 0 && (current_offset & 0x0F)) { + const size_t index = GetNextTableIndexWidth2(current_offset++); + *buffer++ = value_table[index]; + count -= 1; + } + + // process elements in current channel in groups of 16 + if (count >= 16) { + const uint32_t* indices = reinterpret_cast( + &compressed_indices_[current_offset >> 2]); + + while (count >= 16) { + count -= 16; + uint32_t index = *indices++; + uint64_t value, value2; + + value = static_cast(value_table[(index >> 6) & 0x03]); + value |= static_cast(value_table[(index >> 4) & 0x03]) << 8; + value |= static_cast(value_table[(index >> 2) & 0x03]) << 16; + value |= static_cast(value_table[index & 0x03]) << 24; + value |= static_cast(value_table[(index >> 14) & 0x03]) << 32; + value |= static_cast(value_table[(index >> 12) & 0x03]) << 40; + value |= static_cast(value_table[(index >> 10) & 0x03]) << 48; + value |= static_cast(value_table[(index >> 8) & 0x03]) << 56; + + *reinterpret_cast(buffer) = value; + + value2 = static_cast(value_table[(index >> 22) & 0x03]); + value2 |= static_cast(value_table[(index >> 20) & 0x03]) << 8; + value2 |= static_cast(value_table[(index >> 18) & 0x03]) + << 16; + value2 |= static_cast(value_table[(index >> 16) & 0x03]) + << 24; + value2 |= static_cast(value_table[(index >> 30) & 0x03]) + << 32; + value2 |= static_cast(value_table[(index >> 28) & 0x03]) + << 40; + value2 |= static_cast(value_table[(index >> 26) & 0x03]) + << 48; + value2 |= static_cast(value_table[(index >> 24) & 0x03]) + << 56; + + *reinterpret_cast(buffer + 8) = value2; + + buffer += 16; + } + + current_offset = + (reinterpret_cast(indices) - compressed_indices_) + << 2; + } + + // process remaining elements in current channel + while (count > 0) { + count -= 1; + const size_t index = GetNextTableIndexWidth2(current_offset++); + *buffer++ = value_table[index]; + } + + value_table += stride; + } +} + +void DecodeStateLUT::DecompressToBufferWidth3_32(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const size_t stride = value_table_channel_stride_; + const uint8_t* value_table = static_cast(value_table_); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + // process elements at start of channel up to next uint32_t alignment of + // compressed_indices_ + while (count > 0 && (current_offset & 0x1F)) { + const size_t index = GetNextTableIndexWidth3(current_offset++); + *buffer++ = value_table[index]; + count -= 1; + } + + // process elements in current channel in groups of 32 + if (count >= 32) { + const uint32_t* indices = reinterpret_cast( + &compressed_indices_[(current_offset >> 5) * 12]); + + while (count >= 32) { + count -= 32; + uint32_t index0 = *indices++; + uint32_t index1 = *indices++; + uint32_t index2 = *indices++; + uint64_t value, value2; + + value = static_cast(value_table[(index0 >> 5) & 0x07]); + value |= static_cast(value_table[(index0 >> 2) & 0x07]) << 8; + value |= + static_cast( + value_table[((index0 << 1) & 0b110) | ((index0 >> 15) & 0b1)]) + << 16; + value |= static_cast(value_table[(index0 >> 12) & 0x07]) + << 24; + value |= static_cast(value_table[(index0 >> 9) & 0x07]) << 32; + value |= + static_cast( + value_table[((index0 >> 6) & 0b100) | ((index0 >> 22) & 0b11)]) + << 40; + value |= static_cast(value_table[(index0 >> 19) & 0x07]) + << 48; + value |= static_cast(value_table[(index0 >> 16) & 0x07]) + << 56; + + *reinterpret_cast(buffer) = value; + + value2 = static_cast(value_table[(index0 >> 29) & 0x07]); + value2 |= static_cast(value_table[(index0 >> 26) & 0x07]) + << 8; + value2 |= + static_cast( + value_table[((index0 >> 23) & 0b110) | ((index1 >> 7) & 0b1)]) + << 16; + value2 |= static_cast(value_table[(index1 >> 4) & 0x07]) + << 24; + value2 |= static_cast(value_table[(index1 >> 1) & 0x07]) + << 32; + value2 |= + static_cast( + value_table[((index1 << 2) & 0b100) | ((index1 >> 14) & 0b11)]) + << 40; + value2 |= static_cast(value_table[(index1 >> 11) & 0x07]) + << 48; + value2 |= static_cast(value_table[(index1 >> 8) & 0x07]) + << 56; + + *reinterpret_cast(buffer + 8) = value2; + + value = static_cast(value_table[(index1 >> 21) & 0x07]); + value |= static_cast(value_table[(index1 >> 18) & 0x07]) << 8; + value |= + static_cast( + value_table[((index1 >> 15) & 0b110) | ((index1 >> 31) & 0b1)]) + << 16; + value |= static_cast(value_table[(index1 >> 28) & 0x07]) + << 24; + value |= static_cast(value_table[(index1 >> 25) & 0x07]) + << 32; + value |= + static_cast( + value_table[((index1 >> 22) & 0b100) | ((index2 >> 6) & 0b11)]) + << 40; + value |= static_cast(value_table[(index2 >> 3) & 0x07]) << 48; + value |= static_cast(value_table[(index2 >> 0) & 0x07]) << 56; + + *reinterpret_cast(buffer + 16) = value; + + value2 = static_cast(value_table[(index2 >> 13) & 0x07]); + value2 |= static_cast(value_table[(index2 >> 10) & 0x07]) + << 8; + value2 |= + static_cast( + value_table[((index2 >> 7) & 0b110) | ((index2 >> 23) & 0b1)]) + << 16; + value2 |= static_cast(value_table[(index2 >> 20) & 0x07]) + << 24; + value2 |= static_cast(value_table[(index2 >> 17) & 0x07]) + << 32; + value2 |= + static_cast( + value_table[((index2 >> 14) & 0b100) | ((index2 >> 30) & 0b11)]) + << 40; + value2 |= static_cast(value_table[(index2 >> 27) & 0x07]) + << 48; + value2 |= static_cast(value_table[(index2 >> 24) & 0x07]) + << 56; + + *reinterpret_cast(buffer + 24) = value2; + + buffer += 32; + current_offset += 32; + } + } + + // process remaining elements in current channel + while (count > 0) { + count -= 1; + const size_t index = GetNextTableIndexWidth3(current_offset++); + *buffer++ = value_table[index]; + } + + value_table += stride; + } +} + +// TODO(ddavis-2015): templating GetNextTableIndexWidth makes this method +// more than 2x faster, but with a large code size increase +template +void DecodeStateLUT::DecompressToBufferWidthAny(T* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + if (use_alternate_axis_) { + const size_t stride = value_table_channel_stride_; + size_t current_offset = 0; + size_t count = count_indices_; + + while (count > 0) { + const T* value_table = static_cast(value_table_); + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t index; + switch (compressed_bit_width_) { + case 1: + index = GetNextTableIndexWidth1(current_offset); + break; + case 2: + index = GetNextTableIndexWidth2(current_offset); + break; + case 3: + index = GetNextTableIndexWidth3(current_offset); + break; + case 4: + index = GetNextTableIndexWidth4(current_offset); + break; + case 5: + index = GetNextTableIndexWidth5(current_offset); + break; + case 6: + index = GetNextTableIndexWidth6(current_offset); + break; + case 7: + index = GetNextTableIndexWidth7(current_offset); + break; + } + current_offset++; + *buffer++ = value_table[index]; + value_table += stride; + } + count -= num_channels_; + } + } else { + const size_t stride = value_table_channel_stride_; + const T* value_table = static_cast(value_table_); + const size_t max_count = elements_per_channel_; + size_t current_offset = 0; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + + while (count-- > 0) { + size_t index; + switch (compressed_bit_width_) { + case 1: + index = GetNextTableIndexWidth1(current_offset); + break; + case 2: + index = GetNextTableIndexWidth2(current_offset); + break; + case 3: + index = GetNextTableIndexWidth3(current_offset); + break; + case 4: + index = GetNextTableIndexWidth4(current_offset); + break; + case 5: + index = GetNextTableIndexWidth5(current_offset); + break; + case 6: + index = GetNextTableIndexWidth6(current_offset); + break; + case 7: + index = GetNextTableIndexWidth7(current_offset); + break; + } + current_offset++; + *buffer++ = value_table[index]; + } + value_table += stride; + } + } +} + +template void DecodeStateLUT::DecompressToBufferWidthAny(bool*); +template void DecodeStateLUT::DecompressToBufferWidthAny(float*); +template void DecodeStateLUT::DecompressToBufferWidthAny(int8_t*); +template void DecodeStateLUT::DecompressToBufferWidthAny(int16_t*); +template void DecodeStateLUT::DecompressToBufferWidthAny(int32_t*); +template void DecodeStateLUT::DecompressToBufferWidthAny(int64_t*); + +inline size_t DecodeStateLUT::GetNextTableIndexWidth7( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 3) * 7; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b111) { + case 0: + return indices[0] >> 1; + case 1: + return ((indices[0] & 0b1) << 6) | (indices[1] >> 2); + case 2: + return ((indices[1] & 0b11) << 5) | (indices[2] >> 3); + case 3: + return ((indices[2] & 0b111) << 4) | (indices[3] >> 4); + case 4: + return ((indices[3] & 0x0F) << 3) | (indices[4] >> 5); + case 5: + return ((indices[4] & 0x1F) << 2) | (indices[5] >> 6); + case 6: + return ((indices[5] & 0x3F) << 1) | (indices[6] >> 7); + case 7: + return indices[6] & 0x7F; + } + // NOTREACHED + return 0; +} + +inline size_t DecodeStateLUT::GetNextTableIndexWidth6( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 2) * 3; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b11) { + case 0: + return indices[0] >> 2; + case 1: + return ((indices[0] & 0b11) << 4) | (indices[1] >> 4); + case 2: + return ((indices[1] & 0x0F) << 2) | (indices[2] >> 6); + case 3: + return indices[2] & 0x3F; + } + // NOTREACHED + return 0; +} + +inline size_t DecodeStateLUT::GetNextTableIndexWidth5( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 3) * 5; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b111) { + case 0: + return indices[0] >> 3; + case 1: + return ((indices[0] & 0b111) << 2) | (indices[1] >> 6); + case 2: + return (indices[1] >> 1) & 0x1F; + case 3: + return ((indices[1] & 0b1) << 4) | (indices[2] >> 4); + case 4: + return ((indices[2] & 0x0F) << 1) | (indices[3] >> 7); + case 5: + return (indices[3] >> 2) & 0x1F; + case 6: + return ((indices[3] & 0b11) << 3) | (indices[4] >> 5); + case 7: + return indices[4] & 0x1F; + } + // NOTREACHED + return 0; +} + +inline size_t DecodeStateLUT::GetNextTableIndexWidth4( + const size_t current_offset) { + if (current_offset & 1) { + return compressed_indices_[current_offset >> 1] & 0x0F; + } else { + return compressed_indices_[current_offset >> 1] >> 4; + } +} + +inline size_t DecodeStateLUT::GetNextTableIndexWidth3( + const size_t current_offset) { + const size_t current_byte_index = (current_offset >> 3) * 3; + const uint8_t* indices = &compressed_indices_[current_byte_index]; + switch (current_offset & 0b111) { + case 0: + return indices[0] >> 5; + case 1: + return (indices[0] >> 2) & 0b111; + case 2: + return ((indices[0] & 0b11) << 1) | (indices[1] >> 7); + case 3: + return (indices[1] >> 4) & 0b111; + case 4: + return (indices[1] >> 1) & 0b111; + case 5: + return ((indices[1] & 0b1) << 2) | (indices[2] >> 6); + case 6: + return (indices[2] >> 3) & 0b111; + case 7: + return indices[2] & 0b111; + } + // NOTREACHED + return 0; +} + +inline size_t DecodeStateLUT::GetNextTableIndexWidth2( + const size_t current_offset) { + if (current_offset & 0b10) { + if (current_offset & 1) { + return compressed_indices_[current_offset >> 2] & 0x03; + } else { + return (compressed_indices_[current_offset >> 2] >> 2) & 0x03; + } + } else { + if (current_offset & 1) { + return (compressed_indices_[current_offset >> 2] >> 4) & 0x03; + } else { + return (compressed_indices_[current_offset >> 2] >> 6) & 0x03; + } + } +} + +inline size_t DecodeStateLUT::GetNextTableIndexWidth1( + const size_t current_offset) { + const size_t shift = ~current_offset & 0b111; + return (compressed_indices_[current_offset >> 3] >> shift) & 0b1; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state_lut.h b/tensorflow/lite/micro/kernels/decode_state_lut.h new file mode 100644 index 00000000000..dbb64683960 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_lut.h @@ -0,0 +1,92 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_LUT_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_LUT_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" + +namespace tflite { + +struct DecodeStateLUT : public DecodeState { + DecodeStateLUT() = delete; + + DecodeStateLUT(const TfLiteContext* context, MicroProfilerInterface* profiler) + : DecodeState(context, profiler) {} + + virtual TfLiteStatus Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) override; + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + protected: + // LUT compression constants + static constexpr size_t kMaxBitWidth = 7; + static constexpr size_t kMaxValueTableChannelStride = 128; + + private: + // LUT Decode Common Metadata constants + static constexpr size_t kDcmVersionOffset = 4; + static constexpr size_t kDcmParamsOffset = 5; + static constexpr uint8_t kDcmParamsBitWidthMask = 0x07; + static constexpr size_t kDcmValueTableStrideOffset = 6; + + protected: + virtual ~DecodeStateLUT() = default; + + template + T* DecompressToBuffer(void* buffer); + + // optimized C++ for INT8, use_alt_axis == false + void DecompressToBufferWidth4_16(int8_t* buffer); + void DecompressToBufferWidth3_32(int8_t* buffer); + void DecompressToBufferWidth2_16(int8_t* buffer); + + // generic C++ for any bit width and value table type + template + void DecompressToBufferWidthAny(T* buffer); + + // Optimized C++ table index fetch + inline size_t GetNextTableIndexWidth7(const size_t current_offset); + inline size_t GetNextTableIndexWidth6(const size_t current_offset); + inline size_t GetNextTableIndexWidth5(const size_t current_offset); + inline size_t GetNextTableIndexWidth4(const size_t current_offset); + inline size_t GetNextTableIndexWidth3(const size_t current_offset); + inline size_t GetNextTableIndexWidth2(const size_t current_offset); + inline size_t GetNextTableIndexWidth1(const size_t current_offset); + + protected: + const uint8_t* compressed_indices_ = nullptr; + size_t count_indices_ = 0; + size_t num_channels_ = 1; + size_t elements_per_channel_ = 0; // computed from use_alternate_axis_ + const void* value_table_ = nullptr; // Pointer into FlatBuffer values + uint8_t value_table_channel_stride_ = 0; // elements per channel + uint8_t compressed_bit_width_ = 0; // 1 to 7 bits + bool use_alternate_axis_ = false; // shape channel axis: + // false = first, true = last + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_LUT_H_ diff --git a/tensorflow/lite/micro/kernels/decode_state_prune.cc b/tensorflow/lite/micro/kernels/decode_state_prune.cc new file mode 100644 index 00000000000..f5ff7ac6a58 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_prune.cc @@ -0,0 +1,199 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +TfLiteStatus DecodeStatePrune::Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) { + const uint8_t* const ancillary_data = GetTensorData(&ancillary); + if (ancillary_data[kDcmVersionOffset] != 1) { + MicroPrintf("unsupported version %u", ancillary_data[kDcmVersionOffset]); + return kTfLiteError; + } + + // resolve num_channels_, use_alternate_axis_, and zero points + if (output.quantization.type == kTfLiteAffineQuantization && + output.quantization.params != nullptr) { + const TfLiteAffineQuantization* quantization = + reinterpret_cast(output.quantization.params); + num_channels_ = quantization->scale->size; + if ((quantization->quantized_dimension == output.dims->size - 1) && + num_channels_ > 1) { + use_alternate_axis_ = true; + } else if (quantization->quantized_dimension != 0) { + MicroPrintf("unsupported quantization axis %u", + quantization->quantized_dimension); + return kTfLiteError; + } + + if (output.type != kTfLiteInt8) { + // make sure all zero points are 0 (zero) + for (size_t i = 0; i < num_channels_; i++) { + TF_LITE_ENSURE(const_cast(context_), + quantization->zero_point->data[i] == 0); + } + } + + if (num_channels_ > 1 && output.type == kTfLiteInt8) { + // copy zero points + MicroContext* micro_context = GetMicroContext(context_); + const size_t bufsize = num_channels_ * sizeof(*zero_points_); + zero_points_ = static_cast( + micro_context->AllocatePersistentBuffer(bufsize)); + if (zero_points_ == nullptr) { + MicroPrintf("unable to allocate zero_points_"); + return kTfLiteError; + } + std::copy_n(quantization->zero_point->data, num_channels_, zero_points_); + } else { + single_zero_point_ = quantization->zero_point->data[0]; + } + } + + compressed_indices_ = GetTensorData(&input); + count_indices_ = NumElements(&output); + elements_per_channel_ = + use_alternate_axis_ ? 1 : count_indices_ / num_channels_; + value_table_ = &ancillary_data[kDcmSizeInBytes]; + + return kTfLiteOk; +} + +TfLiteStatus DecodeStatePrune::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteBool: + DecompressToBuffer(buffer); + break; + case kTfLiteFloat32: + DecompressToBuffer(buffer); + break; + case kTfLiteInt8: + if (num_channels_ > 1) { + DecompressToBufferPerChannelInt8(buffer); + } else { + DecompressToBuffer(buffer); + } + break; + case kTfLiteInt16: + DecompressToBuffer(buffer); + break; + case kTfLiteInt32: + DecompressToBuffer(buffer); + break; + case kTfLiteInt64: + DecompressToBuffer(buffer); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +template +void DecodeStatePrune::DecompressToBuffer(void* vp) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + T* buffer = static_cast(vp); + const T* value_table = static_cast(value_table_); + const size_t max_count = count_indices_; + const uint8_t* const indices = compressed_indices_; + + for (size_t index = 0; index < max_count; index++) { + size_t shift = ~index & 0b111; + size_t is_not_zp = (indices[index >> 3] >> shift) & 0b1; + + if (is_not_zp) { + *buffer++ = *value_table++; + } else { + *buffer++ = single_zero_point_; + } + } +} + +void DecodeStatePrune::DecompressToBufferPerChannelInt8(void* vp) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int8_t* buffer = static_cast(vp); + size_t current_offset = 0; + const uint8_t* const indices = compressed_indices_; + const int8_t* value_table = static_cast(value_table_); + + if (use_alternate_axis_) { + const size_t max_channels = num_channels_; + size_t count = count_indices_; + + while (count > 0) { + for (size_t channel = 0; channel < max_channels; channel++) { + const int8_t zp = zero_points_[channel]; + size_t shift = ~current_offset & 0b111; + size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1; + + if (is_not_zp) { + *buffer++ = *value_table++; + } else { + *buffer++ = zp; + } + current_offset++; + } + count -= max_channels; + } + } else { + const size_t max_count = elements_per_channel_; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + const int8_t zp = zero_points_[channel]; + + while (count-- > 0) { + size_t shift = ~current_offset & 0b111; + size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1; + + if (is_not_zp) { + *buffer++ = *value_table++; + } else { + *buffer++ = zp; + } + current_offset++; + } + } + } +} + +template void DecodeStatePrune::DecompressToBuffer(void*); +template void DecodeStatePrune::DecompressToBuffer(void*); +template void DecodeStatePrune::DecompressToBuffer(void*); +template void DecodeStatePrune::DecompressToBuffer(void*); + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state_prune.h b/tensorflow/lite/micro/kernels/decode_state_prune.h new file mode 100644 index 00000000000..de5ddd84249 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_prune.h @@ -0,0 +1,69 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_PRUNE_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_PRUNE_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" + +namespace tflite { + +struct DecodeStatePrune : public DecodeState { + DecodeStatePrune() = delete; + + DecodeStatePrune(const TfLiteContext* context, + MicroProfilerInterface* profiler) + : DecodeState(context, profiler) {} + + virtual TfLiteStatus Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) override; + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + private: + // Prune Decode Common Metadata constants + static constexpr size_t kDcmVersionOffset = 4; + + protected: + virtual ~DecodeStatePrune() = default; + + template + void DecompressToBuffer(void* buffer); + + void DecompressToBufferPerChannelInt8(void* buffer); + + protected: + const uint8_t* compressed_indices_ = nullptr; + size_t count_indices_ = 0; + size_t num_channels_ = 1; + size_t elements_per_channel_ = 0; // computed from use_alternate_axis_ + const void* value_table_ = nullptr; // original non-pruned values + int8_t* zero_points_ = nullptr; // quantized per-channel zero points + int8_t single_zero_point_ = 0; // single channel zero point + bool use_alternate_axis_ = false; // shape channel axis: + // false = first, true = last + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_PRUNE_H_ diff --git a/tensorflow/lite/micro/kernels/decode_test.cc b/tensorflow/lite/micro/kernels/decode_test.cc new file mode 100644 index 00000000000..9a9f7931f5e --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_test.cc @@ -0,0 +1,1224 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" +#include "tensorflow/lite/micro/kernels/decode_state_huffman.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +struct TensorInDatum { + const void* const data; + const TfLiteIntArray& dims; +}; + +struct TensorOutDatum { + void* const data; + const TfLiteIntArray& dims; + const TfLiteType type; + const TfLiteFloatArray& scales; + const TfLiteIntArray& zero_points; + const int quantized_dimension; + + // initialized by CreatePerChannelQuantizedTensor + const TfLiteAffineQuantization affine_quantization; +}; + +template +struct AncillaryData { + AncillaryData() = delete; + AncillaryData(const uint8_t (&dcm)[tflite::DecodeState::kDcmSizeInBytes], + const T (&values)[N]) { + std::copy(std::begin(dcm), std::end(dcm), std::begin(dcm_)); + std::copy(std::begin(values), std::end(values), std::begin(value_table_)); + } + + private: + uint8_t dcm_[tflite::DecodeState::kDcmSizeInBytes]; + T value_table_[N > 0 ? N : 1]; // assure not zero length +}; + +// +// LUT test data +// +constexpr int kBitWidthLUT = 2; + +constexpr int8_t kAncillaryDataLUT0[] = {1, 2, 3, 4}; +constexpr int16_t kAncillaryDataLUT1[] = {5, 6, 7, 8}; + +constexpr uint8_t kDcmLUT0[tflite::DecodeState::kDcmSizeInBytes] = { + tflite::DecodeState::kDcmTypeLUT, // type: LUT + 1, // DCM version: 1 + 0, // reserved + 0, // reserved + 1, // LUT version: 1 + kBitWidthLUT, // Parameters: bit-width 2 + std::size(kAncillaryDataLUT0), // channel stride +}; + +constexpr uint8_t kDcmLUT1[tflite::DecodeState::kDcmSizeInBytes] = { + tflite::DecodeState::kDcmTypeLUT, // type: LUT + 1, // DCM version: 1 + 0, // reserved + 0, // reserved + 1, // LUT version: 1 + kBitWidthLUT, // Parameters: bit-width 2 + std::size(kAncillaryDataLUT1), // channel stride +}; + +// Align the tensor data the same as a Buffer in the TfLite schema +alignas(16) const uint8_t kEncodedLUT[] = {0x1B, 0xE4}; + +// Tensor shapes as TfLiteIntArray +constexpr int kOutputShapeLUT[] = {3, 1, 2, 4}; +constexpr int kEncodedShapeLUT[] = {1, sizeof(kEncodedLUT)}; + +constexpr int8_t kExpectLUT0[] = {1, 2, 3, 4, 4, 3, 2, 1}; +constexpr int16_t kExpectLUT1[] = {5, 6, 7, 8, 8, 7, 6, 5}; + +// +// Prune test data +// +constexpr int8_t kAncillaryDataPrune0[] = { + 1, 2, 3, 4, // 0 + 1, 2, 3, 4, // 1 + 1, 2, 3, 4, // 2 + 1, 2, 3, 4, // 3 + 1, 2, 3, 4 // 4 +}; +constexpr int16_t kAncillaryDataPrune1[] = { + 5, 6, 7, 8, // 0 + 5, 6, 7, 8, // 1 + 5, 6, 7, 8, // 2 + 5, 6, 7, 8, // 3 + 5, 6, 7, 8 // 4 +}; +constexpr float kAncillaryDataPrune2[] = { + 9.0f, 10.0f, 11.0f, 12.0f, // 0 + 9.0f, 10.0f, 11.0f, 12.0f, // 1 + 9.0f, 10.0f, 11.0f, 12.0f, // 2 + 9.0f, 10.0f, 11.0f, 12.0f, // 3 + 9.0f, 10.0f, 11.0f, 12.0f // 4 +}; +constexpr int8_t kAncillaryDataPrune3[] = { + 13, 14, 15, 16, // 0 + 13, 14, 15, 16, // 1 + 13, 14, 15, 16, // 2 + 13, 14, 15, 16, // 3 + 13, 14, 15, 16 // 4 +}; +constexpr int8_t kAncillaryDataPrune4[] = { + 17, 18, 19, 20, // 0 + 17, 18, 19, 20, // 1 + 17, 18, 19, 20, // 2 + 17, 18, 19, 20, // 3 + 17, 18, 19, 20 // 4 +}; + +constexpr uint8_t kDcmPrune[tflite::DecodeState::kDcmSizeInBytes] = { + tflite::DecodeState::kDcmTypePrune, // type: Prune + 1, // DCM version: 1 + 0, // reserved + 0, // reserved + 1, // Prune version: 1 +}; + +// Align the tensor data the same as a Buffer in the TfLite schema +alignas(16) const uint8_t kEncodedPrune[] = {0xA5, 0xA5, 0xA5, 0xA5, 0xA5}; + +// Tensor shapes as TfLiteIntArray +constexpr int kOutputShapePrune[] = {3, 2, 5, 4}; +constexpr int kEncodedShapePrune[] = {1, sizeof(kEncodedPrune)}; + +// Quantization datum as TfLiteIntArray. +// Scales are modified by FloatArrayFromFloats. As globals they cannot be +// without causing a processor exception. +float kScalesPrune0[] = {2, 1.0f, 1.0f}; +constexpr int kZeroPointsPrune0[] = {2, -128, -64}; +float kScalesPrune1[] = {4, 1.0f, 1.0f, 1.0f, 1.0f}; +constexpr int kZeroPointsPrune1[] = {4, 0, 0, 0, 0}; +float kScalesPrune4[] = {4, 1.0f, 1.0f, 1.0f, 1.0f}; +constexpr int kZeroPointsPrune4[] = {4, -126, -62, -30, -14}; + +constexpr int8_t kExpectPrune0[] = { + 1, -128, 2, -128, -128, 3, -128, 4, 1, -128, // 0 + 2, -128, -128, 3, -128, 4, 1, -128, 2, -128, // 0 + -64, 3, -64, 4, 1, -64, 2, -64, -64, 3, // 1 + -64, 4, 1, -64, 2, -64, -64, 3, -64, 4 // 1 +}; +constexpr int16_t kExpectPrune1[] = { + 5, 0, 6, 0, // 0 + 0, 7, 0, 8, // 1 + 5, 0, 6, 0, // 2 + 0, 7, 0, 8, // 3 + 5, 0, 6, 0, // 4 + 0, 7, 0, 8, // 5 + 5, 0, 6, 0, // 6 + 0, 7, 0, 8, // 7 + 5, 0, 6, 0, // 8 + 0, 7, 0, 8 // 9 +}; +constexpr float kExpectPrune2[] = { + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // 0 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // 1 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // 2 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // 3 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f // 4 +}; +constexpr int8_t kExpectPrune3[] = { + 13, 0, 14, 0, 0, 15, 0, 16, // 0 + 13, 0, 14, 0, 0, 15, 0, 16, // 1 + 13, 0, 14, 0, 0, 15, 0, 16, // 2 + 13, 0, 14, 0, 0, 15, 0, 16, // 3 + 13, 0, 14, 0, 0, 15, 0, 16 // 4 +}; +constexpr int8_t kExpectPrune4[] = { + 17, -62, 18, -14, // 0 + -126, 19, -30, 20, // 1 + 17, -62, 18, -14, // 2 + -126, 19, -30, 20, // 3 + 17, -62, 18, -14, // 4 + -126, 19, -30, 20, // 5 + 17, -62, 18, -14, // 6 + -126, 19, -30, 20, // 7 + 17, -62, 18, -14, // 8 + -126, 19, -30, 20 // 9 +}; + +// +// Huffman test data +// +// Test data is based on an intial table size of 8 elements with the following +// codeword map: +// 'm' : '000', +// 'a' : '001', +// 'j' : '010', +// 'o' : '0110', +// 'h' : '0111', +// 'd' : '1000', +// 'i' : '10010', +// 'f' : '10011', +// 'c' : '1010', +// 'g' : '101100', +// 'l' : '101101', +// 'k' : '10111', +// 'e' : '1100', +// 'b' : '1101', +// 'p' : '1110', +// 'n' : '1111' +// + +constexpr int kHuffmanInitial = 2; // log2(8) - 1 +constexpr int kHuffmanShift = + tflite::DecodeStateHuffman::kDcmTableSizeInitialShift; +constexpr int kHuffman32 = tflite::DecodeStateHuffman::kDcmTableSize32BitsMask; + +constexpr uint8_t kDcmHuffman16[tflite::DecodeState::kDcmSizeInBytes] = { + tflite::DecodeState::kDcmTypeHuffman, // type: Huffman + 1, // DCM version: 1 + 0, // reserved + 0, // reserved + 1, // Huffman version: 1 + kHuffmanInitial << kHuffmanShift, // Table size: 8 16-bit elements +}; + +constexpr uint8_t kDcmHuffman32[tflite::DecodeState::kDcmSizeInBytes] = { + tflite::DecodeState::kDcmTypeHuffman, // type: Huffman + 1, // DCM version: 1 + 0, // reserved + 0, // reserved + 1, // Huffman version: 1 + (kHuffmanInitial << kHuffmanShift) | + kHuffman32, // Table size: 8 32-bit elements +}; + +// Align the tensor data the same as a Buffer in the TfLite schema +alignas(16) const uint8_t kEncodedHuffman[] = { + 0xF0, 0x0D, 0x79, 0xFC, 0x9C, 0x1A, 0x6E, 0x4C, 0x32, 0xAF, 0x29, + 0xB3, 0x5D, 0xF6, 0x36, 0x02, 0x50, 0x15, 0x8C, 0xA7, 0x95, 0xDB, + 0x29, 0x68, 0x3F, 0xBA, 0xB7, 0xED, 0xB1, 0x19, 0xE2, 0xE0, +}; + +// Tensor shapes as TfLiteIntArray +constexpr int kOutputShapeHuffman[] = {1, 64}; +constexpr int kEncodedShapeHuffman[] = {1, sizeof(kEncodedHuffman)}; + +constexpr int8_t kExpectHuffmanInt8[] = { + 'n', 'm', 'm', 'a', 'c', 'n', 'a', 'n', 'e', 'f', 'd', 'a', 'c', + 'o', 'p', 'j', 'o', 'm', 'e', 'c', 'k', 'i', 'f', 'o', 'o', 'k', + 'h', 'b', 'd', 'b', 'd', 'm', 'j', 'j', 'd', 'm', 'j', 'g', 'o', + 'j', 'f', 'e', 'c', 'p', 'b', 'i', 'i', 'b', 'm', 'a', 'n', 'b', + 'b', 'j', 'b', 'n', 'l', 'g', 'j', 'a', 'f', 'e', 'j', 'p', +}; +constexpr int16_t kExpectHuffmanInt16[] = { + 'n', 'm', 'm', 'a', 'c', 'n', 'a', 'n', 'e', 'f', 'd', 'a', 'c', + 'o', 'p', 'j', 'o', 'm', 'e', 'c', 'k', 'i', 'f', 'o', 'o', 'k', + 'h', 'b', 'd', 'b', 'd', 'm', 'j', 'j', 'd', 'm', 'j', 'g', 'o', + 'j', 'f', 'e', 'c', 'p', 'b', 'i', 'i', 'b', 'm', 'a', 'n', 'b', + 'b', 'j', 'b', 'n', 'l', 'g', 'j', 'a', 'f', 'e', 'j', 'p', +}; + +constexpr uint16_t kAncillaryDataHuffman16[] = { + // Table 0: + 0x986D, // [0]: size= 3 symbol=m + 0x9861, // [1]: size= 3 symbol=a + 0x986A, // [2]: size= 3 symbol=j + 0x0005, // [3]: size= 0 offset= 5 (@8) + 0x0806, // [4]: size= 1 offset= 6 (@10) + 0x0809, // [5]: size= 1 offset= 9 (@14) + 0x000E, // [6]: size= 0 offset= 14 (@20) + 0x000F, // [7]: size= 0 offset= 15 (@22) + // Table 1: + 0x886F, // [8]: size= 1 symbol=o + 0x8868, // [9]: size= 1 symbol=h + // Table 2: + 0x8864, // [10]: size= 1 symbol=d + 0x8864, // [11]: size= 1 symbol=d + 0x9069, // [12]: size= 2 symbol=i + 0x9066, // [13]: size= 2 symbol=f + // Table 3: + 0x8863, // [14]: size= 1 symbol=c + 0x8863, // [15]: size= 1 symbol=c + 0x0002, // [16]: size= 0 offset= 2 (@18) + 0x906B, // [17]: size= 2 symbol=k + // Table 4: + 0x8867, // [18]: size= 1 symbol=g + 0x886C, // [19]: size= 1 symbol=l + // Table 5: + 0x8865, // [20]: size= 1 symbol=e + 0x8862, // [21]: size= 1 symbol=b + // Table 6: + 0x8870, // [22]: size= 1 symbol=p + 0x886E, // [23]: size= 1 symbol=n +}; +constexpr uint32_t kAncillaryDataHuffman32[] = { + // Table 0: + 0x9800006D, // [0]: size= 3 symbol=m + 0x98000061, // [1]: size= 3 symbol=a + 0x9800006A, // [2]: size= 3 symbol=j + 0x00000005, // [3]: size= 0 offset= 5 (@8) + 0x08000006, // [4]: size= 1 offset= 6 (@10) + 0x08000009, // [5]: size= 1 offset= 9 (@14) + 0x0000000E, // [6]: size= 0 offset= 14 (@20) + 0x0000000F, // [7]: size= 0 offset= 15 (@22) + // Table 1: + 0x8800006F, // [8]: size= 1 symbol=o + 0x88000068, // [9]: size= 1 symbol=h + // Table 2: + 0x88000064, // [10]: size= 1 symbol=d + 0x88000064, // [11]: size= 1 symbol=d + 0x90000069, // [12]: size= 2 symbol=i + 0x90000066, // [13]: size= 2 symbol=f + // Table 3: + 0x88000063, // [14]: size= 1 symbol=c + 0x88000063, // [15]: size= 1 symbol=c + 0x00000002, // [16]: size= 0 offset= 2 (@18) + 0x9000006B, // [17]: size= 2 symbol=k + // Table 4: + 0x88000067, // [18]: size= 1 symbol=g + 0x8800006C, // [19]: size= 1 symbol=l + // Table 5: + 0x88000065, // [20]: size= 1 symbol=e + 0x88000062, // [21]: size= 1 symbol=b + // Table 6: + 0x88000070, // [22]: size= 1 symbol=p + 0x8800006E, // [23]: size= 1 symbol=n +}; + +template +TfLiteStatus CheckOutput(const TfLiteTensor& output, + const void* const expected) { + const T* const expected_data = reinterpret_cast(expected); + const T* const output_data = tflite::GetTensorData(&output); + + constexpr float kTolerance = 1e-5; + const size_t kOutputCount = tflite::NumElements(&output); + for (size_t i = 0; i < kOutputCount; i++) { + TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], kTolerance); + TF_LITE_MICRO_CHECK_FAIL(); + } + + return kTfLiteOk; +} + +template +TfLiteStatus ExecuteDecodeTest( + TfLiteTensor* tensors, const TFLMRegistration& registration, + const std::initializer_list& expected, + const std::initializer_list* amr = + nullptr) { + int kInputArrayData[kNumInputs + 1] = {kNumInputs}; + for (size_t i = 0; i < kNumInputs; i++) { + kInputArrayData[i + 1] = i; + } + TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData); + + int kOutputArrayData[kNumOutputs + 1] = {kNumOutputs}; + for (size_t i = 0; i < kNumOutputs; i++) { + kOutputArrayData[i + 1] = i + kNumInputs; + } + TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData); + + micro::KernelRunner runner(registration, tensors, kNumInputs + kNumOutputs, + inputs_array, outputs_array, nullptr); + + if (amr != nullptr) { + runner.GetFakeMicroContext()->SetDecompressionMemory(*amr); + } + + if (runner.InitAndPrepare() != kTfLiteOk || runner.Invoke() != kTfLiteOk) { + return kTfLiteError; + } + + const TfLiteTensor* const output_tensors = &tensors[kNumInputs]; + TfLiteStatus status = kTfLiteError; + for (size_t i = 0; i < kNumOutputs; i++) { + switch (output_tensors[i].type) { + case kTfLiteInt8: + status = CheckOutput(output_tensors[i], expected.begin()[i]); + break; + case kTfLiteInt16: + status = CheckOutput(output_tensors[i], expected.begin()[i]); + break; + case kTfLiteFloat32: + status = CheckOutput(output_tensors[i], expected.begin()[i]); + break; + default: + TF_LITE_MICRO_FAIL("unsupported tensor type in test"); + break; + } + } + + return status; +} + +template +void TestDecode( + const std::initializer_list& encodes, + const std::initializer_list& ancillaries, + const std::initializer_list& outputs, + const std::initializer_list& expected, + const TFLMRegistration& registration, + const std::initializer_list* amr = + nullptr, + const TfLiteStatus expected_status = kTfLiteOk) { + TfLiteTensor tensors[kNumInputs + kNumOutputs] = {}; + + for (size_t i = 0; i < kNumInputs; i += 2) { + const TensorInDatum& tid_encode = *encodes.begin()[i / 2]; + tensors[i] = CreateTensor(tid_encode.data, + const_cast(&tid_encode.dims), + false, kTfLiteUInt8); + // must be a const tensor + tensors[i].allocation_type = kTfLiteMmapRo; + const TensorInDatum& tid_ancillary = *ancillaries.begin()[i / 2]; + tensors[i + 1] = CreateTensor( + tid_ancillary.data, const_cast(&tid_ancillary.dims), + false, kTfLiteUInt8); + // must be a const tensor + tensors[i + 1].allocation_type = kTfLiteMmapRo; + } + for (size_t i = 0; i < kNumOutputs; i++) { + const TensorOutDatum& tod = *outputs.begin()[i]; + if (tod.scales.size == 0) { + tensors[i + kNumInputs] = CreateTensor( + tod.data, const_cast(&tod.dims), false, tod.type); + } else { + tensors[i + kNumInputs] = CreatePerChannelQuantizedTensor( + tod.data, const_cast(&tod.dims), + const_cast(&tod.scales), + const_cast(&tod.zero_points), + const_cast(&tod.affine_quantization), + tod.quantized_dimension, false, tod.type); + } + } + + TfLiteStatus s = ExecuteDecodeTest( + tensors, registration, expected, amr); + TF_LITE_MICRO_EXPECT_EQ(s, expected_status); +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +using tflite::testing::AncillaryData; +using tflite::testing::kAncillaryDataHuffman16; +using tflite::testing::kAncillaryDataHuffman32; +using tflite::testing::kAncillaryDataLUT0; +using tflite::testing::kAncillaryDataLUT1; +using tflite::testing::kAncillaryDataPrune0; +using tflite::testing::kAncillaryDataPrune1; +using tflite::testing::kAncillaryDataPrune2; +using tflite::testing::kAncillaryDataPrune3; +using tflite::testing::kAncillaryDataPrune4; +using tflite::testing::kDcmHuffman16; +using tflite::testing::kDcmHuffman32; +using tflite::testing::kDcmLUT0; +using tflite::testing::kDcmLUT1; +using tflite::testing::kDcmPrune; +using tflite::testing::kEncodedHuffman; +using tflite::testing::kEncodedLUT; +using tflite::testing::kEncodedPrune; +using tflite::testing::kEncodedShapeHuffman; +using tflite::testing::kEncodedShapeLUT; +using tflite::testing::kEncodedShapePrune; +using tflite::testing::kExpectHuffmanInt16; +using tflite::testing::kExpectHuffmanInt8; +using tflite::testing::kExpectLUT0; +using tflite::testing::kExpectLUT1; +using tflite::testing::kExpectPrune0; +using tflite::testing::kExpectPrune1; +using tflite::testing::kExpectPrune2; +using tflite::testing::kExpectPrune3; +using tflite::testing::kExpectPrune4; +using tflite::testing::kOutputShapeHuffman; +using tflite::testing::kOutputShapeLUT; +using tflite::testing::kOutputShapePrune; +using tflite::testing::kScalesPrune0; +using tflite::testing::kScalesPrune1; +using tflite::testing::kScalesPrune4; +using tflite::testing::kZeroPointsPrune0; +using tflite::testing::kZeroPointsPrune1; +using tflite::testing::kZeroPointsPrune4; +using tflite::testing::TensorInDatum; +using tflite::testing::TensorOutDatum; + +TF_LITE_MICRO_TEST(DecodeSingleTensor) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectLUT0)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmLUT0}, {kAncillaryDataLUT0}}; + + constexpr int kAncillaryShapeLUT[] = {1, sizeof(kAncillaryData)}; + + const TfLiteIntArray* const encoded_dims = + tflite::testing::IntArrayFromInts(kEncodedShapeLUT); + static const TensorInDatum tid_encode = { + kEncodedLUT, + *encoded_dims, + }; + static constexpr std::initializer_list encodes = { + &tid_encode, + }; + + const TfLiteIntArray* const ancillary_dims = + tflite::testing::IntArrayFromInts(kAncillaryShapeLUT); + static const TensorInDatum tid_ancillary = { + &kAncillaryData, + *ancillary_dims, + }; + static constexpr std::initializer_list ancillaries = { + &tid_ancillary}; + + const TfLiteIntArray* const output_dims = + tflite::testing::IntArrayFromInts(kOutputShapeLUT); + constexpr float output_scales_data[] = {0}; + const TfLiteFloatArray* const output_scales = + tflite::testing::FloatArrayFromFloats(output_scales_data); + constexpr int output_zero_points_data[] = {0}; + const TfLiteIntArray* const output_zero_points = + tflite::testing::IntArrayFromInts(output_zero_points_data); + static const TensorOutDatum tod = { + output_data, + *output_dims, + kTfLiteInt8, + *output_scales, + *output_zero_points, + 0, + {}, + }; + static constexpr std::initializer_list outputs = { + &tod}; + + const std::initializer_list expected = {kExpectLUT0}; + + tflite::testing::TestDecode( + encodes, ancillaries, outputs, expected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodeTwoTensors) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data0[std::size(kExpectLUT0)] = {}; + alignas(16) int16_t output_data1[std::size(kExpectLUT1)] = {}; + alignas(16) const AncillaryData + kAncillaryData0 = {{kDcmLUT0}, {kAncillaryDataLUT0}}; + alignas(16) const AncillaryData + kAncillaryData1 = {{kDcmLUT1}, {kAncillaryDataLUT1}}; + + constexpr int kAncillaryShapeLUT0[] = {1, sizeof(kAncillaryData0)}; + constexpr int kAncillaryShapeLUT1[] = {1, sizeof(kAncillaryData1)}; + + const TfLiteIntArray* const encoded_dims = + tflite::testing::IntArrayFromInts(kEncodedShapeLUT); + static const TensorInDatum tid_encode0 = { + kEncodedLUT, + *encoded_dims, + }; + static const TensorInDatum tid_encode1 = { + kEncodedLUT, + *encoded_dims, + }; + static constexpr std::initializer_list encodes = { + &tid_encode0, &tid_encode1}; + + const TfLiteIntArray* const ancillary_dims0 = + tflite::testing::IntArrayFromInts(kAncillaryShapeLUT0); + static const TensorInDatum tid_ancillary0 = { + &kAncillaryData0, + *ancillary_dims0, + }; + const TfLiteIntArray* const ancillary_dims1 = + tflite::testing::IntArrayFromInts(kAncillaryShapeLUT1); + static const TensorInDatum tid_ancillary1 = { + &kAncillaryData1, + *ancillary_dims1, + }; + static constexpr std::initializer_list ancillaries = { + &tid_ancillary0, &tid_ancillary1}; + + const TfLiteIntArray* const output_dims = + tflite::testing::IntArrayFromInts(kOutputShapeLUT); + constexpr float output_scales_data[] = {1, 1.0f}; + const TfLiteFloatArray* const output_scales = + tflite::testing::FloatArrayFromFloats(output_scales_data); + constexpr int output_zero_points_data[] = {1, 0}; + const TfLiteIntArray* const output_zero_points = + tflite::testing::IntArrayFromInts(output_zero_points_data); + static const TensorOutDatum tod0 = { + output_data0, + *output_dims, + kTfLiteInt8, + *output_scales, + *output_zero_points, + 0, + {}, + }; + static const TensorOutDatum tod1 = { + output_data1, + *output_dims, + kTfLiteInt16, + *output_scales, + *output_zero_points, + 0, + {}, + }; + static constexpr std::initializer_list outputs = { + &tod0, &tod1}; + + const std::initializer_list expected = {kExpectLUT0, + kExpectLUT1}; + + tflite::testing::TestDecode( + encodes, ancillaries, outputs, expected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneFloat) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) float output_data[std::size(kExpectPrune2)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune2}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + constexpr float kOutputScalesData[] = {0}; + const TfLiteFloatArray* const kOutputScales = + tflite::testing::FloatArrayFromFloats(kOutputScalesData); + constexpr int kOutputZeroPointsData[] = {0}; + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kOutputZeroPointsData); + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteFloat32, + *kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune2}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune3)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune3}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + constexpr float kOutputScalesData[] = {0}; + const TfLiteFloatArray* const kOutputScales = + tflite::testing::FloatArrayFromFloats(kOutputScalesData); + constexpr int kOutputZeroPointsData[] = {0}; + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kOutputZeroPointsData); + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt8, + *kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune3}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune0)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune0}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteFloatArray* const kOutputScales = + tflite::testing::FloatArrayFromFloats(kScalesPrune0); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune0); + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt8, + *kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune0}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedAltAxisInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune4)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune4}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteFloatArray* const kOutputScales = + tflite::testing::FloatArrayFromFloats(kScalesPrune4); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune4); + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt8, + *kOutputScales, + *kOutputZeroPoints, + (kOutputDims->size - 1), + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune4}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedAltAxisInt16) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int16_t output_data[std::size(kExpectPrune1)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune1}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteFloatArray* const kOutputScales = + tflite::testing::FloatArrayFromFloats(kScalesPrune1); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune1); + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt16, + *kOutputScales, + *kOutputZeroPoints, + (kOutputDims->size - 1), + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune1}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedInvalidZeroPointInt16) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int16_t output_data[std::size(kExpectPrune1)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune1}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + float kScales[] = {2, 1.0f, 1.0f}; + const TfLiteFloatArray* const kOutputScales = + tflite::testing::FloatArrayFromFloats(kScales); + const int kZeroPoints[] = {2, 0, -1}; + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPoints); + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt16, + *kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune1}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE(), + nullptr, kTfLiteError); +} + +TF_LITE_MICRO_TEST(DecodeHuffmanTable16BitsInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectHuffmanInt8)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmHuffman16}, {kAncillaryDataHuffman16}}; + + constexpr int kAncillaryShapeHuffman[] = {1, sizeof(kAncillaryData)}; + + const TfLiteIntArray* const encoded_dims = + tflite::testing::IntArrayFromInts(kEncodedShapeHuffman); + static const TensorInDatum tid_encode = { + kEncodedHuffman, + *encoded_dims, + }; + static constexpr std::initializer_list encodes = { + &tid_encode, + }; + + const TfLiteIntArray* const ancillary_dims = + tflite::testing::IntArrayFromInts(kAncillaryShapeHuffman); + static const TensorInDatum tid_ancillary = { + &kAncillaryData, + *ancillary_dims, + }; + static constexpr std::initializer_list ancillaries = { + &tid_ancillary}; + + const TfLiteIntArray* const output_dims = + tflite::testing::IntArrayFromInts(kOutputShapeHuffman); + constexpr float output_scales_data[] = {0}; + const TfLiteFloatArray* const output_scales = + tflite::testing::FloatArrayFromFloats(output_scales_data); + constexpr int output_zero_points_data[] = {0}; + const TfLiteIntArray* const output_zero_points = + tflite::testing::IntArrayFromInts(output_zero_points_data); + static const TensorOutDatum tod = { + output_data, + *output_dims, + kTfLiteInt8, + *output_scales, + *output_zero_points, + 0, + {}, + }; + static constexpr std::initializer_list outputs = { + &tod}; + + const std::initializer_list expected = {kExpectHuffmanInt8}; + + tflite::testing::TestDecode( + encodes, ancillaries, outputs, expected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodeHuffmanTable16BitsInt16Fail) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int16_t output_data[std::size(kExpectHuffmanInt16)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmHuffman16}, {kAncillaryDataHuffman16}}; + + constexpr int kAncillaryShapeHuffman[] = {1, sizeof(kAncillaryData)}; + + const TfLiteIntArray* const encoded_dims = + tflite::testing::IntArrayFromInts(kEncodedShapeHuffman); + static const TensorInDatum tid_encode = { + kEncodedHuffman, + *encoded_dims, + }; + static constexpr std::initializer_list encodes = { + &tid_encode, + }; + + const TfLiteIntArray* const ancillary_dims = + tflite::testing::IntArrayFromInts(kAncillaryShapeHuffman); + static const TensorInDatum tid_ancillary = { + &kAncillaryData, + *ancillary_dims, + }; + static constexpr std::initializer_list ancillaries = { + &tid_ancillary}; + + const TfLiteIntArray* const output_dims = + tflite::testing::IntArrayFromInts(kOutputShapeHuffman); + constexpr float output_scales_data[] = {0}; + const TfLiteFloatArray* const output_scales = + tflite::testing::FloatArrayFromFloats(output_scales_data); + constexpr int output_zero_points_data[] = {0}; + const TfLiteIntArray* const output_zero_points = + tflite::testing::IntArrayFromInts(output_zero_points_data); + static const TensorOutDatum tod = { + output_data, + *output_dims, + kTfLiteInt16, + *output_scales, + *output_zero_points, + 0, + {}, + }; + static constexpr std::initializer_list outputs = { + &tod}; + + const std::initializer_list expected = {kExpectHuffmanInt16}; + + tflite::testing::TestDecode( + encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), + nullptr, kTfLiteError); +} + +TF_LITE_MICRO_TEST(DecodeHuffmanTable32BitsInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectHuffmanInt8)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmHuffman32}, {kAncillaryDataHuffman32}}; + + constexpr int kAncillaryShapeHuffman[] = {1, sizeof(kAncillaryData)}; + + const TfLiteIntArray* const encoded_dims = + tflite::testing::IntArrayFromInts(kEncodedShapeHuffman); + static const TensorInDatum tid_encode = { + kEncodedHuffman, + *encoded_dims, + }; + static constexpr std::initializer_list encodes = { + &tid_encode, + }; + + const TfLiteIntArray* const ancillary_dims = + tflite::testing::IntArrayFromInts(kAncillaryShapeHuffman); + static const TensorInDatum tid_ancillary = { + &kAncillaryData, + *ancillary_dims, + }; + static constexpr std::initializer_list ancillaries = { + &tid_ancillary}; + + const TfLiteIntArray* const output_dims = + tflite::testing::IntArrayFromInts(kOutputShapeHuffman); + constexpr float output_scales_data[] = {0}; + const TfLiteFloatArray* const output_scales = + tflite::testing::FloatArrayFromFloats(output_scales_data); + constexpr int output_zero_points_data[] = {0}; + const TfLiteIntArray* const output_zero_points = + tflite::testing::IntArrayFromInts(output_zero_points_data); + static const TensorOutDatum tod = { + output_data, + *output_dims, + kTfLiteInt8, + *output_scales, + *output_zero_points, + 0, + {}, + }; + static constexpr std::initializer_list outputs = { + &tod}; + + const std::initializer_list expected = {kExpectHuffmanInt8}; + + tflite::testing::TestDecode( + encodes, ancillaries, outputs, expected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodeHuffmanTable32BitsInt16) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int16_t output_data[std::size(kExpectHuffmanInt16)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmHuffman32}, {kAncillaryDataHuffman32}}; + + constexpr int kAncillaryShapeHuffman[] = {1, sizeof(kAncillaryData)}; + + const TfLiteIntArray* const encoded_dims = + tflite::testing::IntArrayFromInts(kEncodedShapeHuffman); + static const TensorInDatum tid_encode = { + kEncodedHuffman, + *encoded_dims, + }; + static constexpr std::initializer_list encodes = { + &tid_encode, + }; + + const TfLiteIntArray* const ancillary_dims = + tflite::testing::IntArrayFromInts(kAncillaryShapeHuffman); + static const TensorInDatum tid_ancillary = { + &kAncillaryData, + *ancillary_dims, + }; + static constexpr std::initializer_list ancillaries = { + &tid_ancillary}; + + const TfLiteIntArray* const output_dims = + tflite::testing::IntArrayFromInts(kOutputShapeHuffman); + constexpr float output_scales_data[] = {0}; + const TfLiteFloatArray* const output_scales = + tflite::testing::FloatArrayFromFloats(output_scales_data); + constexpr int output_zero_points_data[] = {0}; + const TfLiteIntArray* const output_zero_points = + tflite::testing::IntArrayFromInts(output_zero_points_data); + static const TensorOutDatum tod = { + output_data, + *output_dims, + kTfLiteInt16, + *output_scales, + *output_zero_points, + 0, + {}, + }; + static constexpr std::initializer_list outputs = { + &tod}; + + const std::initializer_list expected = {kExpectHuffmanInt16}; + + tflite::testing::TestDecode( + encodes, ancillaries, outputs, expected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodeWithAltDecompressionMemory) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectLUT0)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmLUT0}, {kAncillaryDataLUT0}}; + + constexpr int kAncillaryShapeLUT[] = {1, sizeof(kAncillaryData)}; + + const TfLiteIntArray* const encoded_dims = + tflite::testing::IntArrayFromInts(kEncodedShapeLUT); + static const TensorInDatum tid_encode = { + kEncodedLUT, + *encoded_dims, + }; + static constexpr std::initializer_list encodes = { + &tid_encode, + }; + + const TfLiteIntArray* const ancillary_dims = + tflite::testing::IntArrayFromInts(kAncillaryShapeLUT); + static const TensorInDatum tid_ancillary = { + &kAncillaryData, + *ancillary_dims, + }; + static constexpr std::initializer_list ancillaries = { + &tid_ancillary}; + + const TfLiteIntArray* const output_dims = + tflite::testing::IntArrayFromInts(kOutputShapeLUT); + constexpr float output_scales_data[] = {0}; + const TfLiteFloatArray* const output_scales = + tflite::testing::FloatArrayFromFloats(output_scales_data); + constexpr int output_zero_points_data[] = {0}; + const TfLiteIntArray* const output_zero_points = + tflite::testing::IntArrayFromInts(output_zero_points_data); + static const TensorOutDatum tod = { + nullptr, // using alternate decompression memory + *output_dims, kTfLiteInt8, *output_scales, *output_zero_points, 0, {}, + }; + static constexpr std::initializer_list outputs = { + &tod}; + + const std::initializer_list expected = {kExpectLUT0}; + + std::initializer_list amr = { + {output_data, sizeof(output_data)}}; + + tflite::testing::TestDecode( + encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), &amr); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/kernel_runner.h b/tensorflow/lite/micro/kernels/kernel_runner.h index 8dbd7f8b015..aa644d01f32 100644 --- a/tensorflow/lite/micro/kernels/kernel_runner.h +++ b/tensorflow/lite/micro/kernels/kernel_runner.h @@ -67,6 +67,9 @@ class KernelRunner { // to stub out MicroGraph methods and track invocations on each subgraph. MockMicroGraph* GetMockGraph() { return &mock_micro_graph_; } + // Returns a pointer to the internal FakeMicroContext. + FakeMicroContext* GetFakeMicroContext() { return &fake_micro_context_; } + // Returns true if all temp buffer in tests are deallocated. // TODO(b/209453859): move this function to private after deallocation checks // are enabled for all kernel tests. diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index b3c9204b4d8..264af300a02 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -53,6 +53,7 @@ TFLMRegistration Register_CONCATENATION(); TFLMRegistration Register_CONV_2D(); TFLMRegistration Register_COS(); TFLMRegistration Register_CUMSUM(); +TFLMRegistration Register_DECODE(); TFLMRegistration Register_DEPTH_TO_SPACE(); TFLMRegistration Register_DEPTHWISE_CONV_2D(); TFLMRegistration Register_DEQUANTIZE(); diff --git a/tensorflow/lite/micro/kernels/xtensa/decode_state.cc b/tensorflow/lite/micro/kernels/xtensa/decode_state.cc new file mode 100644 index 00000000000..57cecea2353 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/decode_state.cc @@ -0,0 +1,92 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/decode_state.h" + +#include "tensorflow/lite/micro/kernels/decode_state_huffman.h" +#include "tensorflow/lite/micro/kernels/decode_state_lut.h" +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" +#include "tensorflow/lite/micro/micro_context.h" + +#ifdef HIFI5 +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_huffman.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h" +#endif // HIFI5 + +namespace tflite { + +DecodeState* DecodeState::CreateDecodeStateLUT( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); +#ifdef HIFI5 + constexpr size_t kBufferSize = sizeof(XtensaDecodeStateLUT); +#else + constexpr size_t kBufferSize = sizeof(DecodeStateLUT); +#endif // HIFI5 + void* buffer = micro_context->AllocatePersistentBuffer(kBufferSize); + if (buffer == nullptr) { + return nullptr; + } +#ifdef HIFI5 + DecodeState* dsp = new (buffer) XtensaDecodeStateLUT(context, profiler); +#else + DecodeState* dsp = new (buffer) DecodeStateLUT(context, profiler); +#endif // HIFI5 + + return dsp; +} + +DecodeState* DecodeState::CreateDecodeStatePrune( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); +#ifdef HIFI5 + constexpr size_t kBufferSize = sizeof(XtensaDecodeStatePrune); +#else + constexpr size_t kBufferSize = sizeof(DecodeStatePrune); +#endif // HIFI5 + void* buffer = micro_context->AllocatePersistentBuffer(kBufferSize); + if (buffer == nullptr) { + return nullptr; + } +#ifdef HIFI5 + DecodeState* dsp = new (buffer) XtensaDecodeStatePrune(context, profiler); +#else + DecodeState* dsp = new (buffer) DecodeStatePrune(context, profiler); +#endif // HIFI5 + return dsp; +} + +DecodeState* DecodeState::CreateDecodeStateHuffman( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); +#ifdef HIFI5 + constexpr size_t kBufferSize = sizeof(XtensaDecodeStateHuffman); +#else + constexpr size_t kBufferSize = sizeof(DecodeStateHuffman); +#endif // HIFI5 + void* buffer = micro_context->AllocatePersistentBuffer(kBufferSize); + if (buffer == nullptr) { + return nullptr; + } +#ifdef HIFI5 + DecodeState* dsp = new (buffer) XtensaDecodeStateHuffman(context, profiler); +#else + DecodeState* dsp = new (buffer) DecodeStateHuffman(context, profiler); +#endif // HIFI5 + return dsp; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_huffman.cc b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_huffman.cc new file mode 100644 index 00000000000..6c5014760f6 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_huffman.cc @@ -0,0 +1,119 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_huffman.h" + +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +TfLiteStatus XtensaDecodeStateHuffman::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteInt8: + if (use_32bit_table_) { + Decompress32BitTable_Xtensa(static_cast(buffer)); + } else { + Decompress16BitTable_Xtensa(static_cast(buffer)); + } + break; + case kTfLiteInt16: + Decompress32BitTable_Xtensa(static_cast(buffer)); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +void XtensaDecodeStateHuffman::Decompress16BitTable_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + size_t remaining = count_codewords_; + const uint16_t* huffman_tables = + static_cast(huffman_tables_); + const uint16_t* __restrict p_stream = + reinterpret_cast(compressed_codewords_); + + WAE_BITPTR(15); + WAE_BITSUSED(1); + // byte swap the preload half-word + WAE_BITHEAD(p_stream[0] << 8 | p_stream[0] >> 8); + WAE_SEARCHDONE(1); + WAE_FIRST_TS(initial_table_size_); + AE_VLDL16C(p_stream); + + while (remaining--) { + xtbool complete = 0; + unsigned long int symbol; + + while (!complete) { + AE_VLDL16T(complete, symbol, huffman_tables); + AE_VLDL16C(p_stream); + } + + *buffer++ = symbol; + } +} + +template +void XtensaDecodeStateHuffman::Decompress32BitTable_Xtensa(T* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + size_t remaining = count_codewords_; + const uint32_t* huffman_tables = + static_cast(huffman_tables_); + const uint16_t* __restrict p_stream = + reinterpret_cast(compressed_codewords_); + + WAE_BITPTR(15); + WAE_BITSUSED(1); + // byte swap the preload half-word + WAE_BITHEAD(p_stream[0] << 8 | p_stream[0] >> 8); + WAE_SEARCHDONE(1); + WAE_FIRST_TS(initial_table_size_); + AE_VLDL16C(p_stream); + + while (remaining--) { + xtbool complete = 0; + unsigned long int symbol; + + while (!complete) { + AE_VLDL32T(complete, symbol, huffman_tables); + AE_VLDL16C(p_stream); + } + + *buffer++ = symbol; + } +} + +template void XtensaDecodeStateHuffman::Decompress32BitTable_Xtensa( + int8_t*); +template void XtensaDecodeStateHuffman::Decompress32BitTable_Xtensa( + int16_t*); + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_huffman.h b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_huffman.h new file mode 100644 index 00000000000..e8016a2832a --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_huffman.h @@ -0,0 +1,51 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_HUFFMAN_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_HUFFMAN_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state_huffman.h" + +namespace tflite { + +struct XtensaDecodeStateHuffman : public DecodeStateHuffman { + XtensaDecodeStateHuffman() = delete; + + XtensaDecodeStateHuffman(const TfLiteContext* context, + MicroProfilerInterface* profiler) + : DecodeStateHuffman(context, profiler) {} + + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + protected: + virtual ~XtensaDecodeStateHuffman() = default; + + template + void Decompress32BitTable_Xtensa(T* buffer); + + void Decompress16BitTable_Xtensa(int8_t* buffer); + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_HUFFMAN_H_ diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc new file mode 100644 index 00000000000..de5435f4b00 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc @@ -0,0 +1,609 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +void XtensaDecodeStateLUT::DecompressToBufferWidth4_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + int j; + + ae_int8x8 d_out1, d_out2; + ae_int8x8 d_value_0_t, d_value_1_t; + ae_int8x8 d_value_0, d_value_1; + ae_int8x8 d_index, d_dummy; + + ae_int8x8* __restrict pIn_tmp = (ae_int8x8*)compressed_indices_; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + + const size_t stride = value_table_channel_stride_; + const uint8_t* __restrict value_table = + static_cast(value_table_); + + const uint8_t* __restrict value_table_t = value_table; + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (size_t i = 0; i < num_channels_; i++) { + value_table_t = value_table; + ae_valignx2 align_vtab = AE_LA128_PP(value_table_t); + AE_LA8X8X2_IP(d_value_0_t, d_value_1_t, align_vtab, + (ae_int8x16*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0_t, d_value_1_t, + d_shuffle_value_t); + + ae_valign align_load = AE_LA64_PP(pIn_tmp); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LA8X8_IP(d_index, align_load, pIn_tmp); + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + + value_table += stride; + if (elements_per_channel_t_rem) { + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 1)); /* Loading 48 bits for decoding 16 weight values */ + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + } + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void XtensaDecodeStateLUT::DecompressToBufferWidth3_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int i, j; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; + const uint8_t* __restrict value_table = + static_cast(value_table_); + + const uint8_t* __restrict value_table_t = value_table; + + int num_channels_t = num_channels_; + const size_t stride = value_table_channel_stride_; + + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + + ae_int8x8 d_index, d_dummy; + ae_int8x8 d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11; + ae_int8x8 d_out1, d_out2; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0x0F00050C00020000LL); + ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0x000E00040B000100LL); + ae_int8x8 d_shuffle_t3 = AE_MOVINT8X8_FROMINT64(0x0F060D040C030A01LL); + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (i = 0; i < num_channels_t; i++) { + ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + + value_table_t = value_table; + + ae_valign align_vtab = AE_LA64_PP(value_table_t); + AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 6); /* Loading 48 bits for decoding 16 weight values */ + + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1)); + d2 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d3 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3)); + d4 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4)); + + d1 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL)); + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL)); + d4 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL)); + + d5 = d1 | d2; + d6 = d3 | d4; + + d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4)); + d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4)); + + d9 = AE_SEL8X8(d5, d7, d_shuffle_t1); + d10 = AE_SEL8X8(d6, d8, d_shuffle_t2); + d11 = AE_SEL8X8(d9, d10, d_shuffle_t3); + + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + if (elements_per_channel_t_rem) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 3); /* Loading 48 bits for decoding 16 weight values */ + + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1)); + d2 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d3 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3)); + d4 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4)); + + d1 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL)); + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL)); + d4 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL)); + + d5 = d1 | d2; + d6 = d3 | d4; + + d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4)); + d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4)); + + d9 = AE_SEL8X8(d5, d7, d_shuffle_t1); + d10 = AE_SEL8X8(d6, d8, d_shuffle_t2); + d11 = AE_SEL8X8(d9, d10, d_shuffle_t3); + + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + } + + value_table = value_table + stride; + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void XtensaDecodeStateLUT::DecompressToBufferWidth2_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int i, j; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; + const uint8_t* __restrict value_table = + static_cast(value_table_); + + const uint8_t* __restrict value_table_t = value_table; + + int num_channels_t = num_channels_; + const size_t stride = value_table_channel_stride_; + + int elements_per_channel_t_by_5 = elements_per_channel_ >> 5; + int elements_per_channel_t_rem = elements_per_channel_ & 0x1F; + int elements_per_channel_t_rem_minus_16 = 0; + if (elements_per_channel_t_rem > 16) { + elements_per_channel_t_rem_minus_16 = elements_per_channel_t_rem - 16; + } + + ae_int8x8 d_index, d_dummy; + ae_int8x8 d0, d1, d2, d3, d4, d5; + ae_int8x8 q0, q1, q2, q3; + ae_int8x8 d_out1, d_out2; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0xFBEA7362D9C85140LL); + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (i = 0; i < num_channels_t; i++) { + ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + + value_table_t = value_table; + + ae_valign align_vtab = AE_LA64_PP(value_table_t); + AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t); + + for (j = 0; j < elements_per_channel_t_by_5; j++) { + // AE_LA8X8_IP( d_index, align_index, pIn_tmp ); /* Loading 64 bits + // for decoding 32 weight values */ + + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 8); /* Loading 64 bits for decoding 32 weight values */ + d0 = d_index; + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d0), + 0x3333333333333333LL)); // i1,i3,i5, .... + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), + 0x3333333333333333LL)); // i0,i2,i4, .... + + AE_DSEL8X8(d4, d5, d3, d2, + d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 = + // i16,i18, i17,i19, .... + + AE_DSEL8X8(q0, q1, d_value_0, d_value_1, + d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15 + AE_DSEL8X8( + q2, q3, d_value_0, d_value_1, + d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31 + + AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + + AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + if (elements_per_channel_t_rem) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 2)); /* Loading 48 bits for decoding 16 weight values */ + d0 = d_index; + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d0), + 0x3333333333333333LL)); // i1,i3,i5, .... + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), + 0x3333333333333333LL)); // i0,i2,i4, .... + + AE_DSEL8X8(d4, d5, d3, d2, + d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 = + // i16,i18, i17,i19, .... + + AE_DSEL8X8(q0, q1, d_value_0, d_value_1, + d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15 + AE_DSEL8X8( + q2, q3, d_value_0, d_value_1, + d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31 + + AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + + AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem_minus_16); + } + + value_table = value_table + stride; + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void XtensaDecodeStateLUT::DecompressToBufferWidthAnyInt8_Xtensa( + int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = value_table_channel_stride_; + const uint8_t* __restrict value_table = + static_cast(value_table_); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (use_alternate_axis_) { + int count = count_indices_; + const uint8_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + uint32_t index_1, index_2; + uint32_t mask_bits = (1 << compressed_bit_width_) - 1; + + for (int i = 0; i < num_channels_t; i++) { + elements_per_channel_t = elements_per_channel_; + /* if output pointer is not 2 byte aligned */ + if ((unsigned int)p_out_tmp & 0x1) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + elements_per_channel_t = elements_per_channel_t - 1; + } + for (int j = 0; j < (elements_per_channel_t >> 1); j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, 2 * bw); + index_1 = (index >> compressed_bit_width_) & mask_bits; + index_2 = (index)&mask_bits; + ae_int8x8 d_tmp1 = AE_L8_X((const ae_int8*)value_table, index_1); + ae_int8x8 d_tmp2 = AE_L8_X((const ae_int8*)value_table, index_2); + ae_int16x4 d_tmp = + AE_MOVINT16X4_FROMINT8X8(AE_SEL8X8I(d_tmp2, d_tmp1, 21)); + AE_S16_0_IP(d_tmp, (ae_int16*)p_out_tmp, 2); + } + if (elements_per_channel_t & 0x1) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + } + value_table += stride; + } + } +} + +void XtensaDecodeStateLUT::DecompressToBufferWidthAnyInt16_Xtensa( + int16_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = value_table_channel_stride_; + const uint16_t* __restrict value_table = + static_cast(value_table_); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int16* __restrict p_out_tmp = (ae_int16*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (use_alternate_axis_) { + int count = count_indices_; + const uint16_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int16x4 d_tmp = AE_L16_X((const ae_int16*)value_table, index << 1); + AE_S16_0_IP(d_tmp, p_out_tmp, 2); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int16x4 d_tmp = AE_L16_X((const ae_int16*)value_table, index << 1); + AE_S16_0_IP(d_tmp, p_out_tmp, 2); + } + + value_table += stride; + } + } +} + +void XtensaDecodeStateLUT::DecompressToBufferWidthAnyInt32_Xtensa( + int32_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = value_table_channel_stride_; + const uint32_t* __restrict value_table = + static_cast(value_table_); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int32* __restrict p_out_tmp = (ae_int32*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (use_alternate_axis_) { + int count = count_indices_; + const uint32_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int32x2 d_tmp = AE_L32_X((const ae_int32*)value_table, index << 2); + AE_S32_L_IP(d_tmp, p_out_tmp, 4); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int32x2 d_tmp = AE_L32_X((const ae_int32*)value_table, index << 2); + AE_S32_L_IP(d_tmp, p_out_tmp, 4); + } + + value_table += stride; + } + } +} + +void XtensaDecodeStateLUT::DecompressToBufferWidthAnyInt64_Xtensa( + int64_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = value_table_channel_stride_; + const uint64_t* __restrict value_table = + static_cast(value_table_); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int64* __restrict p_out_tmp = (ae_int64*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (use_alternate_axis_) { + int count = count_indices_; + const uint64_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int64 d_tmp = AE_L64_X((const ae_int64*)value_table, index << 3); + AE_S64_IP(d_tmp, p_out_tmp, 8); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int64 d_tmp = AE_L64_X((const ae_int64*)value_table, index << 3); + AE_S64_IP(d_tmp, p_out_tmp, 8); + } + + value_table += stride; + } + } +} + +void XtensaDecodeStateLUT::DecompressToBuffer(int8_t* buffer) { + if (compressed_bit_width_ == 4 && !use_alternate_axis_) { + if (!(elements_per_channel_ & 0x01)) { + DecompressToBufferWidth4_Xtensa(buffer); + } else { + DecompressToBufferWidthAnyInt8_Xtensa(buffer); + } + } else if (compressed_bit_width_ == 3 && !use_alternate_axis_) { + if (!(elements_per_channel_ & 0x07)) { + DecompressToBufferWidth3_Xtensa(buffer); + } else { + DecompressToBufferWidthAnyInt8_Xtensa(buffer); + } + } else if (compressed_bit_width_ == 2 && !use_alternate_axis_) { + if (!(elements_per_channel_ & 0x03)) { + DecompressToBufferWidth2_Xtensa(buffer); + } else { + DecompressToBufferWidthAnyInt8_Xtensa(buffer); + } + } else { + DecompressToBufferWidthAnyInt8_Xtensa(buffer); + } +} + +TfLiteStatus XtensaDecodeStateLUT::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + TFLITE_DCHECK(compressed_bit_width_ <= kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteBool: + DecompressToBuffer(static_cast(buffer)); + break; + case kTfLiteFloat32: + DecompressToBufferWidthAnyInt32_Xtensa(static_cast(buffer)); + break; + case kTfLiteInt8: + DecompressToBuffer(static_cast(buffer)); + break; + case kTfLiteInt16: + DecompressToBufferWidthAnyInt16_Xtensa(static_cast(buffer)); + break; + case kTfLiteInt32: + DecompressToBufferWidthAnyInt32_Xtensa(static_cast(buffer)); + break; + case kTfLiteInt64: + DecompressToBufferWidthAnyInt64_Xtensa(static_cast(buffer)); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h new file mode 100644 index 00000000000..b614887a4cc --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h @@ -0,0 +1,57 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_LUT_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_LUT_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state_lut.h" + +namespace tflite { + +struct XtensaDecodeStateLUT : public DecodeStateLUT { + XtensaDecodeStateLUT() = delete; + + XtensaDecodeStateLUT(const TfLiteContext* context, + MicroProfilerInterface* profiler) + : DecodeStateLUT(context, profiler) {} + + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + protected: + virtual ~XtensaDecodeStateLUT() = default; + + void DecompressToBuffer(int8_t* buffer); + + void DecompressToBufferWidth4_Xtensa(int8_t* buffer); + void DecompressToBufferWidth3_Xtensa(int8_t* buffer); + void DecompressToBufferWidth2_Xtensa(int8_t* buffer); + + void DecompressToBufferWidthAnyInt8_Xtensa(int8_t* buffer); + void DecompressToBufferWidthAnyInt16_Xtensa(int16_t* buffer); + void DecompressToBufferWidthAnyInt32_Xtensa(int32_t* buffer); + void DecompressToBufferWidthAnyInt64_Xtensa(int64_t* buffer); + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_LUT_H_ diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc new file mode 100644 index 00000000000..c237ee3b44f --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc @@ -0,0 +1,443 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h" + +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +TfLiteStatus XtensaDecodeStatePrune::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteBool: + DecompressToBufferInt8_Xtensa(buffer); + break; + case kTfLiteFloat32: + DecodeStatePrune::DecompressToBuffer(buffer); + break; + case kTfLiteInt8: + DecompressToBufferInt8_Xtensa(buffer); + break; + case kTfLiteInt16: + DecompressToBufferInt16_Xtensa(buffer); + break; + case kTfLiteInt32: + DecodeStatePrune::DecompressToBuffer(buffer); + break; + case kTfLiteInt64: + DecodeStatePrune::DecompressToBuffer(buffer); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +void XtensaDecodeStatePrune::DecompressToBufferInt8_Xtensa(void* buffer) { + if (num_channels_ > 1) { + DecompressToBufferPerChannelInt8_Xtensa(buffer); + return; + } + + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x16* p_weights = (ae_int8x16*)value_table_; + int* __restrict p_mask32 = (int*)compressed_indices_; + ae_valign align = AE_LA64_PP(p_weights); + ae_int8x8 data0, data1, data2, data3; + ae_int8x8 shfl0, shfl1, shfl2, shfl3; + const int count = count_indices_; + int8_t* __restrict pCoeff = static_cast(buffer); + ae_int8x8 zero = single_zero_point_; + ae_int8x8 discarded; + + for (int i = 0; i < count >> 5; i++) { + // unpack elements + int mask = *p_mask32++; + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 0); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 3); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // merge into elements + AE_MOVT8X16_L(discarded, data0, zero, data0, mask); + AE_MOVT8X16_L(discarded, data1, zero, data1, mask >> 8); + AE_MOVT8X16_H(discarded, data2, zero, data2, mask); + AE_MOVT8X16_H(discarded, data3, zero, data3, mask >> 8); + + // move merged elements to output + AE_S8X8X2_IP(data0, data1, (ae_int8x16*)pCoeff, 16); + AE_S8X8X2_IP(data2, data3, (ae_int8x16*)pCoeff, 16); + } + + const int count_rem = count & 0x1F; + if (count_rem) { + ae_valignx2 align2 = AE_ZALIGN128(); + int8_t* __restrict p_mask8 = reinterpret_cast(p_mask32); + + // unpack and merge into remaining elements + int mask = *p_mask8++; + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + AE_MOVT8X16_L(discarded, data0, zero, data0, mask); + if (count_rem > 8) { + mask = *p_mask8++; + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 0); + data1 = AE_SHFL8X8(data1, shfl1); + AE_MOVT8X16_L(discarded, data1, zero, data1, mask); + } + if (count_rem > 16) { + mask = *p_mask8++; + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 0); + data2 = AE_SHFL8X8(data2, shfl2); + AE_MOVT8X16_L(discarded, data2, zero, data2, mask); + } + if (count_rem > 24) { + mask = *p_mask8++; + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data3 = AE_SHFL8X8(data3, shfl3); + AE_MOVT8X16_L(discarded, data3, zero, data3, mask); + } + + // move merged elements to output + if (count_rem <= 16) { + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, count_rem); + } else { + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, + count_rem & 0xF); + } + AE_SA128POS_FP(align2, pCoeff); + } +} + +void XtensaDecodeStatePrune::DecompressToBufferPerChannelInt8_Xtensa( + void* buffer) { + if (use_alternate_axis_) { + DecompressToBufferPerChannelAltAxisInt8_Xtensa(buffer); + return; + } + + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x16* p_weights = (ae_int8x16*)value_table_; + short* __restrict p_stream = (short*)compressed_indices_; + ae_valign align = AE_LA64_PP(p_weights); + ae_valignx2 align2 = AE_ZALIGN128(); + ae_int8x8 data0, data1, data2, data3; + ae_int8x8 shfl0, shfl1, shfl2, shfl3; + const int count = elements_per_channel_; + int8_t* __restrict pCoeff = static_cast(buffer); + ae_int8x8 discarded; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + for (size_t channel = 0; channel < num_channels_; channel++) { + ae_int8x8 zero = zero_points_[channel]; + uint32_t mask_low, mask_high; + + for (int i = 0; i < count >> 5; i++) { + // unpack elements + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_low, 16); + const int mask = (mask_high << 16) | mask_low; + + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // merge into elements + AE_MOVT8X16_H(discarded, data0, zero, data0, mask >> 8); + AE_MOVT8X16_H(discarded, data1, zero, data1, mask); + AE_MOVT8X16_L(discarded, data2, zero, data2, mask >> 8); + AE_MOVT8X16_L(discarded, data3, zero, data3, mask); + + // move merged elements to output + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, 16); + AE_SA128POS_FP(align2, pCoeff); + } + + const int count_rem = count & 0x1F; + if (count_rem) { + if (count_rem > 16) { + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LB_DB_IP((unsigned short*)p_stream, mask_low, count_rem - 16); + mask_low <<= 32 - count_rem; + } else { + AE_LB_DB_IP((unsigned short*)p_stream, mask_high, count_rem); + mask_high <<= 16 - count_rem; + mask_low = 0; + } + const int mask = (mask_high << 16) | mask_low; + + // unpack and merge into remaining elements + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + data0 = AE_SHFL8X8(data0, shfl0); + AE_MOVT8X16_H(discarded, data0, zero, data0, mask >> 8); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + data1 = AE_SHFL8X8(data1, shfl1); + AE_MOVT8X16_H(discarded, data1, zero, data1, mask); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + data2 = AE_SHFL8X8(data2, shfl2); + AE_MOVT8X16_L(discarded, data2, zero, data2, mask >> 8); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data3 = AE_SHFL8X8(data3, shfl3); + AE_MOVT8X16_L(discarded, data3, zero, data3, mask); + + // move merged elements to output + if (count_rem <= 16) { + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, count_rem); + } else { + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, + count_rem & 0xF); + } + AE_SA128POS_FP(align2, pCoeff); + } + } +} + +void XtensaDecodeStatePrune::DecompressToBufferPerChannelAltAxisInt8_Xtensa( + void* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x16* p_weights = (ae_int8x16*)value_table_; + short* __restrict p_stream = (short*)compressed_indices_; + ae_valign align = AE_LA64_PP(p_weights); + ae_valignx2 align2 = AE_ZALIGN128(); + ae_int8x8 data0, data1, data2, data3; + ae_int8x8 shfl0, shfl1, shfl2, shfl3; + int count = count_indices_ / num_channels_; + const int max_channels = num_channels_; + int8_t* __restrict pCoeff = static_cast(buffer); + ae_int8x8 discarded; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + while (count-- > 0) { + ae_int8x8 zero0, zero1, zero2, zero3; + uint32_t mask_low, mask_high; + // p_zero is always 16 byte aligned due to copy during Setup(). + int8_t* __restrict p_zero = (int8_t*)zero_points_; + + for (int i = 0; i < max_channels >> 5; i++) { + // unpack elements + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_low, 16); + const int mask = (mask_high << 16) | mask_low; + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // load values + AE_L8X8X2_IP(zero0, zero1, (ae_int8x16*)p_zero, 16); + AE_L8X8X2_IP(zero2, zero3, (ae_int8x16*)p_zero, 16); + + // merge into elements + AE_MOVT8X16_H(discarded, data0, zero0, data0, mask >> 8); + AE_MOVT8X16_H(discarded, data1, zero1, data1, mask); + AE_MOVT8X16_L(discarded, data2, zero2, data2, mask >> 8); + AE_MOVT8X16_L(discarded, data3, zero3, data3, mask); + + // move merged elements to output + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, 16); + AE_SA128POS_FP(align2, pCoeff); + } + + const int count_rem = max_channels & 0x1F; + if (count_rem) { + if (count_rem > 16) { + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LB_DB_IP((unsigned short*)p_stream, mask_low, count_rem - 16); + mask_low <<= 32 - count_rem; + } else { + AE_LB_DB_IP((unsigned short*)p_stream, mask_high, count_rem); + mask_high <<= 16 - count_rem; + mask_low = 0; + } + const int mask = (mask_high << 16) | mask_low; + + // unpack remaining elements + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // load values, merge into elements and + // move merged elements to output + ae_valignx2 align_zero = AE_LA128_PP(p_zero); + if (count_rem <= 16) { + AE_LAV8X8X2_XP(zero0, zero1, align_zero, (ae_int8x16*)p_zero, + count_rem); + AE_MOVT8X16_H(discarded, data0, zero0, data0, mask >> 8); + AE_MOVT8X16_H(discarded, data1, zero1, data1, mask); + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, count_rem); + } else { + AE_LAV8X8X2_XP(zero0, zero1, align_zero, (ae_int8x16*)p_zero, 16); + AE_LAV8X8X2_XP(zero2, zero3, align_zero, (ae_int8x16*)p_zero, + count_rem & 0xF); + AE_MOVT8X16_H(discarded, data0, zero0, data0, mask >> 8); + AE_MOVT8X16_H(discarded, data1, zero1, data1, mask); + AE_MOVT8X16_L(discarded, data2, zero2, data2, mask >> 8); + AE_MOVT8X16_L(discarded, data3, zero3, data3, mask); + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, + count_rem & 0xF); + } + AE_SA128POS_FP(align2, pCoeff); + } + } +} + +void XtensaDecodeStatePrune::DecompressToBufferInt16_Xtensa(void* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int16x8* p_weights = (ae_int16x8*)value_table_; + int* __restrict p_mask32 = (int*)compressed_indices_; + ae_valign align = AE_LA64_PP(p_weights); + ae_int16x4 data0, data1, data2, data3; + ae_int16x4 data4, data5, data6, data7; + ae_int16x4 shfl0, shfl1, shfl2, shfl3; + ae_int16x4 shfl4, shfl5, shfl6, shfl7; + const int count = count_indices_; + int16_t* __restrict pCoeff = static_cast(buffer); + + for (int i = 0; i < count >> 5; i++) { + // unpack elements and merge 0 (zero) elements + int mask = *p_mask32++; + AE_LAVUNSQZ16X4_XP(data0, shfl0, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data1, shfl1, align, p_weights, mask, 0); + AE_LAVUNSQZ16X4_XP(data2, shfl2, align, p_weights, mask, 3); + AE_LAVUNSQZ16X4_XP(data3, shfl3, align, p_weights, mask, 2); + AE_LAVUNSQZ16X4_XP(data4, shfl4, align, p_weights, mask, 5); + AE_LAVUNSQZ16X4_XP(data5, shfl5, align, p_weights, mask, 4); + AE_LAVUNSQZ16X4_XP(data6, shfl6, align, p_weights, mask, 7); + AE_LAVUNSQZ16X4_XP(data7, shfl7, align, p_weights, mask, 6); + data0 = AE_SHFL16X4(data0, shfl0); + data1 = AE_SHFL16X4(data1, shfl1); + data2 = AE_SHFL16X4(data2, shfl2); + data3 = AE_SHFL16X4(data3, shfl3); + data4 = AE_SHFL16X4(data4, shfl4); + data5 = AE_SHFL16X4(data5, shfl5); + data6 = AE_SHFL16X4(data6, shfl6); + data7 = AE_SHFL16X4(data7, shfl7); + + // move merged elements to output + AE_S16X4X2_IP(data0, data1, (ae_int16x8*)pCoeff, 16); + AE_S16X4X2_IP(data2, data3, (ae_int16x8*)pCoeff, 16); + AE_S16X4X2_IP(data4, data5, (ae_int16x8*)pCoeff, 16); + AE_S16X4X2_IP(data6, data7, (ae_int16x8*)pCoeff, 16); + } + + const int count_rem = count & 0x1F; + if (count_rem) { + ae_valignx2 align2 = AE_ZALIGN128(); + int8_t* __restrict p_mask8 = reinterpret_cast(p_mask32); + + // unpack and merge into remaining elements + int mask = *p_mask8++; + AE_LAVUNSQZ16X4_XP(data0, shfl0, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data1, shfl1, align, p_weights, mask, 0); + data0 = AE_SHFL16X4(data0, shfl0); + data1 = AE_SHFL16X4(data1, shfl1); + if (count_rem > 8) { + mask = *p_mask8++; + AE_LAVUNSQZ16X4_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data3, shfl3, align, p_weights, mask, 0); + data2 = AE_SHFL16X4(data2, shfl2); + data3 = AE_SHFL16X4(data3, shfl3); + } + if (count_rem > 16) { + mask = *p_mask8++; + AE_LAVUNSQZ16X4_XP(data4, shfl4, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data5, shfl5, align, p_weights, mask, 0); + data4 = AE_SHFL16X4(data4, shfl4); + data5 = AE_SHFL16X4(data5, shfl5); + } + if (count_rem > 24) { + mask = *p_mask8++; + AE_LAVUNSQZ16X4_XP(data6, shfl6, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data7, shfl7, align, p_weights, mask, 0); + data6 = AE_SHFL16X4(data6, shfl6); + data7 = AE_SHFL16X4(data7, shfl7); + } + + // move merged elements to output + if (count_rem <= 8) { + AE_SAV16X4X2_XP(data0, data1, align2, (ae_int16x8*)pCoeff, + count_rem << 1); + } else if (count_rem <= 16) { + AE_SAV16X4X2_XP(data0, data1, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data2, data3, align2, (ae_int16x8*)pCoeff, + (count_rem - 8) << 1); + } else if (count_rem <= 24) { + AE_SAV16X4X2_XP(data0, data1, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data2, data3, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data4, data5, align2, (ae_int16x8*)pCoeff, + (count_rem - 16) << 1); + } else { + AE_SAV16X4X2_XP(data0, data1, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data2, data3, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data4, data5, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data6, data7, align2, (ae_int16x8*)pCoeff, + (count_rem - 24) << 1); + } + AE_SA128POS_FP(align2, pCoeff); + } +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h new file mode 100644 index 00000000000..fb6935f3383 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h @@ -0,0 +1,51 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_PRUNE_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_PRUNE_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" + +namespace tflite { + +struct XtensaDecodeStatePrune : public DecodeStatePrune { + XtensaDecodeStatePrune() = delete; + + XtensaDecodeStatePrune(const TfLiteContext* context, + MicroProfilerInterface* profiler) + : DecodeStatePrune(context, profiler) {} + + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + protected: + virtual ~XtensaDecodeStatePrune() = default; + + void DecompressToBufferInt8_Xtensa(void* buffer); + void DecompressToBufferPerChannelInt8_Xtensa(void* buffer); + void DecompressToBufferPerChannelAltAxisInt8_Xtensa(void* buffer); + void DecompressToBufferInt16_Xtensa(void* buffer); + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_PRUNE_H_ diff --git a/tensorflow/lite/micro/micro_context.cc b/tensorflow/lite/micro/micro_context.cc index ea4fd8e8dc7..ffcecb7c99d 100644 --- a/tensorflow/lite/micro/micro_context.cc +++ b/tensorflow/lite/micro/micro_context.cc @@ -15,11 +15,13 @@ limitations under the License. #include "tensorflow/lite/micro/micro_context.h" +#include #include #include #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/micro/kernels/decompress.h" +#include "tensorflow/lite/micro/memory_helpers.h" #include "tensorflow/lite/micro/micro_common.h" #include "tensorflow/lite/micro/micro_log.h" #include "tensorflow/lite/micro/micro_utils.h" @@ -125,18 +127,50 @@ void* MicroContext::DecompressTensorToBuffer( return nullptr; } +#endif // USE_TFLM_COMPRESSION + TfLiteStatus MicroContext::SetDecompressionMemory( const std::initializer_list& regions) { - return kTfLiteError; + if (decompress_regions_ != nullptr) { + return kTfLiteError; + } + + decompress_regions_ = ®ions; + decompress_regions_allocations_ = static_cast( + AllocatePersistentBuffer(sizeof(size_t) * regions.size())); + if (decompress_regions_allocations_ == nullptr) { + return kTfLiteError; + } + ResetDecompressionMemoryAllocations(); + + return kTfLiteOk; } void* MicroContext::AllocateDecompressionMemory(size_t bytes, size_t alignment) { + if (decompress_regions_ != nullptr) { + for (size_t i = 0; i < decompress_regions_->size(); i++) { + const AlternateMemoryRegion* region = &decompress_regions_->begin()[i]; + uint8_t* start = static_cast(region->address) + + decompress_regions_allocations_[i]; + uint8_t* aligned_start = AlignPointerUp(start, alignment); + size_t total = bytes + (aligned_start - start); + if (total + decompress_regions_allocations_[i] <= region->bytes) { + decompress_regions_allocations_[i] += total; + return aligned_start; + } + } + } + return nullptr; } -void MicroContext::ResetDecompressionMemoryAllocations() {} - -#endif // USE_TFLM_COMPRESSION +void MicroContext::ResetDecompressionMemoryAllocations() { + if (decompress_regions_ == nullptr) { + return; + } + TFLITE_DCHECK(decompress_regions_allocations_ != nullptr); + std::fill_n(decompress_regions_allocations_, decompress_regions_->size(), 0); +} } // namespace tflite diff --git a/tensorflow/lite/micro/micro_context.h b/tensorflow/lite/micro/micro_context.h index 5b1ea9ca798..f9a70bf5d32 100644 --- a/tensorflow/lite/micro/micro_context.h +++ b/tensorflow/lite/micro/micro_context.h @@ -16,14 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_MICRO_CONTEXT_H_ #define TENSORFLOW_LITE_MICRO_MICRO_CONTEXT_H_ +#include +#include + #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/micro_graph.h" #include "tensorflow/lite/micro/micro_profiler_interface.h" #ifdef USE_TFLM_COMPRESSION -#include - #include "tensorflow/lite/micro/compression.h" #endif // USE_TFLM_COMPRESSION @@ -126,6 +127,8 @@ class MicroContext { const TfLiteEvalTensor& tensor, const CompressionTensorData& compression_data, void* buffer); +#endif // USE_TFLM_COMPRESSION + // Used for configuring alternate decompression memory struct AlternateMemoryRegion { void* address; @@ -140,14 +143,13 @@ class MicroContext { // Return a pointer to memory that can be used for decompression. // The pointer will be aligned to the value. // Return nullptr if the requested size is not available. - // Can be called during kPrepare and kInvoke states. + // Can be called during kPrepare state. virtual void* AllocateDecompressionMemory(size_t bytes, size_t alignment); - // reset all allocation tracking + // Reset all allocation tracking. + // Can be called during kPrepare state. virtual void ResetDecompressionMemoryAllocations(); -#endif // USE_TFLM_COMPRESSION - // Set the alternate MicroProfilerInterface. // This can be used to profile subsystems simultaneously with the profiling // of kernels during the Eval phase. See (b/379584353). @@ -168,6 +170,11 @@ class MicroContext { } private: + const std::initializer_list* decompress_regions_ = + nullptr; + // array of size_t elements with length equal to decompress_regions_.size() + size_t* decompress_regions_allocations_ = nullptr; + TF_LITE_REMOVE_VIRTUAL_DELETE }; diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index 666516c6a3a..0de18d3a928 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -339,14 +339,10 @@ TfLiteStatus MicroInterpreter::SetAlternateProfiler( return micro_context_.SetAlternateProfiler(alt_profiler); } -#ifdef USE_TFLM_COMPRESSION - TfLiteStatus MicroInterpreter::SetDecompressionMemory( const std::initializer_list& regions) { return micro_context_.SetDecompressionMemory(regions); } -#endif // USE_TFLM_COMPRESSION - } // namespace tflite diff --git a/tensorflow/lite/micro/micro_interpreter.h b/tensorflow/lite/micro/micro_interpreter.h index 4a03c3fe825..e47c2b8ef0f 100644 --- a/tensorflow/lite/micro/micro_interpreter.h +++ b/tensorflow/lite/micro/micro_interpreter.h @@ -160,8 +160,6 @@ class MicroInterpreter { // decompression subsystem. TfLiteStatus SetAlternateProfiler(MicroProfilerInterface* alt_profiler); -#ifdef USE_TFLM_COMPRESSION - // Set the alternate decompression memory regions. // Can only be called during the MicroInterpreter kInit state (i.e. must // be called before MicroInterpreter::AllocateTensors). @@ -169,8 +167,6 @@ class MicroInterpreter { const std::initializer_list& regions); -#endif // USE_TFLM_COMPRESSION - protected: const MicroAllocator& allocator() const { return allocator_; } const TfLiteContext& context() const { return context_; } diff --git a/tensorflow/lite/micro/micro_interpreter_context.cc b/tensorflow/lite/micro/micro_interpreter_context.cc index 62b33afe631..3bcd115729f 100644 --- a/tensorflow/lite/micro/micro_interpreter_context.cc +++ b/tensorflow/lite/micro/micro_interpreter_context.cc @@ -15,18 +15,12 @@ limitations under the License. #include "tensorflow/lite/micro/micro_interpreter_context.h" -#include - -#ifdef USE_TFLM_COMPRESSION - #include +#include +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/micro/memory_helpers.h" #include "tensorflow/lite/micro/micro_arena_constants.h" - -#endif // USE_TFLM_COMPRESSION - -#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/micro/micro_utils.h" namespace tflite { @@ -220,54 +214,29 @@ void* MicroInterpreterContext::DecompressTensorToBuffer( buffer); } +#endif // USE_TFLM_COMPRESSION + TfLiteStatus MicroInterpreterContext::SetDecompressionMemory( const std::initializer_list& regions) { if (state_ != InterpreterState::kInit) { return kTfLiteError; } - decompress_regions_ = ®ions; - decompress_regions_allocations_ = static_cast( - AllocatePersistentBuffer(sizeof(size_t) * regions.size())); - if (decompress_regions_allocations_ == nullptr) { - return kTfLiteError; - } - ResetDecompressionMemoryAllocations(); - - return kTfLiteOk; + return MicroContext::SetDecompressionMemory(regions); } void* MicroInterpreterContext::AllocateDecompressionMemory(size_t bytes, size_t alignment) { +#ifdef USE_TFLM_COMPRESSION TFLITE_DCHECK(state_ == InterpreterState::kPrepare || state_ == InterpreterState::kInvoke); - if (decompress_regions_ != nullptr) { - for (size_t i = 0; i < decompress_regions_->size(); i++) { - const AlternateMemoryRegion* region = &decompress_regions_->begin()[i]; - uint8_t* start = static_cast(region->address) + - decompress_regions_allocations_[i]; - uint8_t* aligned_start = AlignPointerUp(start, alignment); - size_t total = bytes + (aligned_start - start); - if (total + decompress_regions_allocations_[i] <= region->bytes) { - decompress_regions_allocations_[i] += total; - return aligned_start; - } - } - } - - return nullptr; -} +#else + TFLITE_DCHECK(state_ == InterpreterState::kPrepare); +#endif // USE_TFLM_COMPRESSION -void MicroInterpreterContext::ResetDecompressionMemoryAllocations() { - if (decompress_regions_ == nullptr) { - return; - } - TFLITE_DCHECK(decompress_regions_allocations_ != nullptr); - std::fill_n(decompress_regions_allocations_, decompress_regions_->size(), 0); + return MicroContext::AllocateDecompressionMemory(bytes, alignment); } -#endif // USE_TFLM_COMPRESSION - TfLiteStatus MicroInterpreterContext::SetAlternateProfiler( tflite::MicroProfilerInterface* alt_profiler) { alt_profiler_ = alt_profiler; diff --git a/tensorflow/lite/micro/micro_interpreter_context.h b/tensorflow/lite/micro/micro_interpreter_context.h index a3927580d51..5f17c1efac6 100644 --- a/tensorflow/lite/micro/micro_interpreter_context.h +++ b/tensorflow/lite/micro/micro_interpreter_context.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_CONTEXT_H_ #define TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_CONTEXT_H_ +#include +#include + #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/micro_allocator.h" #include "tensorflow/lite/micro/micro_context.h" @@ -128,6 +131,8 @@ class MicroInterpreterContext : public MicroContext { const CompressionTensorData& compression_data, void* buffer) override; +#endif // USE_TFLM_COMPRESSION + // Set the alternate decompression memory regions. // Can only be called during the MicroInterpreter kInit state. TfLiteStatus SetDecompressionMemory( @@ -136,14 +141,9 @@ class MicroInterpreterContext : public MicroContext { // Return a pointer to memory that can be used for decompression. // The pointer will be aligned to the value. // Return nullptr if the requested size is not available. - // Can be called during kPrepare and kInvoke states. + // Can be called during kPrepare state. void* AllocateDecompressionMemory(size_t bytes, size_t alignment) override; - // reset all allocation tracking - void ResetDecompressionMemoryAllocations() override; - -#endif // USE_TFLM_COMPRESSION - // Set the alternate MicroProfilerInterface. // This can be used to profile subsystems simultaneously with the profiling // of kernels during the Eval phase. See (b/379584353). @@ -169,15 +169,6 @@ class MicroInterpreterContext : public MicroContext { void* external_context_payload_ = nullptr; MicroProfilerInterface* alt_profiler_ = nullptr; -#ifdef USE_TFLM_COMPRESSION - - const std::initializer_list* decompress_regions_ = - nullptr; - // array of size_t elements with length equal to decompress_regions_.size() - size_t* decompress_regions_allocations_; - -#endif // USE_TFLM_COMPRESSION - TF_LITE_REMOVE_VIRTUAL_DELETE }; diff --git a/tensorflow/lite/micro/micro_interpreter_context_test.cc b/tensorflow/lite/micro/micro_interpreter_context_test.cc index fd7fb43831f..e61514910e8 100644 --- a/tensorflow/lite/micro/micro_interpreter_context_test.cc +++ b/tensorflow/lite/micro/micro_interpreter_context_test.cc @@ -32,7 +32,7 @@ tflite::MicroInterpreterContext CreateMicroInterpreterContext() { // the test need to place non-transient memories in static variables. This is // safe because tests are guaranteed to run serially. constexpr size_t kArenaSize = 1024; - static uint8_t tensor_arena[kArenaSize]; + alignas(16) static uint8_t tensor_arena[kArenaSize]; const tflite::Model* model = tflite::testing::GetSimpleMockModel(); MicroAllocator* micro_allocator = @@ -199,4 +199,109 @@ TF_LITE_MICRO_TEST(TestGetTempIntermediateTensor) { TF_LITE_MICRO_EXPECT_TRUE(invalid_output == nullptr); } +TF_LITE_MICRO_TEST(TestSetDecompressionMemory) { + tflite::MicroInterpreterContext micro_context = + tflite::CreateMicroInterpreterContext(); + + constexpr size_t kAltMemorySize = 1; + alignas(16) uint8_t g_alt_memory[kAltMemorySize]; + std::initializer_list + alt_memory_region = {{g_alt_memory, kAltMemorySize}}; + TfLiteStatus status; + + // fail during Prepare state + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kPrepare); + status = micro_context.SetDecompressionMemory(alt_memory_region); + TF_LITE_MICRO_EXPECT(status == kTfLiteError); + + // fail during Invoke state + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kInvoke); + status = micro_context.SetDecompressionMemory(alt_memory_region); + TF_LITE_MICRO_EXPECT(status == kTfLiteError); + + // succeed during Init state + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kInit); + status = micro_context.SetDecompressionMemory(alt_memory_region); + TF_LITE_MICRO_EXPECT(status == kTfLiteOk); + + // fail on second Init state attempt + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kInit); + status = micro_context.SetDecompressionMemory(alt_memory_region); + TF_LITE_MICRO_EXPECT(status == kTfLiteError); +} + +TF_LITE_MICRO_TEST(TestAllocateDecompressionMemory) { + tflite::MicroInterpreterContext micro_context = + tflite::CreateMicroInterpreterContext(); + + constexpr size_t kAltMemorySize = 30; + constexpr size_t kAllocateSize = 10; + alignas(16) uint8_t g_alt_memory[kAltMemorySize]; + std::initializer_list + alt_memory_region = {{g_alt_memory, kAltMemorySize}}; + + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kInit); + TfLiteStatus status = micro_context.SetDecompressionMemory(alt_memory_region); + TF_LITE_MICRO_EXPECT(status == kTfLiteOk); + + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kPrepare); + + // allocate first 10 bytes + uint8_t* p = static_cast(micro_context.AllocateDecompressionMemory( + kAllocateSize, tflite::MicroArenaBufferAlignment())); + TF_LITE_MICRO_EXPECT(p == &g_alt_memory[0]); + + // allocate next 10 bytes + p = static_cast(micro_context.AllocateDecompressionMemory( + kAllocateSize, tflite::MicroArenaBufferAlignment())); + TF_LITE_MICRO_EXPECT(p == &g_alt_memory[16]); + + // fail next allocation + p = static_cast(micro_context.AllocateDecompressionMemory( + kAllocateSize, tflite::MicroArenaBufferAlignment())); + TF_LITE_MICRO_EXPECT(p == nullptr); +} + +TF_LITE_MICRO_TEST(TestResetDecompressionMemory) { + tflite::MicroInterpreterContext micro_context = + tflite::CreateMicroInterpreterContext(); + + constexpr size_t kAltMemorySize = 30; + constexpr size_t kAllocateSize = 10; + alignas(16) uint8_t g_alt_memory[kAltMemorySize]; + std::initializer_list + alt_memory_region = {{g_alt_memory, kAltMemorySize}}; + + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kInit); + TfLiteStatus status = micro_context.SetDecompressionMemory(alt_memory_region); + TF_LITE_MICRO_EXPECT(status == kTfLiteOk); + + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kPrepare); + + // allocate first 10 bytes + uint8_t* p = static_cast(micro_context.AllocateDecompressionMemory( + kAllocateSize, tflite::MicroArenaBufferAlignment())); + TF_LITE_MICRO_EXPECT(p == &g_alt_memory[0]); + + // allocate next 10 bytes + p = static_cast(micro_context.AllocateDecompressionMemory( + kAllocateSize, tflite::MicroArenaBufferAlignment())); + TF_LITE_MICRO_EXPECT(p == &g_alt_memory[16]); + + micro_context.ResetDecompressionMemoryAllocations(); + + // allocate first 10 bytes again + p = static_cast(micro_context.AllocateDecompressionMemory( + kAllocateSize, tflite::MicroArenaBufferAlignment())); + TF_LITE_MICRO_EXPECT(p == &g_alt_memory[0]); +} + TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index ba94ac19482..cf28f8ccf2c 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -213,6 +213,11 @@ class MicroMutableOpResolver : public MicroOpResolver { return AddBuiltin(BuiltinOperator_CUMSUM, registration, ParseCumsum); } + TfLiteStatus AddDecode() { + const TFLMRegistration& registration = tflite::Register_DECODE(); + return AddCustom("TFLM_DECODE", ®istration); + } + TfLiteStatus AddDelay() { // TODO(b/286250473): change back name to "Delay" and remove namespace return AddCustom("SignalDelay", tflite::tflm_signal::Register_DELAY()); diff --git a/tensorflow/lite/micro/tools/benchmarking/Makefile.inc b/tensorflow/lite/micro/tools/benchmarking/Makefile.inc index a79420cb982..8094f2edbd6 100644 --- a/tensorflow/lite/micro/tools/benchmarking/Makefile.inc +++ b/tensorflow/lite/micro/tools/benchmarking/Makefile.inc @@ -20,14 +20,12 @@ endif $(GENERATED_SRCS_DIR)$(GENERIC_BENCHMARK_MODEL_DIR)$(GENERIC_BENCHMARK_MODEL_NAME)_model_data.h endif -ifeq ($(ENABLE_COMPRESSION), yes) ifneq ($(GENERIC_BENCHMARK_ALT_MEM_ATTR),) CXXFLAGS += -DGENERIC_BENCHMARK_ALT_MEM_ATTR=$(GENERIC_BENCHMARK_ALT_MEM_ATTR) endif ifneq ($(GENERIC_BENCHMARK_ALT_MEM_SIZE),) CXXFLAGS += -DGENERIC_BENCHMARK_ALT_MEM_SIZE=$(GENERIC_BENCHMARK_ALT_MEM_SIZE) endif -endif GENERIC_BENCHMARK_SRCS := \ $(MICROLITE_BENCHMARK_ROOT_DIR)/generic_model_benchmark.cc \ diff --git a/tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc b/tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc index 0f58219644b..704a5075bdc 100644 --- a/tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc +++ b/tensorflow/lite/micro/tools/benchmarking/generic_model_benchmark.cc @@ -70,11 +70,10 @@ limitations under the License. // !defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) #if defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) && \ - defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) && defined(USE_TFLM_COMPRESSION) + defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) #define USE_ALT_DECOMPRESSION_MEM #endif // defined(GENERIC_BENCHMARK_ALT_MEM_SIZE) && - // defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) && - // defined(USE_TFLM_COMPRESSION) + // defined(GENERIC_BENCHMARK_ALT_MEM_ATTR) /* * Generic model benchmark. Evaluates runtime performance of a provided @@ -220,11 +219,6 @@ int Benchmark(const uint8_t* model_data, tflite::PrettyPrintType print_type) { alignas(16) static uint8_t tensor_arena[kTensorArenaSize]; -#ifdef USE_ALT_DECOMPRESSION_MEM - std::initializer_list - alt_memory_region = {{g_alt_memory, kAltMemorySize}}; -#endif // USE_ALT_DECOMPRESSION_MEM - uint32_t event_handle = profiler.BeginEvent("tflite::GetModel"); const tflite::Model* model = tflite::GetModel(model_data); profiler.EndEvent(event_handle); @@ -252,6 +246,8 @@ int Benchmark(const uint8_t* model_data, tflite::PrettyPrintType print_type) { #ifdef USE_ALT_DECOMPRESSION_MEM event_handle = profiler.BeginEvent("tflite::MicroInterpreter::SetDecompressionMemory"); + std::initializer_list + alt_memory_region = {{g_alt_memory, kAltMemorySize}}; status = interpreter.SetDecompressionMemory(alt_memory_region); if (status != kTfLiteOk) { MicroPrintf("tflite::MicroInterpreter::SetDecompressionMemory failed"); diff --git a/tensorflow/lite/micro/tools/benchmarking/op_resolver.h b/tensorflow/lite/micro/tools/benchmarking/op_resolver.h index 7817eaed0e5..42063dcca7e 100644 --- a/tensorflow/lite/micro/tools/benchmarking/op_resolver.h +++ b/tensorflow/lite/micro/tools/benchmarking/op_resolver.h @@ -45,6 +45,7 @@ inline TfLiteStatus CreateOpResolver(TflmOpResolver& op_resolver) { TF_LITE_ENSURE_STATUS(op_resolver.AddConv2D()); TF_LITE_ENSURE_STATUS(op_resolver.AddCos()); TF_LITE_ENSURE_STATUS(op_resolver.AddCumSum()); + TF_LITE_ENSURE_STATUS(op_resolver.AddDecode()); TF_LITE_ENSURE_STATUS(op_resolver.AddDelay()); TF_LITE_ENSURE_STATUS(op_resolver.AddDepthToSpace()); TF_LITE_ENSURE_STATUS(op_resolver.AddDepthwiseConv2D()); diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index a21765b3454..b77bf010dbb 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -386,6 +386,11 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/conv.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/conv_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_huffman.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_lut.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_prune.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space.cc \ diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc index b05a0670248..6001a90067b 100644 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc @@ -124,6 +124,14 @@ ifeq ($(OPTIMIZED_KERNEL_DIR), xtensa) MICROLITE_CC_KERNEL_SRCS += \ $(TENSORFLOW_ROOT)tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/kernels/kernel_util.cc + + # Additional kernel sources for DECODE operator support + ifeq ($(TARGET_ARCH), $(filter $(TARGET_ARCH), hifi5)) + MICROLITE_CC_KERNEL_SRCS += \ + $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_huffman.cc \ + $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc \ + $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc + endif endif # override KERNEL_OPTIMIZATION_LEVEL to enable higher performance @@ -131,3 +139,15 @@ endif $(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/decompress.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/decompress.cc @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@ + +$(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_huffman.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_huffman.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@ + +$(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@ + +$(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@