Skip to content

Commit 28df91b

Browse files
rino20tensorflower-gardener
authored andcommitted
Add experimental.SyncBatchNormalization to Prune Registry
PiperOrigin-RevId: 423694994
1 parent 4cf1b45 commit 28df91b

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class PruneRegistry(object):
9898
'_query_dense.kernel', '_key_dense.kernel', '_value_dense.kernel',
9999
'_output_dense.kernel'
100100
],
101+
layers.experimental.SyncBatchNormalization: [],
101102
layers.experimental.preprocessing.Rescaling.__class__: [],
102103
TensorFlowOpLayer: [],
103104
layers_compat_v1.BatchNormalization: [],

tensorflow_model_optimization/python/core/sparsity/keras/prune_registry_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ class PruneRegistryTest(tf.test.TestCase, parameterized.TestCase):
7272
layers.Conv2D(10, 5),
7373
layers.Dropout(0.5),
7474
# Supports specific layers from experimental or compat_v1.
75-
tf.keras.layers.experimental.preprocessing.Rescaling,
75+
layers.experimental.SyncBatchNormalization(),
76+
layers.experimental.preprocessing.Rescaling,
7677
tf.compat.v1.keras.layers.BatchNormalization(),
7778
# Supports Keras RNN Layers with prunable cells.
7879
layers.LSTM(10),

0 commit comments

Comments
 (0)