@@ -518,11 +518,44 @@ def _verify_feature_exists(feature_name, should_exist):
518518 inputs , variable_dtype = get_variable_dtype ())
519519 else :
520520 raise ValueError ("unrecognized class" )
521+
522+ # calculate probabilities for the output texts
523+ # Replaces everything after EOS with 0 (along last dim).
524+ eos_and_after = mtf .cumsum (mtf .cast (mtf .equal (mtf_samples , 1 ), tf .int32 ),
525+ exclusive = True , dim = mtf_samples .shape [1 ])
526+ valid_ids = mtf .equal (eos_and_after , 0 )
527+ targets_for_score = mtf .where (valid_ids , mtf_samples , 0 )
528+
529+ logits , _ = transformer_model .call_simple (
530+ inputs = inputs ,
531+ targets = targets_for_score ,
532+ compute_loss = False ,
533+ mode = 'score' ,
534+ variable_dtype = get_variable_dtype ())
535+
536+ # calculate log probability
537+ targets = mtf_features ["targets" ] = targets_for_score
538+
539+ batch_dim , length_dim , vocab_dim = logits .shape .dims
540+ cross_entropy = mtf .layers .softmax_cross_entropy_with_logits (
541+ logits , mtf_features ["targets" ], vocab_dim )
542+ cross_entropy *= mtf .cast (
543+ mtf .not_equal (targets , 0 ), cross_entropy .dtype )
544+ if mode == "delimited_lm" :
545+ cross_entropy *= mtf .cast (mtf .logical_not (
546+ transformer .delimited_lm_inputs_mask (targets )), cross_entropy .dtype )
547+ scores = - mtf .reduce_sum (cross_entropy , reduced_dim = length_dim )
548+
549+ # convert log prob to prob
550+ probabilities = mtf .exp (scores )
551+ probabilities = mtf .anonymize (probabilities )
552+
521553 mtf_samples = mtf .anonymize (mtf_samples )
522554 inputs = mtf .anonymize (inputs )
523555 lowering = mtf .Lowering (graph , {mesh : mesh_impl }, autostack = autostack )
524556 inputs = clean_decodes (lowering .export_to_tf_tensor (inputs ))
525557 outputs = clean_decodes (lowering .export_to_tf_tensor (mtf_samples ))
558+ probabilities = lowering .export_to_tf_tensor (probabilities )
526559
527560 # Detokenize in the graph if supported by vocabulary and accelerator.
528561 def _maybe_detokenize (ids , vocab ):
@@ -535,7 +568,9 @@ def _maybe_detokenize(ids, vocab):
535568
536569 predictions = {
537570 "inputs" : inputs ,
538- "outputs" : outputs }
571+ "outputs" : outputs ,
572+ "probabilities" : probabilities
573+ }
539574
540575 if mode in ["score" , tf .estimator .ModeKeys .PREDICT ]:
541576 # When exporting a model, we need to communicate to TF-Serving that
@@ -1203,6 +1238,7 @@ def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1):
12031238 return tf .where_v2 (valid_ids , ids , pad_id )
12041239
12051240
1241+
12061242def _score_with_estimator (estimator , input_fn , eval_checkpoint_step , model_dir ,
12071243 scores_filename , num_examples = None ):
12081244 """For each example returned by input_fn, compute log likelihood.
@@ -2217,4 +2253,4 @@ def _input_fn(params, eval_dataset):
22172253 else :
22182254 raise ValueError (
22192255 "unknown mode %s - must be train/perplexity_eval/eval/infer/export"
2220- % mode )
2256+ % mode )
0 commit comments