1-
2-
3-
4-
51import time
6- import random
2+ import os
3+
74import torch
8- import bmtrain as bmt
95import numpy as np
10- import os
11- import csv
6+ from sklearn .metrics import accuracy_score , recall_score , f1_score
7+
8+ import bmtrain as bmt
129
1310from model_center import get_args
14- from model_center .model import CPM2
15- from model_center .tokenizer import CPM2Tokenizer
16- from model_center .dataset .cpm2dataset import DATASET
11+ from model_center .model import Bert
12+ from model_center .tokenizer import BertTokenizer
13+ from model_center .dataset .bertdataset import DATASET
1714from model_center .utils import print_inspect
15+ from model_center .layer import Linear
1816from model_center .dataset import DistributedDataLoader
1917
18+ class BertModel (torch .nn .Module ):
19+ def __init__ (self , args , num_types ):
20+ super ().__init__ ()
21+ self .bert : Bert = Bert .from_pretrained (args .model_config )
22+ dim_model = self .bert .input_embedding .dim_model
23+ self .dense = Linear (dim_model , num_types )
24+ bmt .init_parameters (self .dense )
25+
26+ def forward (self , * args , ** kwargs ):
27+ pooler_output = self .bert (* args , ** kwargs , output_pooler_output = True ).pooler_output
28+ logits = self .dense (pooler_output )
29+ return logits
30+
2031def get_tokenizer (args ):
21- tokenizer = CPM2Tokenizer .from_pretrained (args .model_config )
32+ tokenizer = BertTokenizer .from_pretrained (args .model_config )
2233 return tokenizer
2334
2435def get_model (args ):
25- model = CPM2 .from_pretrained (args .model_config )
36+ num_types = {
37+ "BoolQ" : 2 ,
38+ "CB" : 3 ,
39+ "COPA" : 1 ,
40+ "RTE" : 2 ,
41+ "WiC" : 2 ,
42+ }
43+ model = BertModel (args , num_types [args .dataset_name ])
2644 return model
2745
2846def get_optimizer (args , model ):
@@ -96,38 +114,52 @@ def prepare_dataset(args, tokenizer, base_path, dataset_name, rank, world_size):
96114 splits = ['train' , 'dev' , 'test' ]
97115 dataset = {}
98116 for split in splits :
99- dataset [split ] = DATASET [dataset_name ](base_path , split , rank , world_size , tokenizer , args .max_encoder_length , args .max_decoder_length )
100- verbalizer = torch .LongTensor (DATASET [dataset_name ].get_verbalizer (tokenizer )).cuda ()
101- return dataset , verbalizer
117+ dataset [split ] = DATASET [dataset_name ](base_path , split , rank , world_size , tokenizer , args .max_encoder_length )
118+ return dataset
102119
103120
104- def finetune (args , tokenizer , model , optimizer , lr_scheduler , dataset , verbalizer ):
121+ def finetune (args , tokenizer , model , optimizer , lr_scheduler , dataset ):
105122 loss_func = bmt .loss .FusedCrossEntropy (ignore_index = - 100 )
106123
107124 optim_manager = bmt .optim .OptimManager (loss_scale = args .loss_scale )
108125 optim_manager .add_optimizer (optimizer , lr_scheduler )
109126
110- dataloader = {
111- "train" : DistributedDataLoader (dataset ['train' ], batch_size = args .batch_size , shuffle = True ),
112- "dev" : DistributedDataLoader (dataset ['dev' ], batch_size = args .batch_size , shuffle = False ),
113- "test" : DistributedDataLoader (dataset ['test' ], batch_size = args .batch_size , shuffle = False ),
114- }
127+ print_inspect (model , '*' )
128+
129+ for epoch in range (12 ):
130+ dataloader = {
131+ "train" : DistributedDataLoader (dataset ['train' ], batch_size = args .batch_size , shuffle = True ),
132+ "dev" : DistributedDataLoader (dataset ['dev' ], batch_size = args .batch_size , shuffle = False ),
133+ }
115134
116- for epoch in range (5 ):
117135 model .train ()
118136 for it , data in enumerate (dataloader ['train' ]):
119- enc_input = data ["enc_input" ]
120- enc_length = data ["enc_length" ]
121- dec_input = data ["dec_input" ]
122- dec_length = data ["dec_length" ]
123- targets = data ["targets" ]
124- index = data ["index" ]
125-
126- logits = model (enc_input , enc_length , dec_input , dec_length )
127- logits = logits .index_select (dim = - 1 , index = verbalizer )
128- logits = logits [torch .where (index == 1 )]
129-
130- loss = loss_func (logits , targets )
137+ if args .dataset_name == 'COPA' :
138+ input_ids0 = data ["input_ids0" ]
139+ attention_mask0 = data ["attention_mask0" ]
140+ token_type_ids0 = data ["token_type_ids0" ]
141+ input_ids1 = data ["input_ids1" ]
142+ attention_mask1 = data ["attention_mask1" ]
143+ token_type_ids1 = data ["token_type_ids1" ]
144+ labels = data ["labels" ]
145+ else :
146+ input_ids = data ["input_ids" ]
147+ attention_mask = data ["attention_mask" ]
148+ token_type_ids = data ["token_type_ids" ]
149+ labels = data ["labels" ]
150+
151+ torch .cuda .synchronize ()
152+ st_time = time .time ()
153+
154+ if args .dataset_name == 'COPA' :
155+ logits = torch .cat ([
156+ model (input_ids0 , attention_mask = attention_mask0 , token_type_ids = token_type_ids0 ),
157+ model (input_ids1 , attention_mask = attention_mask1 , token_type_ids = token_type_ids1 ),
158+ ], dim = 1 )
159+ else :
160+ logits = model (input_ids , attention_mask = attention_mask , token_type_ids = token_type_ids )
161+ loss = loss_func (logits .view (- 1 , logits .shape [- 1 ]), labels .view (- 1 ))
162+
131163 global_loss = bmt .sum_loss (loss ).item ()
132164
133165 optim_manager .zero_grad ()
@@ -137,64 +169,87 @@ def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset, verbalize
137169
138170 optim_manager .step ()
139171
172+ torch .cuda .synchronize ()
173+ elapsed_time = time .time () - st_time
174+
140175 bmt .print_rank (
141- "train | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | grad_norm: {:.4f} |" .format (
176+ "train | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | grad_norm: {:.4f} | time: {:.3f} " .format (
142177 epoch ,
143178 it ,
144179 len (dataloader ["train" ]),
145180 global_loss ,
146181 lr_scheduler .current_lr ,
147182 int (optim_manager .loss_scale ),
148183 grad_norm ,
184+ elapsed_time ,
149185 )
150186 )
151- # if it % args.inspect_iters == 0: print_inspect(model, "*")
152- # if args.save != None and it % args.save_iters == 0:
153- # bmt.save(model, os.path.join(args.save, args.save_name+("-%d.pt" % it)))
154187
155188 model .eval ()
156189 with torch .no_grad ():
157- acc = 0
158- total = 0
159- for it , data in enumerate (dataloader ['dev' ]):
160- enc_input = data ["enc_input" ]
161- enc_length = data ["enc_length" ]
162- dec_input = data ["dec_input" ]
163- dec_length = data ["dec_length" ]
164- targets = data ["targets" ]
165- index = data ["index" ]
166-
167- logits = model (enc_input , enc_length , dec_input , dec_length )
168- logits = logits .index_select (dim = - 1 , index = verbalizer )
169- logits = logits [torch .where (index == 1 )]
170- logits = logits .argmax (dim = - 1 )
171-
172- acc += torch .sum (logits == targets ).item ()
173- total += logits .shape [0 ]
174- bmt .print_rank (
175- "dev | epoch {:3d} | Iter: {:6d}/{:6d} | acc: {:6d} | total: {:6d} |" .format (
176- epoch ,
177- it ,
178- len (dataloader ["dev" ]),
179- acc ,
180- total ,
190+ for split in ['dev' ]:
191+ pd = []
192+ gt = []
193+ for it , data in enumerate (dataloader [split ]):
194+ if args .dataset_name == 'COPA' :
195+ input_ids0 = data ["input_ids0" ]
196+ attention_mask0 = data ["attention_mask0" ]
197+ token_type_ids0 = data ["token_type_ids0" ]
198+ input_ids1 = data ["input_ids1" ]
199+ attention_mask1 = data ["attention_mask1" ]
200+ token_type_ids1 = data ["token_type_ids1" ]
201+ labels = data ["labels" ]
202+ logits = torch .cat ([
203+ model (input_ids0 , attention_mask = attention_mask0 , token_type_ids = token_type_ids0 ),
204+ model (input_ids1 , attention_mask = attention_mask1 , token_type_ids = token_type_ids1 ),
205+ ], dim = 1 )
206+ else :
207+ input_ids = data ["input_ids" ]
208+ attention_mask = data ["attention_mask" ]
209+ token_type_ids = data ["token_type_ids" ]
210+ labels = data ["labels" ]
211+ logits = model (input_ids , attention_mask = attention_mask , token_type_ids = token_type_ids )
212+
213+ loss = loss_func (logits .view (- 1 , logits .shape [- 1 ]), labels .view (- 1 ))
214+ logits = logits .argmax (dim = - 1 )
215+ pd .extend (logits .cpu ().tolist ())
216+ gt .extend (labels .cpu ().tolist ())
217+
218+ bmt .print_rank (
219+ "{} | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f}" .format (
220+ split ,
221+ epoch ,
222+ it ,
223+ len (dataloader [split ]),
224+ loss ,
225+ )
181226 )
182- )
183- acc = torch .tensor (acc / total ).cuda ()
184- acc = bmt .sum_loss (acc ).cpu ().item ()
185- bmt .print_rank (f"dev epoch { epoch } : accuracy: { acc } " )
227+
228+ pd = bmt .gather_result (torch .tensor (pd ).int ()).cpu ().tolist ()
229+ gt = bmt .gather_result (torch .tensor (gt ).int ()).cpu ().tolist ()
230+
231+ bmt .print_rank (f"{ split } epoch { epoch } :" )
232+ if args .dataset_name in ["BoolQ" , "CB" , "COPA" , "RTE" , "WiC" , "WSC" ]:
233+ acc = accuracy_score (gt , pd )
234+ bmt .print_rank (f"accuracy: { acc * 100 :.2f} " )
235+ if args .dataset_name in ["CB" ]:
236+ rcl = f1_score (gt , pd , average = "macro" )
237+ f1 = recall_score (gt , pd , average = "macro" )
238+ bmt .print_rank (f"recall: { rcl * 100 :.2f} " )
239+ bmt .print_rank (f"Average F1: { f1 * 100 :.2f} " )
240+
186241
187242def main ():
188243 args = initialize ()
189244 tokenizer , model , optimizer , lr_scheduler = setup_model_and_optimizer (args )
190- dataset , verbalizer = prepare_dataset (
245+ dataset = prepare_dataset (
191246 args ,
192247 tokenizer ,
193- f"{ args .base_path } /down_data/paraphrase " ,
248+ f"{ args .base_path } /down_data/superglue/ " ,
194249 args .dataset_name ,
195250 bmt .rank (), bmt .world_size (),
196251 )
197- finetune (args , tokenizer , model , optimizer , lr_scheduler , dataset , verbalizer )
252+ finetune (args , tokenizer , model , optimizer , lr_scheduler , dataset )
198253
199254if __name__ == "__main__" :
200- main ()
255+ main ()
0 commit comments