Skip to content

Commit 125d290

Browse files
nkovela1tensorflower-gardener
authored andcommitted
Fixes improper serialization code route for functional models.
PiperOrigin-RevId: 524097977
1 parent 5b241ef commit 125d290

File tree

4 files changed

+17
-2
lines changed

4 files changed

+17
-2
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,10 @@ def testClusterStrippingFunctionalModel(self):
686686
stripped_model = cluster.strip_clustering(clustered_model)
687687

688688
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
689+
690+
# Ensures old Keras serialization format
691+
model.use_legacy_config = True
692+
stripped_model.use_legacy_config = True
689693
self.assertEqual(model.get_config(), stripped_model.get_config())
690694

691695
def testClusterWeightsStrippedWeights(self):

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,9 @@ def transform(self):
554554
# 'output_layers': [ ... ],
555555
# 'name': 'MODEL_NAME',
556556
#
557+
558+
# Ensures old Keras serialization format
559+
self.model.use_legacy_config = True
557560
self._config = self.model.get_config()
558561

559562
# Stores map of Transform -> List of layer names matched by transform.
@@ -631,4 +634,6 @@ def transform(self):
631634
if names_and_weights:
632635
self._set_layer_names_and_weights(layer, names_and_weights)
633636

637+
# Ensures old Keras serialization format
638+
transformed_model.use_legacy_config = True
634639
return transformed_model, copy.deepcopy(self._layer_metadata_map)

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,8 @@ def replacement(self, match_layer):
463463
inp = keras.layers.Input((3,))
464464
out = keras.layers.Dense(2, activation='relu')(inp)
465465
model_fused = keras.Model(inp, out)
466+
# Ensures old Keras serialization format
467+
model_fused.use_legacy_config = True
466468
else:
467469
model_fused = keras.Sequential(
468470
[keras.layers.Dense(2, activation='relu', input_shape=(3,))])

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,11 +461,15 @@ def _quantize(layer): # pylint: disable=missing-docstring
461461
'`quantize_scope` for your calls to `quantize_model` and '
462462
'`quantize_apply`. [%s].' % er) from er
463463

464+
if hasattr(model, 'use_legacy_config'):
465+
model_copy.use_legacy_config = model.use_legacy_config
466+
464467
# 2. Remove QuantizeAnnotate wrappers from the layers in the model. This
465468
# extracts the original model structure (easier to transform), and
466469
# stores relevant quantization information in a map.
467-
(unwrapped_model, layer_quantize_map,
468-
requires_output_quantize) = _extract_original_model(model_copy)
470+
(unwrapped_model, layer_quantize_map, requires_output_quantize) = (
471+
_extract_original_model(model_copy)
472+
)
469473
# Model cloning excludes input layers. Add input layers into the map
470474
# since they need to be matched for patterns as well.
471475
# pylint: disable=protected-access

0 commit comments

Comments
 (0)