Skip to content

Commit 3d03ad7

Browse files
teijeongtensorflower-gardener
authored andcommitted
Use TF API instead of directly importing internal members when possible
Fixed formatting to comply with python style guide PiperOrigin-RevId: 371043407
1 parent 5649d1c commit 3d03ad7

File tree

4 files changed

+30
-45
lines changed

4 files changed

+30
-45
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ def _replace(self, bn_layer_node, conv_layer_node):
181181
if _has_custom_quantize_config(bn_layer_node, conv_layer_node):
182182
return bn_layer_node
183183

184-
conv_layer_node.layer['config']['activation'] = \
185-
keras.activations.serialize(quantize_aware_activation.NoOpActivation())
186-
bn_layer_node.metadata['quantize_config'] = \
187-
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig()
184+
conv_layer_node.layer['config']['activation'] = (
185+
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
186+
bn_layer_node.metadata['quantize_config'] = (
187+
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig())
188188

189189
return bn_layer_node
190190

@@ -235,10 +235,10 @@ def _replace(self, relu_layer_node, bn_layer_node, conv_layer_node):
235235
relu_layer_node, bn_layer_node, conv_layer_node):
236236
return relu_layer_node
237237

238-
conv_layer_node.layer['config']['activation'] = \
239-
keras.activations.serialize(quantize_aware_activation.NoOpActivation())
240-
bn_layer_node.metadata['quantize_config'] = \
241-
default_8bit_quantize_configs.NoOpQuantizeConfig()
238+
conv_layer_node.layer['config']['activation'] = (
239+
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
240+
bn_layer_node.metadata['quantize_config'] = (
241+
default_8bit_quantize_configs.NoOpQuantizeConfig())
242242

243243
return relu_layer_node
244244

@@ -508,8 +508,8 @@ def replacement(self, match_layer):
508508
relu_layer_node = match_layer
509509
add_layer_node = relu_layer_node.input_layers[0]
510510

511-
add_layer_node.metadata['quantize_config'] = \
512-
default_8bit_quantize_configs.NoOpQuantizeConfig()
511+
add_layer_node.metadata['quantize_config'] = (
512+
default_8bit_quantize_configs.NoOpQuantizeConfig())
513513

514514
return match_layer
515515

@@ -585,8 +585,8 @@ def replacement(self, match_layer):
585585
concat_layer_node = match_layer
586586
feeding_layer_nodes = match_layer.input_layers
587587

588-
default_registry = default_8bit_quantize_registry.\
589-
Default8BitQuantizeRegistry()
588+
default_registry = (
589+
default_8bit_quantize_registry.Default8BitQuantizeRegistry())
590590

591591
feed_quantize_configs = []
592592
for feed_layer_node in feeding_layer_nodes:
@@ -599,8 +599,8 @@ def replacement(self, match_layer):
599599

600600
if layer_class == keras.layers.Concatenate:
601601
# Input layer to Concat is also Concat. Don't quantize it.
602-
feed_layer_node.metadata['quantize_config'] = \
603-
default_8bit_quantize_configs.NoOpQuantizeConfig()
602+
feed_layer_node.metadata['quantize_config'] = (
603+
default_8bit_quantize_configs.NoOpQuantizeConfig())
604604
continue
605605

606606
if not default_registry._is_supported_layer(layer_class):
@@ -619,8 +619,8 @@ def replacement(self, match_layer):
619619
self._disable_output_quantize(quantize_config)
620620

621621
if not concat_layer_node.metadata.get('quantize_config'):
622-
concat_layer_node.metadata['quantize_config'] = \
623-
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig()
622+
concat_layer_node.metadata['quantize_config'] = (
623+
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig())
624624

625625
return concat_layer_node
626626

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,14 @@ py_strict_test(
131131
],
132132
)
133133

134-
py_library(
134+
py_strict_library(
135135
name = "pruning_impl",
136136
srcs = ["pruning_impl.py"],
137137
srcs_version = "PY3",
138138
visibility = ["//visibility:public"],
139139
deps = [
140140
":pruning_utils",
141141
# tensorflow dep1,
142-
# python:summary tensorflow dep2,
143142
"//tensorflow_model_optimization/python/core/keras:compat",
144143
],
145144
)
@@ -155,16 +154,14 @@ py_strict_library(
155154
],
156155
)
157156

