Skip to content

Commit d2ecffc

Browse files
daverimtensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 439776069
1 parent 92bfb45 commit d2ecffc

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -380,28 +380,33 @@ def _unwrap(layer):
380380

381381
unwrapped_model = keras.models.clone_model(
382382
model_to_unwrap, input_tensors=None, clone_function=_unwrap)
383-
384383
return unwrapped_model, layer_quantize_map, requires_output_quantize
385384

386385
def _quantize(layer): # pylint: disable=missing-docstring
387-
if ((layer.name not in layer_quantize_map and
388-
layer.name not in requires_output_quantize) or
389-
(isinstance(layer, quantize_wrapper.QuantizeWrapper))):
390-
# It supports for custom QuantizeWrapper.
391-
return layer
392-
386+
# Handle quantize layer before any layers.
393387
# layer is a QuantizeLayer, possibly rebuild
394388
# layer with modified config from parameters stored in the map.
395389
if isinstance(layer, quantize_layer.QuantizeLayer):
396390
# If there is more than one usage of the input, even if all are concat,
397391
# we need to quantize.
398392
if len(layer._outbound_nodes) > 1: # pylint: disable=protected-access
399393
return layer
394+
400395
layer_config = layer.get_config()
396+
if layer.name not in layer_quantize_map: # Possibly added manually.
397+
with quantize_scope():
398+
return quantize_layer.QuantizeLayer.from_config(layer_config)
399+
401400
for key, value in layer_quantize_map[layer.name].items():
402401
layer_config[key] = value
403402
return quantize_layer.QuantizeLayer.from_config(layer_config)
404403

404+
if ((layer.name not in layer_quantize_map and
405+
layer.name not in requires_output_quantize) or
406+
(isinstance(layer, quantize_wrapper.QuantizeWrapper))):
407+
# It supports for custom QuantizeWrapper.
408+
return layer
409+
405410
if layer.name in requires_output_quantize:
406411
if not quantize_registry.supports(layer):
407412
return layer

tensorflow_model_optimization/python/core/quantization/keras/quantizers.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,22 @@ def from_config(cls, config):
129129
class _QuantizeHelper(object):
130130
"""Mixin with helper functions for quantizers."""
131131

132-
def _add_range_weights(self, layer, name):
132+
def _add_range_weights(self, layer, name, per_axis=False, tensor_shape=None):
133133
"""Add min and max vars to layer."""
134+
shape = None
135+
if per_axis and tensor_shape is not None:
136+
shape = (tensor_shape[-1])
137+
134138
min_weight = layer.add_weight(
135139
name + '_min',
136140
initializer=keras.initializers.Constant(-6.0),
137-
trainable=False)
141+
trainable=False,
142+
shape=shape)
138143
max_weight = layer.add_weight(
139144
name + '_max',
140145
initializer=keras.initializers.Constant(6.0),
141-
trainable=False)
146+
trainable=False,
147+
shape=shape)
142148

143149
return {'min_var': min_weight, 'max_var': max_weight}
144150

@@ -169,7 +175,7 @@ def __init__(self, num_bits, per_axis, symmetric, narrow_range):
169175
self.narrow_range = narrow_range
170176

171177
def build(self, tensor_shape, name, layer):
172-
return self._add_range_weights(layer, name)
178+
return self._add_range_weights(layer, name, self.per_axis, tensor_shape)
173179

174180
def __call__(self, inputs, training, weights, **kwargs):
175181
"""Quantize tensor.

0 commit comments

Comments
 (0)