Skip to content

Commit 5dbfbca

Browse files
Suyog Guptatensorflower-gardener
authored andcommitted
In pruning_wrapper, propagate training boolean to the wrapped layer.
PiperOrigin-RevId: 303829300
1 parent f0cd9b4 commit 5dbfbca

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import inspect
2223
# import g3
2324
import numpy as np
2425
import tensorflow as tf
@@ -255,6 +256,12 @@ def no_op():
255256
# self.add_update does nothing during eager execution.
256257
self.add_update(self.pruning_obj.weight_mask_op())
257258

259+
args = inspect.getargspec(self.layer.call)[0]
260+
# Propagate the training bool to the underlying layer if it accepts
261+
# training as an arg.
262+
if 'training' in args:
263+
return self.layer.call(inputs, training=training)
264+
258265
return self.layer.call(inputs)
259266

260267
def compute_output_shape(self, input_shape):

0 commit comments

Comments
 (0)