We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f0cd9b4 commit 5dbfbcaCopy full SHA for 5dbfbca
tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py
@@ -19,6 +19,7 @@
19
from __future__ import division
20
from __future__ import print_function
21
22
+import inspect
23
# import g3
24
import numpy as np
25
import tensorflow as tf
@@ -255,6 +256,12 @@ def no_op():
255
256
# self.add_update does nothing during eager execution.
257
self.add_update(self.pruning_obj.weight_mask_op())
258
259
+ args = inspect.getargspec(self.layer.call)[0]
260
+ # Propagate the training bool to the underlying layer if it accepts
261
+ # training as an arg.
262
+ if 'training' in args:
263
+ return self.layer.call(inputs, training=training)
264
+
265
return self.layer.call(inputs)
266
267
def compute_output_shape(self, input_shape):
0 commit comments