Skip to content

Commit 983837f

Browse files
Internal change
PiperOrigin-RevId: 365082527
1 parent a49c733 commit 983837f

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

official/benchmark/bert_pretrain_benchmark.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,27 @@ def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self):
315315
report_accuracy=False,
316316
ds_type=FLAGS.distribution_strategy)
317317

318+
@owner_utils.Owner('tf-model-garden')
319+
def benchmark_perf_8x16_tpu_bf16_seq128_1k_steps(self):
320+
"""Test bert pretraining with 8x16 TPU for 1000 steps."""
321+
self._setup()
322+
self._specify_common_flags()
323+
self._specify_tpu_common_flags()
324+
FLAGS.train_batch_size = 4096
325+
FLAGS.warmup_steps = 0
326+
FLAGS.num_steps_per_epoch = 1000
327+
FLAGS.num_train_epochs = 1
328+
FLAGS.steps_per_loop = 500
329+
FLAGS.model_dir = self._get_model_dir(
330+
'benchmark_perf_8x16_tpu_bf16_seq128_1k_steps')
331+
summary_path = os.path.join(FLAGS.model_dir,
332+
'summaries/training_summary.txt')
333+
# Disable accuracy check.
334+
self._run_and_report_benchmark(
335+
summary_path=summary_path,
336+
report_accuracy=False,
337+
ds_type=FLAGS.distribution_strategy)
338+
318339
@owner_utils.Owner('tf-dist-strat')
319340
def benchmark_accuracy_1x8_gpu_fp16_seq128_15k_steps(self):
320341
"""Test bert pretraining with 8 GPU for 15k steps."""

0 commit comments

Comments
 (0)