Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit cfc7a67

Browse files
HyperparticleMesh TensorFlow Team
authored andcommitted
Add utility to save score predictions to TFRecords for scoring large datasets.
PiperOrigin-RevId: 396705745
1 parent f08b18e commit cfc7a67

File tree

1 file changed

+143
-9
lines changed

1 file changed

+143
-9
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 143 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626

2727
import functools
2828
import itertools
29+
import math
2930
import os
3031
import random
3132
import re
33+
import time
3234

3335
import gin
3436
import gin.tf
@@ -1654,6 +1656,54 @@ def get_sequence_length(tokens, pad_id=0):
16541656
return scores
16551657

16561658

1659+
@gin.configurable
1660+
def save_scores_to_tfrecords(
1661+
results, vocabulary, scores_filename, shard_idx=0, save_ids_only=False):
1662+
"""Processes results from scoring examples and saves them to tfrecords files.
1663+
1664+
Args:
1665+
results: list of dictionaries containing the results for each scored
1666+
example.
1667+
vocabulary: a function that that returns a tf.data.Dataset with examples
1668+
containing the string field 'targets' and optionally the field 'inputs'
1669+
scores_filename: a string (path of file to write scores to).
1670+
shard_idx: an integer indicating the current index of the file for sharding.
1671+
save_ids_only: if true, save the ID that is prepended to the inputs,
1672+
delimited by a space.
1673+
"""
1674+
results = _maybe_add_pretokenized_features(results, vocabulary)
1675+
scores = [r.get("scores", 0.0) for r in results]
1676+
targets = [r.get("targets_pretokenized", r["targets"]) for r in results]
1677+
inputs = [r.get("targets_neg_pretokenized",
1678+
r.get("inputs", "")) for r in results]
1679+
1680+
if save_ids_only:
1681+
inputs = [r.split(" ", 1)[0] for r in inputs]
1682+
1683+
table_path = "{}_{}.tfrecord".format(scores_filename, shard_idx)
1684+
tf.logging.info("Saving results to {}".format(table_path))
1685+
1686+
with tf.io.TFRecordWriter(table_path) as file_writer:
1687+
for input_, target, score in zip(inputs, targets, scores):
1688+
record_bytes = tf.train.Example(
1689+
features=tf.train.Features(
1690+
feature={
1691+
"input":
1692+
tf.train.Feature(
1693+
bytes_list=tf.train.BytesList(
1694+
value=[bytes(input_, "utf8")])),
1695+
"target":
1696+
tf.train.Feature(
1697+
bytes_list=tf.train.BytesList(
1698+
value=[bytes(target, "utf8")])),
1699+
"score":
1700+
tf.train.Feature(
1701+
float_list=tf.train.FloatList(value=[score])),
1702+
})).SerializeToString()
1703+
file_writer.write(record_bytes)
1704+
1705+
1706+
@gin.configurable
16571707
def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
16581708
vocabulary, score_postprocess_fn=save_scores,
16591709
num_examples=None):
@@ -1691,6 +1741,74 @@ def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
16911741
return score_postprocess_fn(results, vocabulary)
16921742

16931743

