Skip to content

Commit 5649d1c

Browse files
teijeongtensorflower-gardener
authored andcommitted
Use tf.keras instead of importing keras directly
PiperOrigin-RevId: 371042274
1 parent 7123df2 commit 5649d1c

File tree

7 files changed

+33
-39
lines changed

7 files changed

+33
-39
lines changed

tensorflow_model_optimization/python/core/clustering/keras/BUILD

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ py_strict_library(
1919
],
2020
)
2121

22-
py_library(
22+
py_strict_library(
2323
name = "cluster",
2424
srcs = ["cluster.py"],
2525
srcs_version = "PY3",
@@ -38,7 +38,7 @@ py_strict_library(
3838
visibility = ["//visibility:public"],
3939
)
4040

41-
py_library(
41+
py_strict_library(
4242
name = "clustering_registry",
4343
srcs = ["clustering_registry.py"],
4444
srcs_version = "PY3",
@@ -73,7 +73,7 @@ py_strict_library(
7373
],
7474
)
7575

76-
py_library(
76+
py_strict_library(
7777
name = "cluster_wrapper",
7878
srcs = ["cluster_wrapper.py"],
7979
srcs_version = "PY3",
@@ -99,7 +99,7 @@ py_strict_library(
9999
],
100100
)
101101

102-
py_library(
102+
py_strict_library(
103103
name = "clustering_callbacks",
104104
srcs = ["clustering_callbacks.py"],
105105
srcs_version = "PY3",

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
# ==============================================================================
1515
"""Clustering API functions for Keras models."""
1616

17-
from tensorflow import keras
17+
import tensorflow as tf
1818

1919
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
2020
from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
2121

22-
k = keras.backend
23-
CustomObjectScope = keras.utils.CustomObjectScope
24-
Layer = keras.layers.Layer
25-
InputLayer = keras.layers.InputLayer
22+
k = tf.keras.backend
23+
CustomObjectScope = tf.keras.utils.CustomObjectScope
24+
Layer = tf.keras.layers.Layer
25+
InputLayer = tf.keras.layers.InputLayer
2626

2727

2828
def cluster_scope():
@@ -38,10 +38,10 @@ def cluster_scope():
3838
3939
```python
4040
clustered_model = cluster_weights(model, **self.params)
41-
keras.models.save_model(clustered_model, keras_file)
41+
tf.keras.models.save_model(clustered_model, keras_file)
4242
4343
with cluster_scope():
44-
loaded_model = keras.models.load_model(keras_file)
44+
loaded_model = tf.keras.models.load_model(keras_file)
4545
```
4646
"""
4747
return CustomObjectScope(
@@ -92,7 +92,7 @@ def cluster_weights(to_cluster,
9292
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
9393
}
9494
95-
model = keras.Sequential([
95+
model = tf.keras.Sequential([
9696
layers.Dense(10, activation='relu', input_shape=(100,)),
9797
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
9898
])
@@ -174,7 +174,7 @@ def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
174174
'preserve_sparsity': False
175175
}
176176
177-
model = keras.Sequential([
177+
model = tf.keras.Sequential([
178178
layers.Dense(10, activation='relu', input_shape=(100,)),
179179
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
180180
])
@@ -189,7 +189,7 @@ def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
189189
'preserve_sparsity': True
190190
}
191191
192-
model = keras.Sequential([
192+
model = tf.keras.Sequential([
193193
layers.Dense(10, activation='relu', input_shape=(100,)),
194194
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
195195
])
@@ -221,16 +221,16 @@ def _cluster_weights(to_cluster, number_of_clusters, cluster_centroids_init,
221221
cluster_centroids_init))
222222

223223
def _add_clustering_wrapper(layer):
224-
if isinstance(layer, keras.Model):
224+
if isinstance(layer, tf.keras.Model):
225225
# Check whether the model is a subclass.
226226
# NB: This check is copied from keras.py file in tensorflow.
227227
# There is no available public API to do this check.
228228
# pylint: disable=protected-access
229229
if (not layer._is_graph_network and
230-
not isinstance(layer, keras.models.Sequential)):
230+
not isinstance(layer, tf.keras.models.Sequential)):
231231
raise ValueError('Subclassed models are not supported currently.')
232232

233-
return keras.models.clone_model(
233+
return tf.keras.models.clone_model(
234234
layer, input_tensors=None, clone_function=_add_clustering_wrapper)
235235
if isinstance(layer, cluster_wrapper.ClusterWeights):
236236
return layer
@@ -248,10 +248,9 @@ def _wrap_list(layers):
248248

249249
return output
250250

251-
if isinstance(to_cluster, keras.Model):
252-
return keras.models.clone_model(to_cluster,
253-
input_tensors=None,
254-
clone_function=_add_clustering_wrapper)
251+
if isinstance(to_cluster, tf.keras.Model):
252+
return tf.keras.models.clone_model(
253+
to_cluster, input_tensors=None, clone_function=_add_clustering_wrapper)
255254
if isinstance(to_cluster, Layer):
256255
return _add_clustering_wrapper(layer=to_cluster)
257256
if isinstance(to_cluster, list):
@@ -285,13 +284,13 @@ def strip_clustering(model):
285284
```
286285
The exported_model and the orig_model have the same structure.
287286
"""
288-
if not isinstance(model, keras.Model):
287+
if not isinstance(model, tf.keras.Model):
289288
raise ValueError(
290289
'Expected model to be a `tf.keras.Model` instance but got: ', model)
291290

292291
def _strip_clustering_wrapper(layer):
293-
if isinstance(layer, keras.Model):
294-
return keras.models.clone_model(
292+
if isinstance(layer, tf.keras.Model):
293+
return tf.keras.models.clone_model(
295294
layer, input_tensors=None, clone_function=_strip_clustering_wrapper)
296295

297296
elif isinstance(layer, cluster_wrapper.ClusterWeights):
@@ -317,6 +316,5 @@ def _strip_clustering_wrapper(layer):
317316
return layer
318317

319318
# Just copy the model with the right callback
320-
return keras.models.clone_model(model,
321-
input_tensors=None,
322-
clone_function=_strip_clustering_wrapper)
319+
return tf.keras.models.clone_model(
320+
model, input_tensors=None, clone_function=_strip_clustering_wrapper)

tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
import tensorflow as tf
1818

19-
from tensorflow.keras import initializers
20-
2119
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2220
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
2321
from tensorflow_model_optimization.python.core.clustering.keras import clustering_centroids
@@ -130,8 +128,8 @@ def __init__(self,
130128
# If the input shape was specified, then we need to preserve this
131129
# information in the layer. If this info is not preserved, then the `built`
132130
# state will not be preserved between serializations.
133-
if (not hasattr(self, '_batch_input_shape')
134-
and hasattr(layer, '_batch_input_shape')):
131+
if (not hasattr(self, '_batch_input_shape') and
132+
hasattr(layer, '_batch_input_shape')):
135133
self._batch_input_shape = self.layer._batch_input_shape
136134

137135
# Save the input shape specified in the build
@@ -171,7 +169,7 @@ def build(self, input_shape):
171169
shape=(self.number_of_clusters,),
172170
dtype=weight.dtype,
173171
trainable=True,
174-
initializer=initializers.Constant(value=cluster_centroids))
172+
initializer=tf.keras.initializers.Constant(value=cluster_centroids))
175173

176174
# Init the weight clustering algorithm
177175
self.clustering_algorithms[weight_name] = (
@@ -192,7 +190,7 @@ def build(self, input_shape):
192190
trainable=False,
193191
synchronization=tf.VariableSynchronization.ON_READ,
194192
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
195-
initializer=initializers.Constant(value=pulling_indices))
193+
initializer=tf.keras.initializers.Constant(value=pulling_indices))
196194

197195
if self.preserve_sparsity:
198196
# Init the sparsity mask

tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616

1717
import tensorflow as tf
1818

19-
from tensorflow import keras
2019
from tensorflow_model_optimization.python.core.keras import compat
2120

2221

23-
class ClusteringSummaries(keras.callbacks.TensorBoard):
22+
class ClusteringSummaries(tf.keras.callbacks.TensorBoard):
2423
"""Helper class to create tensorboard summaries for the clustering progress.
2524
2625
This class is derived from tf.keras.callbacks.TensorBoard and just adds

tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
"""Registry responsible for built-in keras classes."""
1616

1717
import tensorflow as tf
18-
from tensorflow.keras import layers
1918

2019
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
2120
from tensorflow_model_optimization.python.core.clustering.keras import clustering_algorithm
2221

22+
layers = tf.keras.layers
2323
AbstractClusteringAlgorithm = clustering_algorithm.AbstractClusteringAlgorithm
2424

2525

tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,13 @@ py_strict_library(
2828
],
2929
)
3030

31-
py_library(
31+
py_strict_library(
3232
name = "cluster_preserve_quantize_registry",
3333
srcs = [
3434
"cluster_preserve_quantize_registry.py",
3535
],
3636
srcs_version = "PY3",
3737
deps = [
38-
":cluster_utils",
3938
# tensorflow dep1,
4039
"//tensorflow_model_optimization/python/core/clustering/keras:clustering_registry",
4140
"//tensorflow_model_optimization/python/core/quantization/keras:quant_ops",

tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_quantize_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
"""Registry responsible for built-in keras classes."""
1616

1717
import tensorflow as tf
18-
from tensorflow.keras import backend as K
1918

2019
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
2120
from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
2221
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
2322
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
2423
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers
2524

25+
K = tf.keras.backend
2626
layers = tf.keras.layers
2727

2828

0 commit comments

Comments
 (0)