Skip to content

Commit 3d73ed2

Browse files
Xharktensorflower-gardener
authored andcommitted
Add a fallback that uses (name, weight) tuple list when weight map key has a collision.
PiperOrigin-RevId: 367911645
1 parent 3082412 commit 3d73ed2

File tree

3 files changed

+87
-3
lines changed

3 files changed

+87
-3
lines changed

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def _get_layers(self, layer_names):
9595
def _get_layer_weights(self, layer_name):
9696
return self._layer_weights_map.get(layer_name, {})
9797

98+
def _get_layer_names_and_weights(self, layer_name):
99+
return self._layer_names_and_weights_map.get(layer_name, {})
100+
98101
def _get_layer_metadata(self, layer_name):
99102
return self._layer_metadata_map.get(layer_name, {})
100103

@@ -191,7 +194,9 @@ def _match_layer_with_inputs(self, layer, pattern, is_head_node):
191194
if len(pattern.inputs) == 0:
192195
# Leaf layer in pattern.
193196
return LayerNode(layer, self._get_layer_weights(layer['config']['name']),
194-
[], self._get_layer_metadata(layer['config']['name']))
197+
[], self._get_layer_metadata(layer['config']['name']),
198+
self._get_layer_names_and_weights(
199+
layer['config']['name']))
195200

196201
# There is a possible edge case where a single layer may output multiple
197202
# tensors and multiple tensors from that layer may be used by the
@@ -221,7 +226,8 @@ def _match_layer_with_inputs(self, layer, pattern, is_head_node):
221226

222227
return LayerNode(layer, self._get_layer_weights(layer['config']['name']),
223228
input_match_layer_nodes,
224-
self._get_layer_metadata(layer['config']['name']))
229+
self._get_layer_metadata(layer['config']['name']),
230+
self._get_layer_names_and_weights(layer['config']['name']))
225231

226232
def _find_pattern(self, pattern, matched_layers=None):
227233
for layer in self._config['layers']:
@@ -265,6 +271,7 @@ def _remove_layers(self, layers_to_remove, layers_to_remove_names):
265271
# now that layer has been removed.
266272
for layer_name in layers_to_remove_names:
267273
self._layer_weights_map.pop(layer_name, None)
274+
self._layer_names_and_weights_map.pop(layer_name, None)
268275
self._layer_metadata_map.pop(layer_name, None)
269276

270277
def _replace(self, match_layer_node, replacement_layer_node):
@@ -355,8 +362,12 @@ def _add_replacement_layer(layer_node):
355362
"""Recursively add new layers."""
356363
self._config['layers'].append(layer_node.layer)
357364
layer_name = layer_node.layer['config']['name']
365+
# TODO(b/184603494): Remove weight map structure from model_transformer.
358366
if layer_node.weights:
359367
self._layer_weights_map[layer_name] = layer_node.weights
368+
if layer_node.names_and_weights:
369+
self._layer_names_and_weights_map[
370+
layer_name] = layer_node.names_and_weights
360371
if layer_node.metadata:
361372
self._layer_metadata_map[layer_name] = layer_node.metadata
362373
if self.candidate_layers:
@@ -403,6 +414,9 @@ def _add_replacement_nodes(first_layer_removed_index, replacement_nodes):
403414
layer_name = replacement_node.layer['config']['name']
404415
if replacement_node.weights:
405416
self._layer_weights_map[layer_name] = replacement_node.weights
417+
if replacement_node.names_and_weights:
418+
self._layer_names_and_weights_map[
419+
layer_name] = replacement_node.names_and_weights
406420
if replacement_node.metadata:
407421
self._layer_metadata_map[layer_name] = replacement_node.metadata
408422
if self.candidate_layers:
@@ -433,8 +447,17 @@ def _get_keras_layer_weights(self, keras_layer):
433447
zip(keras_layer.weights, keras_layer.get_weights()):
434448
weights_map[self._weight_name(weight_tensor.name)] = weight_numpy
435449

450+
if len(weights_map) != len(keras_layer.weights):
451+
# The case that variable identifier is not unique. It's a fallback that
452+
# uses weight list instead of the weights map.
453+
return None
454+
436455
return weights_map
437456

457+
def _get_keras_layer_names_and_weights(self, keras_layer):
458+
return zip([weight.name for weight in keras_layer.weights],
459+
keras_layer.get_weights())
460+
438461
def _set_layer_weights(self, layer, weights_map):
439462
"""Sets the values of weights in a Keras layer."""
440463

@@ -447,6 +470,9 @@ def _set_layer_weights(self, layer, weights_map):
447470

448471
K.batch_set_value(weight_value_tuples)
449472

