Skip to content

Commit 0607d0b

Browse files
haothutensorflower-gardener
authored andcommitted
Internal change.
PiperOrigin-RevId: 394577401
1 parent 3327f63 commit 0607d0b

File tree

2 files changed

+84
-38
lines changed

2 files changed

+84
-38
lines changed

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@
3232
class ModelTransformer(object):
3333
"""Matches patterns to apply transforms in a tf.keras model graph."""
3434

35-
def __init__(
36-
self, model, transforms, candidate_layers=None, layer_metadata=None):
35+
def __init__(self,
36+
model,
37+
transforms,
38+
candidate_layers=None,
39+
layer_metadata=None):
3740
"""Construct ModelTransformer.
3841
3942
Args:
@@ -68,7 +71,8 @@ def _is_functional_model(self, model):
6871

6972
def _inbound_node_generator(self, layer):
7073
for inbound_node in layer['inbound_nodes']:
71-
if len(inbound_node) > 0 and isinstance(inbound_node[0], str):
74+
if (isinstance(inbound_node, list) and len(inbound_node) > 0 and
75+
isinstance(inbound_node[0], str)):
7276
# TODO(tfmot): The case for the SlicingOpLambda.
7377
yield [inbound_node]
7478
else:
@@ -78,12 +82,17 @@ def _get_inbound_layer_names(self, layer):
7882
"""Return all the inbound connection layer names for the layer."""
7983
inbound_layer_names = []
8084
for inbound_node in self._inbound_node_generator(layer):
85+
# TODO(b/197935452): temporary fix when the input is a dictionary of
86+
# tensors. A comprehensive solution may be needed.
87+
if isinstance(inbound_node, dict):
88+
inbound_node = inbound_node.values()
8189
for connection_info in inbound_node:
8290
# input argument case.
8391
inbound_layer_names.append(connection_info[0])
8492
# **kwarg argument case.
8593
inbound_layer_names += [
86-
value[0] for value in connection_info[3].items()]
94+
value[0] for value in connection_info[3].items()
95+
]
8796

8897
return inbound_layer_names
8998

@@ -212,10 +221,10 @@ def _match_layer_with_inputs(self, layer, pattern, is_head_node):
212221

213222
if len(pattern.inputs) == 0:
214223
# Leaf layer in pattern.
215-
return LayerNode(layer, self._get_layer_weights(layer['config']['name']),
216-
[], self._get_layer_metadata(layer['config']['name']),
217-
self._get_layer_names_and_weights(
218-
layer['config']['name']))
224+
return LayerNode(
225+
layer, self._get_layer_weights(layer['config']['name']), [],
226+
self._get_layer_metadata(layer['config']['name']),
227+
self._get_layer_names_and_weights(layer['config']['name']))
219228

220229
# There is a possible edge case where a single layer may output multiple
221230
# tensors and multiple tensors from that layer may be used by the
@@ -313,8 +322,8 @@ def _replace_functional(self, match_layer_node, replacement_layer_node):
313322
match_name = match_layer_node.layer['config']['name']
314323
replacement_name = replacement_layer_node.layer['config']['name']
315324

316-
def _replace_layer_name_for_connection_info(
317-
connection_info, match_name, replacement_name):
325+
def _replace_layer_name_for_connection_info(connection_info, match_name,
326+
replacement_name):
318327
if connection_info[0] == match_name:
319328
connection_info[0] = replacement_name
320329
for key in connection_info[3]:
@@ -323,9 +332,11 @@ def _replace_layer_name_for_connection_info(
323332

324333
for consumer in consuming_layers:
325334
for inbound_node in self._inbound_node_generator(consumer):
335+
if isinstance(inbound_node, dict):
336+
inbound_node = inbound_node.values()
326337
for connection_info in inbound_node:
327-
_replace_layer_name_for_connection_info(
328-
connection_info, match_name, replacement_name)
338+
_replace_layer_name_for_connection_info(connection_info, match_name,
339+
replacement_name)
329340

330341
output_consumers = self._get_output_consumers(match_layer_node.layer)
331342
for output_consumer in output_consumers:
@@ -493,8 +504,7 @@ def _set_layer_weights(self, layer, weights_map):
493504
for weight_tensor in layer.weights:
494505
weight_name = self._weight_name(weight_tensor.name)
495506
if weight_name in weights_map:
496-
weight_value_tuples.append(
497-
(weight_tensor, weights_map[weight_name]))
507+
weight_value_tuples.append((weight_tensor, weights_map[weight_name]))
498508

499509
K.batch_set_value(weight_value_tuples)
500510

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,14 @@ def _get_layer(self, model, n_excluding_input, model_type):
6060
return model.layers[n_excluding_input]
6161

6262
def _create_model_inputs(self, model):
63-
return np.random.randn(*self._batch(model.input.get_shape().as_list(), 1))
63+
if isinstance(model.input, dict):
64+
inputs = {}
65+
for key, input_layer in model.input.items():
66+
inputs[key] = np.random.randn(
67+
*self._batch(input_layer.get_shape().as_list(), 1))
68+
return inputs
69+
else:
70+
return np.random.randn(*self._batch(model.input.get_shape().as_list(), 1))
6471

6572
def _simple_dense_model(self, model_type='functional'):
6673
if model_type == 'functional':
@@ -90,9 +97,7 @@ def _nested_model(self, model_type='functional', submodel_type='functional'):
9097
out = keras.layers.ReLU(6.0)(x)
9198
return keras.Model(inp, out)
9299
elif model_type == 'sequential':
93-
return keras.Sequential(
94-
[submodel,
95-
keras.layers.ReLU(6.0)])
100+
return keras.Sequential([submodel, keras.layers.ReLU(6.0)])
96101

97102
def _assert_config(self, expected_config, actual_config, exclude_keys=None):
98103
"""Asserts that the two config dictionaries are equal.
@@ -216,6 +221,35 @@ def testReplaceSingleLayerWithSingleLayer_MultipleOccurrences(
216221

217222
self._assert_model_results_equal(model, transformed_model)
218223

224+
def testReplaceSingleLayerWithSingleLayer_DictInputOutput(self):
225+
inp = {
226+
'input1': keras.layers.Input((3,)),
227+
'input2': keras.layers.Input((3,))
228+
}
229+
x1 = keras.layers.Dense(2)(inp['input1'])
230+
x2 = keras.layers.Dense(2)(inp['input2'])
231+
out1 = keras.layers.ReLU(6.0)(x1)
232+
out2 = keras.layers.ReLU(6.0)(x2)
233+
model = keras.Model(inp, {'output1': out1, 'output2': out2})
234+
235+
transformed_model, _ = ModelTransformer(
236+
model, [self.ReplaceDenseLayer()]).transform()
237+
238+
# build_input_shape is a TensorShape object and the two objects are not
239+
# considered the same even though the shapes are the same.
240+
self._assert_config(model.get_config(), transformed_model.get_config(),
241+
['class_name', 'build_input_shape'])
242+
243+
# There are two input layers in the input dict.
244+
self.assertEqual(
245+
'MyDense',
246+
self._get_layer(transformed_model, 1, 'functional').__class__.__name__)
247+
self.assertEqual(
248+
'MyDense',
249+
self._get_layer(transformed_model, 2, 'functional').__class__.__name__)
250+
251+
self._assert_model_results_equal(model, transformed_model)
252+
219253
@parameterized.parameters(['sequential', 'functional'])
220254
def testReplaceSingleLayerWithSingleLayer_MatchParameters(self, model_type):
221255

@@ -241,8 +275,8 @@ def replacement(self, match_layer):
241275

242276
model = self._simple_dense_model(model_type)
243277

244-
transformed_model, _ = ModelTransformer(
245-
model, [RemoveBiasInDense()]).transform()
278+
transformed_model, _ = ModelTransformer(model,
279+
[RemoveBiasInDense()]).transform()
246280

247281
# build_input_shape is a TensorShape object and the two objects are not
248282
# considered the same even though the shapes are the same.
@@ -312,8 +346,7 @@ def replacement(self, match_layer):
312346
layer_config['name'] = activation_layer.name
313347

314348
activation_layer_node = LayerNode(
315-
layer_config,
316-
input_layers=[match_layer])
349+
layer_config, input_layers=[match_layer])
317350

