Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 82ffe9d

Browse files
katelee168Mesh TensorFlow Team
authored andcommitted
Add parameter start_step to train_model to only skip recovered_step - start_step steps.
PiperOrigin-RevId: 350655487
1 parent 3755319 commit 82ffe9d

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,9 +1606,11 @@ def get_estimator(model_type, vocabulary, mesh_shape,
16061606
return estimator
16071607

16081608

1609+
@gin.configurable
16091610
def train_model(estimator, vocabulary, sequence_length, batch_size,
16101611
train_dataset_fn, train_steps, ensemble_inputs,
1611-
dataset_split="train", skip_seen_data=False):
1612+
dataset_split="train", skip_seen_data=False,
1613+
seen_data_init_step=0):
16121614
"""Train a Mesh-TF model.
16131615
16141616
Args:
@@ -1634,6 +1636,8 @@ def train_model(estimator, vocabulary, sequence_length, batch_size,
16341636
restarts to skip already seen data. This flag is only consistent when
16351637
every setting (such as batch size and random seed) on the model is the
16361638
same between the original run and the new run.
1639+
seen_data_init_step: an integer, when `skip_seen_data` is True, skip seen
1640+
steps from this starting point. Useful when finetuning.
16371641
"""
16381642

16391643
def input_fn(params):
@@ -1651,8 +1655,10 @@ def input_fn(params):
16511655
# already been seen.
16521656
if skip_seen_data and estimator.latest_checkpoint() is not None:
16531657
recovered_step = estimator.get_variable_value("global_step")
1654-
tf.logging.info("Skipping %d steps of data.", recovered_step)
1655-
dataset = dataset.skip(recovered_step)
1658+
steps_to_skip = recovered_step - seen_data_init_step
1659+
if steps_to_skip > 0:
1660+
tf.logging.info("Skipping %d steps of data.", steps_to_skip)
1661+
dataset = dataset.skip(steps_to_skip)
16561662
return dataset
16571663

16581664
estimator.train(input_fn=input_fn, max_steps=train_steps)

0 commit comments

Comments
 (0)