@@ -1606,9 +1606,11 @@ def get_estimator(model_type, vocabulary, mesh_shape,
16061606 return estimator
16071607
16081608
1609+ @gin .configurable
16091610def train_model (estimator , vocabulary , sequence_length , batch_size ,
16101611 train_dataset_fn , train_steps , ensemble_inputs ,
1611- dataset_split = "train" , skip_seen_data = False ):
1612+ dataset_split = "train" , skip_seen_data = False ,
1613+ seen_data_init_step = 0 ):
16121614 """Train a Mesh-TF model.
16131615
16141616 Args:
@@ -1634,6 +1636,8 @@ def train_model(estimator, vocabulary, sequence_length, batch_size,
16341636 restarts to skip already seen data. This flag is only consistent when
16351637 every setting (such as batch size and random seed) on the model is the
16361638 same between the original run and the new run.
1639+ seen_data_init_step: an integer, when `skip_seen_data` is True, skip seen
1640+ steps from this starting point. Useful when finetuning.
16371641 """
16381642
16391643 def input_fn (params ):
@@ -1651,8 +1655,10 @@ def input_fn(params):
16511655 # already been seen.
16521656 if skip_seen_data and estimator .latest_checkpoint () is not None :
16531657 recovered_step = estimator .get_variable_value ("global_step" )
1654- tf .logging .info ("Skipping %d steps of data." , recovered_step )
1655- dataset = dataset .skip (recovered_step )
1658+ steps_to_skip = recovered_step - seen_data_init_step
1659+ if steps_to_skip > 0 :
1660+ tf .logging .info ("Skipping %d steps of data." , steps_to_skip )
1661+ dataset = dataset .skip (steps_to_skip )
16561662 return dataset
16571663
16581664 estimator .train (input_fn = input_fn , max_steps = train_steps )
0 commit comments