@@ -483,16 +483,9 @@ def _verify_feature_exists(feature_name, should_exist):
483483 compute_loss = False ,
484484 mode = mode ,
485485 variable_dtype = get_variable_dtype ())
486- batch_dim , length_dim , vocab_dim = logits .shape .dims
487- cross_entropy = mtf .layers .softmax_cross_entropy_with_logits (
488- logits , mtf_features ["targets" ], vocab_dim )
489- cross_entropy *= mtf .cast (
490- mtf .not_equal (targets , 0 ), cross_entropy .dtype )
491- if model_type == "delimited_lm" :
492- cross_entropy *= mtf .cast (mtf .logical_not (
493- transformer .delimited_lm_inputs_mask (targets )), cross_entropy .dtype )
494- scores = - mtf .reduce_sum (cross_entropy , reduced_dim = length_dim )
495- scores = mtf .anonymize (scores )
486+
487+ # calculate log likelihood
488+ scores = compute_score (logits , targets , model_type )
496489 lowering = mtf .Lowering (graph , {mesh : mesh_impl }, autostack = autostack )
497490 predictions = {
498491 "scores" : lowering .export_to_tf_tensor (scores )
@@ -533,29 +526,15 @@ def _verify_feature_exists(feature_name, should_exist):
533526 mode = 'score' ,
534527 variable_dtype = get_variable_dtype ())
535528
536- # calculate log probability
537- targets = mtf_features ["targets" ] = targets_for_score
538-
539- batch_dim , length_dim , vocab_dim = logits .shape .dims
540- cross_entropy = mtf .layers .softmax_cross_entropy_with_logits (
541- logits , mtf_features ["targets" ], vocab_dim )
542- cross_entropy *= mtf .cast (
543- mtf .not_equal (targets , 0 ), cross_entropy .dtype )
544- if mode == "delimited_lm" :
545- cross_entropy *= mtf .cast (mtf .logical_not (
546- transformer .delimited_lm_inputs_mask (targets )), cross_entropy .dtype )
547- scores = - mtf .reduce_sum (cross_entropy , reduced_dim = length_dim )
548-
549- # convert log prob to prob
550- probabilities = mtf .exp (scores )
551- probabilities = mtf .anonymize (probabilities )
529+ # calculate log likelihood
530+ scores = compute_score (logits , targets_for_score , model_type )
552531
553532 mtf_samples = mtf .anonymize (mtf_samples )
554533 inputs = mtf .anonymize (inputs )
555534 lowering = mtf .Lowering (graph , {mesh : mesh_impl }, autostack = autostack )
556535 inputs = clean_decodes (lowering .export_to_tf_tensor (inputs ))
557536 outputs = clean_decodes (lowering .export_to_tf_tensor (mtf_samples ))
558- probabilities = lowering .export_to_tf_tensor (probabilities )
537+ scores = lowering .export_to_tf_tensor (scores )
559538
560539 # Detokenize in the graph if supported by vocabulary and accelerator.
561540 def _maybe_detokenize (ids , vocab ):
@@ -569,9 +548,9 @@ def _maybe_detokenize(ids, vocab):
569548 predictions = {
570549 "inputs" : inputs ,
571550 "outputs" : outputs ,
572- "probabilities " : probabilities
551+ "scores " : scores
573552 }
574-
553+
575554 if mode in ["score" , tf .estimator .ModeKeys .PREDICT ]:
576555 # When exporting a model, we need to communicate to TF-Serving that
577556 # master variables need to be copied to their slave slice variables.
@@ -1238,6 +1217,32 @@ def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1):
12381217 return tf .where_v2 (valid_ids , ids , pad_id )
12391218
12401219
1220+ def compute_score (logits , targets , model_type ):
1221+ """Compute the log likelihood given logits and targets.
1222+
1223+ Args:
1224+ logits: A mtf Tensor with floating-point dtype, containing the predicted
1225+ relative log probabilities of the classes.
1226+ targets: A mtf Tensor with integer dtype whose values are in the range
1227+ [0, vocab_dim.size).
1228+ model_type: a string. One of "bitransformer", "lm", "delimited_lm",
1229+ "aligned", or "bi_teacher_student"
1230+
1231+ Returns:
1232+ a float mtf.Tensor with the log likelihood.
1233+ """
1234+ batch_dim , length_dim , vocab_dim = logits .shape .dims
1235+ cross_entropy = mtf .layers .softmax_cross_entropy_with_logits (
1236+ logits , targets , vocab_dim )
1237+ cross_entropy *= mtf .cast (
1238+ mtf .not_equal (targets , 0 ), cross_entropy .dtype )
1239+ if model_type == "delimited_lm" :
1240+ cross_entropy *= mtf .cast (mtf .logical_not (
1241+ transformer .delimited_lm_inputs_mask (targets )), cross_entropy .dtype )
1242+ scores = - mtf .reduce_sum (cross_entropy , reduced_dim = length_dim )
1243+ scores = mtf .anonymize (scores )
1244+ return scores
1245+
12411246
12421247def _score_with_estimator (estimator , input_fn , eval_checkpoint_step , model_dir ,
12431248 scores_filename , num_examples = None ):
0 commit comments