@@ -1170,8 +1170,8 @@ def _ForceAlignment(self, log_probs, source_num_sentences, hyp_num_sentences):
11701170 # the current hyp contains fewer sentences than expected to disallow
11711171 # eos in such misaligned cases.
11721172 large_negative_value = tf .ones_like (log_probs [:, eos_id ]) * tf .constant (
1173- - self ._FLOAT_DTYPE_MAX_SCALER ,
1174- dtype = log_probs .dtype ) * log_probs . dtype . max
1173+ - self ._FLOAT_DTYPE_MAX_SCALER * log_probs . dtype . max ,
1174+ dtype = log_probs .dtype )
11751175 eos_log_probs = tf .where (
11761176 tf .math .greater (source_num_sentences , hyp_num_sentences ),
11771177 large_negative_value , log_probs [:, eos_id ])
@@ -1214,8 +1214,8 @@ def _UpdateLogitsForSingleTokenFastDecode(self, log_probs, is_single_token,
12141214 is_eos = tf .math .equal (tf .range (v ), tf .ones_like (tf .range (v )) * eos_id )
12151215 is_eos = tf .tile (tf .expand_dims (is_eos , 0 ), [b , 1 ])
12161216 large_neg_probs = tf .ones_like (log_probs ) * tf .constant (
1217- - self ._FLOAT_DTYPE_MAX_SCALER ,
1218- dtype = log_probs .dtype ) * log_probs . dtype . max
1217+ - self ._FLOAT_DTYPE_MAX_SCALER * log_probs . dtype . max ,
1218+ dtype = log_probs .dtype )
12191219 new_log_probs = tf .where (is_eos , tf .zeros_like (large_neg_probs ),
12201220 large_neg_probs )
12211221 return tf .where (is_single_token_2d , new_log_probs , log_probs )
0 commit comments