Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 7a96222

Browse files
allen-qallen-qin
authored andcommitted
added probabilities for generated text in inference mode
1 parent 897511d commit 7a96222

File tree

1 file changed

+38
-2
lines changed

1 file changed

+38
-2
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,44 @@ def _verify_feature_exists(feature_name, should_exist):
518518
inputs, variable_dtype=get_variable_dtype())
519519
else:
520520
raise ValueError("unrecognized class")
521+
522+
# calculate probabilities for the output texts
523+
# Replaces everything after EOS with 0 (along last dim).
524+
eos_and_after = mtf.cumsum(mtf.cast(mtf.equal(mtf_samples, 1), tf.int32),
525+
exclusive=True, dim=mtf_samples.shape[1])
526+
valid_ids = mtf.equal(eos_and_after, 0)
527+
targets_for_score = mtf.where(valid_ids, mtf_samples, 0)
528+
529+
logits, _ = transformer_model.call_simple(
530+
inputs=inputs,
531+
targets=targets_for_score,
532+
compute_loss=False,
533+
mode='score',
534+
variable_dtype=get_variable_dtype())
535+
536+
# calculate log probability
537+
targets = mtf_features["targets"] = targets_for_score
538+
539+
batch_dim, length_dim, vocab_dim = logits.shape.dims
540+
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
541+
logits, mtf_features["targets"], vocab_dim)
542+
cross_entropy *= mtf.cast(
543+
mtf.not_equal(targets, 0), cross_entropy.dtype)
544+
if mode == "delimited_lm":
545+
cross_entropy *= mtf.cast(mtf.logical_not(
546+
transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype)
547+
scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
548+
549+
# convert log prob to prob
550+
probabilities = mtf.exp(scores)
551+
probabilities = mtf.anonymize(probabilities)
552+
521553
mtf_samples = mtf.anonymize(mtf_samples)
522554
inputs = mtf.anonymize(inputs)
523555
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
524556
inputs = clean_decodes(lowering.export_to_tf_tensor(inputs))
525557
outputs = clean_decodes(lowering.export_to_tf_tensor(mtf_samples))
558+
probabilities = lowering.export_to_tf_tensor(probabilities)
526559

527560
# Detokenize in the graph if supported by vocabulary and accelerator.
528561
def _maybe_detokenize(ids, vocab):
@@ -535,7 +568,9 @@ def _maybe_detokenize(ids, vocab):
535568

536569
predictions = {
537570
"inputs": inputs,
538-
"outputs": outputs}
571+
"outputs": outputs,
572+
"probabilities": probabilities
573+
}
539574

540575
if mode in ["score", tf.estimator.ModeKeys.PREDICT]:
541576
# When exporting a model, we need to communicate to TF-Serving that
@@ -1203,6 +1238,7 @@ def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1):
12031238
return tf.where_v2(valid_ids, ids, pad_id)
12041239

12051240

1241+
12061242
def _score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
12071243
scores_filename, num_examples=None):
12081244
"""For each example returned by input_fn, compute log likelihood.
@@ -2217,4 +2253,4 @@ def _input_fn(params, eval_dataset):
22172253
else:
22182254
raise ValueError(
22192255
"unknown mode %s - must be train/perplexity_eval/eval/infer/export"
2220-
% mode)
2256+
% mode)

0 commit comments

Comments
 (0)