14
14
# ==============================================================================
15
15
"""Tests for Pruning callbacks."""
16
16
17
+ import os
18
+ import tempfile
19
+
17
20
from absl .testing import parameterized
18
21
import numpy as np
19
22
import tensorflow as tf
20
23
21
- # TODO(b/139939526): move to public API.
22
24
from tensorflow .python .keras import keras_parameterized
23
25
from tensorflow_model_optimization .python .core .keras import test_utils as keras_test_utils
24
26
from tensorflow_model_optimization .python .core .sparsity .keras import prune
25
27
from tensorflow_model_optimization .python .core .sparsity .keras import pruning_callbacks
26
28
29
+ # TODO(b/139939526): move to public API.
30
+
27
31
28
32
@keras_parameterized .run_all_keras_modes
29
33
class PruneTest (tf .test .TestCase , parameterized .TestCase ):
30
34
31
- def testUpdatesPruningStep (self ):
35
+ def _assertLogsExist (self , log_dir ):
36
+ self .assertNotEmpty (os .listdir (log_dir ))
37
+
38
+ def testUpdatePruningStepsAndLogsSummaries (self ):
39
+ log_dir = tempfile .mkdtemp ()
32
40
model = prune .prune_low_magnitude (
33
41
keras_test_utils .build_simple_dense_model ())
34
42
model .compile (
@@ -38,13 +46,17 @@ def testUpdatesPruningStep(self):
38
46
tf .keras .utils .to_categorical (np .random .randint (5 , size = (20 , 1 )), 5 ),
39
47
batch_size = 20 ,
40
48
epochs = 3 ,
41
- callbacks = [pruning_callbacks .UpdatePruningStep ()])
49
+ callbacks = [
50
+ pruning_callbacks .UpdatePruningStep (),
51
+ pruning_callbacks .PruningSummaries (log_dir = log_dir )
52
+ ])
42
53
43
54
self .assertEqual (2 ,
44
55
tf .keras .backend .get_value (model .layers [0 ].pruning_step ))
45
56
self .assertEqual (2 ,
46
57
tf .keras .backend .get_value (model .layers [1 ].pruning_step ))
47
58
59
+ self ._assertLogsExist (log_dir )
48
60
49
61
if __name__ == '__main__' :
50
62
tf .test .main ()
0 commit comments