41
41
parser .add_argument ("--max_seq_len" , type = int , default = 64 , help = "Number of words of the longest seqence." )
42
42
parser .add_argument ("--device" , default = "gpu" , type = str , choices = ["cpu" , "gpu" ] ,help = "The device to select to train the model, is must be cpu/gpu." )
43
43
parser .add_argument ("--base_lr" , type = float , default = 0.001 , help = "The basic learning rate that affects the entire network." )
44
+ parser .add_argument ("--crf_lr" , type = float , default = 0.2 , help = "The learning rate ratio that affects CRF layers." )
44
45
parser .add_argument ("--emb_dim" , type = int , default = 128 , help = "The dimension in which a word is embedded." )
45
46
parser .add_argument ("--hidden_size" , type = int , default = 128 , help = "The number of hidden nodes in the GRU layer." )
46
47
parser .add_argument ("--logging_steps" , type = int , default = 10 , help = "Log every X updates steps." )
@@ -61,14 +62,19 @@ def evaluate(model, metric, data_loader):
61
62
metric .update (num_infer_chunks .numpy (),
62
63
num_label_chunks .numpy (), num_correct_chunks .numpy ())
63
64
precision , recall , f1_score = metric .accumulate ()
64
- print ("eval precision: %f, recall: %f, f1: %f" %
65
- (precision , recall , f1_score ))
65
+ logger . info ("eval precision: %f, recall: %f, f1: %f" %
66
+ (precision , recall , f1_score ))
66
67
model .train ()
68
+ return precision , recall , f1_score
67
69
68
70
69
71
def train (args ):
70
72
paddle .set_device (args .device )
71
73
74
+ trainer_num = paddle .distributed .get_world_size ()
75
+ if trainer_num > 1 :
76
+ paddle .distributed .init_parallel_env ()
77
+ rank = paddle .distributed .get_rank ()
72
78
# Create dataset.
73
79
train_ds , test_ds = load_dataset (datafiles = (os .path .join (
74
80
args .data_dir , 'train.tsv' ), os .path .join (args .data_dir , 'test.tsv' )))
@@ -117,24 +123,34 @@ def train(args):
117
123
collate_fn = batchify_fn )
118
124
119
125
# Define the model netword and its loss
120
- model = BiGruCrf (args .emb_dim , args .hidden_size ,
121
- len (word_vocab ), len (label_vocab ))
126
+ model = BiGruCrf (
127
+ args .emb_dim ,
128
+ args .hidden_size ,
129
+ len (word_vocab ),
130
+ len (label_vocab ),
131
+ crf_lr = args .crf_lr )
122
132
# Prepare optimizer, loss and metric evaluator
123
133
optimizer = paddle .optimizer .Adam (
124
134
learning_rate = args .base_lr , parameters = model .parameters ())
125
135
chunk_evaluator = ChunkEvaluator (label_list = label_vocab .keys (), suffix = True )
126
136
127
137
if args .init_checkpoint :
128
- model_dict = paddle .load (args .init_checkpoint )
129
- model .load_dict (model_dict )
130
-
138
+ if os .path .exists (args .init_checkpoint ):
139
+ logger .info ("Init checkpoint from %s" % args .init_checkpoint )
140
+ model_dict = paddle .load (args .init_checkpoint )
141
+ model .load_dict (model_dict )
142
+ else :
143
+ logger .info ("Cannot init checkpoint from %s which doesn't exist" %
144
+ args .init_checkpoint )
145
+ logger .info ("Start training" )
131
146
# Start training
132
147
global_step = 0
133
148
last_step = args .epochs * len (train_loader )
134
149
train_reader_cost = 0.0
135
150
train_run_cost = 0.0
136
151
total_samples = 0
137
152
reader_start = time .time ()
153
+ max_f1_score = - 1
138
154
for epoch in range (args .epochs ):
139
155
for step , batch in enumerate (train_loader ):
140
156
train_reader_cost += time .time () - reader_start
@@ -146,7 +162,7 @@ def train(args):
146
162
train_run_cost += time .time () - train_start
147
163
total_samples += args .batch_size
148
164
if global_step % args .logging_steps == 0 :
149
- print (
165
+ logger . info (
150
166
"global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
151
167
% (global_step , last_step , avg_loss , train_reader_cost /
152
168
args .logging_steps , (train_reader_cost + train_run_cost )
@@ -159,12 +175,21 @@ def train(args):
159
175
optimizer .step ()
160
176
optimizer .clear_grad ()
161
177
if global_step % args .save_steps == 0 or global_step == last_step :
162
- if paddle .distributed .get_rank () == 0 :
163
- if args .do_eval :
164
- evaluate (model , chunk_evaluator , test_loader )
178
+ if rank == 0 :
165
179
paddle .save (model .state_dict (),
166
180
os .path .join (args .model_save_dir ,
167
181
"model_%d.pdparams" % global_step ))
182
+ logger .info ("Save %d steps model." % (global_step ))
183
+ if args .do_eval :
184
+ precision , recall , f1_score = evaluate (
185
+ model , chunk_evaluator , test_loader )
186
+ if f1_score > max_f1_score :
187
+ max_f1_score = f1_score
188
+ paddle .save (model .state_dict (),
189
+ os .path .join (args .model_save_dir ,
190
+ "best_model.pdparams" ))
191
+ logger .info ("Save best model." )
192
+
168
193
reader_start = time .time ()
169
194
170
195
0 commit comments