@@ -265,7 +265,8 @@ def get_updater(for_weight_name):
265265 def fn ():
266266 # Get the clustered weights
267267 pulling_indices = self .pulling_indices_tf [for_weight_name ]
268- clustered_weights = self .clustering_impl [for_weight_name ].get_clustered_weight (pulling_indices )
268+ clustered_weights = self .clustering_impl [for_weight_name ].\
269+ get_clustered_weight (pulling_indices )
269270
270271 if self .preserve_sparsity :
271272 # Get the sparsity mask
@@ -295,10 +296,20 @@ def call(self, inputs):
295296 # since they are integers and not differentiable. Gradients won't flow back
296297 # through tf.argmin
297298 # Go through all tensors and replace them with their clustered copies.
298- for weight_name , _ in self .clustered_vars :
299- # Get the clustered weights
299+ for weight_name in self .ori_weights_vars_tf :
300300 pulling_indices = self .pulling_indices_tf [weight_name ]
301- clustered_weights = self .clustering_impl [weight_name ].get_clustered_weight (pulling_indices )
301+
302+ # Update cluster associations
303+ pulling_indices .assign (tf .dtypes .cast (
304+ self .clustering_impl [weight_name ].\
305+ get_pulling_indices (self .ori_weights_vars_tf [weight_name ]),
306+ pulling_indices .dtype
307+ ))
308+
309+ # Get the clustered weights
310+ clustered_weights = self .clustering_impl [weight_name ].\
311+ get_clustered_weight_forward (pulling_indices ,\
312+ self .ori_weights_vars_tf [weight_name ])
302313
303314 if self .preserve_sparsity :
304315 # Get the sparsity mask
0 commit comments