@@ -1903,7 +1903,8 @@ def auto_train_steps(batch_size,
19031903
19041904
19051905@gin .configurable
1906- def get_checkpoint_iterator (checkpoint_step , model_dir , skip_until = 0 ):
1906+ def get_checkpoint_iterator (checkpoint_step , model_dir , skip_until = 0 ,
1907+ stop_after = None ):
19071908 """Get an iterable of checkpoint paths from a provided checkpoint step(s).
19081909
19091910 Args:
@@ -1917,6 +1918,9 @@ def get_checkpoint_iterator(checkpoint_step, model_dir, skip_until=0):
19171918 model_dir: str, directory to look for checkpoints in.
19181919 skip_until: an integer - for "all" or "None" behavior, filter out
19191920 checkpoint numbers that are <= skip_until.
1921+ stop_after: an optional integer - for "None behavior, if specified
1922+ stop after finding a checkpoint number that is >= stop_at. When a
1923+ checkpoint number == stop_at is found, it is yielded before exiting.
19201924
19211925 Returns:
19221926 An iterable which yields checkpoint paths.
@@ -1957,7 +1961,19 @@ def _filter_fn(p):
19571961 return filter (_filter_fn ,
19581962 [_get_checkpoint_path (s ) for s in sorted (list (ckpt_steps ))])
19591963 elif checkpoint_step is None :
1960- return filter (_filter_fn , tf .train .checkpoints_iterator (model_dir ))
1964+ checkpoints_iterator = filter (
1965+ _filter_fn , tf .train .checkpoints_iterator (model_dir ))
1966+ if stop_after is not None :
1967+ def _generate_checkpoints ():
1968+ for p in checkpoints_iterator :
1969+ step = get_step_from_checkpoint_path (p )
1970+ if step <= stop_after :
1971+ yield p
1972+ if step >= stop_after :
1973+ break
1974+ return _generate_checkpoints ()
1975+ else :
1976+ return checkpoints_iterator
19611977 elif isinstance (checkpoint_step , int ):
19621978 return [_get_checkpoint_path (_get_closest_checkpoint (checkpoint_step ))]
19631979 else :
0 commit comments