Skip to content

Commit 471b5c1

Browse files
Internal change
PiperOrigin-RevId: 531273973
1 parent 5a08ff8 commit 471b5c1

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

official/vision/modeling/layers/nn_blocks.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,9 +1564,7 @@ def build(self, inputs_shape):
15641564

15651565
def call(self, inputs, inputs_positions=None):
15661566
del inputs_positions
1567-
input_dtype = inputs.dtype
1568-
gamma = self.gamma
1569-
return tf.cast(tf.cast(inputs, tf.float32) * gamma, input_dtype)
1567+
return tf.cast(self.gamma, inputs.dtype) * inputs
15701568

15711569

15721570
@tf.keras.utils.register_keras_serializable(package='Vision')

0 commit comments

Comments
 (0)