19
19
from __future__ import print_function
20
20
21
21
# import g3
22
+ import numpy as np
22
23
import tensorflow as tf
23
24
24
25
from tensorflow .python .keras import backend as K
@@ -61,12 +62,16 @@ def on_epoch_end(self, batch, logs=None):
61
62
# At the end of every epoch, remask the weights. This ensures that when
62
63
# the model is saved after completion, the weights represent mask*weights.
63
64
layers = self .model .layers
65
+ weight_mask_ops = []
66
+
64
67
for layer in layers :
65
68
if isinstance (layer , pruning_wrapper .PruneLowMagnitude ):
66
69
if tf .executing_eagerly ():
67
70
layer .pruning_obj .weight_mask_op ()
68
71
else :
69
- K .get_session ().run (layer .pruning_obj .weight_mask_op ())
72
+ weight_mask_ops .append (layer .pruning_obj .weight_mask_op ())
73
+
74
+ K .batch_get_value (weight_mask_ops )
70
75
71
76
72
77
class PruningSummaries (callbacks .TensorBoard ):
@@ -83,15 +88,28 @@ def on_epoch_end(self, batch, logs=None):
83
88
super (PruningSummaries , self ).on_epoch_end (batch , logs )
84
89
85
90
pruning_logs = {}
91
+ params = []
86
92
layers = self .model .layers
87
93
for layer in layers :
88
94
if isinstance (layer , pruning_wrapper .PruneLowMagnitude ):
89
95
for _ , mask , threshold in layer .pruning_vars :
90
- pruning_logs .update ({
91
- mask .name + '/sparsity' :
92
- K .get_value (1.0 - math_ops .reduce_mean (mask ))
93
- })
94
- pruning_logs .update (
95
- {threshold .name + '/threshold' : K .get_value (threshold )})
96
- self ._log_metrics (pruning_logs , '' ,
97
- K .get_value (self .model .optimizer .iterations ))
96
+ params .append (mask )
97
+ params .append (threshold )
98
+ params .append (self .model .optimizer .iterations )
99
+
100
+ values = K .batch_get_value (params )
101
+ iteration = values [- 1 ]
102
+ del values [- 1 ]
103
+ del params [- 1 ]
104
+
105
+ param_value_pairs = zip (params , values )
106
+
107
+ for mask , mask_value in param_value_pairs [::2 ]:
108
+ pruning_logs .update ({
109
+ mask .name + '/sparsity' : 1 - np .mean (mask_value )
110
+ })
111
+
112
+ for threshold , threshold_value in param_value_pairs [1 ::2 ]:
113
+ pruning_logs .update ({threshold .name + '/threshold' : threshold_value })
114
+
115
+ self ._log_metrics (pruning_logs , '' , iteration )
0 commit comments