Skip to content

Commit f13568b

Browse files
alanchiaotensorflower-gardener
authored andcommitted
Pruning for TF 2.X batchnorm
PiperOrigin-RevId: 284038440
1 parent f78f974 commit f13568b

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ py_library(
5252
visibility = ["//visibility:public"],
5353
deps = [
5454
":prunable_layer",
55+
# tensorflow dep1,
5556
# python/keras:layers_base tensorflow dep2,
5657
],
5758
)

tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
# ==============================================================================
1515
"""Registry responsible for built-in keras classes."""
1616

17-
from tensorflow.python.keras import layers
17+
import tensorflow as tf
1818

19+
from tensorflow.python.keras import layers
1920
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
2021

2122

@@ -75,7 +76,7 @@ class PruneRegistry(object):
7576
layers.noise.AlphaDropout: [],
7677
layers.noise.GaussianDropout: [],
7778
layers.noise.GaussianNoise: [],
78-
layers.normalization.BatchNormalization: [],
79+
tf.keras.layers.BatchNormalization: [],
7980
layers.normalization.LayerNormalization: [],
8081
layers.pooling.AveragePooling1D: [],
8182
layers.pooling.AveragePooling2D: [],

0 commit comments

Comments
 (0)