158-
py_library(
157+
py_strict_library(
159158
name = "estimator_utils",
160159
srcs = ["estimator_utils.py"],
161160
srcs_version = "PY3",
162161
visibility = ["//visibility:public"],
163162
deps = [
164163
":pruning_wrapper",
165-
# python:control_flow_ops tensorflow dep2,
166-
# python:math_ops tensorflow dep2,
167-
# python:state_ops tensorflow dep2,
164+
# tensorflow dep1,
168165
# python/framework:for_generated_wrappers tensorflow dep2,
169166
],
170167
)

tensorflow_model_optimization/python/core/sparsity/keras/estimator_utils.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,13 @@
1414
# ==============================================================================
1515
"""Utility functions for making pruning wrapper work with estimators."""
1616

17-
from __future__ import absolute_import
18-
from __future__ import division
19-
from __future__ import print_function
20-
# import g3
17+
import tensorflow as tf
2118

22-
from tensorflow.python.estimator.model_fn import EstimatorSpec
23-
from tensorflow.python.framework import dtypes
2419
from tensorflow.python.framework import ops
25-
from tensorflow.python.ops import control_flow_ops
26-
from tensorflow.python.ops import math_ops
27-
from tensorflow.python.ops import state_ops
28-
from tensorflow.python.training import monitored_session
2920
from tensorflow_model_optimization.python.core.sparsity.keras.pruning_wrapper import PruneLowMagnitude
3021

3122

32-
class PruningEstimatorSpec(EstimatorSpec):
23+
class PruningEstimatorSpec(tf.estimator.EstimatorSpec):
3324
"""Returns an EstimatorSpec modified to prune the model while training."""
3425

3526
def __new__(cls, model, step=None, train_op=None, **kwargs):
@@ -49,13 +40,12 @@ def _get_step_increment_ops(model, step=None):
4940
if isinstance(layer, PruneLowMagnitude):
5041
if step is None:
5142
# Add ops to increment the pruning_step by 1
52-
increment_ops.append(state_ops.assign_add(layer.pruning_step, 1))
43+
increment_ops.append(tf.assign_add(layer.pruning_step, 1))
5344
else:
5445
increment_ops.append(
55-
state_ops.assign(layer.pruning_step,
56-
math_ops.cast(step, dtypes.int32)))
46+
tf.assign(layer.pruning_step, tf.cast(step, tf.int32)))
5747

58-
return control_flow_ops.group(increment_ops)
48+
return tf.group(increment_ops)
5949

6050
pruning_ops = []
6151
# Grab the ops to update pruning step in every prunable layer
@@ -64,21 +54,21 @@ def _get_step_increment_ops(model, step=None):
6454
# Grab the model updates.
6555
pruning_ops.append(model.updates)
6656

67-
kwargs["train_op"] = control_flow_ops.group(pruning_ops, train_op)
57+
kwargs["train_op"] = tf.group(pruning_ops, train_op)
6858

6959
def init_fn(scaffold, session): # pylint: disable=unused-argument
7060
return session.run(step_increment_ops)
7161

7262
def get_new_scaffold(old_scaffold):
7363
if old_scaffold.init_fn is None:
74-
return monitored_session.Scaffold(
64+
return tf.compat.v1.train.Scaffold(
7565
init_fn=init_fn, copy_from_scaffold=old_scaffold)
7666
# TODO(suyoggupta): Figure out a way to merge the init_fn of the
7767
# original scaffold with the one defined above.
7868
raise ValueError("Scaffold provided to PruningEstimatorSpec must not "
7969
"set an init_fn.")
8070

81-
scaffold = monitored_session.Scaffold(init_fn=init_fn)
71+
scaffold = tf.compat.v1.train.Scaffold(init_fn=init_fn)
8272
if "scaffold" in kwargs:
8373
scaffold = get_new_scaffold(kwargs["scaffold"])
8474

tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121
import tensorflow as tf
2222

23-
from tensorflow.python.ops import summary_ops_v2
24-
from tensorflow.python.summary import summary as summary_ops_v1
2523
from tensorflow_model_optimization.python.core.keras import compat as tf_compat
2624
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
2725

@@ -255,9 +253,9 @@ def update_distributed():
255253
def add_pruning_summaries(self):
256254
"""Adds summaries of weight sparsities and thresholds."""
257255
# b/(139939526): update to use public API.
258-
summary = summary_ops_v1
259-
if tf.executing_eagerly():
260-
summary = summary_ops_v2
256+
summary = tf.summary
257+
if not tf.executing_eagerly():
258+
summary = tf.compat.v1.summary
261259
summary.scalar('sparsity', self._pruning_schedule(self._step_fn())[1])
262260
for _, mask, threshold in self._pruning_vars:
263261
summary.scalar(mask.name + '/sparsity', 1.0 - tf.math.reduce_mean(mask))

0 commit comments

Comments
 (0)