Skip to content

Commit deaceca

Browse files
committed
Return kTfLiteError if case of error
1 parent 2c13432 commit deaceca

File tree

3 files changed

+52
-15
lines changed

3 files changed

+52
-15
lines changed

tensorflow/lite/micro/micro_interpreter_graph.cc

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,11 @@ TfLiteStatus MicroInterpreterGraph::ResetVariableTensors() {
307307
subgraph_idx++) {
308308
const SubGraph* subgraph = (*subgraphs_)[subgraph_idx];
309309
for (size_t i = 0; i < subgraph->tensors()->size(); ++i) {
310-
TF_LITE_ENSURE_STATUS(ResetVariableTensor(i, subgraph_idx));
310+
auto* tensor = subgraph->tensors()->Get(i);
311+
if (tensor->is_variable()) {
312+
TF_LITE_ENSURE_STATUS(ResetTensorData(
313+
tensor, &subgraph_allocations_[subgraph_idx].tensors[i]));
314+
}
311315
}
312316
}
313317
if (resource_variables_ != nullptr) {
@@ -319,21 +323,43 @@ TfLiteStatus MicroInterpreterGraph::ResetVariableTensors() {
319323

320324
TfLiteStatus MicroInterpreterGraph::ResetVariableTensor(int tensor_index,
321325
int subgraph_index) {
326+
if (static_cast<size_t>(subgraph_index) >= subgraphs_->size()) {
327+
MicroPrintf("Accessing subgraph %d but only %d subgraphs found",
328+
subgraph_index, subgraphs_->size());
329+
return kTfLiteError;
330+
}
322331
const SubGraph* subgraph = (*subgraphs_)[subgraph_index];
332+
if (subgraph->tensors() == nullptr ||
333+
static_cast<size_t>(tensor_index) >= subgraph->tensors()->size()) {
334+
MicroPrintf(
335+
"Accessing tensor %d but only %d tensors found in subgraph %d",
336+
tensor_index,
337+
(subgraph->tensors() != nullptr ? subgraph->tensors()->size() : 0),
338+
subgraph_index);
339+
return kTfLiteError;
340+
}
323341
auto* tensor = subgraph->tensors()->Get(tensor_index);
324-
if (tensor->is_variable()) {
325-
size_t buffer_size;
326-
TF_LITE_ENSURE_STATUS(TfLiteEvalTensorByteLength(
327-
&subgraph_allocations_[subgraph_index].tensors[tensor_index],
328-
&buffer_size));
329-
330-
int value = 0;
331-
if (tensor->type() == tflite::TensorType_INT8) {
332-
value = tensor->quantization()->zero_point()->Get(0);
333-
}
334-
memset(subgraph_allocations_[subgraph_index].tensors[tensor_index].data.raw,
335-
value, buffer_size);
342+
if (!tensor->is_variable()) {
343+
MicroPrintf("Accessing tensor %d in subgraph %d which is not a variable",
344+
tensor_index, subgraph_index);
345+
return kTfLiteError;
336346
}
347+
348+
return ResetTensorData(
349+
tensor, &subgraph_allocations_[subgraph_index].tensors[tensor_index]);
350+
}
351+
352+
TfLiteStatus MicroInterpreterGraph::ResetTensorData(
353+
const tflite::Tensor* tensor, TfLiteEvalTensor* eval_tensor) {
354+
size_t buffer_size;
355+
TF_LITE_ENSURE_STATUS(TfLiteEvalTensorByteLength(eval_tensor, &buffer_size));
356+
357+
int value = 0;
358+
if (tensor->type() == tflite::TensorType_INT8) {
359+
value = tensor->quantization()->zero_point()->Get(0);
360+
}
361+
memset(eval_tensor->data.raw, value, buffer_size);
362+
337363
return kTfLiteOk;
338364
}
339365

tensorflow/lite/micro/micro_interpreter_graph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ class MicroInterpreterGraph : public MicroGraph {
103103
MicroResourceVariables* GetResourceVariables() { return resource_variables_; }
104104

105105
private:
106+
TfLiteStatus ResetTensorData(const tflite::Tensor* tensor,
107+
TfLiteEvalTensor* eval_tensor);
108+
106109
TfLiteContext* context_;
107110
const Model* model_;
108111
MicroAllocator* allocator_;

tensorflow/lite/micro/micro_interpreter_graph_test.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,20 @@ TF_LITE_MICRO_TEST(TestResetVariableTensor) {
7575
for (size_t i = 0; i < buffer_size; ++i) {
7676
non_variable_tensor->data.uint8[i] = 0xBB;
7777
}
78-
TF_LITE_MICRO_EXPECT_EQ(
79-
kTfLiteOk, interpreter.ResetVariableTensor(non_variable_tensor_idx, 0));
78+
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, interpreter.ResetVariableTensor(
79+
non_variable_tensor_idx, 0));
8080
for (size_t i = 0; i < buffer_size; ++i) {
8181
TF_LITE_MICRO_EXPECT_EQ(0xBB, non_variable_tensor->data.uint8[i]);
8282
}
8383
}
84+
85+
// Test invalid tensor index.
86+
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
87+
interpreter.ResetVariableTensor(100, 0));
88+
89+
// Test invalid subgraph index.
90+
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
91+
interpreter.ResetVariableTensor(1, 100));
8492
}
8593

8694
TF_LITE_MICRO_TESTS_END

0 commit comments

Comments
 (0)