-
Notifications
You must be signed in to change notification settings - Fork 981
TFLM: Add ONE_HOT operator and unit tests #3260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5f6a2f4
e4fe10e
2795ddf
61aa04b
cd28bad
05edd10
9d738c1
d20bab0
54a9d78
931a4c4
fd570e3
f9fb4be
8ecac95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| grass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,239 @@ | ||
| /* Copyright 2025 The TensorFlow Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| ==============================================================================*/ | ||
| #include <stdint.h> | ||
|
|
||
| #include "tensorflow/lite/c/builtin_op_data.h" | ||
| #include "tensorflow/lite/c/common.h" | ||
| #include "tensorflow/lite/micro/kernels/kernel_util.h" | ||
| #include "tensorflow/lite/micro/micro_common.h" | ||
|
|
||
| namespace tflite { | ||
| namespace ops { | ||
| namespace micro { | ||
| namespace one_hot { | ||
|
Comment on lines
+22
to
+25
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please flatten the namespace to just
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I'll reflect it in the next commit. |
||
|
|
||
| constexpr int kIndicesTensor = 0; | ||
| constexpr int kDepthTensor = 1; | ||
| constexpr int kOnValueTensor = 2; | ||
| constexpr int kOffValueTensor = 3; | ||
| constexpr int kOutputTensor = 0; | ||
|
|
||
| namespace { // Local Util functions | ||
| inline int NumElements(const TfLiteEvalTensor* t) { | ||
| int count = 1; | ||
| for (int i = 0; i < t->dims->size; ++i) { | ||
| count *= t->dims->data[i]; | ||
| } | ||
| return count; | ||
| } | ||
| } // namespace | ||
|
|
||
| // Retrieves the input tensors (indices, depth, on_value, off_value) and the | ||
| // output tensor (output) from the TfLiteNode. | ||
| // Reads params->axis to compute the actual position (axis) where the depth | ||
| // dimension will be inserted. | ||
| // These values are created temporarily within the Prepare and Eval functions | ||
| // and are destroyed afterward → efficient use of stack memory. | ||
| struct OneHotContext { | ||
| OneHotContext(TfLiteContext* context, TfLiteNode* node) { | ||
| indices = tflite::micro::GetEvalInput(context, node, kIndicesTensor); | ||
| depth = tflite::micro::GetEvalInput(context, node, kDepthTensor); | ||
| on_value = tflite::micro::GetEvalInput(context, node, kOnValueTensor); | ||
| off_value = tflite::micro::GetEvalInput(context, node, kOffValueTensor); | ||
| output = tflite::micro::GetEvalOutput(context, node, kOutputTensor); | ||
|
|
||
| const auto* params = | ||
| reinterpret_cast<TfLiteOneHotParams*>(node->builtin_data); | ||
| const int indices_dims = indices->dims->size; | ||
| axis = (params->axis == -1) ? indices_dims : params->axis; | ||
| output_dims = indices_dims + 1; | ||
| dtype = on_value->type; | ||
| } | ||
|
|
||
| const TfLiteEvalTensor* indices; | ||
| const TfLiteEvalTensor* depth; | ||
| const TfLiteEvalTensor* on_value; | ||
| const TfLiteEvalTensor* off_value; | ||
| TfLiteEvalTensor* output; | ||
|
|
||
| int axis; | ||
| int output_dims; | ||
| TfLiteType dtype; | ||
| }; | ||
|
|
||
| // Operation function | ||
| template <typename T, typename TI> | ||
| void OneHotComputeImpl(const OneHotContext& op_context) { | ||
| int prefix_dim_size = 1; | ||
| for (int i = 0; i < op_context.axis; ++i) { | ||
| prefix_dim_size *= op_context.indices->dims->data[i]; | ||
| } | ||
| if (prefix_dim_size == 0) { | ||
| return; | ||
| } | ||
|
|
||
| const RuntimeShape indices_shape = | ||
| tflite::micro::GetTensorShape(op_context.indices); | ||
| const int suffix_dim_size = indices_shape.FlatSize() / prefix_dim_size; | ||
|
|
||
| const int depth = *op_context.depth->data.i32; | ||
|
|
||
| const T on_value = *tflite::micro::GetTensorData<T>(op_context.on_value); | ||
| const T off_value = *tflite::micro::GetTensorData<T>(op_context.off_value); | ||
|
|
||
| T* output_data = tflite::micro::GetTensorData<T>(op_context.output); | ||
| const TI* indices_data = tflite::micro::GetTensorData<TI>(op_context.indices); | ||
|
|
||
| for (int i = 0; i < prefix_dim_size; ++i) { | ||
| for (int j = 0; j < depth; ++j) { | ||
| for (int k = 0; k < suffix_dim_size; ++k, ++output_data) { | ||
| *output_data = | ||
| static_cast<int>(indices_data[i * suffix_dim_size + k]) == j | ||
| ? on_value | ||
| : off_value; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| void OneHotCompute(const OneHotContext& op_context) { | ||
| if (op_context.indices->type == kTfLiteInt64) { | ||
| OneHotComputeImpl<T, int64_t>(op_context); | ||
| } else { | ||
| OneHotComputeImpl<T, int32_t>(op_context); | ||
| } | ||
| } | ||
|
|
||
| TfLiteStatus ResizeOutputTensor(TfLiteContext* context, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Method should have a name more representative of it's functionality.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I'll fix the method in the next commit. |
||
| const OneHotContext& op_context) { | ||
| TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0); | ||
|
|
||
| // read depth data | ||
| const int depth_val = | ||
| *tflite::micro::GetTensorData<int32_t>(op_context.depth); | ||
| TF_LITE_ENSURE(context, depth_val >= 0); | ||
|
|
||
| // Output Tensor evaluation | ||
| TF_LITE_ENSURE(context, op_context.output != nullptr); | ||
|
|
||
| TF_LITE_ENSURE(context, op_context.output->dims != nullptr); | ||
|
|
||
| // TFLM assumes that the output tensor’s dims are already allocated | ||
| const int expected_dims_size = op_context.output_dims; | ||
| TF_LITE_ENSURE_EQ(context, op_context.output->dims->size, expected_dims_size); | ||
|
|
||
| for (int i = 0; i < expected_dims_size; ++i) { | ||
| int expected_dim_i; | ||
| if (i < op_context.axis) { | ||
| expected_dim_i = op_context.indices->dims->data[i]; | ||
| } else if (i == op_context.axis) { | ||
| expected_dim_i = depth_val; | ||
| } else { | ||
| expected_dim_i = op_context.indices->dims->data[i - 1]; | ||
| } | ||
|
|
||
| // If the size pre-allocated by the TFLM compiler (Offline Memory Planner) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I'll change the memory planner. |
||
| // does not match the actual computed size, an error is raised. | ||
| TF_LITE_ENSURE_EQ(context, op_context.output->dims->data[i], | ||
| expected_dim_i); | ||
| } | ||
|
|
||
| return kTfLiteOk; | ||
| } | ||
|
|
||
| TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { | ||
| TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); | ||
| TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); | ||
|
|
||
| OneHotContext op_context{context, node}; | ||
| TF_LITE_ENSURE(context, op_context.output != nullptr); | ||
|
|
||
| switch (op_context.dtype) { | ||
| case kTfLiteFloat32: | ||
| case kTfLiteInt16: | ||
| case kTfLiteInt32: | ||
| case kTfLiteInt64: | ||
| case kTfLiteInt8: | ||
| case kTfLiteUInt8: | ||
| case kTfLiteBool: | ||
| op_context.output->type = op_context.dtype; | ||
| break; | ||
| default: | ||
| TF_LITE_KERNEL_LOG(context, "Unknown output data type: %s", | ||
| TfLiteTypeGetName(op_context.dtype)); | ||
| return kTfLiteError; | ||
| } | ||
|
|
||
| TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 || | ||
| op_context.indices->type == kTfLiteInt64); | ||
| TF_LITE_ENSURE(context, op_context.axis >= 0 && | ||
| op_context.axis < op_context.output_dims); | ||
| TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1); | ||
| TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1); | ||
| TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1); | ||
| TF_LITE_ENSURE_TYPES_EQ(context, op_context.on_value->type, op_context.dtype); | ||
| TF_LITE_ENSURE_TYPES_EQ(context, op_context.off_value->type, | ||
| op_context.dtype); | ||
|
|
||
| // Even if the depth tensor is not a constant, the test predefines the output | ||
| // shape, so here we only perform validation. | ||
| return ResizeOutputTensor(context, op_context); | ||
| } | ||
|
|
||
| TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||
| OneHotContext op_context{context, node}; | ||
|
|
||
| switch (op_context.output->type) { | ||
| case kTfLiteFloat32: | ||
| OneHotCompute<float>(op_context); | ||
| break; | ||
| case kTfLiteInt32: | ||
| OneHotCompute<int32_t>(op_context); | ||
| break; | ||
| case kTfLiteInt64: | ||
| OneHotCompute<int64_t>(op_context); | ||
| break; | ||
| case kTfLiteInt8: | ||
| OneHotCompute<int8_t>(op_context); | ||
| break; | ||
| case kTfLiteUInt8: | ||
| OneHotCompute<uint8_t>(op_context); | ||
| break; | ||
| case kTfLiteBool: | ||
| OneHotCompute<bool>(op_context); | ||
| break; | ||
| default: | ||
| return kTfLiteError; | ||
| } | ||
|
|
||
| return kTfLiteOk; | ||
| } | ||
|
|
||
| } // namespace one_hot | ||
|
|
||
| // Implementation of Register_ONE_HOT declared in the header | ||
| const TFLMRegistration* Register_ONE_HOT() { | ||
| static TFLMRegistration r = {}; | ||
|
|
||
| r.prepare = one_hot::Prepare; | ||
| r.invoke = one_hot::Eval; | ||
|
|
||
| return &r; | ||
| } | ||
|
Comment on lines
+228
to
+235
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please follow the registration implementation of all other kernels.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I will modify the registration implementation soon. |
||
|
|
||
| } // namespace micro | ||
| } // namespace ops | ||
| } // namespace tflite | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file is not required. External declaration of the registration method is handled elsewhere in TFLM.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay. I will modify at next PR. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| #ifndef TENSORFLOW_LITE_MICRO_KERNELS_ONE_HOT_H_ | ||
| #define TENSORFLOW_LITE_MICRO_KERNELS_ONE_HOT_H_ | ||
|
|
||
| #include "tensorflow/lite/c/common.h" | ||
| #include "tensorflow/lite/micro/micro_common.h" | ||
|
|
||
| namespace tflite { | ||
| namespace ops { | ||
| namespace micro { | ||
|
|
||
| // ONE_HOT Kernel regist function (use at all_ops_resolver) | ||
| const TFLMRegistration* Register_ONE_HOT(); | ||
|
|
||
| } // namespace micro | ||
| } // namespace ops | ||
| } // namespace tflite | ||
|
|
||
| #endif // TENSORFLOW_LITE_MICRO_KERNELS_ONE_HOT_H_ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| #include "tensorflow/lite/c/builtin_op_data.h" | ||
| #include "tensorflow/lite/c/common.h" | ||
| #include "tensorflow/lite/micro/kernels/kernel_runner.h" | ||
| #include "tensorflow/lite/micro/test_helpers.h" | ||
| #include "tensorflow/lite/micro/testing/micro_test.h" | ||
|
|
||
| namespace tflite { | ||
| namespace ops { | ||
| namespace micro { | ||
|
|
||
| const TFLMRegistration* Register_ONE_HOT(); | ||
| } // namespace micro | ||
| } // namespace ops | ||
| } // namespace tflite | ||
|
|
||
| namespace tflite { | ||
| namespace testing { | ||
| namespace { | ||
|
|
||
| // Helper function for OneHot operation test | ||
| template <typename T> | ||
| void TestOneHot(const int* indices_dims, const int32_t* indices_data, | ||
| const int* depth_dims, const int32_t* depth_data, | ||
| const int* on_dims, const T* on_data, const int* off_dims, | ||
| const T* off_data, const int* output_dims, | ||
| const T* expected_output_data, T* output_data, int axis = -1) { | ||
| // 1. Tensor Setting | ||
| TfLiteIntArray* in_dims = IntArrayFromInts(indices_dims); | ||
| TfLiteIntArray* d_dims = IntArrayFromInts(depth_dims); | ||
| TfLiteIntArray* on_val_dims = IntArrayFromInts(on_dims); | ||
| TfLiteIntArray* off_val_dims = IntArrayFromInts(off_dims); | ||
| TfLiteIntArray* out_dims = IntArrayFromInts(output_dims); | ||
|
|
||
| const int output_dims_count = ElementCount(*out_dims); | ||
|
|
||
| // 2. Create Input Tensor | ||
| constexpr int inputs_size = 4; | ||
| constexpr int outputs_size = 1; | ||
| constexpr int tensors_size = inputs_size + outputs_size; | ||
| TfLiteTensor tensors[tensors_size] = { | ||
| CreateTensor(indices_data, in_dims), CreateTensor(depth_data, d_dims), | ||
| CreateTensor(on_data, on_val_dims), CreateTensor(off_data, off_val_dims), | ||
| CreateTensor(output_data, out_dims), // Output Tensor | ||
| }; | ||
|
|
||
| // 3. Parameter setting | ||
| TfLiteOneHotParams builtin_data = {axis}; | ||
|
|
||
| // 4. KernelRunner execution | ||
| int inputs_array_data[] = {4, 0, 1, 2, 3}; // indices, depth, on, off | ||
| TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); | ||
| int outputs_array_data[] = {1, 4}; // output tensor index | ||
| TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); | ||
|
|
||
| // tflite::ops::micro::Register_ONE_HOT) | ||
| const TFLMRegistration registration = *tflite::ops::micro::Register_ONE_HOT(); | ||
| micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, | ||
| outputs_array, | ||
| reinterpret_cast<void*>(&builtin_data)); | ||
|
|
||
| TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); | ||
| TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); | ||
|
|
||
| // 5. Result evaluation | ||
| for (int i = 0; i < output_dims_count; ++i) { | ||
| TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); | ||
| } | ||
| } | ||
|
|
||
| } // namespace | ||
| } // namespace testing | ||
| } // namespace tflite | ||
|
|
||
| // UNIT TEST | ||
| TF_LITE_MICRO_TESTS_BEGIN | ||
|
|
||
| TF_LITE_MICRO_TEST(OneHot_BasicInt32) { | ||
| // Indices: [0, 1, 2] | ||
| const int indices_dims[] = {1, 3}; | ||
| const int32_t indices_data[] = {0, 1, 2}; | ||
|
|
||
| // Depth: 3 | ||
| const int depth_dims[] = {1, 1}; | ||
| const int32_t depth_data[] = {3}; | ||
|
|
||
| // On: 1, Off: 0 | ||
| const int on_dims[] = {1, 1}; | ||
| const int32_t on_data[] = {1}; | ||
| const int off_dims[] = {1, 1}; | ||
| const int32_t off_data[] = {0}; | ||
|
|
||
| // Output: [3, 3] -> Identity Matrix | ||
| const int output_dims[] = {2, 3, 3}; | ||
| const int32_t expected_output[] = {1, 0, 0, 0, 1, 0, 0, 0, 1}; | ||
|
|
||
| int32_t output_data[9]; | ||
|
|
||
| tflite::testing::TestOneHot(indices_dims, indices_data, depth_dims, | ||
| depth_data, on_dims, on_data, off_dims, off_data, | ||
| output_dims, expected_output, output_data); | ||
| } | ||
|
|
||
| TF_LITE_MICRO_TESTS_END |
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this file included in the PR?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's my mistake, I will remove this file at next PR. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| // one_hot_test_model_data.cc 같은 별도 파일로 두면 좋음 | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| extern "C" { | ||
|
|
||
| // 그냥 더미 바이트들 (유효한 TFLite 모델이 아님) | ||
| const unsigned char g_one_hot_basic_float_model[] = { | ||
| // FlatBuffer signature 자리에는 보통 'T','F','L','3' 가 오지만 | ||
| // 여기서는 진짜 모델을 만들지 않았으니 그냥 대충 채워둔 상태입니다. | ||
| 0x54, 0x46, 0x4C, 0x33, // 'T','F','L','3' 비슷하게 맞춰줌 | ||
| 0x00, 0x00, 0x00, 0x00, // 나머지는 전부 0 | ||
| 0x00, 0x00, 0x00, 0x00, | ||
| }; | ||
|
|
||
| const int g_one_hot_basic_float_model_len = | ||
| sizeof(g_one_hot_basic_float_model) / | ||
| sizeof(g_one_hot_basic_float_model[0]); | ||
|
|
||
| } // extern "C" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this file in the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's my mistake, I will remove this file at next PR.