Skip to content

Commit 7f69eb3

Browse files
Internal change
PiperOrigin-RevId: 397842944
1 parent dbe3927 commit 7f69eb3

File tree

4 files changed

+6
-1
lines changed

4 files changed

+6
-1
lines changed

official/nlp/configs/pretraining_experiments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
def bert_pretraining() -> cfg.ExperimentConfig:
5252
"""BERT pretraining experiment."""
5353
config = cfg.ExperimentConfig(
54+
runtime=cfg.RuntimeConfig(enable_xla=True),
5455
task=masked_lm.MaskedLMConfig(
5556
train_data=pretrain_dataloader.BertPretrainDataConfig(),
5657
validation_data=pretrain_dataloader.BertPretrainDataConfig(
@@ -70,6 +71,7 @@ def bert_dynamic() -> cfg.ExperimentConfig:
7071
TPU needs to run with tf.data service with round-robin behavior.
7172
"""
7273
config = cfg.ExperimentConfig(
74+
runtime=cfg.RuntimeConfig(enable_xla=True),
7375
task=masked_lm.MaskedLMConfig(
7476
train_data=pretrain_dynamic_dataloader.BertPretrainDataConfig(),
7577
validation_data=pretrain_dataloader.BertPretrainDataConfig(

official/nlp/configs/wmt_transformer_experiments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def wmt_transformer_large() -> cfg.ExperimentConfig:
4343
encdecoder = translation.EncDecoder(
4444
num_attention_heads=16, intermediate_size=hidden_size * 4)
4545
config = cfg.ExperimentConfig(
46+
runtime=cfg.RuntimeConfig(enable_xla=True),
4647
task=translation.TranslationConfig(
4748
model=translation.ModelConfig(
4849
encoder=encdecoder,

official/vision/beta/configs/image_classification.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def image_classification_imagenet() -> cfg.ExperimentConfig:
119119
eval_batch_size = 4096
120120
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
121121
config = cfg.ExperimentConfig(
122+
runtime=cfg.RuntimeConfig(enable_xla=True),
122123
task=ImageClassificationTask(
123124
model=ImageClassificationModel(
124125
num_classes=1001,

official/vision/beta/configs/maskrcnn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
292292
eval_batch_size = 8
293293

294294
config = cfg.ExperimentConfig(
295-
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
295+
runtime=cfg.RuntimeConfig(
296+
mixed_precision_dtype='bfloat16', enable_xla=True),
296297
task=MaskRCNNTask(
297298
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080',
298299
init_checkpoint_modules='backbone',

0 commit comments

Comments
 (0)