21
21
import time
22
22
import yaml
23
23
import shutil
24
-
25
- os .path .expandvars ('$HOME' )
26
- os .path .expanduser ('~' )
24
+ import collections
27
25
28
26
import numpy as np
29
27
import paddle
@@ -524,12 +522,25 @@ def do_train(args):
524
522
paddle .static .load (main_program ,
525
523
os .path .join (checkpoint_dir , "static_vars" ), exe )
526
524
525
+ fetch_vars = collections .OrderedDict ()
526
+ fetch_vars ["loss" ] = loss
527
+ fetch_vars ["lm_loss" ] = lm_loss
528
+ fetch_vars ["sop_loss" ] = sop_loss
529
+ fetch_vars ["learning_rate" ] = main_program .global_block ().vars [
530
+ "learning_rate_0" ]
531
+
532
+ additional_vars = collections .OrderedDict ()
533
+ if args .use_amp :
534
+ for key in ["loss_scaling" , "num_good_steps" , "num_bad_steps" ]:
535
+ additional_vars [key ] = main_program .global_block ().vars [key + "_0" ]
536
+
527
537
tic_train = time .time ()
528
- learning_rate = main_program .global_block ().vars ["learning_rate_0" ]
529
538
while True :
530
539
fetchs = []
540
+ fetchs_keys = []
531
541
if topo .is_last :
532
- fetchs = [loss , lm_loss , sop_loss , learning_rate ]
542
+ fetchs = list (fetch_vars .values ()) + list (additional_vars .values ())
543
+ fetchs_keys = list (fetch_vars .keys ()) + list (additional_vars .keys ())
533
544
534
545
# Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
535
546
# many times. and start a new random dataloader.
@@ -550,21 +561,25 @@ def do_train(args):
550
561
551
562
if global_step % args .logging_freq == 0 :
552
563
if topo .is_last :
553
- loss_return , lm_loss_return , sop_loss_return , lr_return = ret
564
+ res = {}
565
+ for k , v in zip (fetchs_keys , ret ):
566
+ res [k ] = v [0 ]
554
567
555
568
speed = args .logging_freq / (time .time () - tic_train )
556
- logger .info (
557
- "global step %d, loss: %.9f, lm_loss: %.6f, sop_loss: %.6f, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e"
558
- % (global_step , loss_return [0 ], lm_loss_return [0 ],
559
- sop_loss_return [0 ], speed ,
560
- speed * args .global_batch_size , lr_return [0 ]))
561
- log_writer .add_scalar ("loss" , loss_return [0 ], global_step )
562
- log_writer .add_scalar ("lm_loss" , lm_loss_return [0 ],
563
- global_step )
564
- log_writer .add_scalar ("sop_loss" , sop_loss_return [0 ],
565
- global_step )
566
- log_writer .add_scalar ("learning_rate" , lr_return [0 ],
567
- global_step )
569
+ common_loginfo = "global step %d, loss: %.9f, lm_loss: %.6f, sop_loss: %.6f, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % (
570
+ global_step , res ["loss" ], res ["lm_loss" ],
571
+ res ["sop_loss" ], speed , speed * args .global_batch_size ,
572
+ res ["learning_rate" ])
573
+ additional_loginfo = ", " .join ([
574
+ "{}: {}" .format (k , res [k ])
575
+ for k in additional_vars .keys ()
576
+ ])
577
+ if additional_loginfo :
578
+ common_loginfo += ", " + additional_loginfo
579
+ logger .info (common_loginfo )
580
+ for k , v in res .items ():
581
+ log_writer .add_scalar (k , v , global_step )
582
+
568
583
tic_train = time .time ()
569
584
570
585
#if args.check_accuracy:
0 commit comments