473+
def _set_layer_names_and_weights(self, layer, names_and_weights):
474+
layer.set_weights([weight for _, weight in names_and_weights])
475+
450476
@staticmethod
451477
def _name(obj):
452478
return obj.__class__.__name__
@@ -496,8 +522,12 @@ def transform(self):
496522
# to prevent infinite loops.
497523
self._transform_matched_layers_map = {}
498524
self._layer_weights_map = {}
525+
self._layer_names_and_weights_map = {}
526+
499527
for layer in self.model.layers:
500528
self._layer_weights_map[layer.name] = self._get_keras_layer_weights(layer)
529+
self._layer_names_and_weights_map[
530+
layer.name] = self._get_keras_layer_names_and_weights(layer)
501531

502532
# Maintains a current mutable copy of the metadata through transformation.
503533
self._layer_metadata_map = copy.deepcopy(self.layer_metadata)
@@ -557,5 +587,9 @@ def transform(self):
557587
weights = self._layer_weights_map.get(layer.name)
558588
if weights:
559589
self._set_layer_weights(layer, weights)
590+
else:
591+
names_and_weights = self._layer_names_and_weights_map.get(layer.name)
592+
if names_and_weights:
593+
self._set_layer_names_and_weights(layer, names_and_weights)
560594

561595
return transformed_model, copy.deepcopy(self._layer_metadata_map)

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,27 @@ def _simple_dense_model(self, model_type='functional'):
7575
[keras.layers.Dense(2, input_shape=(3,)),
7676
keras.layers.ReLU(6.0)])
7777

78+
def _nested_model(self, model_type='functional', submodel_type='functional'):
79+
if submodel_type == 'functional':
80+
inp = keras.layers.Input((3,))
81+
x = keras.layers.Dense(2)(inp)
82+
out = keras.layers.Dense(2)(x)
83+
submodel = keras.Model(inp, out)
84+
elif submodel_type == 'sequential':
85+
submodel = keras.Sequential(
86+
[keras.layers.Dense(2, input_shape=(3,)),
87+
keras.layers.Dense(2)])
88+
89+
if model_type == 'functional':
90+
inp = keras.layers.Input((3,))
91+
x = submodel(inp)
92+
out = keras.layers.ReLU(6.0)(x)
93+
return keras.Model(inp, out)
94+
elif model_type == 'sequential':
95+
return keras.Sequential(
96+
[submodel,
97+
keras.layers.ReLU(6.0)])
98+
7899
def _assert_config(self, expected_config, actual_config, exclude_keys=None):
79100
"""Asserts that the two config dictionaries are equal.
80101
@@ -681,6 +702,24 @@ def replacement(self, match_layer):
681702
self._assert_config(model.get_config(), transformed_model.get_config(),
682703
['build_input_shape'])
683704

705+
@parameterized.parameters([
706+
('sequential', 'sequential'),
707+
('sequential', 'functional'),
708+
('functional', 'sequential'),
709+
('functional', 'functional'),])
710+
def testNestedModelNoChange(self, model_type, submodel_type):
711+
model = self._nested_model(model_type, submodel_type)
712+
713+
transformed_model, _ = ModelTransformer(
714+
model, []).transform()
715+
716+
# build_input_shape is a TensorShape object and the two objects are not
717+
# considered the same even though the shapes are the same.
718+
self._assert_config(model.get_config(), transformed_model.get_config(),
719+
['class_name', 'build_input_shape'])
720+
721+
self._assert_model_results_equal(model, transformed_model)
722+
684723
# Validation Tests
685724

686725
def testRaisesErrorForSubclassModels(self):

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/transforms.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,26 +79,37 @@ class LayerNode(object):
7979
been found in a model, and layers which should be replaced inside the model.
8080
"""
8181

82-
def __init__(self, layer, weights=None, input_layers=None, metadata=None):
82+
def __init__(
83+
self,
84+
layer,
85+
weights=None,
86+
input_layers=None,
87+
metadata=None,
88+
names_and_weights=None):
8389
"""Construct a LayerNode representing a tree of layers.
8490
8591
Args:
8692
layer: layer config of this node.
8793
weights: An OrderedDict of weight name => value for the layer.
8894
input_layers: List of `LayerNode`s that feed into this layer.
8995
metadata: Dictionary of metadata for a given layer.
96+
names_and_weights: A list of tuples (name, weight). It only used when
97+
weights dictionary is empty.
9098
"""
9199
if weights is None:
92100
weights = collections.OrderedDict()
93101
if input_layers is None:
94102
input_layers = []
95103
if metadata is None:
96104
metadata = {}
105+
if names_and_weights is None:
106+
names_and_weights = []
97107

98108
self.layer = layer
99109
self.weights = weights
100110
self.input_layers = input_layers
101111
self.metadata = metadata
112+
self.names_and_weights = names_and_weights
102113

103114
def __str__(self):
104115
return '{} <- [{}]'.format(

0 commit comments

Comments
 (0)