Skip to content

Commit ec778f9

Browse files
Some customer layers' call method have more than one input arguments.
For example: The Masked LM layer for BERT. PiperOrigin-RevId: 354593462
1 parent 4f4e5d0 commit ec778f9

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def training_step_fn():
229229
block_size=self.block_size,
230230
block_pooling_type=self.block_pooling_type)
231231

232-
def call(self, inputs, training=None):
232+
def call(self, inputs, training=None, **kwargs):
233233
if training is None:
234234
training = K.learning_phase()
235235

@@ -273,9 +273,9 @@ def no_op():
273273
# Propagate the training bool to the underlying layer if it accepts
274274
# training as an arg.
275275
if 'training' in args:
276-
return self.layer.call(inputs, training=training)
276+
return self.layer.call(inputs, training=training, **kwargs)
277277

278-
return self.layer.call(inputs)
278+
return self.layer.call(inputs, **kwargs)
279279

280280
def compute_output_shape(self, input_shape):
281281
return self.layer.compute_output_shape(input_shape)

0 commit comments

Comments
 (0)