Skip to content

Commit 73a314c

Browse files
alanchiaotensorflower-gardener
authored andcommitted
Add basic error checking for log_dir for PruningSummaries.
PiperOrigin-RevId: 316778704
1 parent 3fec5a7 commit 73a314c

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ py_library(
105105
deps = [
106106
":pruning_wrapper",
107107
# numpy dep1,
108+
# six dep1,
108109
# tensorflow dep1,
109110
"//tensorflow_model_optimization/python/core/keras:compat",
110111
],

tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
# import g3
2222
import numpy as np
23+
import six
2324
import tensorflow as tf
2425

2526
from tensorflow_model_optimization.python.core.keras import compat
@@ -95,6 +96,11 @@ class PruningSummaries(callbacks.TensorBoard):
9596
"""
9697

9798
def __init__(self, log_dir, update_freq='epoch', **kwargs):
99+
if not isinstance(log_dir, six.string_types) or not log_dir:
100+
raise ValueError(
101+
'`log_dir` must be a non-empty string. You passed `log_dir`='
102+
'{input}.'.format(input=log_dir))
103+
98104
super(PruningSummaries, self).__init__(
99105
log_dir=log_dir, update_freq=update_freq, **kwargs)
100106

tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,17 @@ def testPruneTrainingLoopRaisesError_PruningStepCallbackMissing_CustomTrainingLo
137137
with tf.GradientTape():
138138
pruned_model(inp, training=True)
139139

140+
@keras_parameterized.run_all_keras_modes
141+
def testPruningSummariesRaisesError_LogDirNotNonEmptyString(self):
142+
with self.assertRaises(ValueError):
143+
pruning_callbacks.PruningSummaries(log_dir='')
144+
145+
with self.assertRaises(ValueError):
146+
pruning_callbacks.PruningSummaries(log_dir=None)
147+
148+
with self.assertRaises(ValueError):
149+
pruning_callbacks.PruningSummaries(log_dir=object())
150+
140151

141152
if __name__ == '__main__':
142153
tf.test.main()

0 commit comments

Comments
 (0)