4141parser .add_argument ("--max_seq_len" , type = int , default = 64 , help = "Number of words of the longest seqence." )
4242parser .add_argument ("--device" , default = "gpu" , type = str , choices = ["cpu" , "gpu" ] ,help = "The device to select to train the model, is must be cpu/gpu." )
4343parser .add_argument ("--base_lr" , type = float , default = 0.001 , help = "The basic learning rate that affects the entire network." )
44+ parser .add_argument ("--crf_lr" , type = float , default = 0.2 , help = "The learning rate ratio that affects CRF layers." )
4445parser .add_argument ("--emb_dim" , type = int , default = 128 , help = "The dimension in which a word is embedded." )
4546parser .add_argument ("--hidden_size" , type = int , default = 128 , help = "The number of hidden nodes in the GRU layer." )
4647parser .add_argument ("--logging_steps" , type = int , default = 10 , help = "Log every X updates steps." )
@@ -61,14 +62,19 @@ def evaluate(model, metric, data_loader):
6162 metric .update (num_infer_chunks .numpy (),
6263 num_label_chunks .numpy (), num_correct_chunks .numpy ())
6364 precision , recall , f1_score = metric .accumulate ()
64- print ("eval precision: %f, recall: %f, f1: %f" %
65- (precision , recall , f1_score ))
65+ logger . info ("eval precision: %f, recall: %f, f1: %f" %
66+ (precision , recall , f1_score ))
6667 model .train ()
68+ return precision , recall , f1_score
6769
6870
6971def train (args ):
7072 paddle .set_device (args .device )
7173
74+ trainer_num = paddle .distributed .get_world_size ()
75+ if trainer_num > 1 :
76+ paddle .distributed .init_parallel_env ()
77+ rank = paddle .distributed .get_rank ()
7278 # Create dataset.
7379 train_ds , test_ds = load_dataset (datafiles = (os .path .join (
7480 args .data_dir , 'train.tsv' ), os .path .join (args .data_dir , 'test.tsv' )))
@@ -117,24 +123,34 @@ def train(args):
117123 collate_fn = batchify_fn )
118124
119125 # Define the model netword and its loss
120- model = BiGruCrf (args .emb_dim , args .hidden_size ,
121- len (word_vocab ), len (label_vocab ))
126+ model = BiGruCrf (
127+ args .emb_dim ,
128+ args .hidden_size ,
129+ len (word_vocab ),
130+ len (label_vocab ),
131+ crf_lr = args .crf_lr )
122132 # Prepare optimizer, loss and metric evaluator
123133 optimizer = paddle .optimizer .Adam (
124134 learning_rate = args .base_lr , parameters = model .parameters ())
125135 chunk_evaluator = ChunkEvaluator (label_list = label_vocab .keys (), suffix = True )
126136
127137 if args .init_checkpoint :
128- model_dict = paddle .load (args .init_checkpoint )
129- model .load_dict (model_dict )
130-
138+ if os .path .exists (args .init_checkpoint ):
139+ logger .info ("Init checkpoint from %s" % args .init_checkpoint )
140+ model_dict = paddle .load (args .init_checkpoint )
141+ model .load_dict (model_dict )
142+ else :
143+ logger .info ("Cannot init checkpoint from %s which doesn't exist" %
144+ args .init_checkpoint )
145+ logger .info ("Start training" )
131146 # Start training
132147 global_step = 0
133148 last_step = args .epochs * len (train_loader )
134149 train_reader_cost = 0.0
135150 train_run_cost = 0.0
136151 total_samples = 0
137152 reader_start = time .time ()
153+ max_f1_score = - 1
138154 for epoch in range (args .epochs ):
139155 for step , batch in enumerate (train_loader ):
140156 train_reader_cost += time .time () - reader_start
@@ -146,7 +162,7 @@ def train(args):
146162 train_run_cost += time .time () - train_start
147163 total_samples += args .batch_size
148164 if global_step % args .logging_steps == 0 :
149- print (
165+ logger . info (
150166 "global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
151167 % (global_step , last_step , avg_loss , train_reader_cost /
152168 args .logging_steps , (train_reader_cost + train_run_cost )
@@ -159,12 +175,21 @@ def train(args):
159175 optimizer .step ()
160176 optimizer .clear_grad ()
161177 if global_step % args .save_steps == 0 or global_step == last_step :
162- if paddle .distributed .get_rank () == 0 :
163- if args .do_eval :
164- evaluate (model , chunk_evaluator , test_loader )
178+ if rank == 0 :
165179 paddle .save (model .state_dict (),
166180 os .path .join (args .model_save_dir ,
167181 "model_%d.pdparams" % global_step ))
182+ logger .info ("Save %d steps model." % (global_step ))
183+ if args .do_eval :
184+ precision , recall , f1_score = evaluate (
185+ model , chunk_evaluator , test_loader )
186+ if f1_score > max_f1_score :
187+ max_f1_score = f1_score
188+ paddle .save (model .state_dict (),
189+ os .path .join (args .model_save_dir ,
190+ "best_model.pdparams" ))
191+ logger .info ("Save best model." )
192+
168193 reader_start = time .time ()
169194
170195
0 commit comments