Skip to content

Commit aa2cb7c

Browse files
nkovela1tensorflower-gardener
authored andcommitted
Fix broken TFMO tests after Sequential serialization fix.
PiperOrigin-RevId: 544483744
1 parent 70bb2e3 commit aa2cb7c

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,12 +463,13 @@ def replacement(self, match_layer):
463463
inp = keras.layers.Input((3,))
464464
out = keras.layers.Dense(2, activation='relu')(inp)
465465
model_fused = keras.Model(inp, out)
466-
# Ensures old Keras serialization format
467-
model_fused.use_legacy_config = True
468466
else:
469467
model_fused = keras.Sequential(
470468
[keras.layers.Dense(2, activation='relu', input_shape=(3,))])
471469

470+
# Ensures old Keras serialization format
471+
model_fused.use_legacy_config = True
472+
472473
if model_type == 'functional':
473474
inp = keras.layers.Input((3,))
474475
x = keras.layers.Dense(2)(inp)

tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,12 @@ def _remove_keys(config):
7979
if isinstance(item, dict):
8080
_remove_keys(item)
8181

82+
# Ensure the same config format
83+
model1.use_legacy_config, model2.use_legacy_config = True, True
8284
model1_config = model1.get_config()
8385
model2_config = model2.get_config()
8486
exclude_keys = exclude_keys or []
85-
exclude_keys += ['build_input_shape', 'build_config']
87+
exclude_keys += ['build_config'] # Exclude model build information
8688
_remove_keys(model1_config)
8789
_remove_keys(model2_config)
8890
self.assertEqual(model1_config, model2_config)

0 commit comments

Comments
 (0)