318351
return activation_layer_node
319352

@@ -371,8 +404,8 @@ def replacement(self, match_layer):
371404
keras.layers.ReLU()])
372405
model.set_weights(model_fused.get_weights())
373406

374-
transformed_model, _ = ModelTransformer(
375-
model, [FuseReLUIntoDense()]).transform()
407+
transformed_model, _ = ModelTransformer(model,
408+
[FuseReLUIntoDense()]).transform()
376409

377410
self._assert_config(
378411
model_fused.get_config(),
@@ -430,6 +463,7 @@ def replacement(self, match_layer):
430463
['build_input_shape'])
431464

432465
def testReplaceListOfLayers_Sequential(self):
466+
433467
class ReplaceConvBatchNorm(transforms.Transform):
434468
"""Replaces a ConvBatchNorm pattern with the same set of layers.
435469
@@ -438,8 +472,8 @@ class ReplaceConvBatchNorm(transforms.Transform):
438472
"""
439473

440474
def pattern(self):
441-
return LayerPattern('BatchNormalization',
442-
inputs=[LayerPattern('Conv2D')])
475+
return LayerPattern(
476+
'BatchNormalization', inputs=[LayerPattern('Conv2D')])
443477

444478
def replacement(self, match_layer):
445479
# Adds a modification so the transform happens. If the layers are
@@ -457,7 +491,8 @@ def replacement(self, match_layer):
457491
transformed_model, _ = ModelTransformer(
458492
model, [ReplaceConvBatchNorm()]).transform()
459493
transformed_model_layer_names = [
460-
layer.name for layer in transformed_model.layers]
494+
layer.name for layer in transformed_model.layers
495+
]
461496

