Skip to content

Commit f15ad8e

Browse files
qlzh727tensorflower-gardener
authored andcommitted
Convert tf_model_optimization code to use public tf and Keras API.
PiperOrigin-RevId: 376979959
1 parent a3ae06a commit f15ad8e

File tree

1 file changed

+26
-35
lines changed

1 file changed

+26
-35
lines changed

tensorflow_model_optimization/python/core/quantization/keras/layers/conv_batchnorm.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,8 @@
1919
from __future__ import print_function
2020

2121
import tensorflow as tf
22-
23-
from tensorflow.python.framework import dtypes
24-
from tensorflow.python.keras import activations
25-
from tensorflow.python.keras import backend as K
26-
from tensorflow.python.keras import initializers
27-
from tensorflow.python.keras.layers import convolutional
28-
from tensorflow.python.keras.layers import serialization
2922
from tensorflow.python.keras.utils import conv_utils
30-
from tensorflow.python.ops import array_ops
3123
from tensorflow.python.ops import math_ops
32-
from tensorflow.python.ops import nn
33-
from tensorflow.python.ops import nn_ops
3424
from tensorflow_model_optimization.python.core.keras import utils
3525
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
3626
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers
@@ -51,19 +41,19 @@ def _build_for_quantization(self):
5141

5242
self.optimizer_step = self.add_weight(
5343
'optimizer_step',
54-
initializer=initializers.Constant(-1),
55-
dtype=dtypes.int32,
44+
initializer=tf.compat.v1.keras.initializers.constant(-1),
45+
dtype=tf.int32,
5646
trainable=False)
5747

5848
# TODO(alanchiao): re-explore if we can handle this with
5949
# QuantizeAwareActivation.
6050
self._activation_min_var = self.add_variable( # pylint: disable=protected-access
6151
'activation_min',
62-
initializer=initializers.Constant(-6.0),
52+
initializer=tf.compat.v1.keras.initializers.constant(-6.0),
6353
trainable=False)
6454
self._activation_max_var = self.add_variable( # pylint: disable=protected-access
6555
'activation_max',
66-
initializer=initializers.Constant(6.0),
56+
initializer=tf.compat.v1.keras.initializers.constant(6.0),
6757
trainable=False)
6858

