Skip to content

Commit 452b4f8

Browse files
committed
small update
1 parent 397c5d2 commit 452b4f8

File tree

3 files changed

+133
-74
lines changed

3 files changed

+133
-74
lines changed

.gitignore

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ t.sh
5555
**/outputs/
5656

5757

58-
unittest/outputs/
59-
unittest/tmp/
60-
**/tmp/
58+
**/unittest/**
59+
!unittest/**.py
60+
!unittest/**.sh
61+

docs/source/notes/update.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Update Logs and Known Issues
22

3+
## Version 0.3.1
4+
- We update [must_try.py](https://github.com/thunlp/OpenDelta/tree/main/examples/unittest/must_try.py) for a simple introduction of the core functionality of OpenDelta.
5+
36

47
## Version 0.3.0
58
### Updates:

examples/unittest/test_bmtrain.py

Lines changed: 126 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,46 @@
1-
2-
3-
4-
51
import time
6-
import random
2+
import os
3+
74
import torch
8-
import bmtrain as bmt
95
import numpy as np
10-
import os
11-
import csv
6+
from sklearn.metrics import accuracy_score, recall_score, f1_score
7+
8+
import bmtrain as bmt
129

1310
from model_center import get_args
14-
from model_center.model import CPM2
15-
from model_center.tokenizer import CPM2Tokenizer
16-
from model_center.dataset.cpm2dataset import DATASET
11+
from model_center.model import Bert
12+
from model_center.tokenizer import BertTokenizer
13+
from model_center.dataset.bertdataset import DATASET
1714
from model_center.utils import print_inspect
15+
from model_center.layer import Linear
1816
from model_center.dataset import DistributedDataLoader
1917

18+
class BertModel(torch.nn.Module):
19+
def __init__(self, args, num_types):
20+
super().__init__()
21+
self.bert : Bert = Bert.from_pretrained(args.model_config)
22+
dim_model = self.bert.input_embedding.dim_model
23+
self.dense = Linear(dim_model, num_types)
24+
bmt.init_parameters(self.dense)
25+
26+
def forward(self, *args, **kwargs):
27+
pooler_output = self.bert(*args, **kwargs, output_pooler_output=True).pooler_output
28+
logits = self.dense(pooler_output)
29+
return logits
30+
2031
def get_tokenizer(args):
21-
tokenizer = CPM2Tokenizer.from_pretrained(args.model_config)
32+
tokenizer = BertTokenizer.from_pretrained(args.model_config)
2233
return tokenizer
2334

2435
def get_model(args):
25-
model = CPM2.from_pretrained(args.model_config)
36+
num_types = {
37+
"BoolQ" : 2,
38+
"CB" : 3,
39+
"COPA" : 1,
40+
"RTE" : 2,
41+
"WiC" : 2,
42+
}
43+
model = BertModel(args, num_types[args.dataset_name])
2644
return model
2745

2846
def get_optimizer(args, model):
@@ -96,38 +114,52 @@ def prepare_dataset(args, tokenizer, base_path, dataset_name, rank, world_size):
96114
splits = ['train', 'dev', 'test']
97115
dataset = {}
98116
for split in splits:
99-
dataset[split] = DATASET[dataset_name](base_path, split, rank, world_size, tokenizer, args.max_encoder_length, args.max_decoder_length)
100-
verbalizer = torch.LongTensor(DATASET[dataset_name].get_verbalizer(tokenizer)).cuda()
101-
return dataset, verbalizer
117+
dataset[split] = DATASET[dataset_name](base_path, split, rank, world_size, tokenizer, args.max_encoder_length)
118+
return dataset
102119

103120

104-
def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset, verbalizer):
121+
def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset):
105122
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
106123

107124
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale)
108125
optim_manager.add_optimizer(optimizer, lr_scheduler)
109126

110-
dataloader = {
111-
"train": DistributedDataLoader(dataset['train'], batch_size=args.batch_size, shuffle=True),
112-
"dev": DistributedDataLoader(dataset['dev'], batch_size=args.batch_size, shuffle=False),
113-
"test": DistributedDataLoader(dataset['test'], batch_size=args.batch_size, shuffle=False),
114-
}
127+
print_inspect(model, '*')
128+
129+
for epoch in range(12):
130+
dataloader = {
131+
"train": DistributedDataLoader(dataset['train'], batch_size=args.batch_size, shuffle=True),
132+
"dev": DistributedDataLoader(dataset['dev'], batch_size=args.batch_size, shuffle=False),
133+
}
115134

116-
for epoch in range(5):
117135
model.train()
118136
for it, data in enumerate(dataloader['train']):
119-
enc_input = data["enc_input"]
120-
enc_length = data["enc_length"]
121-
dec_input = data["dec_input"]
122-
dec_length = data["dec_length"]
123-
targets = data["targets"]
124-
index = data["index"]
125-
126-
logits = model(enc_input, enc_length, dec_input, dec_length)
127-
logits = logits.index_select(dim=-1, index=verbalizer)
128-
logits = logits[torch.where(index==1)]
129-
130-
loss = loss_func(logits, targets)
137+
if args.dataset_name == 'COPA':
138+
input_ids0 = data["input_ids0"]
139+
attention_mask0 = data["attention_mask0"]
140+
token_type_ids0 = data["token_type_ids0"]
141+
input_ids1 = data["input_ids1"]
142+
attention_mask1 = data["attention_mask1"]
143+
token_type_ids1 = data["token_type_ids1"]
144+
labels = data["labels"]
145+
else:
146+
input_ids = data["input_ids"]
147+
attention_mask = data["attention_mask"]
148+
token_type_ids = data["token_type_ids"]
149+
labels = data["labels"]
150+
151+
torch.cuda.synchronize()
152+
st_time = time.time()
153+
154+
if args.dataset_name == 'COPA':
155+
logits = torch.cat([
156+
model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0),
157+
model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1),
158+
], dim=1)
159+
else:
160+
logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
161+
loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1))
162+
131163
global_loss = bmt.sum_loss(loss).item()
132164

133165
optim_manager.zero_grad()
@@ -137,64 +169,87 @@ def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset, verbalize
137169

138170
optim_manager.step()
139171

172+
torch.cuda.synchronize()
173+
elapsed_time = time.time() - st_time
174+
140175
bmt.print_rank(
141-
"train | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | grad_norm: {:.4f} |".format(
176+
"train | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | grad_norm: {:.4f} | time: {:.3f}".format(
142177
epoch,
143178
it,
144179
len(dataloader["train"]),
145180
global_loss,
146181
lr_scheduler.current_lr,
147182
int(optim_manager.loss_scale),
148183
grad_norm,
184+
elapsed_time,
149185
)
150186
)
151-
# if it % args.inspect_iters == 0: print_inspect(model, "*")
152-
# if args.save != None and it % args.save_iters == 0:
153-
# bmt.save(model, os.path.join(args.save, args.save_name+("-%d.pt" % it)))
154187

155188
model.eval()
156189
with torch.no_grad():
157-
acc = 0
158-
total = 0
159-
for it, data in enumerate(dataloader['dev']):
160-
enc_input = data["enc_input"]
161-
enc_length = data["enc_length"]
162-
dec_input = data["dec_input"]
163-
dec_length = data["dec_length"]
164-
targets = data["targets"]
165-
index = data["index"]
166-
167-
logits = model(enc_input, enc_length, dec_input, dec_length)
168-
logits = logits.index_select(dim=-1, index=verbalizer)
169-
logits = logits[torch.where(index==1)]
170-
logits = logits.argmax(dim=-1)
171-
172-
acc += torch.sum(logits == targets).item()
173-
total += logits.shape[0]
174-
bmt.print_rank(
175-
"dev | epoch {:3d} | Iter: {:6d}/{:6d} | acc: {:6d} | total: {:6d} |".format(
176-
epoch,
177-
it,
178-
len(dataloader["dev"]),
179-
acc,
180-
total,
190+
for split in ['dev']:
191+
pd = []
192+
gt = []
193+
for it, data in enumerate(dataloader[split]):
194+
if args.dataset_name == 'COPA':
195+
input_ids0 = data["input_ids0"]
196+
attention_mask0 = data["attention_mask0"]
197+
token_type_ids0 = data["token_type_ids0"]
198+
input_ids1 = data["input_ids1"]
199+
attention_mask1 = data["attention_mask1"]
200+
token_type_ids1 = data["token_type_ids1"]
201+
labels = data["labels"]
202+
logits = torch.cat([
203+
model(input_ids0, attention_mask=attention_mask0, token_type_ids=token_type_ids0),
204+
model(input_ids1, attention_mask=attention_mask1, token_type_ids=token_type_ids1),
205+
], dim=1)
206+
else:
207+
input_ids = data["input_ids"]
208+
attention_mask = data["attention_mask"]
209+
token_type_ids = data["token_type_ids"]
210+
labels = data["labels"]
211+
logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
212+
213+
loss = loss_func(logits.view(-1, logits.shape[-1]), labels.view(-1))
214+
logits = logits.argmax(dim=-1)
215+
pd.extend(logits.cpu().tolist())
216+
gt.extend(labels.cpu().tolist())
217+
218+
bmt.print_rank(
219+
"{} | epoch {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f}".format(
220+
split,
221+
epoch,
222+
it,
223+
len(dataloader[split]),
224+
loss,
225+
)
181226
)
182-
)
183-
acc = torch.tensor(acc / total).cuda()
184-
acc = bmt.sum_loss(acc).cpu().item()
185-
bmt.print_rank(f"dev epoch {epoch}: accuracy: {acc}")
227+
228+
pd = bmt.gather_result(torch.tensor(pd).int()).cpu().tolist()
229+
gt = bmt.gather_result(torch.tensor(gt).int()).cpu().tolist()
230+
231+
bmt.print_rank(f"{split} epoch {epoch}:")
232+
if args.dataset_name in ["BoolQ", "CB", "COPA", "RTE", "WiC", "WSC"]:
233+
acc = accuracy_score(gt, pd)
234+
bmt.print_rank(f"accuracy: {acc*100:.2f}")
235+
if args.dataset_name in ["CB"]:
236+
rcl = f1_score(gt, pd, average="macro")
237+
f1 = recall_score(gt, pd, average="macro")
238+
bmt.print_rank(f"recall: {rcl*100:.2f}")
239+
bmt.print_rank(f"Average F1: {f1*100:.2f}")
240+
186241

187242
def main():
188243
args = initialize()
189244
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
190-
dataset, verbalizer = prepare_dataset(
245+
dataset = prepare_dataset(
191246
args,
192247
tokenizer,
193-
f"{args.base_path}/down_data/paraphrase",
248+
f"{args.base_path}/down_data/superglue/",
194249
args.dataset_name,
195250
bmt.rank(), bmt.world_size(),
196251
)
197-
finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset, verbalizer)
252+
finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset)
198253

199254
if __name__ == "__main__":
200-
main()
255+
main()

0 commit comments

Comments
 (0)