462497
self.assertEqual(model_layer_names, transformed_model_layer_names)
463498

@@ -495,10 +530,10 @@ def replacement(self, match_layer):
495530

496531
model = self._simple_dense_model(model_type)
497532
transformed_model, _ = ModelTransformer(
498-
model,
499-
[ReplaceReLUWithSoftmax(), ReplaceSoftmaxWithELU()],
500-
candidate_layers=set([layer.name for layer in model.layers])
501-
).transform()
533+
model, [ReplaceReLUWithSoftmax(),
534+
ReplaceSoftmaxWithELU()],
535+
candidate_layers=set([layer.name for layer in model.layers
536+
])).transform()
502537

503538
self.assertEqual(transformed_model.layers[-1].__class__.__name__, 'ELU')
504539

@@ -515,8 +550,8 @@ def replacement(self, match_layer):
515550

516551
model = self._simple_dense_model(model_type)
517552

518-
transformed_model, _ = ModelTransformer(
519-
model, [ReplaceWithSelf()]).transform()
553+
transformed_model, _ = ModelTransformer(model,
554+
[ReplaceWithSelf()]).transform()
520555

521556
# build_input_shape is a TensorShape object and the two objects are not
522557
# considered the same even though the shapes are the same.
@@ -689,8 +724,8 @@ def replacement(self, match_layer):
689724
}
690725
}
691726

692-
transformer = ModelTransformer(
693-
model, [ReplaceLayerMetadata()], None, layer_metadata)
727+
transformer = ModelTransformer(model, [ReplaceLayerMetadata()], None,
728+
layer_metadata)
694729
transformed_model, updated_metadata = transformer.transform()
695730

696731
self.assertEqual(expected_metadata, updated_metadata)
@@ -704,12 +739,12 @@ def replacement(self, match_layer):
704739
('sequential', 'sequential'),
705740
('sequential', 'functional'),
706741
('functional', 'sequential'),
707-
('functional', 'functional'),])
742+
('functional', 'functional'),
743+
])
708744
def testNestedModelNoChange(self, model_type, submodel_type):
709745
model = self._nested_model(model_type, submodel_type)
710746

711-
transformed_model, _ = ModelTransformer(
712-
model, []).transform()
747+
transformed_model, _ = ModelTransformer(model, []).transform()
713748

714749
# build_input_shape is a TensorShape object and the two objects are not
715750
# considered the same even though the shapes are the same.
@@ -721,6 +756,7 @@ def testNestedModelNoChange(self, model_type, submodel_type):
721756
# Validation Tests
722757

723758
def testRaisesErrorForSubclassModels(self):
759+
724760
class MyModel(keras.Model):
725761
pass
726762

0 commit comments

Comments
 (0)