|
19 | 19 | import bitarray
|
20 | 20 | import bitarray.util
|
21 | 21 | from dataclasses import dataclass, field
|
| 22 | +import os |
22 | 23 | import sys
|
| 24 | +import tempfile |
23 | 25 | from typing import ByteString, Iterable, Optional
|
24 | 26 |
|
25 | 27 | import absl.app
|
|
30 | 32 | from tflite_micro.tensorflow.lite.micro.compression import model_facade
|
31 | 33 | from tflite_micro.tensorflow.lite.micro.compression import spec
|
32 | 34 | from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema
|
| 35 | +from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper |
33 | 36 |
|
34 | 37 | USAGE = f"""\
|
35 | 38 | Usage: compress.py --input <in.tflite> --spec <spec.yaml> [--output <out.tflite>]
|
@@ -250,6 +253,43 @@ def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray:
|
250 | 253 | return buffer
|
251 | 254 |
|
252 | 255 |
|
| 256 | +def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray: |
| 257 | + """Applies proper FlatBuffer alignment to a model. |
| 258 | + |
| 259 | + The Python flatbuffers library doesn't respect `force_align` schema attributes, |
| 260 | + so we use the C++ wrapper which properly handles alignment requirements. |
| 261 | + |
| 262 | + Args: |
| 263 | + model_bytes: The model flatbuffer to align |
| 264 | + |
| 265 | + Returns: |
| 266 | + The properly aligned model flatbuffer |
| 267 | + """ |
| 268 | + # C++ wrapper requires file paths, not byte buffers |
| 269 | + with tempfile.NamedTemporaryFile(suffix='.tflite', delete=False) as temp_in: |
| 270 | + temp_in.write(model_bytes) |
| 271 | + temp_in_path = temp_in.name |
| 272 | + |
| 273 | + with tempfile.NamedTemporaryFile(suffix='.tflite', delete=False) as temp_out: |
| 274 | + temp_out_path = temp_out.name |
| 275 | + |
| 276 | + try: |
| 277 | + # Unpack and repack with proper alignment |
| 278 | + tflite_flatbuffer_align_wrapper.align_tflite_model(temp_in_path, |
| 279 | + temp_out_path) |
| 280 | + |
| 281 | + with open(temp_out_path, 'rb') as f: |
| 282 | + aligned_model = bytearray(f.read()) |
| 283 | + |
| 284 | + return aligned_model |
| 285 | + finally: |
| 286 | + # Clean up temporary files |
| 287 | + if os.path.exists(temp_in_path): |
| 288 | + os.unlink(temp_in_path) |
| 289 | + if os.path.exists(temp_out_path): |
| 290 | + os.unlink(temp_out_path) |
| 291 | + |
| 292 | + |
253 | 293 | def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray:
|
254 | 294 | """Compresses a model .tflite flatbuffer.
|
255 | 295 |
|
@@ -291,7 +331,9 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray:
|
291 | 331 | # add compression metadata to model
|
292 | 332 | model.add_metadata(TFLITE_METADATA_KEY, metadata.compile())
|
293 | 333 |
|
294 |
| - return model.compile() |
| 334 | + # Compile the model and apply proper alignment |
| 335 | + unaligned_model = model.compile() |
| 336 | + return _apply_flatbuffer_alignment(unaligned_model) |
295 | 337 |
|
296 | 338 |
|
297 | 339 | def _fail_w_usage() -> int:
|
|
0 commit comments