diff --git a/python/tflite_micro/BUILD b/python/tflite_micro/BUILD index 7ed7f21e36e..539fecac078 100644 --- a/python/tflite_micro/BUILD +++ b/python/tflite_micro/BUILD @@ -153,6 +153,7 @@ py_package( packages = [ "python.tflite_micro", "tensorflow.lite.micro.compression", + "tensorflow.lite.micro.tools", "tensorflow.lite.micro.tools.generate_test_for_model", "tensorflow.lite.python", "tensorflow.lite.tools.flatbuffer_utils", @@ -162,6 +163,7 @@ py_package( ":runtime", ":version", "//tensorflow/lite/micro/compression", + "//tensorflow/lite/micro/tools:tflite_flatbuffer_align", ], ) diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index bfb0d6e4c2d..4690fbe44bd 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -116,6 +116,7 @@ py_library( ":metadata_py", ":model_facade", ":spec", + "//tensorflow/lite/micro/tools:tflite_flatbuffer_align", "@absl_py//absl:app", "@absl_py//absl/flags", "@flatbuffers//:runtime_py", @@ -140,6 +141,11 @@ py_test( srcs = [ "compress_test.py", ], + tags = [ + "noasan", + "nomsan", # Sanitizer symbols don't work with Python extension libs + "noubsan", + ], deps = [ ":compress", ":metadata_py", diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index 93bfb2e814b..79959a7f612 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -19,7 +19,9 @@ import bitarray import bitarray.util from dataclasses import dataclass, field +import os import sys +import tempfile from typing import ByteString, Iterable, Optional import absl.app @@ -30,6 +32,7 @@ from tflite_micro.tensorflow.lite.micro.compression import model_facade from tflite_micro.tensorflow.lite.micro.compression import spec from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema +from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper USAGE = f"""\ Usage: compress.py --input --spec [--output ] @@ -250,6 +253,43 @@ def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray: return buffer +def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray: + """Applies proper FlatBuffer alignment to a model. + + The Python flatbuffers library doesn't respect `force_align` schema attributes, + so we use the C++ wrapper which properly handles alignment requirements. + + Args: + model_bytes: The model flatbuffer to align + + Returns: + The properly aligned model flatbuffer + """ + # C++ wrapper requires file paths, not byte buffers + with tempfile.NamedTemporaryFile(suffix='.tflite', delete=False) as temp_in: + temp_in.write(model_bytes) + temp_in_path = temp_in.name + + with tempfile.NamedTemporaryFile(suffix='.tflite', delete=False) as temp_out: + temp_out_path = temp_out.name + + try: + # Unpack and repack with proper alignment + tflite_flatbuffer_align_wrapper.align_tflite_model(temp_in_path, + temp_out_path) + + with open(temp_out_path, 'rb') as f: + aligned_model = bytearray(f.read()) + + return aligned_model + finally: + # Clean up temporary files + if os.path.exists(temp_in_path): + os.unlink(temp_in_path) + if os.path.exists(temp_out_path): + os.unlink(temp_out_path) + + def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: """Compresses a model .tflite flatbuffer. @@ -291,7 +331,9 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: # add compression metadata to model model.add_metadata(TFLITE_METADATA_KEY, metadata.compile()) - return model.compile() + # Compile the model and apply proper alignment + unaligned_model = model.compile() + return _apply_flatbuffer_alignment(unaligned_model) def _fail_w_usage() -> int: diff --git a/tensorflow/lite/micro/docs/compression.md b/tensorflow/lite/micro/docs/compression.md index e29069fbfef..4698987c6e8 100644 --- a/tensorflow/lite/micro/docs/compression.md +++ b/tensorflow/lite/micro/docs/compression.md @@ -323,11 +323,6 @@ Once the `YAML` specification is ready, compress the model using the following: bazel run -s tensorflow/lite/micro/compression:compress -- --input=binned.tflite --output=compressed.tflite --spec=spec.yaml ``` -Then align the model: -``` -bazel run -s tensorflow/lite/micro/tools:tflite_flatbuffer_align -- compressed.tflite compressed_and_aligned.tflite -``` - # The Generic Benchmark Application The Generic Benchmark Application can be used to see the size of the model, the