@@ -35,7 +35,9 @@ def _make_model(neighbor_config, inputs, keep_rank, weight_dtype=None):
3535 Args:
3636 neighbor_config: An instance of `configs.GraphNeighborConfig`.
3737 inputs: A `tf.keras.Input` or a nested structure of `tf.keras.Input`s.
38- keep_rank: Whether to keep the extra neighborhood size dimention.
38+ keep_rank: Whether to retain the rank of the original input tensors by
39+ merging the neighborhood size with the batch_size dimension, or add an
40+ extra neighborhood size dimension.
3941 weight_dtype: Optional `tf.DType` for weights.
4042
4143 Returns:
@@ -105,15 +107,15 @@ def testDense(self, keep_rank):
105107 # Check that neighbors and weights are grouped together for each sample.
106108 for i in range (batch_size ):
107109 self .assertAllEqual (
108- neighbors [i ] if keep_rank else
109- neighbors [( i * num_neighbors ):(( i + 1 ) * num_neighbors ) ],
110+ neighbors [( i * num_neighbors ):(( i + 1 ) * num_neighbors )]
111+ if keep_rank else neighbors [ i ],
110112 np .stack ([
111113 features ['NL_nbr_0_image' ][i ],
112114 features ['NL_nbr_1_image' ][i ],
113115 features ['NL_nbr_2_image' ][i ],
114116 ]))
115117 self .assertAllEqual (
116- weights [i ] if keep_rank else np . split ( weights , batch_size ) [i ],
118+ np . split ( weights , batch_size ) [i ] if keep_rank else weights [i ],
117119 np .stack ([
118120 features ['NL_nbr_0_weight' ][i ],
119121 features ['NL_nbr_1_weight' ][i ],
@@ -160,8 +162,8 @@ def testSparse(self, keep_rank):
160162 self .assertAllClose (
161163 weights ,
162164 np .array ([0.9 , 0.25 , 0.3 , 0. , 0.6 , 0.75 , 0. ,
163- 0. ]).reshape ((batch_size , 2 ,
164- 1 ) if keep_rank else (batch_size * 2 , 1 )))
165+ 0. ]).reshape ((batch_size * 2 ,
166+ 1 ) if keep_rank else (batch_size , 2 , 1 )))
165167 # Check that neighbors are grouped together.
166168 dense_neighbors = self .evaluate (tf .sparse .to_dense (neighbors ['input' ], - 1. ))
167169 neighbor0 = self .evaluate (
@@ -170,8 +172,8 @@ def testSparse(self, keep_rank):
170172 tf .sparse .to_dense (features ['NL_nbr_1_input' ], - 1 ))
171173 for i in range (batch_size ):
172174 actual = (
173- dense_neighbors [i ]
174- if keep_rank else np . split ( dense_neighbors , batch_size ) [i ])
175+ np . split ( dense_neighbors , batch_size ) [i ]
176+ if keep_rank else dense_neighbors [i ])
175177 self .assertAllEqual (actual , np .stack ([neighbor0 [i ], neighbor1 [i ]]))
176178
177179
0 commit comments