Skip to content

feat(compression): integrate automatic FlatBuffer alignment #3177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/tflite_micro/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ py_package(
packages = [
"python.tflite_micro",
"tensorflow.lite.micro.compression",
"tensorflow.lite.micro.tools",
"tensorflow.lite.micro.tools.generate_test_for_model",
"tensorflow.lite.python",
"tensorflow.lite.tools.flatbuffer_utils",
Expand All @@ -162,6 +163,7 @@ py_package(
":runtime",
":version",
"//tensorflow/lite/micro/compression",
"//tensorflow/lite/micro/tools:tflite_flatbuffer_align",
],
)

Expand Down
6 changes: 6 additions & 0 deletions tensorflow/lite/micro/compression/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ py_library(
":metadata_py",
":model_facade",
":spec",
"//tensorflow/lite/micro/tools:tflite_flatbuffer_align",
"@absl_py//absl:app",
"@absl_py//absl/flags",
"@flatbuffers//:runtime_py",
Expand All @@ -140,6 +141,11 @@ py_test(
srcs = [
"compress_test.py",
],
tags = [
"noasan",
"nomsan", # Sanitizer symbols don't work with Python extension libs
"noubsan",
],
deps = [
":compress",
":metadata_py",
Expand Down
44 changes: 43 additions & 1 deletion tensorflow/lite/micro/compression/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import bitarray
import bitarray.util
from dataclasses import dataclass, field
import os
import sys
import tempfile
from typing import ByteString, Iterable, Optional

import absl.app
Expand All @@ -30,6 +32,7 @@
from tflite_micro.tensorflow.lite.micro.compression import model_facade
from tflite_micro.tensorflow.lite.micro.compression import spec
from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema
from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper

USAGE = f"""\
Usage: compress.py --input <in.tflite> --spec <spec.yaml> [--output <out.tflite>]
Expand Down Expand Up @@ -250,6 +253,43 @@ def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray:
return buffer


def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray:
"""Applies proper FlatBuffer alignment to a model.
The Python flatbuffers library doesn't respect `force_align` schema attributes,
so we use the C++ wrapper which properly handles alignment requirements.
Args:
model_bytes: The model flatbuffer to align
Returns:
The properly aligned model flatbuffer
"""
# C++ wrapper requires file paths, not byte buffers
with tempfile.NamedTemporaryFile(suffix='.tflite', delete=False) as temp_in:
temp_in.write(model_bytes)
temp_in_path = temp_in.name

with tempfile.NamedTemporaryFile(suffix='.tflite', delete=False) as temp_out:
temp_out_path = temp_out.name

try:
# Unpack and repack with proper alignment
tflite_flatbuffer_align_wrapper.align_tflite_model(temp_in_path,
temp_out_path)

with open(temp_out_path, 'rb') as f:
aligned_model = bytearray(f.read())

return aligned_model
finally:
# Clean up temporary files
if os.path.exists(temp_in_path):
os.unlink(temp_in_path)
if os.path.exists(temp_out_path):
os.unlink(temp_out_path)


def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray:
"""Compresses a model .tflite flatbuffer.
Expand Down Expand Up @@ -291,7 +331,9 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray:
# add compression metadata to model
model.add_metadata(TFLITE_METADATA_KEY, metadata.compile())

return model.compile()
# Compile the model and apply proper alignment
unaligned_model = model.compile()
return _apply_flatbuffer_alignment(unaligned_model)


def _fail_w_usage() -> int:
Expand Down
5 changes: 0 additions & 5 deletions tensorflow/lite/micro/docs/compression.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,6 @@ Once the `YAML` specification is ready, compress the model using the following:
bazel run -s tensorflow/lite/micro/compression:compress -- --input=binned.tflite --output=compressed.tflite --spec=spec.yaml
```
Then align the model:
```
bazel run -s tensorflow/lite/micro/tools:tflite_flatbuffer_align -- compressed.tflite compressed_and_aligned.tflite
```
# The Generic Benchmark Application
The Generic Benchmark Application can be used to see the size of the model, the
Expand Down
Loading