6959
def _apply_weight_quantizer(self, training, folded_conv_kernel):
@@ -112,10 +102,10 @@ def _from_config(cls_initializer, config):
112102
config.pop('use_bias')
113103
is_advanced_activation = 'class_name' in config['post_activation']
114104
if is_advanced_activation:
115-
config['post_activation'] = serialization.deserialize(
105+
config['post_activation'] = tf.keras.layers.deserialize(
116106
config['post_activation'])
117107
else:
118-
config['post_activation'] = activations.deserialize(
108+
config['post_activation'] = tf.keras.activations.deserialize(
119109
config['post_activation'])
120110

121111
return cls_initializer(**config)
@@ -138,7 +128,8 @@ def _get_config(self, conv_config):
138128
serialized_activation = keras.utils.serialize_keras_object(
139129
self.post_activation)
140130
else:
141-
serialized_activation = activations.serialize(self.post_activation)
131+
serialized_activation = tf.keras.activations.serialize(
132+
self.post_activation)
142133
config = {
143134
'is_quantized': self.is_quantized,
144135
'post_activation': serialized_activation
@@ -149,7 +140,7 @@ def _get_config(self, conv_config):
149140
list(config.items()))
150141

151142

152-
class _ConvBatchNorm2D(_ConvBatchNormMixin, convolutional.Conv2D):
143+
class _ConvBatchNorm2D(_ConvBatchNormMixin, tf.keras.layers.Convolution2D):
153144
"""Layer for emulating the folding of batch normalization into Conv during serving.
154145
155146
Implements the emulation, as described in https://arxiv.org/abs/1712.05877.
@@ -255,7 +246,7 @@ def __init__(
255246
)
256247

257248
# Named as post_activation to not conflict with Layer self.activation.
258-
self.post_activation = activations.get(post_activation)
249+
self.post_activation = tf.keras.activations.get(post_activation)
259250

260251
self.is_quantized = is_quantized
261252
if self.is_quantized:
@@ -276,20 +267,20 @@ def build(self, input_shape):
276267

277268
def call(self, inputs, training=None):
278269
if training is None:
279-
training = K.learning_phase()
270+
training = tf.keras.backend.learning_phase()
280271

281272
conv_out = super(_ConvBatchNorm2D, self).call(inputs)
282273

283274
# Not all the computations in the batchnorm need to happen,
284275
# but this avoids duplicating code (e.g. moving_average).
285276
self.batchnorm.call(conv_out)
286277

287-
folded_conv_kernel_multiplier = self.batchnorm.gamma * math_ops.rsqrt(
278+
folded_conv_kernel_multiplier = self.batchnorm.gamma * tf.math.rsqrt(
288279
self.batchnorm.moving_variance + self.batchnorm.epsilon)
289280
folded_conv_kernel = math_ops.mul(
290281
folded_conv_kernel_multiplier, self.kernel, name='folded_conv_kernel')
291282

292-
folded_conv_bias = math_ops.subtract(
283+
folded_conv_bias = tf.math.subtract(
293284
self.batchnorm.beta,
294285
self.batchnorm.moving_mean * folded_conv_kernel_multiplier,
295286
name='folded_conv_bias')
@@ -313,7 +304,7 @@ def call(self, inputs, training=None):
313304
if not isinstance(op_padding, (list, tuple)):
314305
op_padding = op_padding.upper()
315306

316-
folded_conv_out = nn_ops.conv2d(
307+
folded_conv_out = tf.compat.v1.nn.conv2d(
317308
inputs,
318309
folded_conv_kernel,
319310
strides=self.strides,
@@ -328,13 +319,13 @@ def call(self, inputs, training=None):
328319
if self.data_format == 'channels_first':
329320
if self.rank == 1:
330321
# nn.bias_add does not accept a 1D input tensor.
331-
bias = array_ops.reshape(folded_conv_bias, (1, self.filters, 1))
322+
bias = tf.reshape(folded_conv_bias, (1, self.filters, 1))
332323
folded_conv_out += bias
333324
else:
334-
outputs = nn.bias_add(
325+
outputs = tf.nn.bias_add(
335326
folded_conv_out, folded_conv_bias, data_format='NCHW')
336327
else:
337-
outputs = nn.bias_add(
328+
outputs = tf.nn.bias_add(
338329
folded_conv_out, folded_conv_bias, data_format='NHWC')
339330

340331
if self.post_activation is not None:
@@ -353,7 +344,7 @@ def from_config(cls, config):
353344

354345

355346
class _DepthwiseConvBatchNorm2D(_ConvBatchNormMixin,
356-
convolutional.DepthwiseConv2D):
347+
tf.keras.layers.DepthwiseConv2D):
357348
"""Layer for emulating the folding of batch normalization into DepthwiseConv during serving.
358349
359350
See ConvBatchNorm2D for detailed comments.
@@ -439,7 +430,7 @@ def __init__(
439430
virtual_batch_size=virtual_batch_size,
440431
adjustment=adjustment,
441432
)
442-
self.post_activation = activations.get(post_activation)
433+
self.post_activation = tf.keras.activations.get(post_activation)
443434

444435
self.is_quantized = is_quantized
445436
if self.is_quantized:
@@ -460,16 +451,16 @@ def build(self, input_shape):
460451

461452
def call(self, inputs, training=None):
462453
if training is None:
463-
training = K.learning_phase()
454+
training = tf.keras.backend.learning_phase()
464455

465456
conv_out = super(_DepthwiseConvBatchNorm2D, self).call(inputs)
466457

467458
self.batchnorm.call(conv_out)
468459

469-
folded_conv_kernel_multiplier = self.batchnorm.gamma * math_ops.rsqrt(
460+
folded_conv_kernel_multiplier = self.batchnorm.gamma * tf.math.rsqrt(
470461
self.batchnorm.moving_variance + self.batchnorm.epsilon)
471462

472-
folded_conv_bias = math_ops.subtract(
463+
folded_conv_bias = tf.math.subtract(
473464
self.batchnorm.beta,
474465
self.batchnorm.moving_mean * folded_conv_kernel_multiplier,
475466
name='folded_conv_bias')
@@ -478,8 +469,8 @@ def call(self, inputs, training=None):
478469
self.depthwise_kernel.get_shape().as_list()[2],
479470
self.depthwise_kernel.get_shape().as_list()[3]
480471
]
481-
folded_conv_kernel_multiplier = array_ops.reshape(
482-
folded_conv_kernel_multiplier, depthwise_weights_shape)
472+
folded_conv_kernel_multiplier = tf.reshape(folded_conv_kernel_multiplier,
473+
depthwise_weights_shape)
483474

484475
folded_conv_kernel = math_ops.mul(
485476
folded_conv_kernel_multiplier,
@@ -495,7 +486,7 @@ def call(self, inputs, training=None):
495486
# backend.conv2d is.
496487
#
497488
# From DepthwiseConv2D layer call() function.
498-
folded_conv_out = K.depthwise_conv2d(
489+
folded_conv_out = tf.keras.backend.depthwise_conv2d(
499490
inputs,
500491
folded_conv_kernel,
501492
strides=self.strides,
@@ -504,7 +495,7 @@ def call(self, inputs, training=None):
504495
data_format=self.data_format,
505496
)
506497

507-
outputs = K.bias_add(
498+
outputs = tf.keras.backend.bias_add(
508499
folded_conv_out, folded_conv_bias, data_format=self.data_format)
509500

510501
if self.post_activation is not None:

0 commit comments

Comments
 (0)