Skip to content

Commit 27452ca

Browse files
liyunlu0618alanchiao
authored andcommitted
Update on_epoch_end callback.
PiperOrigin-RevId: 246562858
1 parent ecb0412 commit 27452ca

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ py_library(
114114
visibility = ["//visibility:public"],
115115
deps = [
116116
":pruning_wrapper",
117+
# numpy dep1,
117118
# tensorflow dep1,
118119
# python:math_ops tensorflow dep2,
119120
# python/keras:backend tensorflow dep2,

tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import print_function
2020

2121
# import g3
22+
import numpy as np
2223
import tensorflow as tf
2324

2425
from tensorflow.python.keras import backend as K
@@ -61,12 +62,16 @@ def on_epoch_end(self, batch, logs=None):
6162
# At the end of every epoch, remask the weights. This ensures that when
6263
# the model is saved after completion, the weights represent mask*weights.
6364
layers = self.model.layers
65+
weight_mask_ops = []
66+
6467
for layer in layers:
6568
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
6669
if tf.executing_eagerly():
6770
layer.pruning_obj.weight_mask_op()
6871
else:
69-
K.get_session().run(layer.pruning_obj.weight_mask_op())
72+
weight_mask_ops.append(layer.pruning_obj.weight_mask_op())
73+
74+
K.batch_get_value(weight_mask_ops)
7075

7176

7277
class PruningSummaries(callbacks.TensorBoard):
@@ -83,15 +88,28 @@ def on_epoch_end(self, batch, logs=None):
8388
super(PruningSummaries, self).on_epoch_end(batch, logs)
8489

8590
pruning_logs = {}
91+
params = []
8692
layers = self.model.layers
8793
for layer in layers:
8894
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
8995
for _, mask, threshold in layer.pruning_vars:
90-
pruning_logs.update({
91-
mask.name + '/sparsity':
92-
K.get_value(1.0 - math_ops.reduce_mean(mask))
93-
})
94-
pruning_logs.update(
95-
{threshold.name + '/threshold': K.get_value(threshold)})
96-
self._log_metrics(pruning_logs, '',
97-
K.get_value(self.model.optimizer.iterations))
96+
params.append(mask)
97+
params.append(threshold)
98+
params.append(self.model.optimizer.iterations)
99+
100+
values = K.batch_get_value(params)
101+
iteration = values[-1]
102+
del values[-1]
103+
del params[-1]
104+
105+
param_value_pairs = zip(params, values)
106+
107+
for mask, mask_value in param_value_pairs[::2]:
108+
pruning_logs.update({
109+
mask.name + '/sparsity': 1 - np.mean(mask_value)
110+
})
111+
112+
for threshold, threshold_value in param_value_pairs[1::2]:
113+
pruning_logs.update({threshold.name + '/threshold': threshold_value})
114+
115+
self._log_metrics(pruning_logs, '', iteration)

0 commit comments

Comments
 (0)