@@ -184,44 +184,111 @@ def testConvolutionalWeightsCA(self, clustering_centroids, pulling_indices,
184
184
self ._check_pull_values (clustering_algo , pulling_indices , expected_output )
185
185
186
186
@parameterized .parameters (
187
- ([[0. , 1 , 2 ], [3 , 4 , 5 ]],
188
- [[[[0 ], [0 ]], [[0 ], [1 ]]],
189
- [[[0 ], [2 ]], [[1 ], [0 ]]]],
190
- [[[[0 ], [0 ]], [[0 ], [0 ]]],
191
- [[[0 ], [0 ]], [[1 ], [1 ]]]]))
192
- def testConvolutionalWeightsPerChannelCA (self ,
187
+ (
188
+ 'channels_last' ,
189
+ [[1 , 2 ], [3 , 4 ], [5 , 6 ]], # 3 channels and 2 cluster per channel
190
+ # pulling indices has shape (2, 2, 1, 3)
191
+ [[[[0 , 1 , 0 ]], [[0 , 1 , 1 ]]], [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]],
192
+ [[[[1 , 4 , 5 ]], [[1 , 4 , 6 ]]], [[[2 , 3 , 6 ]], [[1 , 4 , 5 ]]]]),
193
+ (
194
+ 'channels_first' ,
195
+ [[1 , 2 ], [3 , 4 ], [4 , 5 ], [6 , 7 ]
196
+ ], # 4 channels and 2 clusters per channel
197
+ # pulling indices has shape (1, 4, 2, 2)
198
+ [[[[0 , 1 ], [1 , 1 ]], [[0 , 0 ], [0 , 1 ]], [[1 , 0 ], [0 , 0 ]],
199
+ [[1 , 1 ], [0 , 0 ]]]],
200
+ [[[[1 , 2 ], [2 , 2 ]], [[3 , 3 ], [3 , 4 ]], [[5 , 4 ], [4 , 4 ]],
201
+ [[7 , 7 ], [6 , 6 ]]]]))
202
+ def testConvolutionalWeightsPerChannelCA (self , data_format ,
193
203
clustering_centroids ,
194
204
pulling_indices ,
195
205
expected_output ):
196
- """Verifies that PerChannelCA works as expected."""
206
+ """Verifies that get_clustered_weight function works as expected."""
197
207
clustering_centroids = tf .Variable (clustering_centroids , dtype = tf .float32 )
198
- clustering_algo = clustering_registry .PerChannelCA (
199
- clustering_centroids , GradientAggregation .SUM
208
+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
209
+ clustering_centroids , GradientAggregation .SUM , data_format
200
210
)
211
+ # Note that clustered_weights has the same shape as pulling_indices,
212
+ # because they are defined inside of the check function.
201
213
self ._check_pull_values (clustering_algo , pulling_indices , expected_output )
202
214
203
215
@parameterized .parameters (
204
- (GradientAggregation .AVG ,
205
- [[[[0 ], [0 ]], [[0 ], [1 ]]],
206
- [[[0 ], [2 ]], [[1 ], [0 ]]]], [[1 , 1 , 0 ], [1 , 1 , 1 ]]),
207
- (GradientAggregation .SUM ,
208
- [[[[0 ], [0 ]], [[0 ], [1 ]]],
209
- [[[0 ], [2 ]], [[1 ], [0 ]]]], [[3 , 1 , 0 ], [2 , 1 , 1 ]])
210
- )
211
- def testConvolutionalPerChannelCAGrad (self ,
212
- cluster_gradient_aggregation ,
213
- pulling_indices ,
214
- expected_grad_centroids ):
215
- """Verifies that the gradients of convolutional layer work as expected."""
216
+ (
217
+ 'channels_last' ,
218
+ [[1 , 2 ], [3 , 4 ], [5 , 6 ]], # 3 channels and 2 cluster per channel
219
+ # weight has shape (2, 2, 1, 3)
220
+ [[[[1.1 , 3.2 , 5.2 ]], [[2.0 , 4.1 , 5.2 ]]],
221
+ [[[2.1 , 2. , 6.1 ]], [[1. , 5. , 5. ]]]],
222
+ # expected pulling indices
223
+ [[[[0 , 0 , 0 ]], [[1 , 1 , 0 ]]], [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]]),
224
+ (
225
+ 'channels_first' ,
226
+ # 4 channels and 2 clusters per channel
227
+ [[1 , 2 ], [3 , 4 ], [4 , 5 ], [6 , 7 ]],
228
+ # weight has shape (1, 4, 2, 2)
229
+ [[[[0.1 , 1.5 ], [2.0 , 1.1 ]], [[0. , 3.5 ], [4.4 , 4. ]],
230
+ [[4.1 , 4.2 ], [5.3 , 6. ]], [[7. , 7.1 ], [6.1 , 5.8 ]]]],
231
+ # expected pulling indices
232
+ [[[[0 , 0 ], [1 , 0 ]], [[0 , 0 ], [1 , 1 ]], [[0 , 0 ], [1 , 1 ]],
233
+ [[1 , 1 ], [0 , 0 ]]]]))
234
+ def testConvolutionalPullingIndicesPerChannelCA (self , data_format ,
235
+ clustering_centroids , weight ,
236
+ expected_output ):
237
+ """Verifies that get_pulling_indices function works as expected."""
238
+ clustering_centroids = tf .Variable (clustering_centroids , dtype = tf .float32 )
239
+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
240
+ clustering_centroids , GradientAggregation .SUM , data_format
241
+ )
242
+ weight = tf .convert_to_tensor (weight )
243
+ pulling_indices = clustering_algo .get_pulling_indices (weight )
216
244
217
- clustering_centroids = tf .Variable ([[0. , 1 , 2 ], [3 , 4 , 5 ]],
245
+ # check that pulling_indices has the same shape as weight
246
+ self .assertEqual (pulling_indices .shape , weight .shape )
247
+ self .assertAllEqual (pulling_indices , expected_output )
248
+
249
+ @parameterized .parameters (
250
+ (GradientAggregation .AVG , [
251
+ [[[0 , 0 , 0 ]], [[1 , 1 , 0 ]]], [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]],
252
+ [[1 , 1 ], [1 , 1 ], [1 , 1 ]]),
253
+ (GradientAggregation .SUM , [
254
+ [[[0 , 0 , 0 ]], [[1 , 1 , 0 ]]], [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]],
255
+ [[2 , 2 ], [2 , 2 ], [3 , 1 ]]))
256
+ def testConvolutionalPerChannelCAGradChannelsLast (
257
+ self , cluster_gradient_aggregation , pulling_indices ,
258
+ expected_grad_centroids ):
259
+ """Verifies that the gradients of convolutional layer works."""
260
+
261
+ clustering_centroids = tf .Variable ([[1 , 2 ], [3 , 4 ], [5 , 6 ]],
218
262
dtype = tf .float32 )
219
- weight = tf .constant ([[[[0 .1 , 3.0 ]], [[0.2 , 0.1 ]]],
220
- [[[0 .1 , 3.0 ]], [[0.2 , 0.1 ]]]])
263
+ weight = tf .constant ([[[[1 .1 , 3.2 , 5.2 ]], [[2.0 , 4.1 , 5.2 ]]],
264
+ [[[2 .1 , 2. , 6.1 ]], [[1. , 5. , 5. ]]]])
221
265
222
- clustering_algo = clustering_registry .PerChannelCA (
223
- clustering_centroids , cluster_gradient_aggregation
266
+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
267
+ clustering_centroids , cluster_gradient_aggregation , 'channels_last' )
268
+ self ._check_gradients_clustered_weight (
269
+ clustering_algo ,
270
+ weight ,
271
+ pulling_indices ,
272
+ expected_grad_centroids ,
224
273
)
274
+
275
+ @parameterized .parameters ((GradientAggregation .AVG , [
276
+ [[[0 , 0 ], [1 , 0 ]], [[0 , 0 ], [1 , 1 ]], [[0 , 0 ], [1 , 1 ]], [[1 , 1 ], [0 , 0 ]]]
277
+ ], [[1 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 1 ]]), (GradientAggregation .SUM , [
278
+ [[[0 , 0 ], [1 , 0 ]], [[0 , 0 ], [1 , 1 ]], [[0 , 0 ], [1 , 1 ]], [[1 , 1 ], [0 , 0 ]]]
279
+ ], [[3 , 1 ], [2 , 2 ], [2 , 2 ], [2 , 2 ]]))
280
+ def testConvolutionalPerChannelCAGradChannelsFirst (
281
+ self , cluster_gradient_aggregation , pulling_indices ,
282
+ expected_grad_centroids ):
283
+ """Verifies that the gradients of convolutional layer works."""
284
+
285
+ clustering_centroids = tf .Variable ([[1 , 2 ], [3 , 4 ], [4 , 5 ], [6 , 7 ]],
286
+ dtype = tf .float32 )
287
+ weight = tf .constant ([[[[0.1 , 1.5 ], [2.0 , 1.1 ]], [[0. , 3.5 ], [4.4 , 4. ]],
288
+ [[4.1 , 4.2 ], [5.3 , 6. ]], [[7. , 7.1 ], [6.1 , 5.8 ]]]])
289
+
290
+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
291
+ clustering_centroids , cluster_gradient_aggregation , 'channels_first' )
225
292
self ._check_gradients_clustered_weight (
226
293
clustering_algo ,
227
294
weight ,
0 commit comments