-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__main__.py
More file actions
81 lines (63 loc) · 3.19 KB
/
__main__.py
File metadata and controls
81 lines (63 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import sys
sys.path.extend(["../","./"])
import os
from torch.utils.data import DataLoader
from dataset import WordVocab
from model.bert import BERT
from dataset import BERTDataset,collate_mlm
from driver import BERTTrainer
from module import Paths
import torch
import numpy as np
import configs.hparams as hp
import random
import pdb
def train():
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset for train bert")
parser.add_argument("-t", "--valid_dataset", required=True, type=str, help="valid set for evaluate train set")
parser.add_argument("-v", "--vocab_path", required=True, type=str, help="built vocab model path with vocab")
parser.add_argument("-o", "--output_path", required=True, type=str, help="output/bert.model")
parser.add_argument("-w", "--num_workers", type=int, default=0, help="dataloader worker size")
parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false")
parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus")
parser.add_argument("--cuda_devices", type=int, nargs='+', default=[0], help="CUDA device ids")
parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false")
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
args = parser.parse_args()
print(args)
set_seed(args)
paths = Paths(args.output_path)
print("Loading Vocab", args.vocab_path)
vocab = WordVocab.load_vocab(args.vocab_path)
print("Vocab Size: ", vocab.vocab_size)
args.char_nums = vocab.vocab_size
print("Loading Train Dataset", args.train_dataset)
train_dataset = BERTDataset(args.train_dataset, vocab, corpus_lines=args.corpus_lines, on_memory=args.on_memory)
print("Loading Valid Dataset", args.valid_dataset)
valid_dataset = BERTDataset(args.valid_dataset, vocab, on_memory=args.on_memory) \
if args.valid_dataset is not None else None
print("Creating Dataloader")
train_data_loader = DataLoader(train_dataset, batch_size=2, collate_fn=lambda batch: collate_mlm(batch),num_workers=args.num_workers, shuffle=False) # 训练语料按长度排好序的
valid_data_loader = DataLoader(valid_dataset, batch_size=hp.batch_size, collate_fn=lambda batch: collate_mlm(batch), num_workers=args.num_workers, shuffle=False) \
if valid_dataset is not None else None
for batch in train_data_loader:
pdb.set_trace()
print(batch)
exit()
print("Building BERT model")
bert = BERT(embed_dim=hp.embed_dim, hidden=hp.hidden, args=args)
print(bert)
print("Creating BERT Trainer")
trainer = BERTTrainer(bert, vocab.vocab_size, train_dataloader=train_data_loader, test_dataloader=valid_data_loader,
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, args=args, path=paths)
print("Training Start")
trainer.train()
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if __name__ == '__main__':
train()