Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 54 additions & 13 deletions mesh_tensorflow/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,16 +483,9 @@ def _verify_feature_exists(feature_name, should_exist):
compute_loss=False,
mode=mode,
variable_dtype=get_variable_dtype())
batch_dim, length_dim, vocab_dim = logits.shape.dims
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
logits, mtf_features["targets"], vocab_dim)
cross_entropy *= mtf.cast(
mtf.not_equal(targets, 0), cross_entropy.dtype)
if model_type == "delimited_lm":
cross_entropy *= mtf.cast(mtf.logical_not(
transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype)
scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
scores = mtf.anonymize(scores)

# calculate log likelihood
scores = compute_score(logits, targets, model_type)
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
predictions = {
"scores": lowering.export_to_tf_tensor(scores)
Expand All @@ -518,11 +511,30 @@ def _verify_feature_exists(feature_name, should_exist):
inputs, variable_dtype=get_variable_dtype())
else:
raise ValueError("unrecognized class")

# calculate probabilities for the output texts
# Replaces everything after EOS with 0 (along last dim).
eos_and_after = mtf.cumsum(mtf.cast(mtf.equal(mtf_samples, 1), tf.int32),
exclusive=True, dim=mtf_samples.shape[1])
valid_ids = mtf.equal(eos_and_after, 0)
targets_for_score = mtf.where(valid_ids, mtf_samples, 0)

logits, _ = transformer_model.call_simple(
inputs=inputs,
targets=targets_for_score,
compute_loss=False,
mode='score',
variable_dtype=get_variable_dtype())

# calculate log likelihood
scores = compute_score(logits, targets_for_score, model_type)

mtf_samples = mtf.anonymize(mtf_samples)
inputs = mtf.anonymize(inputs)
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
inputs = clean_decodes(lowering.export_to_tf_tensor(inputs))
outputs = clean_decodes(lowering.export_to_tf_tensor(mtf_samples))
scores = lowering.export_to_tf_tensor(scores)

# Detokenize in the graph if supported by vocabulary and accelerator.
def _maybe_detokenize(ids, vocab):
Expand All @@ -535,8 +547,10 @@ def _maybe_detokenize(ids, vocab):

predictions = {
"inputs": inputs,
"outputs": outputs}

"outputs": outputs,
"scores": scores
}

if mode in ["score", tf.estimator.ModeKeys.PREDICT]:
# When exporting a model, we need to communicate to TF-Serving that
# master variables need to be copied to their slave slice variables.
Expand Down Expand Up @@ -1203,6 +1217,33 @@ def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1):
return tf.where_v2(valid_ids, ids, pad_id)


def compute_score(logits, targets, model_type):
"""Compute the log likelihood given logits and targets.

Args:
logits: A mtf Tensor with floating-point dtype, containing the predicted
relative log probabilities of the classes.
targets: A mtf Tensor with integer dtype whose values are in the range
[0, vocab_dim.size).
model_type: a string. One of "bitransformer", "lm", "delimited_lm",
"aligned", or "bi_teacher_student"

Returns:
a float mtf.Tensor with the log likelihood.
"""
batch_dim, length_dim, vocab_dim = logits.shape.dims
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
logits, targets, vocab_dim)
cross_entropy *= mtf.cast(
mtf.not_equal(targets, 0), cross_entropy.dtype)
if model_type == "delimited_lm":
cross_entropy *= mtf.cast(mtf.logical_not(
transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype)
scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
scores = mtf.anonymize(scores)
return scores


def _score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
scores_filename, num_examples=None):
"""For each example returned by input_fn, compute log likelihood.
Expand Down Expand Up @@ -2217,4 +2258,4 @@ def _input_fn(params, eval_dataset):
else:
raise ValueError(
"unknown mode %s - must be train/perplexity_eval/eval/infer/export"
% mode)
% mode)