14
14
15
15
"""Common configuration settings."""
16
16
17
- from typing import Optional , Sequence , Union
18
-
19
17
import dataclasses
18
+ from typing import Optional , Sequence , Union
20
19
21
20
from official .modeling .hyperparams import base_config
22
21
from official .modeling .optimization .configs import optimization_config
@@ -41,7 +40,9 @@ class DataConfig(base_config.Config):
41
40
tfds_split: A str indicating which split of the data to load from TFDS. It
42
41
is required when above `tfds_name` is specified.
43
42
global_batch_size: The global batch size across all replicas.
44
- is_training: Whether this data is used for training or not.
43
+ is_training: Whether this data is used for training or not. This flag is
44
+ useful for consumers of this object to determine whether the data should
45
+ be repeated or shuffled.
45
46
drop_remainder: Whether the last batch should be dropped in the case it has
46
47
fewer than `global_batch_size` elements.
47
48
shuffle_buffer_size: The buffer size used for shuffling training data.
@@ -178,7 +179,8 @@ class TrainerConfig(base_config.Config):
178
179
eval_tf_function: whether or not to use tf_function for eval.
179
180
allow_tpu_summary: Whether to allow summary happen inside the XLA program
180
181
runs on TPU through automatic outside compilation.
181
- steps_per_loop: number of steps per loop.
182
+ steps_per_loop: number of steps per loop to report training metrics. This
183
+ can also be used to reduce host worker communication in a TPU setup.
182
184
summary_interval: number of steps between each summary.
183
185
checkpoint_interval: number of steps between checkpoints.
184
186
max_to_keep: max checkpoints to keep.
0 commit comments