Skip to content

Commit 20b20c1

Browse files
Merge pull request #467 from arovir01:toupstream/centroid_init_doc
PiperOrigin-RevId: 322167185
2 parents 1bf87ed + 362242c commit 20b20c1

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def cluster_weights(to_cluster,
8181
clustering_params = {
8282
'number_of_clusters': 8,
8383
'cluster_centroids_init':
84-
cluster_config.CentroidInitialization.DENSITY_BASED
84+
CentroidInitialization.DENSITY_BASED
8585
}
8686
8787
clustered_model = cluster_weights(original_model, **clustering_params)
@@ -93,7 +93,7 @@ def cluster_weights(to_cluster,
9393
clustering_params = {
9494
'number_of_clusters': 8,
9595
'cluster_centroids_init':
96-
cluster_config.CentroidInitialization.DENSITY_BASED
96+
CentroidInitialization.DENSITY_BASED
9797
}
9898
9999
model = keras.Sequential([
@@ -108,17 +108,8 @@ def cluster_weights(to_cluster,
108108
number_of_clusters: the number of cluster centroids to form when
109109
clustering a layer/model. For example, if number_of_clusters=8 then only
110110
8 unique values will be used in each weight array.
111-
cluster_centroids_init: enum value that determines how the cluster
112-
centroids will be initialized.
113-
Can have following values:
114-
1. RANDOM : centroids are sampled using the uniform distribution
115-
between the minimum and maximum weight values in a given layer
116-
2. DENSITY_BASED : density-based sampling. First, cumulative
117-
distribution function is built for weights, then y-axis is evenly
118-
spaced into number_of_clusters regions. After this the corresponding x
119-
values are obtained and used to initialize clusters centroids.
120-
3. LINEAR : cluster centroids are evenly spaced between the minimum
121-
and maximum values of a given weight
111+
cluster_centroids_init: `tfmot.clustering.keras.CentroidInitialization`
112+
instance that determines how the cluster centroids will be initialized.
122113
**kwargs: Additional keyword arguments to be passed to the keras layer.
123114
Ignored when to_cluster is not a keras layer.
124115

tensorflow_model_optimization/python/core/clustering/keras/cluster_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@
1818

1919

2020
class CentroidInitialization(str, enum.Enum):
21+
"""Specifies how the cluster centroids should be initialized.
22+
23+
* `LINEAR`: Cluster centroids are evenly spaced between the minimum and
24+
maximum values of a given weight tensor.
25+
* `RANDOM`: Centroids are sampled using the uniform distribution between the
26+
minimum and maximum weight values in a given layer.
27+
* `DENSITY_BASED`: Density-based sampling obtained as follows: first a
28+
cumulative distribution function is built for the weights, then the Y
29+
axis is evenly spaced into as many regions as many clusters we want to
30+
have. After this the corresponding X values are obtained and used to
31+
initialize the clusters centroids.
32+
"""
2133
LINEAR = "LINEAR"
2234
RANDOM = "RANDOM"
2335
DENSITY_BASED = "DENSITY_BASED"

0 commit comments

Comments
 (0)