Skip to content

Commit 86ecc96

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Sequential model transformations are broken. Replacing
a pattern with replacement layers adds the replacement layers to the wrong part of the model. This basically ends up creating an incorrect model. Adding a test to reproduce the error. The previous tests were passing only because the model was solele composed of matching layers. Further, `array.insert` in python allows inserting into arbitrary index in an empty list. This CL fixes the indexing bug. However, the replacement code is quite complicated and should be simplified. PiperOrigin-RevId: 313129431
1 parent c2642e5 commit 86ecc96

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def _replace_sequential(self, match_layer_node, replacement_layer_node):
374374

375375
# These variables are needed when adding the new layers
376376
# and must be set before _remove_layers removes them.
377-
first_layer_removed = layers_to_remove[-1] # layers_to_remove is reversed.
377+
first_layer_removed = layers_to_remove[0]
378378
first_layer_removed_index = self._config['layers'].index(
379379
first_layer_removed)
380380

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,38 @@ def replacement(self, match_layer):
410410
self._assert_config(model.get_config(), transformed_model.get_config(),
411411
['build_input_shape'])
412412

413+
def testReplaceListOfLayers_Sequential(self):
414+
class ReplaceConvBatchNorm(transforms.Transform):
415+
"""Replaces a ConvBatchNorm pattern with the same set of layers.
416+
417+
Doesn't make any meaningful change to the layer. Just verifies that
418+
replacing multiple layers works as expected.
419+
"""
420+
421+
def pattern(self):
422+
return LayerPattern('BatchNormalization',
423+
inputs=[LayerPattern('Conv2D')])
424+
425+
def replacement(self, match_layer):
426+
# Adds a modification so the transform happens. If the layers are
427+
# exactly the same, they get ignored by transformer.
428+
match_layer.metadata['key'] = 'value'
429+
return match_layer
430+
431+
model = tf.keras.Sequential([
432+
tf.keras.layers.Conv2D(32, 5, input_shape=(28, 28, 1)),
433+
tf.keras.layers.BatchNormalization(),
434+
tf.keras.layers.ReLU(),
435+
])
436+
model_layer_names = [layer.name for layer in model.layers]
437+
438+
transformed_model, _ = ModelTransformer(
439+
model, [ReplaceConvBatchNorm()]).transform()
440+
transformed_model_layer_names = [
441+
layer.name for layer in transformed_model.layers]
442+
443+
self.assertEqual(model_layer_names, transformed_model_layer_names)
444+
413445
def testReplaceTreeOfLayers_WithSingleLayer(self):
414446
# TODO(pulkitb): Implement
415447
pass

0 commit comments

Comments
 (0)