Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 6f5e3f1

Browse files
author
Mesh TensorFlow Team
committed
Add stop_after argument to utils.get_checkpoint_iterator. Useful for automatically ending continuous evaluation jobs.
PiperOrigin-RevId: 327471874
1 parent ecdee99 commit 6f5e3f1

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,7 +1903,8 @@ def auto_train_steps(batch_size,
19031903

19041904

19051905
@gin.configurable
1906-
def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0):
1906+
def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0,
1907+
stop_after=None):
19071908
"""Get an iterable of checkpoint paths from a provided checkpoint step(s).
19081909
19091910
Args:
@@ -1917,6 +1918,9 @@ def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0):
19171918
model_dir: str, directory to look for checkpoints in.
19181919
skip_until: an integer - for "all" or "None" behavior, filter out
19191920
checkpoint numbers that are <= skip_until.
1921+
stop_after: an optional integer - for "None behavior, if specified
1922+
stop after finding a checkpoint number that is >= stop_at. When a
1923+
checkpoint number == stop_at is found, it is yielded before exiting.
19201924
19211925
Returns:
19221926
An iterable which yields checkpoint paths.
@@ -1957,7 +1961,19 @@ def _filter_fn(p):
19571961
return filter(_filter_fn,
19581962
[_get_checkpoint_path(s) for s in sorted(list(ckpt_steps))])
19591963
elif checkpoint_step is None:
1960-
return filter(_filter_fn, tf.train.checkpoints_iterator(model_dir))
1964+
checkpoints_iterator = filter(
1965+
_filter_fn, tf.train.checkpoints_iterator(model_dir))
1966+
if stop_after is not None:
1967+
def _generate_checkpoints():
1968+
for p in checkpoints_iterator:
1969+
step = get_step_from_checkpoint_path(p)
1970+
if step <= stop_after:
1971+
yield p
1972+
if step >= stop_after:
1973+
break
1974+
return _generate_checkpoints()
1975+
else:
1976+
return checkpoints_iterator
19611977
elif isinstance(checkpoint_step, int):
19621978
return [_get_checkpoint_path(_get_closest_checkpoint(checkpoint_step))]
19631979
else:

0 commit comments

Comments
 (0)