Skip to content

Commit ae0326e

Browse files
Xharktensorflower-gardener
authored andcommitted
Supports kwarg for the layer call with non-tensor value.
PiperOrigin-RevId: 397241967
1 parent e64ea0b commit ae0326e

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def _get_inbound_layer_names(self, layer):
9191
inbound_layer_names.append(connection_info[0])
9292
# **kwarg argument case.
9393
inbound_layer_names += [
94-
value[0] for value in connection_info[3].items()
94+
value[0] for value in connection_info[3].values() if isinstance(
95+
value, list)
9596
]
9697

9798
return inbound_layer_names
@@ -327,8 +328,9 @@ def _replace_layer_name_for_connection_info(connection_info, match_name,
327328
if connection_info[0] == match_name:
328329
connection_info[0] = replacement_name
329330
for key in connection_info[3]:
330-
if connection_info[3][key][0] == match_name:
331-
connection_info[3][key][0] = replacement_name
331+
if isinstance(connection_info[3][key], list):
332+
if connection_info[3][key][0] == match_name:
333+
connection_info[3][key][0] = replacement_name
332334

333335
for consumer in consuming_layers:
334336
for inbound_node in self._inbound_node_generator(consumer):

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,61 @@ def replacement(self, match_layer):
289289
# Should match since bias is initialized with zeros.
290290
self._assert_model_results_equal(model, transformed_model)
291291

292+
def testReplaceSingleLayerWithSingleLayer_CustomLayerWithNontensorInput(self):
293+
294+
class CustomDense(keras.layers.Dense):
295+
296+
def call(self, inputs, identity=False):
297+
if identity:
298+
return inputs
299+
return super().call(inputs)
300+
301+
class QuantizedCustomDense(CustomDense):
302+
pass
303+
304+
class ReplaceCustomDenseLayer(transforms.Transform):
305+
"""Replaces `CustomDense` layers with `QuantizedCustomDense`."""
306+
307+
def pattern(self):
308+
return LayerPattern('CustomDense')
309+
310+
def replacement(self, match_layer):
311+
match_layer_config = match_layer.layer['config']
312+
my_dense_layer = QuantizedCustomDense(**match_layer_config)
313+
314+
replace_layer = keras.layers.serialize(my_dense_layer)
315+
replace_layer['name'] = replace_layer['config']['name']
316+
317+
return LayerNode(replace_layer, match_layer.weights, [])
318+
319+
def custom_objects(self):
320+
return {
321+
'CustomDense': CustomDense,
322+
'QuantizedCustomDense': QuantizedCustomDense}
323+
324+
inp = keras.layers.Input((3,))
325+
x1 = CustomDense(2)(inp, identity=False)
326+
x2 = CustomDense(2)(x1, identity=True)
327+
out = keras.layers.ReLU(6.0)(x2)
328+
model = keras.Model(inp, [out])
329+
330+
transformed_model, _ = ModelTransformer(
331+
model, [ReplaceCustomDenseLayer()]).transform()
332+
333+
# build_input_shape is a TensorShape object and the two objects are not
334+
# considered the same even though the shapes are the same.
335+
self._assert_config(model.get_config(), transformed_model.get_config(),
336+
['class_name', 'build_input_shape'])
337+
338+
self.assertEqual(
339+
'QuantizedCustomDense',
340+
self._get_layer(transformed_model, 0, 'functional').__class__.__name__)
341+
self.assertEqual(
342+
'QuantizedCustomDense',
343+
self._get_layer(transformed_model, 1, 'functional').__class__.__name__)
344+
345+
self._assert_model_results_equal(model, transformed_model)
346+
292347
@parameterized.parameters(['sequential', 'functional'])
293348
def testReplaceSingleLayer_WithMultipleLayers(self, model_type):
294349

0 commit comments

Comments
 (0)