Skip to content

Commit 74b9db9

Browse files
authored
feat(python): throw error when loading compressed models without support (#3167)
When a model contains COMPRESSION_METADATA but the interpreter was built without compression support, throw a RuntimeError with a helpful message directing users to build with --//:with_compression=true. The implementation uses inline functions in compression_utils.h that are optimized away when compression is disabled, ensuring all code paths remain compile-checked, and readable without preprocessor clutter. Includes test_compression_unsupported.py to verify the error detection, which only runs when compression is disabled. BUG=#3125
1 parent 97166cf commit 74b9db9

File tree

4 files changed

+184
-0
lines changed

4 files changed

+184
-0
lines changed

python/tflite_micro/BUILD

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pybind_extension(
4949
# target = _runtime.so because pybind_extension() appends suffix
5050
srcs = [
5151
"_runtime.cc",
52+
"compression_utils.h",
5253
"interpreter_wrapper.cc",
5354
"interpreter_wrapper.h",
5455
"numpy_utils.cc",
@@ -98,6 +99,27 @@ py_test(
9899
],
99100
)
100101

102+
py_test(
103+
name = "test_compression_unsupported",
104+
srcs = ["test_compression_unsupported.py"],
105+
tags = [
106+
"noasan",
107+
"nomsan", # Python doesn't like these symbols in _runtime.so
108+
"noubsan",
109+
],
110+
# Only compatible when compression is NOT enabled
111+
target_compatible_with = select({
112+
"//:with_compression_enabled": ["@platforms//:incompatible"],
113+
"//conditions:default": [],
114+
}),
115+
deps = [
116+
":runtime",
117+
requirement("numpy"),
118+
requirement("tensorflow"),
119+
"//tensorflow/lite/micro/compression",
120+
],
121+
)
122+
101123
py_library(
102124
name = "postinstall_check",
103125
srcs = [
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_LITE_MICRO_PYTHON_COMPRESSION_UTILS_H_
17+
#define TENSORFLOW_LITE_MICRO_PYTHON_COMPRESSION_UTILS_H_
18+
19+
#include <cstring>
20+
21+
#include "tensorflow/lite/schema/schema_generated.h"
22+
23+
namespace tflite {
24+
25+
// Returns true if interpreter was built with compression support.
26+
// When USE_TFLM_COMPRESSION is defined, this always returns true and
27+
// the compiler can optimize away any if (!IsCompressionSupported()) branches.
28+
inline constexpr bool IsCompressionSupported() {
29+
#ifdef USE_TFLM_COMPRESSION
30+
return true;
31+
#else
32+
return false;
33+
#endif
34+
}
35+
36+
// Helper to check if model has compression metadata.
37+
// This is always compiled in, but when used with IsCompressionSupported()
38+
// the entire check can be optimized away.
39+
inline bool HasCompressionMetadata(const Model& model) {
40+
if (!model.metadata()) {
41+
return false;
42+
}
43+
44+
for (size_t i = 0; i < model.metadata()->size(); ++i) {
45+
const auto* metadata = model.metadata()->Get(i);
46+
if (metadata && metadata->name() &&
47+
strcmp(metadata->name()->c_str(), "COMPRESSION_METADATA") == 0) {
48+
return true;
49+
}
50+
}
51+
return false;
52+
}
53+
54+
} // namespace tflite
55+
56+
#endif // TENSORFLOW_LITE_MICRO_PYTHON_COMPRESSION_UTILS_H_

python/tflite_micro/interpreter_wrapper.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include <numpy/arrayobject.h>
2929
#include <pybind11/pybind11.h>
3030

31+
#include "python/tflite_micro/compression_utils.h"
3132
#include "python/tflite_micro/numpy_utils.h"
3233
#include "python/tflite_micro/pybind11_lib.h"
3334
#include "python/tflite_micro/python_ops_resolver.h"
@@ -255,6 +256,16 @@ InterpreterWrapper::InterpreterWrapper(
255256

256257
const Model* model = GetModel(buf);
257258
model_ = model_data;
259+
260+
// Check if the model has compression metadata but compression is not
261+
// supported
262+
if (!IsCompressionSupported() && HasCompressionMetadata(*model)) {
263+
ThrowRuntimeError(
264+
"Model contains compressed tensors but the interpreter was not "
265+
"built with compression support. Please build the Python wheel with "
266+
"--//:with_compression=true to enable compression support.");
267+
}
268+
258269
memory_arena_ = std::unique_ptr<uint8_t[]>(new uint8_t[arena_size]);
259270
for (const std::string& registerer : registerers_by_name) {
260271
if (!AddCustomOpRegistererByName(registerer.c_str(),
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Test compression metadata detection when compression is disabled."""
16+
17+
import os
18+
import numpy as np
19+
import tensorflow as tf
20+
from tflite_micro.python.tflite_micro import runtime
21+
from tflite_micro.tensorflow.lite.micro import compression
22+
23+
24+
class CompressionDetectionTest(tf.test.TestCase):
25+
"""Test compression metadata detection when compression is disabled."""
26+
27+
def _create_test_model(self):
28+
"""Create a simple quantized model for testing."""
29+
model = tf.keras.Sequential([
30+
tf.keras.layers.Dense(10, input_shape=(5, ), activation='relu'),
31+
tf.keras.layers.Dense(5, activation='softmax')
32+
])
33+
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
34+
35+
# Convert to quantized TFLite
36+
converter = tf.lite.TFLiteConverter.from_keras_model(model)
37+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
38+
39+
def representative_dataset():
40+
for _ in range(10):
41+
yield [np.random.randn(1, 5).astype(np.float32)]
42+
43+
converter.representative_dataset = representative_dataset
44+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
45+
converter.inference_input_type = tf.uint8
46+
converter.inference_output_type = tf.uint8
47+
48+
tflite_model = converter.convert()
49+
return bytes(tflite_model) if isinstance(tflite_model,
50+
bytearray) else tflite_model
51+
52+
def test_regular_model_loads_successfully(self):
53+
"""Non-compressed models should load without issues."""
54+
model_data = self._create_test_model()
55+
interpreter = runtime.Interpreter.from_bytes(model_data)
56+
self.assertIsNotNone(interpreter)
57+
58+
def test_compressed_model_raises_runtime_error(self):
59+
"""Compressed models should raise RuntimeError when compression is disabled."""
60+
# Create and compress a model
61+
model_data = self._create_test_model()
62+
63+
spec = (compression.SpecBuilder().add_tensor(
64+
subgraph=0, tensor=1).with_lut(index_bitwidth=4).build())
65+
66+
compressed_model = compression.compress(model_data, spec)
67+
if isinstance(compressed_model, bytearray):
68+
compressed_model = bytes(compressed_model)
69+
70+
# Should raise RuntimeError
71+
with self.assertRaises(RuntimeError):
72+
runtime.Interpreter.from_bytes(compressed_model)
73+
74+
def test_can_load_regular_after_compressed_failure(self):
75+
"""Verify we can still load regular models after compressed model fails."""
76+
model_data = self._create_test_model()
77+
78+
# First try compressed model (should fail)
79+
spec = (compression.SpecBuilder().add_tensor(
80+
subgraph=0, tensor=1).with_lut(index_bitwidth=4).build())
81+
compressed_model = compression.compress(model_data, spec)
82+
83+
with self.assertRaises(RuntimeError):
84+
runtime.Interpreter.from_bytes(bytes(compressed_model))
85+
86+
# Then load regular model (should succeed)
87+
interpreter = runtime.Interpreter.from_bytes(model_data)
88+
self.assertIsNotNone(interpreter)
89+
90+
91+
if __name__ == '__main__':
92+
# Set TF environment variables to suppress warnings
93+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
94+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
95+
tf.test.main()

0 commit comments

Comments
 (0)