Skip to content

Commit 982abc8

Browse files
Fixed minor bug in mnist example
1 parent 8465279 commit 982abc8

File tree

1 file changed

+4
-3
lines changed
  • tensorflow_model_optimization/python/examples/clustering/keras/mnist

1 file changed

+4
-3
lines changed

tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
from absl import app as absl_app
2323
from absl import flags
2424

25-
import tensorflow.compat.v1 as tf
26-
from tensorflow.python import keras
25+
import tensorflow as tf
2726
from tensorflow_model_optimization.python.core.clustering.keras import cluster
27+
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2828

29+
keras = tf.keras
2930
l = keras.layers
3031

3132
FLAGS = flags.FLAGS
@@ -96,7 +97,7 @@ def train_and_save(models, x_train, y_train, x_test, y_test):
9697

9798
clustering_params = {
9899
'number_of_clusters': 8,
99-
'cluster_centroids_init': 'density-based'
100+
'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED
100101
}
101102

102103
# Cluster model

0 commit comments

Comments
 (0)