Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tensorflow/lite/micro/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,19 @@ tflm_cc_test(
],
)

tflm_cc_test(
name = "micro_interpreter_graph_test",
srcs = [
"micro_interpreter_graph_test.cc",
],
deps = [
":micro_allocator",
":micro_interpreter_graph",
":test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)

tflm_cc_test(
name = "micro_interpreter_test",
srcs = [
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/lite/micro/micro_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ TfLiteEvalTensor* MicroInterpreter::GetTensor(int tensor_index,
return &graph_.GetAllocations()[subgraph_index].tensors[tensor_index];
}

TfLiteStatus MicroInterpreter::ResetVariableTensor(int tensor_index,
int subgraph_index) {
return graph_.ResetVariableTensor(tensor_index, subgraph_index);
}

TfLiteStatus MicroInterpreter::SetMicroExternalContext(
void* external_context_payload) {
return micro_context_.set_external_context(external_context_payload);
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/lite/micro/micro_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ class MicroInterpreter {
// Returns a pointer to the tensor for the corresponding tensor_index
TfLiteEvalTensor* GetTensor(int tensor_index, int subgraph_index = 0);

// Zeros out a single variable tensor in a specified subgraph in the model.
TfLiteStatus ResetVariableTensor(int tensor_index, int subgraph_index = 0);

// Reset the state to be what you would expect when the interpreter is first
// created. i.e. after Init and Prepare is called for the very first time.
TfLiteStatus Reset();
Expand Down
56 changes: 46 additions & 10 deletions tensorflow/lite/micro/micro_interpreter_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,8 @@ TfLiteStatus MicroInterpreterGraph::ResetVariableTensors() {
for (size_t i = 0; i < subgraph->tensors()->size(); ++i) {
auto* tensor = subgraph->tensors()->Get(i);
if (tensor->is_variable()) {
size_t buffer_size;
TF_LITE_ENSURE_STATUS(TfLiteEvalTensorByteLength(
&subgraph_allocations_[subgraph_idx].tensors[i], &buffer_size));

int value = 0;
if (tensor->type() == tflite::TensorType_INT8) {
value = tensor->quantization()->zero_point()->Get(0);
}
memset(subgraph_allocations_[subgraph_idx].tensors[i].data.raw, value,
buffer_size);
TF_LITE_ENSURE_STATUS(ResetTensorData(
tensor, &subgraph_allocations_[subgraph_idx].tensors[i]));
}
}
}
Expand All @@ -329,6 +321,50 @@ TfLiteStatus MicroInterpreterGraph::ResetVariableTensors() {
return kTfLiteOk;
}

TfLiteStatus MicroInterpreterGraph::ResetVariableTensor(int tensor_index,
int subgraph_index) {
if (static_cast<size_t>(subgraph_index) >= subgraphs_->size()) {
MicroPrintf("Accessing subgraph %d but only %d subgraphs found",
subgraph_index, subgraphs_->size());
return kTfLiteError;
}
const SubGraph* subgraph = (*subgraphs_)[subgraph_index];
if (subgraph->tensors() == nullptr ||
static_cast<size_t>(tensor_index) >= subgraph->tensors()->size()) {
MicroPrintf(
"Accessing tensor %d but only %d tensors found in subgraph %d",
tensor_index,
(subgraph->tensors() != nullptr ? subgraph->tensors()->size() : 0),
subgraph_index);
return kTfLiteError;
}
auto* tensor = subgraph->tensors()->Get(tensor_index);
if (!tensor->is_variable()) {
MicroPrintf("Accessing tensor %d in subgraph %d which is not a variable",
tensor_index, subgraph_index);
return kTfLiteError;
}

return ResetTensorData(
tensor, &subgraph_allocations_[subgraph_index].tensors[tensor_index]);
}

TfLiteStatus MicroInterpreterGraph::ResetTensorData(
const tflite::Tensor* tensor, TfLiteEvalTensor* eval_tensor) {
size_t buffer_size;
TF_LITE_ENSURE_STATUS(TfLiteEvalTensorByteLength(eval_tensor, &buffer_size));

int value = 0;
if (tensor->type() == tflite::TensorType_INT8 && tensor->quantization() &&
tensor->quantization()->zero_point() &&
tensor->quantization()->zero_point()->size() > 0) {
value = tensor->quantization()->zero_point()->Get(0);
}
memset(eval_tensor->data.raw, value, buffer_size);

return kTfLiteOk;
}

int MicroInterpreterGraph::NumSubgraphs() {
return model_->subgraphs()->size();
}
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/lite/micro/micro_interpreter_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class MicroInterpreterGraph : public MicroGraph {
// Zeros out all variable tensors in all subgraphs in the model.
virtual TfLiteStatus ResetVariableTensors();

// Zeros out a single variable tensor in a specified subgraph in the model.
virtual TfLiteStatus ResetVariableTensor(int tensor_index,
int subgraph_index = 0);

// Number of tensor inputs to a specified subgraph in the model.
virtual size_t NumSubgraphInputs(int subgraph_idx);

Expand Down Expand Up @@ -99,6 +103,9 @@ class MicroInterpreterGraph : public MicroGraph {
MicroResourceVariables* GetResourceVariables() { return resource_variables_; }

private:
TfLiteStatus ResetTensorData(const tflite::Tensor* tensor,
TfLiteEvalTensor* eval_tensor);

TfLiteContext* context_;
const Model* model_;
MicroAllocator* allocator_;
Expand Down
99 changes: 99 additions & 0 deletions tensorflow/lite/micro/micro_interpreter_graph_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/* Copyright 2026 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>

#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"

TF_LITE_MICRO_TESTS_BEGIN

TF_LITE_MICRO_TEST(TestResetVariableTensor) {
const tflite::Model* model = tflite::testing::GetComplexMockModel();
TF_LITE_MICRO_EXPECT(nullptr != model);

tflite::testing::TestingOpResolver op_resolver;
TF_LITE_MICRO_ASSERT_EQ(kTfLiteOk,
tflite::testing::GetTestingOpResolver(op_resolver));

constexpr size_t allocator_buffer_size = 1024 * 16;
uint8_t allocator_buffer[allocator_buffer_size];

tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer,
allocator_buffer_size, nullptr, nullptr,
true /* preserve_all_tensors */);
TF_LITE_MICRO_ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);

// In GetComplexMockModel, tensor index 1 is a variable tensor.
int variable_tensor_idx = 1;
TfLiteEvalTensor* variable_tensor =
interpreter.GetTensor(variable_tensor_idx);
TF_LITE_MICRO_ASSERT(nullptr != variable_tensor);

if (variable_tensor->data.data != nullptr) {
// Fill the variable tensor with non-zero values.
size_t buffer_size;
TF_LITE_MICRO_ASSERT_EQ(kTfLiteOk, tflite::TfLiteEvalTensorByteLength(
variable_tensor, &buffer_size));
uint8_t* variable_tensor_buffer =
tflite::micro::GetTensorData<uint8_t>(variable_tensor);
for (size_t i = 0; i < buffer_size; ++i) {
variable_tensor_buffer[i] = 0xAA;
}

// Reset the variable tensor.
TF_LITE_MICRO_ASSERT_EQ(
kTfLiteOk, interpreter.ResetVariableTensor(variable_tensor_idx, 0));

// Verify that the variable tensor is zeroed out.
for (size_t i = 0; i < buffer_size; ++i) {
TF_LITE_MICRO_EXPECT_EQ(0, variable_tensor_buffer[i]);
}
}

// Non-variable tensor should NOT be reset.
int non_variable_tensor_idx = 0;
TfLiteEvalTensor* non_variable_tensor =
interpreter.GetTensor(non_variable_tensor_idx);
TF_LITE_MICRO_ASSERT(nullptr != non_variable_tensor);
if (non_variable_tensor->data.data != nullptr) {
size_t buffer_size;
TF_LITE_MICRO_ASSERT_EQ(kTfLiteOk, tflite::TfLiteEvalTensorByteLength(
non_variable_tensor, &buffer_size));
uint8_t* non_variable_tensor_buffer =
tflite::micro::GetTensorData<uint8_t>(non_variable_tensor);
for (size_t i = 0; i < buffer_size; ++i) {
non_variable_tensor_buffer[i] = 0xBB;
}
TF_LITE_MICRO_ASSERT_EQ(kTfLiteError, interpreter.ResetVariableTensor(
non_variable_tensor_idx, 0));
for (size_t i = 0; i < buffer_size; ++i) {
TF_LITE_MICRO_EXPECT_EQ(0xBB, non_variable_tensor_buffer[i]);
}
}

// Test invalid tensor index.
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
interpreter.ResetVariableTensor(100, 0));

// Test invalid subgraph index.
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
interpreter.ResetVariableTensor(1, 100));
}

TF_LITE_MICRO_TESTS_END
121 changes: 115 additions & 6 deletions tensorflow/lite/micro/testing/micro_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,6 @@ inline void InitializeTest() { InitializeTarget(); }
} \
} while (false)

#define TF_LITE_MICRO_FAIL(msg) \
do { \
MicroPrintf("FAIL: %s", msg, __FILE__, __LINE__); \
micro_test::did_test_fail = true; \
} while (false)

#define TF_LITE_MICRO_EXPECT_STRING_EQ(string1, string2) \
do { \
for (int i = 0; string1[i] != '\0' && string2[i] != '\0'; i++) { \
Expand All @@ -264,6 +258,121 @@ inline void InitializeTest() { InitializeTarget(); }
} \
} while (false)

#define TF_LITE_MICRO_ASSERT(x) \
if ((x)) { \
} else { \
MicroPrintf(#x " failed at %s:%d", __FILE__, __LINE__); \
micro_test::did_test_fail = true; \
continue; \
}

#define TF_LITE_MICRO_ASSERT_EQ(x, y) \
if ((x) == (y)) { \
} else { \
auto vx = x; \
auto vy = y; \
bool isFloatingX = (std::is_floating_point<decltype(vx)>::value); \
bool isFloatingY = (std::is_floating_point<decltype(vy)>::value); \
if (isFloatingX && isFloatingY) { \
auto delta = ((vx) > (vy)) ? ((vx) - (vy)) : ((vy) - (vx)); \
if (delta > std::numeric_limits<decltype(delta)>::epsilon()) { \
MicroPrintf(#x " == " #y " failed at %s:%d (%f vs %f)", __FILE__, \
__LINE__, static_cast<double>(vx), \
static_cast<double>(vy)); \
micro_test::did_test_fail = true; \
continue; \
} \
} else { \
MicroPrintf(#x " == " #y " failed at %s:%d (%d vs %d)", __FILE__, \
__LINE__, static_cast<int>(vx), static_cast<int>(vy)); \
if (isFloatingX || isFloatingY) { \
MicroPrintf("-----------WARNING-----------"); \
MicroPrintf("Only one of the values is floating point value."); \
} \
micro_test::did_test_fail = true; \
continue; \
} \
}

#define TF_LITE_MICRO_ASSERT_NE(x, y) \
if (true) { \
auto vx = x; \
auto vy = y; \
bool isFloatingX = (std::is_floating_point<decltype(vx)>::value); \
bool isFloatingY = (std::is_floating_point<decltype(vy)>::value); \
if (isFloatingX && isFloatingY) { \
auto delta = ((vx) > (vy)) ? ((vx) - (vy)) : ((vy) - (vx)); \
if (delta <= std::numeric_limits<decltype(delta)>::epsilon()) { \
MicroPrintf(#x " != " #y " failed at %s:%d", __FILE__, __LINE__); \
micro_test::did_test_fail = true; \
continue; \
} \
} else if ((vx) == (vy)) { \
MicroPrintf(#x " != " #y " failed at %s:%d", __FILE__, __LINE__); \
if (isFloatingX || isFloatingY) { \
MicroPrintf("-----------WARNING-----------"); \
MicroPrintf("Only one of the values is floating point value."); \
} \
micro_test::did_test_fail = true; \
continue; \
} \
} else \
(void)0

#define TF_LITE_MICRO_ASSERT_GT(x, y) \
if ((x) > (y)) { \
} else { \
MicroPrintf(#x " > " #y " failed at %s:%d", __FILE__, __LINE__); \
micro_test::did_test_fail = true; \
continue; \
}

#define TF_LITE_MICRO_ASSERT_LT(x, y) \
if ((x) < (y)) { \
} else { \
MicroPrintf(#x " < " #y " failed at %s:%d", __FILE__, __LINE__); \
micro_test::did_test_fail = true; \
continue; \
}

#define TF_LITE_MICRO_ASSERT_GE(x, y) \
if ((x) >= (y)) { \
} else { \
MicroPrintf(#x " >= " #y " failed at %s:%d", __FILE__, __LINE__); \
micro_test::did_test_fail = true; \
continue; \
}

#define TF_LITE_MICRO_ASSERT_LE(x, y) \
if ((x) <= (y)) { \
} else { \
MicroPrintf(#x " <= " #y " failed at %s:%d", __FILE__, __LINE__); \
micro_test::did_test_fail = true; \
continue; \
}

#define TF_LITE_MICRO_ASSERT_TRUE(x) \
if ((x)) { \
} else { \
MicroPrintf(#x " was not true failed at %s:%d", __FILE__, __LINE__); \
micro_test::did_test_fail = true; \
continue; \
}

#define TF_LITE_MICRO_ASSERT_FALSE(x) \
if (!(x)) { \
} else { \
MicroPrintf(#x " was not false failed at %s:%d", __FILE__, __LINE__); \
micro_test::did_test_fail = true; \
continue; \
}

#define TF_LITE_MICRO_FAIL(msg) \
do { \
MicroPrintf("FAIL: %s", msg, __FILE__, __LINE__); \
micro_test::did_test_fail = true; \
} while (false)

#define TF_LITE_MICRO_CHECK_FAIL() \
do { \
if (micro_test::did_test_fail) { \
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/lite/micro/tools/make/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,12 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/flatbuffer_utils_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/hexdump_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_arena_threshold_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_helpers_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/micro_allocator_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/micro_allocation_info_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/micro_allocator_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/micro_interpreter_context_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/micro_log_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/micro_interpreter_graph_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/micro_interpreter_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/micro_log_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/micro_mutable_op_resolver_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/micro_resource_variable_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/micro_time_test.cc \
Expand Down