Skip to content

Commit 9955ad3

Browse files
committed
feat(compression): integrate automatic FlatBuffer alignment
Enhance compress() function to automatically apply proper FlatBuffer alignment after compression, eliminating the need for users to manually run tflite_flatbuffer_align as a separate step. Use the C++ alignment wrapper internally, as the Python flatbuffers library doesn't respect force_align schema attributes. Keep the API unchanged - compress() still returns a bytearray, but now the output is properly aligned for the TFLM interpreter. Update documentation, and build dependencies of the Python package. BUG=#3125
1 parent 2b3f8d0 commit 9955ad3

File tree

4 files changed

+46
-6
lines changed

4 files changed

+46
-6
lines changed

python/tflite_micro/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ py_package(
153153
packages = [
154154
"python.tflite_micro",
155155
"tensorflow.lite.micro.compression",
156+
"tensorflow.lite.micro.tools",
156157
"tensorflow.lite.micro.tools.generate_test_for_model",
157158
"tensorflow.lite.python",
158159
"tensorflow.lite.tools.flatbuffer_utils",
@@ -162,6 +163,7 @@ py_package(
162163
":runtime",
163164
":version",
164165
"//tensorflow/lite/micro/compression",
166+
"//tensorflow/lite/micro/tools:tflite_flatbuffer_align",
165167
],
166168
)
167169

tensorflow/lite/micro/compression/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ py_library(
116116
":metadata_py",
117117
":model_facade",
118118
":spec",
119+
"//tensorflow/lite/micro/tools:tflite_flatbuffer_align",
119120
"@absl_py//absl:app",
120121
"@absl_py//absl/flags",
121122
"@flatbuffers//:runtime_py",

tensorflow/lite/micro/compression/compress.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
import bitarray
2020
import bitarray.util
2121
from dataclasses import dataclass, field
22+
import os
2223
import sys
24+
import tempfile
2325
from typing import ByteString, Iterable, Optional
2426

2527
import absl.app
@@ -30,6 +32,7 @@
3032
from tflite_micro.tensorflow.lite.micro.compression import model_facade
3133
from tflite_micro.tensorflow.lite.micro.compression import spec
3234
from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema
35+
from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper
3336

3437
USAGE = f"""\
3538
Usage: compress.py --input <in.tflite> --spec <spec.yaml> [--output <out.tflite>]
@@ -250,6 +253,43 @@ def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray:
250253
return buffer
251254

252255

256+
def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray:
257+
"""Applies proper FlatBuffer alignment to a model.
258+
259+
The Python flatbuffers library doesn't respect `force_align` schema attributes,
260+
so we use the C++ wrapper which properly handles alignment requirements.
261+
262+
Args:
263+
model_bytes: The model flatbuffer to align
264+
265+
Returns:
266+
The properly aligned model flatbuffer
267+
"""
268+
# C++ wrapper requires file paths, not byte buffers
269+
with tempfile.NamedTemporaryFile(suffix='.tflite', delete=False) as temp_in:
270+
temp_in.write(model_bytes)
271+
temp_in_path = temp_in.name
272+
273+
with tempfile.NamedTemporaryFile(suffix='.tflite', delete=False) as temp_out:
274+
temp_out_path = temp_out.name
275+
276+
try:
277+
# Unpack and repack with proper alignment
278+
tflite_flatbuffer_align_wrapper.align_tflite_model(temp_in_path,
279+
temp_out_path)
280+
281+
with open(temp_out_path, 'rb') as f:
282+
aligned_model = bytearray(f.read())
283+
284+
return aligned_model
285+
finally:
286+
# Clean up temporary files
287+
if os.path.exists(temp_in_path):
288+
os.unlink(temp_in_path)
289+
if os.path.exists(temp_out_path):
290+
os.unlink(temp_out_path)
291+
292+
253293
def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray:
254294
"""Compresses a model .tflite flatbuffer.
255295
@@ -291,7 +331,9 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray:
291331
# add compression metadata to model
292332
model.add_metadata(TFLITE_METADATA_KEY, metadata.compile())
293333

294-
return model.compile()
334+
# Compile the model and apply proper alignment
335+
unaligned_model = model.compile()
336+
return _apply_flatbuffer_alignment(unaligned_model)
295337

296338

297339
def _fail_w_usage() -> int:

tensorflow/lite/micro/docs/compression.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,6 @@ Once the `YAML` specification is ready, compress the model using the following:
323323
bazel run -s tensorflow/lite/micro/compression:compress -- --input=binned.tflite --output=compressed.tflite --spec=spec.yaml
324324
```
325325
326-
Then align the model:
327-
```
328-
bazel run -s tensorflow/lite/micro/tools:tflite_flatbuffer_align -- compressed.tflite compressed_and_aligned.tflite
329-
```
330-
331326
# The Generic Benchmark Application
332327
333328
The Generic Benchmark Application can be used to see the size of the model, the

0 commit comments

Comments
 (0)