Skip to content

Commit 8b47c48

Browse files
reedwmtensorflower-gardener
authored andcommitted
Improve error message when certain flags are not specified.
In nlp/train.py and vision/beta/train.py, certain flags are marked as required. Additionally, in certain functions, error messages are improved if a necessary flag is not specified, which is a fallback in case a file calling define_flags() does not mark the necessary flags are required. Previously if any of these flags were not specified, it would crash with a cryptic error message, making it hard to tell what went wrong. In a subsequent change, I will mark flags as required in more files which call define_flags(). PiperOrigin-RevId: 381066985
1 parent 6dc4ae7 commit 8b47c48

File tree

5 files changed

+30
-3
lines changed

5 files changed

+30
-3
lines changed

official/common/flags.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,27 @@
1818

1919

2020
def define_flags():
21-
"""Defines flags."""
21+
"""Defines flags.
22+
23+
All flags are defined as optional, but in practice most models use some of
24+
these flags and so mark_flags_as_required() should be called after calling
25+
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
26+
For example:
27+
28+
```
29+
from absl import flags
30+
from official.common import flags as tfm_flags # pylint: disable=line-too-long
31+
...
32+
tfm_flags.define_flags()
33+
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
34+
```
35+
36+
The reason all flags are optional is because unit tests often do not set or
37+
use any of the flags.
38+
"""
2239
flags.DEFINE_string(
23-
'experiment', default=None, help='The experiment type registered.')
40+
'experiment', default=None, help=
41+
'The experiment type registered, specifying an ExperimentConfig.')
2442

2543
flags.DEFINE_enum(
2644
'mode',

official/core/train_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def run_experiment(
7878
params, model_dir))
7979

8080
if trainer.checkpoint:
81+
if model_dir is None:
82+
raise ValueError('model_dir must be specified, but got None')
8183
checkpoint_manager = tf.train.CheckpointManager(
8284
trainer.checkpoint,
8385
directory=model_dir,

official/core/train_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ def __contains__(self, name):
241241
def parse_configuration(flags_obj, lock_return=True, print_return=True):
242242
"""Parses ExperimentConfig from flags."""
243243

244+
if flags_obj.experiment is None:
245+
raise ValueError('The flag --experiment must be specified.')
246+
244247
# 1. Get the default config from the registered experiment.
245248
params = exp_factory.get_exp_config(flags_obj.experiment)
246249

@@ -285,7 +288,7 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
285288

286289
if print_return:
287290
pp = pprint.PrettyPrinter()
288-
logging.info('Final experiment parameters: %s',
291+
logging.info('Final experiment parameters:\n%s',
289292
pp.pformat(params.as_dict()))
290293

291294
return params
@@ -294,6 +297,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
294297
def serialize_config(params: config_definitions.ExperimentConfig,
295298
model_dir: str):
296299
"""Serializes and saves the experiment config."""
300+
if model_dir is None:
301+
raise ValueError('model_dir must be specified, but got None')
297302
params_save_path = os.path.join(model_dir, 'params.yaml')
298303
logging.info('Saving experiment configuration to %s', params_save_path)
299304
tf.io.gfile.makedirs(model_dir)

official/nlp/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,5 @@ def main(_):
6666

6767
if __name__ == '__main__':
6868
tfm_flags.define_flags()
69+
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
6970
app.run(main)

official/vision/beta/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,5 @@ def main(_):
6666

6767
if __name__ == '__main__':
6868
tfm_flags.define_flags()
69+
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
6970
app.run(main)

0 commit comments

Comments
 (0)