Skip to content

Commit dca2998

Browse files
Xharktensorflower-gardener
authored andcommitted
Supports **kwargs for QuantizeWrapper.call method.
PiperOrigin-RevId: 377240100
1 parent 8b2bd1f commit dca2998

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,35 @@ def _is_functional_model(model):
6868
and not isinstance(model, keras.Sequential) \
6969
and model._is_graph_network # pylint: disable=protected-access
7070

71+
def _inbound_node_generator(self, layer):
72+
for inbound_node in layer['inbound_nodes']:
73+
if len(inbound_node) > 0 and isinstance(inbound_node[0], str):
74+
# TODO(tfmot): The case for the SlicingOpLambda.
75+
yield [inbound_node]
76+
else:
77+
yield inbound_node
78+
79+
def _get_inbound_layer_names(self, layer):
80+
"""Return all the inbound connection layer names for the layer."""
81+
inbound_layer_names = []
82+
for inbound_node in self._inbound_node_generator(layer):
83+
for connection_info in inbound_node:
84+
# input argument case.
85+
inbound_layer_names.append(connection_info[0])
86+
# **kwarg argument case.
87+
inbound_layer_names += [
88+
value[0] for value in connection_info[3].items()]
89+
90+
return inbound_layer_names
91+
7192
def _get_consuming_layers(self, check_layer):
7293
"""Returns all the layers which are out nodes from the layer."""
7394
consuming_layers = []
95+
check_layer_name = check_layer['config']['name']
7496
for layer in self._config['layers']:
75-
for inbound_node in layer['inbound_nodes']:
76-
for connection_info in inbound_node:
77-
if connection_info[0] == check_layer['config']['name']:
78-
consuming_layers.append(layer)
97+
if check_layer_name in self._get_inbound_layer_names(layer):
98+
consuming_layers.append(layer)
99+
79100
return consuming_layers
80101

81102
def _get_output_consumers(self, check_layer):
@@ -292,11 +313,22 @@ def _replace_functional(self, match_layer_node, replacement_layer_node):
292313
# replaced layer should equal the original layer.
293314

294315
consuming_layers = self._get_consuming_layers(match_layer_node.layer)
316+
match_name = match_layer_node.layer['config']['name']
317+
replacement_name = replacement_layer_node.layer['config']['name']
318+
319+
def _replace_layer_name_for_connection_info(
320+
connection_info, match_name, replacement_name):
321+
if connection_info[0] == match_name:
322+
connection_info[0] = replacement_name
323+
for key in connection_info[3]:
324+
if connection_info[3][key][0] == match_name:
325+
connection_info[3][key][0] = replacement_name
326+
295327
for consumer in consuming_layers:
296-
for inbound_node in consumer['inbound_nodes']:
328+
for inbound_node in self._inbound_node_generator(consumer):
297329
for connection_info in inbound_node:
298-
if connection_info[0] == match_layer_node.layer['config']['name']:
299-
connection_info[0] = replacement_layer_node.layer['config']['name']
330+
_replace_layer_name_for_connection_info(
331+
connection_info, match_name, replacement_name)
300332

301333
output_consumers = self._get_output_consumers(match_layer_node.layer)
302334
for output_consumer in output_consumers:

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def quantizer_fn():
136136

137137
return quantizer_fn
138138

139-
def call(self, inputs, training=None):
139+
def call(self, inputs, training=None, **kwargs):
140140
if training is None:
141141
training = tf.keras.backend.learning_phase()
142142

@@ -165,9 +165,9 @@ def call(self, inputs, training=None):
165165

166166
args = tf_inspect.getfullargspec(self.layer.call).args
167167
if 'training' in args:
168-
outputs = self.layer.call(inputs, training=training)
168+
outputs = self.layer.call(inputs, training=training, **kwargs)
169169
else:
170-
outputs = self.layer.call(inputs)
170+
outputs = self.layer.call(inputs, **kwargs)
171171

172172
if not self._output_quantizers:
173173
return outputs

0 commit comments

Comments
 (0)