2626
2727import functools
2828import itertools
29+ import math
2930import os
3031import random
3132import re
33+ import time
3234
3335import gin
3436import gin .tf
@@ -1654,6 +1656,54 @@ def get_sequence_length(tokens, pad_id=0):
16541656 return scores
16551657
16561658
1659+ @gin .configurable
1660+ def save_scores_to_tfrecords (
1661+ results , vocabulary , scores_filename , shard_idx = 0 , save_ids_only = False ):
1662+ """Processes results from scoring examples and saves them to tfrecords files.
1663+
1664+ Args:
1665+ results: list of dictionaries containing the results for each scored
1666+ example.
1667+ vocabulary: a function that that returns a tf.data.Dataset with examples
1668+ containing the string field 'targets' and optionally the field 'inputs'
1669+ scores_filename: a string (path of file to write scores to).
1670+ shard_idx: an integer indicating the current index of the file for sharding.
1671+ save_ids_only: if true, save the ID that is prepended to the inputs,
1672+ delimited by a space.
1673+ """
1674+ results = _maybe_add_pretokenized_features (results , vocabulary )
1675+ scores = [r .get ("scores" , 0.0 ) for r in results ]
1676+ targets = [r .get ("targets_pretokenized" , r ["targets" ]) for r in results ]
1677+ inputs = [r .get ("targets_neg_pretokenized" ,
1678+ r .get ("inputs" , "" )) for r in results ]
1679+
1680+ if save_ids_only :
1681+ inputs = [r .split (" " , 1 )[0 ] for r in inputs ]
1682+
1683+ table_path = "{}_{}.tfrecord" .format (scores_filename , shard_idx )
1684+ tf .logging .info ("Saving results to {}" .format (table_path ))
1685+
1686+ with tf .io .TFRecordWriter (table_path ) as file_writer :
1687+ for input_ , target , score in zip (inputs , targets , scores ):
1688+ record_bytes = tf .train .Example (
1689+ features = tf .train .Features (
1690+ feature = {
1691+ "input" :
1692+ tf .train .Feature (
1693+ bytes_list = tf .train .BytesList (
1694+ value = [bytes (input_ , "utf8" )])),
1695+ "target" :
1696+ tf .train .Feature (
1697+ bytes_list = tf .train .BytesList (
1698+ value = [bytes (target , "utf8" )])),
1699+ "score" :
1700+ tf .train .Feature (
1701+ float_list = tf .train .FloatList (value = [score ])),
1702+ })).SerializeToString ()
1703+ file_writer .write (record_bytes )
1704+
1705+
1706+ @gin .configurable
16571707def score_with_estimator (estimator , input_fn , eval_checkpoint_step , model_dir ,
16581708 vocabulary , score_postprocess_fn = save_scores ,
16591709 num_examples = None ):
@@ -1691,6 +1741,74 @@ def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
16911741 return score_postprocess_fn (results , vocabulary )
16921742
16931743
1744+ @gin .configurable
1745+ def score_with_estimator_lazy (
1746+ estimator , input_fn , eval_checkpoint_step , model_dir ,
1747+ vocabulary , score_postprocess_fn = save_scores_to_tfrecords ,
1748+ num_examples = None , num_examples_per_shard = 100000 ):
1749+ """Score each example returned by input_fn lazily.
1750+
1751+ Args:
1752+ estimator: a TPUEstimator
1753+ input_fn: a function that that returns a tf.data.Dataset with examples
1754+ containing the string field 'targets' and optionally the field 'inputs'
1755+ eval_checkpoint_step: int, list of ints, or None, see `eval_model`
1756+ docstring.
1757+ model_dir: string, estimator model_dir
1758+ vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
1759+ targets_vocabulary) tuple
1760+ score_postprocess_fn: a function that takes in model outputs
1761+ post-processes, and saves them.
1762+ num_examples: int, the total # of examples being scored, None if unknown
1763+ num_examples_per_shard: int, the number of examples per file shard.
1764+
1765+ Returns:
1766+ a list of floats
1767+ """
1768+ if num_examples is not None :
1769+ num_shards = math .ceil (num_examples / num_examples_per_shard )
1770+ else :
1771+ num_shards = None
1772+ tf .logging .info (
1773+ "Scoring {} examples with {} shards at {} examples per shard" .format (
1774+ num_examples , num_shards , num_examples_per_shard ))
1775+
1776+ checkpoint_path , = get_checkpoint_iterator (
1777+ eval_checkpoint_step , model_dir )
1778+ result_iter = estimator .predict (input_fn , checkpoint_path = checkpoint_path )
1779+
1780+ start = time .time ()
1781+ results = []
1782+ shard_idx = 0
1783+
1784+ for i , result in enumerate (result_iter ):
1785+ results .append (result )
1786+ num_results = len (results )
1787+ exceeded_examples_per_shard = (
1788+ num_examples_per_shard is not None
1789+ and num_examples_per_shard > 0
1790+ and num_results >= num_examples_per_shard )
1791+ exceeded_num_examples = num_examples is not None and i >= num_examples
1792+
1793+ if exceeded_examples_per_shard or exceeded_num_examples :
1794+ score_postprocess_fn (results , vocabulary , shard_idx = shard_idx )
1795+
1796+ elapsed = time .time () - start
1797+ tf .logging .info (
1798+ "Scored {} results in {} s, {} examples/s for shard {}" .format (
1799+ num_results , elapsed , num_results / elapsed , shard_idx ))
1800+
1801+ results = []
1802+ shard_idx += 1
1803+ start = time .time ()
1804+
1805+ if exceeded_num_examples :
1806+ break
1807+
1808+ if results :
1809+ score_postprocess_fn (results , vocabulary , shard_idx = shard_idx )
1810+
1811+
16941812def _maybe_add_pretokenized_features (examples , vocabulary ):
16951813 """Ensures decoded versions of "inputs" and "targets" exist in each example.
16961814
@@ -1712,9 +1830,19 @@ def _maybe_add_pretokenized_features(examples, vocabulary):
17121830 for example in examples :
17131831 for feature_name in ["inputs" , "targets" ]:
17141832 pretokenized_feature_name = feature_name + "_pretokenized"
1833+ neg_pretokenized_feature_name = feature_name + "_neg_pretokenized"
17151834 if feature_name in example and pretokenized_feature_name not in example :
1716- s = vocabulary [feature_name ].decode (example [feature_name ].tolist ())
1717- example [pretokenized_feature_name ] = s
1835+ ids = example [feature_name ].tolist ()
1836+
1837+ neg_ids = [abs (i ) for i in ids if i < 0 ]
1838+ ids = [i for i in ids if i > 0 ]
1839+
1840+ decoded_string = vocabulary [feature_name ].decode (ids )
1841+ example [pretokenized_feature_name ] = decoded_string
1842+
1843+ if neg_ids :
1844+ neg_decoded_string = vocabulary [feature_name ].decode (neg_ids )
1845+ example [neg_pretokenized_feature_name ] = neg_decoded_string
17181846
17191847 if not added_pretokenized [feature_name ]:
17201848 added_pretokenized [feature_name ] = True
@@ -1730,7 +1858,8 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
17301858 sequence_length , model_dir , eval_checkpoint_step ,
17311859 inputs = gin .REQUIRED , targets = gin .REQUIRED ,
17321860 score_postprocess_fn = gin .REQUIRED , eos_id = 1 ,
1733- score_eos = True ):
1861+ score_eos = True ,
1862+ score_with_estimator_fn = score_with_estimator ):
17341863 """Compute log likelihoods per example and write to a text file.
17351864
17361865 inputs & targets must either be the same length (in lines) or have inputs
@@ -1761,6 +1890,7 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
17611890 score_eos: a boolean - whether to score the final eos token of each line
17621891 If this is set to false, the scores can be interpreted as prefix
17631892 log-likelihoods
1893+ score_with_estimator_fn: a function to run scoring with the estimator.
17641894 Returns:
17651895 a list of floats
17661896 """
@@ -1806,7 +1936,7 @@ def input_fn(params):
18061936 dataset = dataset .batch (batch_size , drop_remainder = True )
18071937 return dataset .prefetch (tf .data .experimental .AUTOTUNE )
18081938
1809- return score_with_estimator (
1939+ return score_with_estimator_fn (
18101940 estimator , input_fn , eval_checkpoint_step , model_dir ,
18111941 vocabulary , score_postprocess_fn , len (targets ))
18121942
@@ -1815,7 +1945,8 @@ def input_fn(params):
18151945def score_from_dataset (estimator , vocabulary , batch_size , sequence_length ,
18161946 model_dir , eval_checkpoint_step , dataset_split ,
18171947 score_dataset_fn = None ,
1818- score_postprocess_fn = gin .REQUIRED ):
1948+ score_postprocess_fn = gin .REQUIRED ,
1949+ score_with_estimator_fn = score_with_estimator ):
18191950 """Compute log likelihoods per example and write to a text file.
18201951
18211952 The function returns a list of floats representing the log-likelihood of the
@@ -1837,6 +1968,7 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
18371968 See `eval_dataset_fn` argument to `eval_model` for details.
18381969 score_postprocess_fn: Function that takes in model outputs and
18391970 post-processes then returns then.
1971+ score_with_estimator_fn: a function to run scoring with the estimator.
18401972
18411973 Returns:
18421974 scores: a list of floats, the log likelihood scores
@@ -1850,9 +1982,9 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
18501982 input_fn = _get_combined_dataset_input_fn (
18511983 scoring_datasets , batch_size , sequence_length )
18521984
1853- return score_with_estimator (
1985+ return score_with_estimator_fn (
18541986 estimator , input_fn , eval_checkpoint_step , model_dir ,
1855- vocabulary , score_postprocess_fn , None )
1987+ vocabulary , score_postprocess_fn )
18561988
18571989
18581990def get_estimator (model_type , vocabulary , mesh_shape ,
@@ -2093,7 +2225,8 @@ def eval_model(estimator,
20932225 eval_checkpoint_step ,
20942226 eval_with_score = False ,
20952227 output_eval_examples = True ,
2096- eval_dir_suffix = None ):
2228+ eval_dir_suffix = None ,
2229+ score_with_estimator_fn = score_with_estimator ):
20972230 """Eval a Mesh-TF model.
20982231
20992232 Args:
@@ -2137,6 +2270,7 @@ def eval_model(estimator,
21372270 of the eval examples in plaintext to eval_summary_dir.
21382271 eval_dir_suffix: string, if not None then will appended to the
21392272 eval_summary_dir.
2273+ score_with_estimator_fn: a function to run scoring with the estimator.
21402274 """
21412275 if eval_dataset_fn is None :
21422276 raise ValueError ("Must provide eval_dataset_fn through gin for eval." )
@@ -2248,7 +2382,7 @@ def eval_model(estimator,
22482382 tf .logging .info ("Checkpoint path %s" % checkpoint_path )
22492383 global_step = int (get_step_from_checkpoint_path (checkpoint_path ))
22502384 if eval_with_score :
2251- outputs , _ = score_with_estimator (
2385+ outputs , _ = score_with_estimator_fn (
22522386 estimator , input_fn , global_step , model_dir , vocabulary ,
22532387 num_examples = sum (len (cex ) for cex in cached_examples .values ()))
22542388 else :
0 commit comments