Skip to content

SPACE_TO_BATCH_ND output shape mismatch with TFLiteΒ #3185

@jianyuzzz

Description

@jianyuzzz

When testing the SPACE_TO_BATCH_ND operator, I noticed that TFLite and TFLite Micro produce different output shapes for the same model and input. There is no error from TFLite Micro, but the output shape does not match TFLite (and Keras).

Minimal Example:

import tensorflow as tf

# Minimal Keras model using SPACE_TO_BATCH_ND
inputs = tf.keras.Input(shape=(8, 8, 1))
block_shape = [2, 2]
paddings = [[0, 0], [0, 0]]
outputs = tf.keras.layers.Lambda(lambda x: tf.space_to_batch_nd(x, block_shape, paddings))(inputs)
model = tf.keras.Model(inputs, outputs)

# Convert to TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open("space_to_batch_nd.tflite", "wb") as f:
    f.write(tflite_model)

Inference Comparison:

import numpy as np
from ai_edge_litert.interpreter import Interpreter
from tflite_micro.python.tflite_micro import runtime

# Load TFLite model and prepare input
model_path = "space_to_batch_nd.tflite"
interpreter = Interpreter(model_path=model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
input_shape = input_details[0]["shape"]
input_data = np.random.rand(*input_shape).astype(np.float32)

# TFLite inference
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
tflite_output = interpreter.get_tensor(interpreter.get_output_details()[0]["index"])

# TFLM inference
arena_size = 1000000
tflm_interpreter = runtime.Interpreter.from_file(model_path, arena_size=arena_size)
tflm_interpreter.set_input(input_data, 0)
tflm_interpreter.invoke()
tflm_output = tflm_interpreter.get_output(0)

print("TFLite output shape:", tflite_output.shape)
print("TFLM output shape:", tflm_output.shape)

Observed Behavior:

  • TFLite output shape: (4, 4, 4, 1)
  • TFLite Micro output shape: (1, 4, 4, 1)

Environment:

  • TensorFlow version: 2.18.1
  • tflite-micro version: 0.dev20250812203306
  • Platform: Linux

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions