Skip to content

Commit d1699cd

Browse files
authored
Add additional log for training with amp. (PaddlePaddle#1042)
1 parent c63a9ad commit d1699cd

File tree

1 file changed

+33
-18
lines changed

1 file changed

+33
-18
lines changed

examples/language_model/ernie-1.0/run_pretrain_static.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
import time
2222
import yaml
2323
import shutil
24-
25-
os.path.expandvars('$HOME')
26-
os.path.expanduser('~')
24+
import collections
2725

2826
import numpy as np
2927
import paddle
@@ -524,12 +522,25 @@ def do_train(args):
524522
paddle.static.load(main_program,
525523
os.path.join(checkpoint_dir, "static_vars"), exe)
526524

525+
fetch_vars = collections.OrderedDict()
526+
fetch_vars["loss"] = loss
527+
fetch_vars["lm_loss"] = lm_loss
528+
fetch_vars["sop_loss"] = sop_loss
529+
fetch_vars["learning_rate"] = main_program.global_block().vars[
530+
"learning_rate_0"]
531+
532+
additional_vars = collections.OrderedDict()
533+
if args.use_amp:
534+
for key in ["loss_scaling", "num_good_steps", "num_bad_steps"]:
535+
additional_vars[key] = main_program.global_block().vars[key + "_0"]
536+
527537
tic_train = time.time()
528-
learning_rate = main_program.global_block().vars["learning_rate_0"]
529538
while True:
530539
fetchs = []
540+
fetchs_keys = []
531541
if topo.is_last:
532-
fetchs = [loss, lm_loss, sop_loss, learning_rate]
542+
fetchs = list(fetch_vars.values()) + list(additional_vars.values())
543+
fetchs_keys = list(fetch_vars.keys()) + list(additional_vars.keys())
533544

534545
# Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
535546
# many times. and start a new random dataloader.
@@ -550,21 +561,25 @@ def do_train(args):
550561

551562
if global_step % args.logging_freq == 0:
552563
if topo.is_last:
553-
loss_return, lm_loss_return, sop_loss_return, lr_return = ret
564+
res = {}
565+
for k, v in zip(fetchs_keys, ret):
566+
res[k] = v[0]
554567

555568
speed = args.logging_freq / (time.time() - tic_train)
556-
logger.info(
557-
"global step %d, loss: %.9f, lm_loss: %.6f, sop_loss: %.6f, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e"
558-
% (global_step, loss_return[0], lm_loss_return[0],
559-
sop_loss_return[0], speed,
560-
speed * args.global_batch_size, lr_return[0]))
561-
log_writer.add_scalar("loss", loss_return[0], global_step)
562-
log_writer.add_scalar("lm_loss", lm_loss_return[0],
563-
global_step)
564-
log_writer.add_scalar("sop_loss", sop_loss_return[0],
565-
global_step)
566-
log_writer.add_scalar("learning_rate", lr_return[0],
567-
global_step)
569+
common_loginfo = "global step %d, loss: %.9f, lm_loss: %.6f, sop_loss: %.6f, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % (
570+
global_step, res["loss"], res["lm_loss"],
571+
res["sop_loss"], speed, speed * args.global_batch_size,
572+
res["learning_rate"])
573+
additional_loginfo = ", ".join([
574+
"{}: {}".format(k, res[k])
575+
for k in additional_vars.keys()
576+
])
577+
if additional_loginfo:
578+
common_loginfo += ", " + additional_loginfo
579+
logger.info(common_loginfo)
580+
for k, v in res.items():
581+
log_writer.add_scalar(k, v, global_step)
582+
568583
tic_train = time.time()
569584

570585
#if args.check_accuracy:

0 commit comments

Comments
 (0)