Skip to content

Commit ae4d198

Browse files
liyunlu0618tensorflower-gardener
authored andcommitted
Fix MOT TAP test failres.
PiperOrigin-RevId: 341087125
1 parent 1fbc2ed commit ae4d198

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,13 @@ def testEndToEndDeepLayer(self):
171171

172172
def clusters_check(stripped_model):
173173
# inner dense layer
174-
weights_as_list = stripped_model._layers[1]._layers[1].trainable_weights[0].\
174+
weights_as_list = stripped_model.submodules[1].trainable_weights[0].\
175175
numpy().flatten()
176176
unique_weights = set(weights_as_list)
177177
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
178178

179179
# outer dense layer
180-
weights_as_list = stripped_model._layers[2].trainable_weights[0].\
180+
weights_as_list = stripped_model.submodules[4].trainable_weights[0].\
181181
numpy().flatten()
182182
unique_weights = set(weights_as_list)
183183
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
@@ -199,20 +199,20 @@ def testEndToEndDeepLayer2(self):
199199

200200
def clusters_check(stripped_model):
201201
# first inner dense layer
202-
weights_as_list = stripped_model._layers[1]._layers[1].trainable_weights[0].\
202+
weights_as_list = stripped_model.submodules[1].trainable_weights[0].\
203203
numpy().flatten()
204204
unique_weights = set(weights_as_list)
205205
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
206206

207207
# second inner dense layer
208-
weights_as_list = stripped_model._layers[1]._layers[1]._layers[1].\
208+
weights_as_list = stripped_model.submodules[4].\
209209
trainable_weights[0].\
210210
numpy().flatten()
211211
unique_weights = set(weights_as_list)
212212
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
213213

214214
# outer dense layer
215-
weights_as_list = stripped_model._layers[2].trainable_weights[0].\
215+
weights_as_list = stripped_model.submodules[7].trainable_weights[0].\
216216
numpy().flatten()
217217
unique_weights = set(weights_as_list)
218218
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])

tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _validate_pruned_layer(self, original_layer, wrapped_layer):
116116
@staticmethod
117117
def _count_pruned_layers(model):
118118
count = 0
119-
for layer in model._layers:
119+
for layer in model.submodules:
120120
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
121121
count += 1
122122
return count

0 commit comments

Comments
 (0)