-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__main__2.py
More file actions
93 lines (74 loc) · 3.67 KB
/
__main__2.py
File metadata and controls
93 lines (74 loc) · 3.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import sys
sys.path.extend(["../","./"])
import os
from torch.utils.data import DataLoader
from dataset import WordVocab
from model.bert import BERT
from dataset import BERTDataset, collate_mlm
from driver import BERTTrainer, BERTTrainerTTS
from transformer import Encoder
from module import Paths
import torch
import numpy as np
import configs.hparams as hp
import random
import yaml
import pdb
from dataset_multi import Dataset
def train():
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset for train bert")
parser.add_argument("-t", "--valid_dataset", required=True, type=str, help="valid set for evaluate train set")
parser.add_argument("-v", "--vocab_path", required=True, type=str, help="built vocab model path with vocab")
parser.add_argument("-o", "--output_path", required=True, type=str, help="output/bert.model")
parser.add_argument("-w", "--num_workers", type=int, default=0, help="dataloader worker size")
parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false")
parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus")
parser.add_argument("--cuda_devices", type=int, nargs='+', default=[0], help="CUDA device ids")
parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false")
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
args = parser.parse_args()
print(args)
set_seed(args)
paths = Paths(args.output_path)
print("Loading Vocab", args.vocab_path)
vocab = WordVocab.load_vocab(args.vocab_path)
print("Vocab Size: ", vocab.vocab_size)
args.char_nums = vocab.vocab_size
print("Loading Train Dataset", args.train_dataset)
train_dataset = BERTDataset(args.train_dataset, vocab, corpus_lines=args.corpus_lines, on_memory=args.on_memory)
print("Loading Valid Dataset", args.valid_dataset)
path = "LibriTTS_StyleSpeech_multilingual_diffusion_style_3layer"
# path = "VNTTS"
# path = "LibriTTS_StyleSpeech_multilingual_diffusion_style_EN"
preprocess_config = yaml.load(
open("./config/config_kaga/{0}/preprocess.yaml".format(path), "r"), Loader=yaml.FullLoader
)
train_config = yaml.load(
open("./config/config_kaga/{0}/train.yaml".format(path), "r"), Loader=yaml.FullLoader
)
model_config = yaml.load(
open("./config/config_kaga/{0}/model.yaml".format(path), "r"), Loader=yaml.FullLoader
)
train_dataset = Dataset("train.txt", preprocess_config, train_config, sort=True, drop_last=True)
val_dataset = Dataset("val.txt", preprocess_config, train_config, sort=False, drop_last=False)
print("Creating Dataloader")
train_data_loader = DataLoader(train_dataset, batch_size=hp.batch_size, shuffle=True, num_workers=0, collate_fn=train_dataset.collate_fn)
valid_data_loader = DataLoader(val_dataset, batch_size=hp.batch_size, shuffle=False, num_workers=0, collate_fn=val_dataset.collate_fn)
print("Building BERT model")
vocab_size = 1051
bert = Encoder(model_config)
# print(bert)
print("Creating BERT Trainer")
trainer = BERTTrainerTTS(bert, vocab_size, train_dataloader=train_data_loader, test_dataloader=valid_data_loader,
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, args=args, path=paths)
print("Training Start")
trainer.train()
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if __name__ == '__main__':
train()