Skip to content

Commit fe69df5

Browse files
authored
upgrade lac (PaddlePaddle#969)
1 parent 14b6f89 commit fe69df5

File tree

1 file changed

+36
-11
lines changed

1 file changed

+36
-11
lines changed

examples/lexical_analysis/train.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
parser.add_argument("--max_seq_len", type=int, default=64, help="Number of words of the longest seqence.")
4242
parser.add_argument("--device", default="gpu", type=str, choices=["cpu", "gpu"] ,help="The device to select to train the model, is must be cpu/gpu.")
4343
parser.add_argument("--base_lr", type=float, default=0.001, help="The basic learning rate that affects the entire network.")
44+
parser.add_argument("--crf_lr", type=float, default=0.2, help="The learning rate ratio that affects CRF layers.")
4445
parser.add_argument("--emb_dim", type=int, default=128, help="The dimension in which a word is embedded.")
4546
parser.add_argument("--hidden_size", type=int, default=128, help="The number of hidden nodes in the GRU layer.")
4647
parser.add_argument("--logging_steps", type=int, default=10, help="Log every X updates steps.")
@@ -61,14 +62,19 @@ def evaluate(model, metric, data_loader):
6162
metric.update(num_infer_chunks.numpy(),
6263
num_label_chunks.numpy(), num_correct_chunks.numpy())
6364
precision, recall, f1_score = metric.accumulate()
64-
print("eval precision: %f, recall: %f, f1: %f" %
65-
(precision, recall, f1_score))
65+
logger.info("eval precision: %f, recall: %f, f1: %f" %
66+
(precision, recall, f1_score))
6667
model.train()
68+
return precision, recall, f1_score
6769

6870

6971
def train(args):
7072
paddle.set_device(args.device)
7173

74+
trainer_num = paddle.distributed.get_world_size()
75+
if trainer_num > 1:
76+
paddle.distributed.init_parallel_env()
77+
rank = paddle.distributed.get_rank()
7278
# Create dataset.
7379
train_ds, test_ds = load_dataset(datafiles=(os.path.join(
7480
args.data_dir, 'train.tsv'), os.path.join(args.data_dir, 'test.tsv')))
@@ -117,24 +123,34 @@ def train(args):
117123
collate_fn=batchify_fn)
118124

119125
# Define the model netword and its loss
120-
model = BiGruCrf(args.emb_dim, args.hidden_size,
121-
len(word_vocab), len(label_vocab))
126+
model = BiGruCrf(
127+
args.emb_dim,
128+
args.hidden_size,
129+
len(word_vocab),
130+
len(label_vocab),
131+
crf_lr=args.crf_lr)
122132
# Prepare optimizer, loss and metric evaluator
123133
optimizer = paddle.optimizer.Adam(
124134
learning_rate=args.base_lr, parameters=model.parameters())
125135
chunk_evaluator = ChunkEvaluator(label_list=label_vocab.keys(), suffix=True)
126136

127137
if args.init_checkpoint:
128-
model_dict = paddle.load(args.init_checkpoint)
129-
model.load_dict(model_dict)
130-
138+
if os.path.exists(args.init_checkpoint):
139+
logger.info("Init checkpoint from %s" % args.init_checkpoint)
140+
model_dict = paddle.load(args.init_checkpoint)
141+
model.load_dict(model_dict)
142+
else:
143+
logger.info("Cannot init checkpoint from %s which doesn't exist" %
144+
args.init_checkpoint)
145+
logger.info("Start training")
131146
# Start training
132147
global_step = 0
133148
last_step = args.epochs * len(train_loader)
134149
train_reader_cost = 0.0
135150
train_run_cost = 0.0
136151
total_samples = 0
137152
reader_start = time.time()
153+
max_f1_score = -1
138154
for epoch in range(args.epochs):
139155
for step, batch in enumerate(train_loader):
140156
train_reader_cost += time.time() - reader_start
@@ -146,7 +162,7 @@ def train(args):
146162
train_run_cost += time.time() - train_start
147163
total_samples += args.batch_size
148164
if global_step % args.logging_steps == 0:
149-
print(
165+
logger.info(
150166
"global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec"
151167
% (global_step, last_step, avg_loss, train_reader_cost /
152168
args.logging_steps, (train_reader_cost + train_run_cost)
@@ -159,12 +175,21 @@ def train(args):
159175
optimizer.step()
160176
optimizer.clear_grad()
161177
if global_step % args.save_steps == 0 or global_step == last_step:
162-
if paddle.distributed.get_rank() == 0:
163-
if args.do_eval:
164-
evaluate(model, chunk_evaluator, test_loader)
178+
if rank == 0:
165179
paddle.save(model.state_dict(),
166180
os.path.join(args.model_save_dir,
167181
"model_%d.pdparams" % global_step))
182+
logger.info("Save %d steps model." % (global_step))
183+
if args.do_eval:
184+
precision, recall, f1_score = evaluate(
185+
model, chunk_evaluator, test_loader)
186+
if f1_score > max_f1_score:
187+
max_f1_score = f1_score
188+
paddle.save(model.state_dict(),
189+
os.path.join(args.model_save_dir,
190+
"best_model.pdparams"))
191+
logger.info("Save best model.")
192+
168193
reader_start = time.time()
169194

170195

0 commit comments

Comments
 (0)