@@ -265,7 +265,8 @@ def get_updater(for_weight_name):
265
265
def fn ():
266
266
# Get the clustered weights
267
267
pulling_indices = self .pulling_indices_tf [for_weight_name ]
268
- clustered_weights = self .clustering_impl [for_weight_name ].get_clustered_weight (pulling_indices )
268
+ clustered_weights = self .clustering_impl [for_weight_name ].\
269
+ get_clustered_weight (pulling_indices )
269
270
270
271
if self .preserve_sparsity :
271
272
# Get the sparsity mask
@@ -295,10 +296,20 @@ def call(self, inputs):
295
296
# since they are integers and not differentiable. Gradients won't flow back
296
297
# through tf.argmin
297
298
# Go through all tensors and replace them with their clustered copies.
298
- for weight_name , _ in self .clustered_vars :
299
- # Get the clustered weights
299
+ for weight_name in self .ori_weights_vars_tf :
300
300
pulling_indices = self .pulling_indices_tf [weight_name ]
301
- clustered_weights = self .clustering_impl [weight_name ].get_clustered_weight (pulling_indices )
301
+
302
+ # Update cluster associations
303
+ pulling_indices .assign (tf .dtypes .cast (
304
+ self .clustering_impl [weight_name ].\
305
+ get_pulling_indices (self .ori_weights_vars_tf [weight_name ]),
306
+ pulling_indices .dtype
307
+ ))
308
+
309
+ # Get the clustered weights
310
+ clustered_weights = self .clustering_impl [weight_name ].\
311
+ get_clustered_weight_forward (pulling_indices ,\
312
+ self .ori_weights_vars_tf [weight_name ])
302
313
303
314
if self .preserve_sparsity :
304
315
# Get the sparsity mask
0 commit comments