@@ -1606,7 +1606,6 @@ def get_estimator(model_type, vocabulary, mesh_shape,
16061606 return estimator
16071607
16081608
1609- @gin .configurable
16101609def train_model (estimator , vocabulary , sequence_length , batch_size ,
16111610 train_dataset_fn , train_steps , ensemble_inputs ,
16121611 dataset_split = "train" , skip_seen_data = False ,
@@ -2282,6 +2281,7 @@ def run(tpu_job_name,
22822281 ensemble_inputs = None ,
22832282 train_model_fn = train_model ,
22842283 skip_seen_data = False ,
2284+ seen_data_init_step = 0 ,
22852285 output_eval_examples = True ):
22862286 """Run training, eval, or inference depending on `mode`.
22872287
@@ -2349,6 +2349,8 @@ def run(tpu_job_name,
23492349 restarts to skip already seen data. This flag is only consistent when
23502350 every setting (such as batch size and random seed) on the model is the
23512351 same between the original run and the new run.
2352+ seen_data_init_step: an integer, when `skip_seen_data` is True, skip seen
2353+ steps from this starting point. Useful when finetuning.
23522354 output_eval_examples: a boolean, is `True` by default. Used to decide
23532355 whether to output whether to dump inputs, targets, and predictions of the
23542356 eval examples in plaintext to eval_summary_dir.
@@ -2440,7 +2442,8 @@ def run(tpu_job_name,
24402442
24412443 train_model_fn (estimator , vocabulary , sequence_length , batch_size ,
24422444 train_dataset_fn , train_steps , ensemble_inputs ,
2443- skip_seen_data = skip_seen_data )
2445+ skip_seen_data = skip_seen_data ,
2446+ seen_data_init_step = seen_data_init_step )
24442447
24452448 elif mode == "perplexity_eval" :
24462449 if eval_dataset_fn is None :
0 commit comments