forked from herenever/KoreanAnaphoraResolution
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
32 lines (29 loc) · 1.36 KB
/
train.py
File metadata and controls
32 lines (29 loc) · 1.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from ELECTRADataModule import *
from ElectraAnaphoraResolution import *
import transformers
import logging
if __name__ == "__main__":
transformers.logging.set_verbosity_error()
# logging.getLogger("transformers.tokenization_utils_base").disabled = True
# logging.getLogger("transformers.tokenization_utils").disabled = True
model = ElectraForResolution(learning_rate=5e-6)
dm = ResolutionDataModule(batch_size=32,train_path="./anaphora_dataset/w2_train2.csv",valid_path="./anaphora_dataset/w2_validation2.csv",max_length=128,doc1_col='document1',doc2_col='document2',label_col='label',ante_col='antecedent',num_workers=32)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor = 'total_Accuracy_Val',
dirpath = './model_checkpoint',
filename = 'version_2/{epoch:02d}--{total_Accuracy_Val:.4f}',
verbose = True,
save_last = True,
mode = 'max',
save_top_k = -1
)
tb_logger = pl_loggers.TensorBoardLogger(os.path.join('./model_checkpoint','tb_logs_v2'),log_graph=True,default_hp_metric=False)
lr_logger = pl.callbacks.LearningRateMonitor()
trainer = pl.Trainer(
default_root_dir='./model_checkpoint',
logger = tb_logger,
callbacks = [checkpoint_callback,lr_logger],
max_epochs = 100,
gpus = 4
)
trainer.fit(model=model,datamodule=dm)