Skip to content

Commit 974f60f

Browse files
lingvo-botcopybara-github
authored andcommitted
Refactor the way large initial numbers are used so that they don't cause overflow.
PiperOrigin-RevId: 491478236
1 parent 0f8caa2 commit 974f60f

File tree

7 files changed

+15
-15
lines changed

7 files changed

+15
-15
lines changed

lingvo/core/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,8 @@ def _PaddedSoftmax(self, logits, padding):
394394
assert logits.dtype.is_floating
395395
assert hasattr(logits.dtype, 'max')
396396
very_negative_logits = (
397-
tf.ones_like(logits) * logits.dtype.max *
398-
tf.constant(-0.7, dtype=logits.dtype))
397+
tf.ones_like(logits) *
398+
tf.constant(-0.7 * logits.dtype.max, dtype=logits.dtype))
399399
if self.do_eval:
400400
very_negative_logits = self.QAct('logits', very_negative_logits)
401401
padded_logits = tf.where(padding > 0.0, very_negative_logits, logits)

lingvo/core/attention_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def FProp(self, theta, x, paddings=None, update=False):
786786

787787
# For padded positions we update the distances to very large numbers.
788788
very_large_dists = tf.ones_like(dists) * tf.constant(
789-
0.1, dtype=dists.dtype) * dists.dtype.max
789+
0.1 * dists.dtype.max, dtype=dists.dtype)
790790
paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters])
791791
dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists)
792792

@@ -977,8 +977,8 @@ def ComputeSparseAttention(q, k, v, sparsity_indices, paddings=None):
977977
logits *= tf.math.rsqrt(tf.cast(dim_per_head, q.dtype))
978978

979979
very_negative_logits = (
980-
tf.ones_like(logits) * logits.dtype.max *
981-
tf.constant(-0.7, dtype=logits.dtype))
980+
tf.ones_like(logits) *
981+
tf.constant(-0.7 * logits.dtype.max, dtype=logits.dtype))
982982
padded_logits = tf.where(
983983
tf.math.logical_or(sparsity_indices < 0, paddings > 0.0),
984984
very_negative_logits, logits)

lingvo/core/batch_major_attention_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,8 @@ def testMultiHeadedAttentionDotProductSegmentMask(self):
327327
segment_id = tf.zeros([6, 6])
328328
segment_mask = attention.SegmentMask(segment_id, segment_id)
329329
padding = tf.tile(tf.reshape(input_padding, [6, 1, 1, 6]), [1, 1, 6, 1])
330-
padding_mask = padding * segment_mask.dtype.max * tf.constant(
331-
-0.7, dtype=segment_mask.dtype)
330+
padding_mask = padding * tf.constant(
331+
-0.7 * segment_mask.dtype.max, dtype=segment_mask.dtype)
332332
segment_mask += padding_mask
333333

334334
l = p.Instantiate()

lingvo/core/conv_layers_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def FProp(self, theta, inputs, paddings):
9595

9696
window_size = p.left_context
9797
left_pad_size = window_size - 1
98-
large_negative = p.dtype.max * tf.constant(-0.7, dtype=p.dtype)
98+
large_negative = tf.constant(-0.7 * p.dtype.max, dtype=p.dtype)
9999
# For max pooling, use a large negative padding value such that the max
100100
# element is almost always from a non-padding position.
101101
pad_value = 0 if p.pooling_type == 'AVG' else large_negative

lingvo/core/conv_layers_with_time_padding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,7 @@ def FProp(self, theta, inputs, paddings):
10061006
out_feature = global_sum / tf.maximum(1.0, count)
10071007
elif p.pooling_type == 'MAX':
10081008
large_negative = (
1009-
tf.ones_like(inputs) * p.dtype.max * tf.constant(-0.7, dtype=p.dtype))
1009+
tf.ones_like(inputs) * tf.constant(-0.7 * p.dtype.max, dtype=p.dtype))
10101010
padded_inputs = tf.where_v2(mask > 0.0, inputs, large_negative)
10111011
out_feature = tf.reduce_max(padded_inputs, axis=[1, 2], keepdims=True)
10121012
if paddings is None:

lingvo/core/gshard_layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,8 +2121,8 @@ def _CreateOverCapacityRatioSummary(mask, position_in_expert, capacity, name):
21212121
# Generates standard Gumbel(0, 1) noise, GSE Tensors
21222122
noise = -tf.math.log(-tf.math.log(noise))
21232123
very_negative_logits = _MaybeSplit(
2124-
(tf.ones_like(logits) * logits.dtype.max *
2125-
tf.constant(-0.7, dtype=logits.dtype)))
2124+
(tf.ones_like(logits) *
2125+
tf.constant(-0.7 * logits.dtype.max, dtype=logits.dtype)))
21262126
# Gets rid of the first expert by setting its logit to be very negative
21272127
updated_logits = _MaybeSplit(
21282128
tf.where(mask_1 > 0.0, very_negative_logits, logits))

lingvo/tasks/mt/decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,8 +1170,8 @@ def _ForceAlignment(self, log_probs, source_num_sentences, hyp_num_sentences):
11701170
# the current hyp contains fewer sentences than expected to disallow
11711171
# eos in such misaligned cases.
11721172
large_negative_value = tf.ones_like(log_probs[:, eos_id]) * tf.constant(
1173-
-self._FLOAT_DTYPE_MAX_SCALER,
1174-
dtype=log_probs.dtype) * log_probs.dtype.max
1173+
-self._FLOAT_DTYPE_MAX_SCALER * log_probs.dtype.max,
1174+
dtype=log_probs.dtype)
11751175
eos_log_probs = tf.where(
11761176
tf.math.greater(source_num_sentences, hyp_num_sentences),
11771177
large_negative_value, log_probs[:, eos_id])
@@ -1214,8 +1214,8 @@ def _UpdateLogitsForSingleTokenFastDecode(self, log_probs, is_single_token,
12141214
is_eos = tf.math.equal(tf.range(v), tf.ones_like(tf.range(v)) * eos_id)
12151215
is_eos = tf.tile(tf.expand_dims(is_eos, 0), [b, 1])
12161216
large_neg_probs = tf.ones_like(log_probs) * tf.constant(
1217-
-self._FLOAT_DTYPE_MAX_SCALER,
1218-
dtype=log_probs.dtype) * log_probs.dtype.max
1217+
-self._FLOAT_DTYPE_MAX_SCALER * log_probs.dtype.max,
1218+
dtype=log_probs.dtype)
12191219
new_log_probs = tf.where(is_eos, tf.zeros_like(large_neg_probs),
12201220
large_neg_probs)
12211221
return tf.where(is_single_token_2d, new_log_probs, log_probs)

0 commit comments

Comments
 (0)