@@ -773,27 +773,25 @@ def serialized_fn(mtf_features):
773773 gin_config_saver_hook = gin .tf .GinConfigSaverHook (
774774 model_dir , summarize_config = True , include_step_in_filename = False )
775775
776+ training_hooks = [
777+ restore_hook ,
778+ saver_hook ,
779+ gin_config_saver_hook ,
780+ ]
781+
776782 if use_tpu :
777783 return tpu_estimator .TPUEstimatorSpec (
778784 mode = tf .estimator .ModeKeys .TRAIN ,
779785 loss = tf_loss ,
780786 train_op = train_op ,
781787 host_call = host_call ,
782- training_hooks = [
783- restore_hook ,
784- saver_hook ,
785- gin_config_saver_hook ,
786- ])
788+ training_hooks = training_hooks )
787789 else :
788790 return tf .estimator .EstimatorSpec (
789791 tf .estimator .ModeKeys .TRAIN ,
790792 loss = tf_loss ,
791793 train_op = train_op ,
792- training_chief_hooks = [
793- restore_hook ,
794- saver_hook ,
795- gin_config_saver_hook ,
796- ])
794+ training_chief_hooks = training_hooks )
797795 elif mode == tf .estimator .ModeKeys .EVAL :
798796 # perplexity eval
799797 logits , loss = logits_and_loss (mtf_features )
@@ -1698,9 +1696,7 @@ def get_estimator(model_type, vocabulary, mesh_shape,
16981696 model_dir = model_dir ,
16991697 tpu_config = my_tpu_config ,
17001698 session_config = session_config ,
1701- # We use a saver hook, so disable checkpoints here to prevent double
1702- # saving.
1703- save_checkpoints_steps = None ,
1699+ save_checkpoints_steps = save_checkpoints_steps ,
17041700 save_checkpoints_secs = None )
17051701
17061702 transformer_model = build_model (
@@ -1748,7 +1744,7 @@ def get_estimator(model_type, vocabulary, mesh_shape,
17481744def train_model (estimator , vocabulary , sequence_length , batch_size ,
17491745 train_dataset_fn , train_steps , ensemble_inputs ,
17501746 dataset_split = "train" , skip_seen_data = False ,
1751- seen_data_init_step = 0 ):
1747+ seen_data_init_step = 0 , checkpoint_input_pipeline = False ):
17521748 """Train a Mesh-TF model.
17531749
17541750 Args:
@@ -1773,11 +1769,20 @@ def train_model(estimator, vocabulary, sequence_length, batch_size,
17731769 skip_seen_data: a boolean, is `False` by default. Used when a training run
17741770 restarts to skip already seen data. This flag is only consistent when
17751771 every setting (such as batch size and random seed) on the model is the
1776- same between the original run and the new run.
1772+ same between the original run and the new run. May require a significant
1773+ amount of time to skip a large number of steps.
17771774 seen_data_init_step: an integer, when `skip_seen_data` is True, skip seen
17781775 steps from this starting point. Useful when finetuning.
1776+ checkpoint_input_pipeline: a boolean, whether to checkpoint the input
1777+ pipeline in order to restart from the previous run. May require a large
1778+ amount of disk space for complicated input pipelines.
17791779 """
17801780
1781+ if skip_seen_data and checkpoint_input_pipeline :
1782+ raise ValueError (
1783+ "At most one of `skip_seen_data` and `checkpoint_input_pipeline` may "
1784+ "be set." )
1785+
17811786 def input_fn (params ):
17821787 del params
17831788
@@ -1799,7 +1804,12 @@ def input_fn(params):
17991804 dataset = dataset .skip (steps_to_skip )
18001805 return dataset
18011806
1802- estimator .train (input_fn = input_fn , max_steps = train_steps )
1807+ hooks = []
1808+ if checkpoint_input_pipeline :
1809+ hooks .append (
1810+ tf .data .experimental .CheckpointInputPipelineHook (estimator ))
1811+
1812+ estimator .train (input_fn = input_fn , max_steps = train_steps , hooks = hooks )
18031813
18041814
18051815@gin .configurable
@@ -2399,7 +2409,8 @@ def run(tpu_job_name,
23992409 train_model_fn = train_model ,
24002410 skip_seen_data = False ,
24012411 seen_data_init_step = 0 ,
2402- output_eval_examples = True ):
2412+ output_eval_examples = True ,
2413+ checkpoint_input_pipeline = False ):
24032414 """Run training, eval, or inference depending on `mode`.
24042415
24052416 Args:
@@ -2465,12 +2476,16 @@ def run(tpu_job_name,
24652476 skip_seen_data: a boolean, is `False` by default. Used when a training run
24662477 restarts to skip already seen data. This flag is only consistent when
24672478 every setting (such as batch size and random seed) on the model is the
2468- same between the original run and the new run.
2479+ same between the original run and the new run. May require a significant
2480+ amount of time to skip a large number of steps.
24692481 seen_data_init_step: an integer, when `skip_seen_data` is True, skip seen
24702482 steps from this starting point. Useful when finetuning.
24712483 output_eval_examples: a boolean, is `True` by default. Used to decide
24722484 whether to output whether to dump inputs, targets, and predictions of the
24732485 eval examples in plaintext to eval_summary_dir.
2486+ checkpoint_input_pipeline: a boolean, whether to checkpoint the input
2487+ pipeline in order to restart from the previous run. May require a large
2488+ amount of disk space for complicated input pipelines.
24742489 """
24752490 if isinstance (sequence_length , int ):
24762491 sequence_length = {"inputs" : sequence_length ,
@@ -2560,7 +2575,8 @@ def run(tpu_job_name,
25602575 train_model_fn (estimator , vocabulary , sequence_length , batch_size ,
25612576 train_dataset_fn , train_steps , ensemble_inputs ,
25622577 skip_seen_data = skip_seen_data ,
2563- seen_data_init_step = seen_data_init_step )
2578+ seen_data_init_step = seen_data_init_step ,
2579+ checkpoint_input_pipeline = checkpoint_input_pipeline )
25642580
25652581 elif mode == "perplexity_eval" :
25662582 if eval_dataset_fn is None :
0 commit comments