Skip to content

Commit b092458

Browse files
chaoyan1037tensorflower-gardener
authored andcommitted
No public description
PiperOrigin-RevId: 592656146
1 parent c72bce8 commit b092458

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

official/projects/yt8m/modeling/yt8m_model_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def frame_pooling(frames, method="average", num_frames=None):
120120
tf.ones_like(frames, dtype=frames.dtype)
121121
* _large_compatible_negative(frames.dtype),
122122
)
123-
reduced = tf.reduce_max(frames, 1)
123+
# Magic to avoid loss NaN when bfloat16 is enabled.
124+
# See yaqs/5377152819545505792 and b/214396297 for more discussion.
125+
reduced = tf.reduce_max(frames, 1) + tf.reduce_mean(frames, 1) * 0
124126
elif method == "swap":
125127
# Note we assume the frames are in the shape of
126128
# [batch_size, num_frames, feature_size]. Otherwise this function might

0 commit comments

Comments
 (0)