@@ -184,43 +184,91 @@ def testConvolutionalWeightsCA(self, clustering_centroids, pulling_indices,
184
184
self ._check_pull_values (clustering_algo , pulling_indices , expected_output )
185
185
186
186
@parameterized .parameters (
187
- ([[0. , 1 , 2 ], [3 , 4 , 5 ]],
188
- [[[[0 ], [0 ]], [[0 ], [1 ]]],
189
- [[[0 ], [2 ]], [[1 ], [0 ]]]],
190
- [[[[0 ], [0 ]], [[0 ], [0 ]]],
191
- [[[0 ], [0 ]], [[1 ], [1 ]]]]))
187
+ ("channels_last" ,
188
+ [[1 , 2 ], [3 , 4 ], [5 , 6 ]], # 3 channels and 2 cluster per channel
189
+ # pulling indices has shape (2, 2, 1, 3)
190
+ [[[[0 , 1 , 0 ]], [[0 , 1 , 1 ]]], [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]],
191
+ [[[[1 , 4 , 5 ]], [[1 , 4 , 6 ]]], [[[2 , 3 , 6 ]], [[1 , 4 , 5 ]]]]),
192
+ ("channels_first" ,
193
+ [[1 , 2 ], [3 , 4 ], [4 , 5 ], [6 , 7 ]], # 4 channels and 2 clusters per channel
194
+ # pulling indices has shape (1, 4, 2, 2)
195
+ [[[[0 , 1 ], [1 , 1 ]], [[0 , 0 ], [0 , 1 ]],
196
+ [[1 , 0 ], [0 , 0 ]], [[1 , 1 ], [0 , 0 ]]]],
197
+ [[[[1 , 2 ], [2 , 2 ]], [[3 , 3 ], [3 , 4 ]],
198
+ [[5 , 4 ], [4 , 4 ]], [[7 , 7 ], [6 , 6 ]]]])
199
+ )
192
200
def testConvolutionalWeightsPerChannelCA (self ,
201
+ data_format ,
193
202
clustering_centroids ,
194
203
pulling_indices ,
195
204
expected_output ):
196
- """Verifies that PerChannelCA works as expected."""
205
+ """Verifies that get_clustered_weight function works as expected."""
197
206
clustering_centroids = tf .Variable (clustering_centroids , dtype = tf .float32 )
198
- clustering_algo = clustering_registry .PerChannelCA (
199
- clustering_centroids , GradientAggregation .SUM
207
+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
208
+ clustering_centroids , GradientAggregation .SUM , data_format
200
209
)
210
+ # Note that clustered_weights has the same shape as pulling_indices,
211
+ # because they are defined inside of the check function.
201
212
self ._check_pull_values (clustering_algo , pulling_indices , expected_output )
202
213
214
+ @parameterized .parameters (
215
+ ("channels_last" ,
216
+ [[1 , 2 ], [3 , 4 ], [5 , 6 ]], # 3 channels and 2 cluster per channel
217
+ # weight has shape (2, 2, 1, 3)
218
+ [[[[1.1 , 3.2 , 5.2 ]], [[2.0 , 4.1 , 5.2 ]]],
219
+ [[[2.1 , 2. , 6.1 ]], [[1. , 5. , 5. ]]]],
220
+ # expected pulling indices
221
+ [[[[0 , 0 , 0 ]], [[1 , 1 , 0 ]]], [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]]),
222
+ ("channels_first" ,
223
+ # 4 channels and 2 clusters per channel
224
+ [[1 , 2 ], [3 , 4 ], [4 , 5 ], [6 , 7 ]],
225
+ # weight has shape (1, 4, 2, 2)
226
+ [[[[0.1 , 1.5 ], [2.0 , 1.1 ]], [[0. , 3.5 ], [4.4 , 4. ]],
227
+ [[4.1 , 4.2 ], [5.3 , 6. ]], [[7. , 7.1 ], [6.1 , 5.8 ]]]],
228
+ # expected pulling indices
229
+ [[[[0 , 0 ], [1 , 0 ]], [[0 , 0 ], [1 , 1 ]],
230
+ [[0 , 0 ], [1 , 1 ]], [[1 , 1 ], [0 , 0 ]]]])
231
+ )
232
+ def testConvolutionalPullingIndicesPerChannelCA (self ,
233
+ data_format ,
234
+ clustering_centroids ,
235
+ weight ,
236
+ expected_output ):
237
+ """Verifies that get_pulling_indices function works as expected."""
238
+ clustering_centroids = tf .Variable (clustering_centroids , dtype = tf .float32 )
239
+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
240
+ clustering_centroids , GradientAggregation .SUM , data_format
241
+ )
242
+ weight = tf .convert_to_tensor (weight )
243
+ pulling_indices = clustering_algo .get_pulling_indices (weight )
244
+
245
+ # check that pulling_indices has the same shape as weight
246
+ self .assertEqual (pulling_indices .shape , weight .shape )
247
+ self .assertAllEqual (pulling_indices , expected_output )
248
+
203
249
@parameterized .parameters (
204
250
(GradientAggregation .AVG ,
205
- [[[[0 ], [0 ]], [[0 ], [1 ]]],
206
- [[[0 ], [2 ]], [[1 ], [0 ]]]], [[1 , 1 , 0 ], [1 , 1 , 1 ]]),
251
+ [[[[0 , 0 , 0 ]], [[1 , 1 , 0 ]]],
252
+ [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]],
253
+ [[1 , 1 ], [1 , 1 ], [1 , 1 ]]),
207
254
(GradientAggregation .SUM ,
208
- [[[[0 ], [0 ]], [[0 ], [1 ]]],
209
- [[[0 ], [2 ]], [[1 ], [0 ]]]], [[3 , 1 , 0 ], [2 , 1 , 1 ]])
255
+ [[[[0 , 0 , 0 ]], [[1 , 1 , 0 ]]],
256
+ [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]],
257
+ [[2 , 2 ], [2 , 2 ], [3 , 1 ]])
210
258
)
211
- def testConvolutionalPerChannelCAGrad (self ,
259
+ def testConvolutionalPerChannelCAGradChannelsLast (self ,
212
260
cluster_gradient_aggregation ,
213
261
pulling_indices ,
214
262
expected_grad_centroids ):
215
- """Verifies that the gradients of convolutional layer work as expected ."""
263
+ """Verifies that the gradients of convolutional layer works ."""
216
264
217
- clustering_centroids = tf .Variable ([[0. , 1 , 2 ], [3 , 4 , 5 ]],
265
+ clustering_centroids = tf .Variable ([[1 , 2 ], [3 , 4 ], [ 5 , 6 ]],
218
266
dtype = tf .float32 )
219
- weight = tf .constant ([[[[0 .1 , 3.0 ]], [[0.2 , 0.1 ]]],
220
- [[[0 .1 , 3.0 ]], [[0.2 , 0.1 ]]]])
267
+ weight = tf .constant ([[[[1 .1 , 3.2 , 5.2 ]], [[2.0 , 4.1 , 5.2 ]]],
268
+ [[[2 .1 , 2. , 6.1 ]], [[1. , 5. , 5. ]]]])
221
269
222
- clustering_algo = clustering_registry .PerChannelCA (
223
- clustering_centroids , cluster_gradient_aggregation
270
+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
271
+ clustering_centroids , cluster_gradient_aggregation , "channels_last"
224
272
)
225
273
self ._check_gradients_clustered_weight (
226
274
clustering_algo ,
@@ -229,6 +277,37 @@ def testConvolutionalPerChannelCAGrad(self,
229
277
expected_grad_centroids ,
230
278
)
231
279
280
+ @parameterized .parameters (
281
+ (GradientAggregation .AVG ,
282
+ [[[[0 , 0 ], [1 , 0 ]], [[0 , 0 ], [1 , 1 ]],
283
+ [[0 , 0 ], [1 , 1 ]], [[1 , 1 ], [0 , 0 ]]]],
284
+ [[1 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 1 ]]),
285
+ (GradientAggregation .SUM ,
286
+ [[[[0 , 0 ], [1 , 0 ]], [[0 , 0 ], [1 , 1 ]],
287
+ [[0 , 0 ], [1 , 1 ]], [[1 , 1 ], [0 , 0 ]]]],
288
+ [[3 , 1 ], [2 , 2 ], [2 , 2 ], [2 , 2 ]])
289
+ )
290
+ def testConvolutionalPerChannelCAGradChannelsFirst (self ,
291
+ cluster_gradient_aggregation ,
292
+ pulling_indices ,
293
+ expected_grad_centroids ):
294
+ """Verifies that the gradients of convolutional layer works."""
295
+
296
+ clustering_centroids = tf .Variable ([[1 , 2 ], [3 , 4 ], [4 , 5 ], [6 , 7 ]],
297
+ dtype = tf .float32 )
298
+ weight = tf .constant ([[[[0.1 , 1.5 ], [2.0 , 1.1 ]],
299
+ [[0. , 3.5 ], [4.4 , 4. ]], [[4.1 , 4.2 ], [5.3 , 6. ]],
300
+ [[7. , 7.1 ], [6.1 , 5.8 ]]]])
301
+
302
+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
303
+ clustering_centroids , cluster_gradient_aggregation , "channels_first"
304
+ )
305
+ self ._check_gradients_clustered_weight (
306
+ clustering_algo ,
307
+ weight ,
308
+ pulling_indices ,
309
+ expected_grad_centroids ,
310
+ )
232
311
233
312
class CustomLayer (layers .Layer ):
234
313
"""A custom non-clusterable layer class."""
0 commit comments