Skip to content

Commit 2b3f8d0

Browse files
authored
Add const-tensor checks (#3141)
@tensorflow/micro Add consistent const-tensor checks and error messages across these kernels: TRANSPOSE STRIDED_SLICE FILL BROADCAST_TO EXPAND_DIMS Modify associated kernel unit tests. bug=fixes #3140
1 parent 3d4cdc1 commit 2b3f8d0

File tree

10 files changed

+59
-25
lines changed

10 files changed

+59
-25
lines changed

tensorflow/lite/micro/kernels/broadcast_to.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -95,6 +95,9 @@ TfLiteStatus BroadcastToPrepare(TfLiteContext* context, TfLiteNode* node) {
9595
// the same as TFLite.
9696
TF_LITE_ENSURE(context, input->type != kTfLiteString);
9797

98+
TF_LITE_ENSURE_MSG(context, IsConstantTensor(shape),
99+
"Non-constant >shape< tensor is not supported");
100+
98101
TF_LITE_ENSURE_STATUS(ValidateOutputTensor(context, input, shape, output));
99102
micro_context->DeallocateTempTfLiteTensor(input);
100103
micro_context->DeallocateTempTfLiteTensor(shape);

tensorflow/lite/micro/kernels/broadcast_to_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -44,6 +44,8 @@ tflite::micro::KernelRunner CreateBroadcastToTestRunner(
4444

4545
tensors[0] = CreateTensor(input_data, IntArrayFromInts(input_shape));
4646
tensors[1] = CreateTensor(dims_data, IntArrayFromInts(dims_shape));
47+
// shape must be a const tensor
48+
tensors[1].allocation_type = kTfLiteMmapRo;
4749
tensors[2] = CreateTensor(output_data, IntArrayFromInts(output_shape));
4850

4951
// The output type matches the value type.

tensorflow/lite/micro/kernels/expand_dims.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -99,10 +99,8 @@ TfLiteStatus ExpandDimsPrepare(TfLiteContext* context, TfLiteNode* node) {
9999
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
100100
TF_LITE_ENSURE(context, output != nullptr);
101101
output->type = input->type;
102-
if (IsDynamicTensor(axis)) {
103-
MicroPrintf("DynamicTensor is not yet supported by Expand_Dims.");
104-
return kTfLiteError;
105-
}
102+
TF_LITE_ENSURE_MSG(context, IsConstantTensor(axis),
103+
"Non-constant >axis< tensor is not supported");
106104
TF_LITE_ENSURE_OK(context, VerifyTensorDim(context, input, axis, output));
107105

108106
micro_context->DeallocateTempTfLiteTensor(input);

tensorflow/lite/micro/kernels/expand_dims_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -57,6 +57,8 @@ micro::KernelRunner CreateExpandDimsKernelRunner(
5757

5858
tensors[kDimsTensorIndex] = CreateTensor(input_data, in_dims);
5959
tensors[kAxisTensorIndex] = CreateTensor(axis_data, ax_dims);
60+
// axis must be a const tensor
61+
tensors[kAxisTensorIndex].allocation_type = kTfLiteMmapRo;
6062
tensors[kOutputTensorIndex] = CreateTensor(output_data, out_dims, true);
6163

6264
TfLiteIntArray* inputs_array =

tensorflow/lite/micro/kernels/fill.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -87,12 +87,11 @@ TfLiteStatus FillPrepare(TfLiteContext* context, TfLiteNode* node) {
8787
// The dimension of the output tensor is known in model already.
8888
TFLITE_DCHECK(output->dims != nullptr);
8989

90-
if (dims->data.data != nullptr) {
91-
// When the dims tensor is specified in model already (i.e. is not an
92-
// activation tensor), the dims tensor must match the output tensor shape.
93-
// As a byproduct, ensures the dims tensor is of an integer type.
94-
TF_LITE_ENSURE_OK(context, EnsureEq(context, output->dims, dims));
95-
}
90+
TF_LITE_ENSURE_MSG(context, IsConstantTensor(dims),
91+
"Non-constant >dims< tensor is not supported");
92+
// The dims tensor must match the output tensor shape.
93+
// As a byproduct, ensures the dims tensor is of an integer type.
94+
TF_LITE_ENSURE_OK(context, EnsureEq(context, output->dims, dims));
9695

9796
micro_context->DeallocateTempTfLiteTensor(dims);
9897
micro_context->DeallocateTempTfLiteTensor(value);

tensorflow/lite/micro/kernels/fill_test.cc

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -44,6 +44,10 @@ tflite::micro::KernelRunner CreateFillTestRunner(
4444
static TfLiteTensor tensors[3];
4545

4646
tensors[0] = CreateTensor(dims_data, IntArrayFromInts(dims_shape));
47+
if (dims_data != nullptr) {
48+
// dims must be a const tensor
49+
tensors[0].allocation_type = kTfLiteMmapRo;
50+
}
4751
tensors[1] = CreateTensor(value_data, IntArrayFromInts(value_shape));
4852
tensors[2] = CreateTensor(output_data, IntArrayFromInts(output_shape));
4953

@@ -154,15 +158,13 @@ TF_LITE_MICRO_TEST(FillInt8Int32Dims) {
154158
output_data);
155159
}
156160

157-
// Verify the FILL still works when the input dims tensor is an activation
158-
// tensor (i.e. has not prepopulated value). Fill a 2x2x2 tensor with a int8
159-
// scalar value.
160-
TF_LITE_MICRO_TEST(FillInt8NoInputDimsData) {
161+
TF_LITE_MICRO_TEST(FillInt8NonConstDimsTensorFail) {
161162
constexpr int kDim1 = 2;
162163
constexpr int kDim2 = 2;
163164
constexpr int kDim3 = 2;
164165

165-
// The dims tensor with unknown data. Note that shape is always known.
166+
// Simulate the dims tensor with dynamic data. Note that shape is always
167+
// known.
166168
int dims_shape[] = {1, 3};
167169
int32_t* dims_data = nullptr;
168170

@@ -172,8 +174,11 @@ TF_LITE_MICRO_TEST(FillInt8NoInputDimsData) {
172174
int output_shape[] = {3, kDim1, kDim2, kDim3};
173175
int8_t output_data[kDim1 * kDim2 * kDim3];
174176

175-
TestFill(dims_shape, dims_data, value_shape, value_data, output_shape,
176-
output_data);
177+
tflite::micro::KernelRunner runner =
178+
CreateFillTestRunner(dims_shape, dims_data, value_shape, value_data,
179+
output_shape, output_data);
180+
181+
TF_LITE_MICRO_EXPECT_EQ(runner.InitAndPrepare(), kTfLiteError);
177182
}
178183

179184
TF_LITE_MICRO_TEST(FillFloatInt32Dims) {

tensorflow/lite/micro/kernels/strided_slice_common.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -141,6 +141,12 @@ TfLiteStatus StridedSlicePrepare(TfLiteContext* context, TfLiteNode* node) {
141141
StridedSliceContext op_context(context, node);
142142
TF_LITE_ENSURE_MSG(context, op_context.dims <= kMaxDim,
143143
"input dim should not exceed 4");
144+
TF_LITE_ENSURE_MSG(context, IsConstantTensor(op_context.begin),
145+
"Non-constant >begin< tensor is not supported");
146+
TF_LITE_ENSURE_MSG(context, IsConstantTensor(op_context.end),
147+
"Non-constant >end< tensor is not supported");
148+
TF_LITE_ENSURE_MSG(context, IsConstantTensor(op_context.strides),
149+
"Non-constant >strides< tensor is not supported");
144150
auto params = BuildStridedSliceParams(&op_context);
145151
memcpy(op_params, &params, sizeof(StridedSliceParams));
146152
return CheckOutputSize(context, &op_context);

tensorflow/lite/micro/kernels/strided_slice_test.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15+
#include "tensorflow/lite/micro/kernels/strided_slice.h"
16+
1517
#include <cstdint>
1618

1719
#include "tensorflow/lite/c/builtin_op_data.h"
@@ -82,6 +84,12 @@ void TestStridedSliceFloat(int* input_shape, int* begin_shape, int* end_shape,
8284
CreateTensor(strides_data, strides_dims),
8385
CreateTensor(output_data, output_dims),
8486
};
87+
// begin must be a const tensor
88+
tensors[kStridedSliceBeginTensor].allocation_type = kTfLiteMmapRo;
89+
// end must be a const tensor
90+
tensors[kStridedSliceEndTensor].allocation_type = kTfLiteMmapRo;
91+
// strides must be a const tensor
92+
tensors[kStridedSliceStridesTensor].allocation_type = kTfLiteMmapRo;
8593

8694
ValidateStridedSliceGoldens(tensors, tensors_size, expected_output,
8795
output_data, ElementCount(*output_dims),
@@ -116,6 +124,12 @@ void TestStridedSliceQuantized(int* input_shape, int* begin_shape,
116124
CreateTensor(strides_data, strides_dims),
117125
CreateQuantizedTensor(output_data, output_dims, 1.0, zero_point),
118126
};
127+
// begin must be a const tensor
128+
tensors[kStridedSliceBeginTensor].allocation_type = kTfLiteMmapRo;
129+
// end must be a const tensor
130+
tensors[kStridedSliceEndTensor].allocation_type = kTfLiteMmapRo;
131+
// strides must be a const tensor
132+
tensors[kStridedSliceStridesTensor].allocation_type = kTfLiteMmapRo;
119133

120134
ValidateStridedSliceGoldens(tensors, tensors_size, expected_output,
121135
output_data, ElementCount(*output_dims),

tensorflow/lite/micro/kernels/transpose_common.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ TfLiteStatus TransposePrepare(TfLiteContext* context, TfLiteNode* node) {
3434
"Transpose op only supports 1D-5D input arrays.");
3535
TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
3636
op_context.output->type);
37+
TF_LITE_ENSURE_MSG(context, IsConstantTensor(op_context.perm),
38+
"Non-constant >perm< tensor is not supported");
3739

3840
int dims = NumDimensions(op_context.input);
3941
const int32_t* perm_data = GetTensorData<int32_t>(op_context.perm);

tensorflow/lite/micro/kernels/transpose_test.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include "tensorflow/lite/c/builtin_op_data.h"
1919
#include "tensorflow/lite/c/common.h"
2020
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
21+
#include "tensorflow/lite/micro/kernels/transpose.h"
2122
#include "tensorflow/lite/micro/micro_utils.h"
2223
#include "tensorflow/lite/micro/test_helpers.h"
2324
#include "tensorflow/lite/micro/testing/micro_test.h"
@@ -123,6 +124,8 @@ void TestTranspose(int* input_dims_data, T* input_data, int* output_dims_data,
123124
CreateTensor(params->perm, perm_dims),
124125
CreateTensor(output_data, output_dims),
125126
};
127+
// perm must be a const tensor
128+
tensors[kTransposePermTensor].allocation_type = kTfLiteMmapRo;
126129

127130
TF_LITE_MICRO_EXPECT_EQ(
128131
kTfLiteOk, ValidateTranspose(tensors, tensors_size, expected_output_data,

0 commit comments

Comments
 (0)