Skip to content

Commit 16f4c6b

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Replace already replaced layer in model_transformer
ModelTransformer had a subtle bug in which recursive transforms weren't being applied in cases where candidate layers existed. Else it worked fine. Now replaced layers are added as candidate layers so they can be replaced. PiperOrigin-RevId: 321030782
1 parent 21aac43 commit 16f4c6b

File tree

4 files changed

+52
-11
lines changed

4 files changed

+52
-11
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,4 @@ def apply(self, model, layer_quantize_map):
6565
]
6666
return model_transformer.ModelTransformer(
6767
model, transforms,
68-
layer_quantize_map.keys(), layer_quantize_map).transform()
68+
set(layer_quantize_map.keys()), layer_quantize_map).transform()

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,9 @@ def _match_layer_with_inputs(self, layer, pattern, is_head_node):
212212
# Inbound layers can have different order from the list of input patterns.
213213
# TODO(pulkitb): Fix by checking all permutations.
214214
input_match_layer_nodes = []
215-
for input_layer, pattern in zip(input_layers, pattern.inputs):
215+
for input_layer, pattern_ in zip(input_layers, pattern.inputs):
216216
match_layer_node = self._match_layer_with_inputs(
217-
input_layer, pattern, is_head_node=False)
217+
input_layer, pattern_, is_head_node=False)
218218
if not match_layer_node:
219219
return None
220220
input_match_layer_nodes.append(match_layer_node)
@@ -354,12 +354,13 @@ def _assign_inbounds_for_replacement(layer_node):
354354
def _add_replacement_layer(layer_node):
355355
"""Recursively add new layers."""
356356
self._config['layers'].append(layer_node.layer)
357+
layer_name = layer_node.layer['config']['name']
357358
if layer_node.weights:
358-
self._layer_weights_map[layer_node.layer['config']
359-
['name']] = layer_node.weights
359+
self._layer_weights_map[layer_name] = layer_node.weights
360360
if layer_node.metadata:
361-
self._layer_metadata_map[layer_node.layer['config']
362-
['name']] = layer_node.metadata
361+
self._layer_metadata_map[layer_name] = layer_node.metadata
362+
if self.candidate_layers:
363+
self.candidate_layers.add(layer_name)
363364

364365
for input_layer in layer_node.input_layers:
365366
_add_replacement_layer(input_layer)
@@ -399,12 +400,13 @@ def _add_replacement_nodes(first_layer_removed_index, replacement_nodes):
399400
i = first_layer_removed_index
400401
for replacement_node in replacement_nodes:
401402
self._config['layers'].insert(i, replacement_node.layer)
403+
layer_name = replacement_node.layer['config']['name']
402404
if replacement_node.weights:
403-
self._layer_weights_map[replacement_node.layer['config']
404-
['name']] = replacement_node.weights
405+
self._layer_weights_map[layer_name] = replacement_node.weights
405406
if replacement_node.metadata:
406-
self._layer_metadata_map[replacement_node.layer['config']
407-
['name']] = replacement_node.metadata
407+
self._layer_metadata_map[layer_name] = replacement_node.metadata
408+
if self.candidate_layers:
409+
self.candidate_layers.add(layer_name)
408410
i += 1
409411

410412
replacement_nodes = _get_replacement_nodes(replacement_layer_node)

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,39 @@ def testReplaceTreeOfLayers_WithTreeOfLayers(self):
450450
# TODO(pulkitb): Implement
451451
pass
452452

453+
@parameterized.parameters(['sequential', 'functional'])
454+
def testReplace_AlreadyReplacedLayer_WithAnotherMatch(self, model_type):
455+
"""Verifies a layer replaced using a transform can be replaced again."""
456+
457+
class ReplaceReLUWithSoftmax(Transform):
458+
459+
def pattern(self):
460+
return LayerPattern('ReLU')
461+
462+
def replacement(self, match_layer):
463+
replace_layer = keras.layers.serialize(keras.layers.Softmax())
464+
replace_layer['name'] = replace_layer['config']['name']
465+
return LayerNode(replace_layer)
466+
467+
class ReplaceSoftmaxWithELU(Transform):
468+
469+
def pattern(self):
470+
return LayerPattern('Softmax')
471+
472+
def replacement(self, match_layer):
473+
replace_layer = keras.layers.serialize(keras.layers.ELU())
474+
replace_layer['name'] = replace_layer['config']['name']
475+
return LayerNode(replace_layer)
476+
477+
model = self._simple_dense_model(model_type)
478+
transformed_model, _ = ModelTransformer(
479+
model,
480+
[ReplaceReLUWithSoftmax(), ReplaceSoftmaxWithELU()],
481+
candidate_layers=set([layer.name for layer in model.layers])
482+
).transform()
483+
484+
self.assertEqual(transformed_model.layers[-1].__class__.__name__, 'ELU')
485+
453486
@parameterized.parameters(['sequential', 'functional'])
454487
def testDoesNotMatchForever_IfReplacementEqualsMatch(self, model_type):
455488

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/transforms.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ def __init__(self, class_name, config=None, inputs=None):
6262
self.config = config
6363
self.inputs = inputs
6464

65+
def __str__(self):
66+
return '{} : {} <- [{}]'.format(
67+
self.class_name,
68+
self.config,
69+
', '.join([str(inp) for inp in self.inputs]))
70+
6571

6672
class LayerNode(object):
6773
"""Represents a Node in a tree containing a layer.

0 commit comments

Comments
 (0)