Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 914bb1f

Browse files
katelee168Mesh TensorFlow Team
authored andcommitted
Make train_model no longer gin configurable.
PiperOrigin-RevId: 351997527
1 parent d80f0bf commit 914bb1f

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,7 +1606,6 @@ def get_estimator(model_type, vocabulary, mesh_shape,
16061606
return estimator
16071607

16081608

1609-
@gin.configurable
16101609
def train_model(estimator, vocabulary, sequence_length, batch_size,
16111610
train_dataset_fn, train_steps, ensemble_inputs,
16121611
dataset_split="train", skip_seen_data=False,
@@ -2282,6 +2281,7 @@ def run(tpu_job_name,
22822281
ensemble_inputs=None,
22832282
train_model_fn=train_model,
22842283
skip_seen_data=False,
2284+
seen_data_init_step=0,
22852285
output_eval_examples=True):
22862286
"""Run training, eval, or inference depending on `mode`.
22872287
@@ -2349,6 +2349,8 @@ def run(tpu_job_name,
23492349
restarts to skip already seen data. This flag is only consistent when
23502350
every setting (such as batch size and random seed) on the model is the
23512351
same between the original run and the new run.
2352+
seen_data_init_step: an integer, when `skip_seen_data` is True, skip seen
2353+
steps from this starting point. Useful when finetuning.
23522354
output_eval_examples: a boolean, is `True` by default. Used to decide
23532355
whether to output whether to dump inputs, targets, and predictions of the
23542356
eval examples in plaintext to eval_summary_dir.
@@ -2440,7 +2442,8 @@ def run(tpu_job_name,
24402442

24412443
train_model_fn(estimator, vocabulary, sequence_length, batch_size,
24422444
train_dataset_fn, train_steps, ensemble_inputs,
2443-
skip_seen_data=skip_seen_data)
2445+
skip_seen_data=skip_seen_data,
2446+
seen_data_init_step=seen_data_init_step)
24442447

24452448
elif mode == "perplexity_eval":
24462449
if eval_dataset_fn is None:

0 commit comments

Comments
 (0)