Skip to content

Commit 105ec70

Browse files
alanchiaotensorflower-gardener
authored andcommitted
Add support for pruning summaries in 1.X fashion.
PiperOrigin-RevId: 286103751
1 parent 21a1fde commit 105ec70

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ def __init__(self, log_dir, update_freq='epoch', **kwargs):
8484
super(PruningSummaries, self).__init__(
8585
log_dir=log_dir, update_freq=update_freq, **kwargs)
8686

87+
def _log_pruning_metrics(self, logs, prefix, step):
88+
if tf.__version__[0] == '1':
89+
self._write_custom_summaries(step, logs)
90+
else:
91+
self._log_metrics(logs, prefix, step)
92+
8793
def on_epoch_end(self, batch, logs=None):
8894
super(PruningSummaries, self).on_epoch_end(batch, logs)
8995

@@ -112,4 +118,4 @@ def on_epoch_end(self, batch, logs=None):
112118
for threshold, threshold_value in param_value_pairs[1::2]:
113119
pruning_logs.update({threshold.name + '/threshold': threshold_value})
114120

115-
self._log_metrics(pruning_logs, '', iteration)
121+
self._log_pruning_metrics(pruning_logs, '', iteration)

tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks_test.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,29 @@
1414
# ==============================================================================
1515
"""Tests for Pruning callbacks."""
1616

17+
import os
18+
import tempfile
19+
1720
from absl.testing import parameterized
1821
import numpy as np
1922
import tensorflow as tf
2023

21-
# TODO(b/139939526): move to public API.
2224
from tensorflow.python.keras import keras_parameterized
2325
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
2426
from tensorflow_model_optimization.python.core.sparsity.keras import prune
2527
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
2628

29+
# TODO(b/139939526): move to public API.
30+
2731

2832
@keras_parameterized.run_all_keras_modes
2933
class PruneTest(tf.test.TestCase, parameterized.TestCase):
3034

31-
def testUpdatesPruningStep(self):
35+
def _assertLogsExist(self, log_dir):
36+
self.assertNotEmpty(os.listdir(log_dir))
37+
38+
def testUpdatePruningStepsAndLogsSummaries(self):
39+
log_dir = tempfile.mkdtemp()
3240
model = prune.prune_low_magnitude(
3341
keras_test_utils.build_simple_dense_model())
3442
model.compile(
@@ -38,13 +46,17 @@ def testUpdatesPruningStep(self):
3846
tf.keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
3947
batch_size=20,
4048
epochs=3,
41-
callbacks=[pruning_callbacks.UpdatePruningStep()])
49+
callbacks=[
50+
pruning_callbacks.UpdatePruningStep(),
51+
pruning_callbacks.PruningSummaries(log_dir=log_dir)
52+
])
4253

4354
self.assertEqual(2,
4455
tf.keras.backend.get_value(model.layers[0].pruning_step))
4556
self.assertEqual(2,
4657
tf.keras.backend.get_value(model.layers[1].pruning_step))
4758

59+
self._assertLogsExist(log_dir)
4860

4961
if __name__ == '__main__':
5062
tf.test.main()

0 commit comments

Comments
 (0)