1744+
@gin.configurable
1745+
def score_with_estimator_lazy(
1746+
estimator, input_fn, eval_checkpoint_step, model_dir,
1747+
vocabulary, score_postprocess_fn=save_scores_to_tfrecords,
1748+
num_examples=None, num_examples_per_shard=100000):
1749+
"""Score each example returned by input_fn lazily.
1750+
1751+
Args:
1752+
estimator: a TPUEstimator
1753+
input_fn: a function that that returns a tf.data.Dataset with examples
1754+
containing the string field 'targets' and optionally the field 'inputs'
1755+
eval_checkpoint_step: int, list of ints, or None, see `eval_model`
1756+
docstring.
1757+
model_dir: string, estimator model_dir
1758+
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
1759+
targets_vocabulary) tuple
1760+
score_postprocess_fn: a function that takes in model outputs
1761+
post-processes, and saves them.
1762+
num_examples: int, the total # of examples being scored, None if unknown
1763+
num_examples_per_shard: int, the number of examples per file shard.
1764+
1765+
Returns:
1766+
a list of floats
1767+
"""
1768+
if num_examples is not None:
1769+
num_shards = math.ceil(num_examples / num_examples_per_shard)
1770+
else:
1771+
num_shards = None
1772+
tf.logging.info(
1773+
"Scoring {} examples with {} shards at {} examples per shard".format(
1774+
num_examples, num_shards, num_examples_per_shard))
1775+
1776+
checkpoint_path, = get_checkpoint_iterator(
1777+
eval_checkpoint_step, model_dir)
1778+
result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)
1779+
1780+
start = time.time()
1781+
results = []
1782+
shard_idx = 0
1783+
1784+
for i, result in enumerate(result_iter):
1785+
results.append(result)
1786+
num_results = len(results)
1787+
exceeded_examples_per_shard = (
1788+
num_examples_per_shard is not None
1789+
and num_examples_per_shard > 0
1790+
and num_results >= num_examples_per_shard)
1791+
exceeded_num_examples = num_examples is not None and i >= num_examples
1792+
1793+
if exceeded_examples_per_shard or exceeded_num_examples:
1794+
score_postprocess_fn(results, vocabulary, shard_idx=shard_idx)
1795+
1796+
elapsed = time.time() - start
1797+
tf.logging.info(
1798+
"Scored {} results in {} s, {} examples/s for shard {}".format(
1799+
num_results, elapsed, num_results / elapsed, shard_idx))
1800+
1801+
results = []
1802+
shard_idx += 1
1803+
start = time.time()
1804+
1805+
if exceeded_num_examples:
1806+
break
1807+
1808+
if results:
1809+
score_postprocess_fn(results, vocabulary, shard_idx=shard_idx)
1810+
1811+
16941812
def _maybe_add_pretokenized_features(examples, vocabulary):
16951813
"""Ensures decoded versions of "inputs" and "targets" exist in each example.
16961814
@@ -1712,9 +1830,19 @@ def _maybe_add_pretokenized_features(examples, vocabulary):
17121830
for example in examples:
17131831
for feature_name in ["inputs", "targets"]:
17141832
pretokenized_feature_name = feature_name + "_pretokenized"
1833+
neg_pretokenized_feature_name = feature_name + "_neg_pretokenized"
17151834
if feature_name in example and pretokenized_feature_name not in example:
1716-
s = vocabulary[feature_name].decode(example[feature_name].tolist())
1717-
example[pretokenized_feature_name] = s
1835+
ids = example[feature_name].tolist()
1836+
1837+
neg_ids = [abs(i) for i in ids if i < 0]
1838+
ids = [i for i in ids if i > 0]
1839+
1840+
decoded_string = vocabulary[feature_name].decode(ids)
1841+
example[pretokenized_feature_name] = decoded_string
1842+
1843+
if neg_ids:
1844+
neg_decoded_string = vocabulary[feature_name].decode(neg_ids)
1845+
example[neg_pretokenized_feature_name] = neg_decoded_string
17181846

17191847
if not added_pretokenized[feature_name]:
17201848
added_pretokenized[feature_name] = True
@@ -1730,7 +1858,8 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
17301858
sequence_length, model_dir, eval_checkpoint_step,
17311859
inputs=gin.REQUIRED, targets=gin.REQUIRED,
17321860
score_postprocess_fn=gin.REQUIRED, eos_id=1,
1733-
score_eos=True):
1861+
score_eos=True,
1862+
score_with_estimator_fn=score_with_estimator):
17341863
"""Compute log likelihoods per example and write to a text file.
17351864
17361865
inputs & targets must either be the same length (in lines) or have inputs
@@ -1761,6 +1890,7 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
17611890
score_eos: a boolean - whether to score the final eos token of each line
17621891
If this is set to false, the scores can be interpreted as prefix
17631892
log-likelihoods
1893+
score_with_estimator_fn: a function to run scoring with the estimator.
17641894
Returns:
17651895
a list of floats
17661896
"""
@@ -1806,7 +1936,7 @@ def input_fn(params):
18061936
dataset = dataset.batch(batch_size, drop_remainder=True)
18071937
return dataset.prefetch(tf.data.experimental.AUTOTUNE)
18081938

1809-
return score_with_estimator(
1939+
return score_with_estimator_fn(
18101940
estimator, input_fn, eval_checkpoint_step, model_dir,
18111941
vocabulary, score_postprocess_fn, len(targets))
18121942

@@ -1815,7 +1945,8 @@ def input_fn(params):
18151945
def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
18161946
model_dir, eval_checkpoint_step, dataset_split,
18171947
score_dataset_fn=None,
1818-
score_postprocess_fn=gin.REQUIRED):
1948+
score_postprocess_fn=gin.REQUIRED,
1949+
score_with_estimator_fn=score_with_estimator):
18191950
"""Compute log likelihoods per example and write to a text file.
18201951
18211952
The function returns a list of floats representing the log-likelihood of the
@@ -1837,6 +1968,7 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
18371968
See `eval_dataset_fn` argument to `eval_model` for details.
18381969
score_postprocess_fn: Function that takes in model outputs and
18391970
post-processes then returns then.
1971+
score_with_estimator_fn: a function to run scoring with the estimator.
18401972
18411973
Returns:
18421974
scores: a list of floats, the log likelihood scores
@@ -1850,9 +1982,9 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
18501982
input_fn = _get_combined_dataset_input_fn(
18511983
scoring_datasets, batch_size, sequence_length)
18521984

1853-
return score_with_estimator(
1985+
return score_with_estimator_fn(
18541986
estimator, input_fn, eval_checkpoint_step, model_dir,
1855-
vocabulary, score_postprocess_fn, None)
1987+
vocabulary, score_postprocess_fn)
18561988

18571989

18581990
def get_estimator(model_type, vocabulary, mesh_shape,
@@ -2093,7 +2225,8 @@ def eval_model(estimator,
20932225
eval_checkpoint_step,
20942226
eval_with_score=False,
20952227
output_eval_examples=True,
2096-
eval_dir_suffix=None):
2228+
eval_dir_suffix=None,
2229+
score_with_estimator_fn=score_with_estimator):
20972230
"""Eval a Mesh-TF model.
20982231
20992232
Args:
@@ -2137,6 +2270,7 @@ def eval_model(estimator,
21372270
of the eval examples in plaintext to eval_summary_dir.
21382271
eval_dir_suffix: string, if not None then will appended to the
21392272
eval_summary_dir.
2273+
score_with_estimator_fn: a function to run scoring with the estimator.
21402274
"""
21412275
if eval_dataset_fn is None:
21422276
raise ValueError("Must provide eval_dataset_fn through gin for eval.")
@@ -2248,7 +2382,7 @@ def eval_model(estimator,
22482382
tf.logging.info("Checkpoint path %s" % checkpoint_path)
22492383
global_step = int(get_step_from_checkpoint_path(checkpoint_path))
22502384
if eval_with_score:
2251-
outputs, _ = score_with_estimator(
2385+
outputs, _ = score_with_estimator_fn(
22522386
estimator, input_fn, global_step, model_dir, vocabulary,
22532387
num_examples=sum(len(cex) for cex in cached_examples.values()))
22542388
else:

0 commit comments

Comments
 (0)