-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
91 lines (75 loc) · 2.67 KB
/
main.py
File metadata and controls
91 lines (75 loc) · 2.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from __future__ import absolute_import
import time
import logging
from utils import read_prompt_examples, get_elapse_time, calculate_rouge
from SOTitlePlus import SOTitlePlus
class Config(object):
def __init__(self):
self.cuda = True
self.train_filename = 'data/train.csv'
self.dev_filename = 'data/valid.csv'
self.test_filename = 'data/test.csv'
self.model_type = 'codet5'
self.model_name_or_path = 'Salesforce/codet5-base'
self.log_name = 'log/python.log'
self.output_dir = "model"
self.data_dir = "data"
self.result_dit = 'results'
self.langs = ['python', 'java', 'c#', 'javascript', 'php', 'html']
self.no_cuda = False
self.visible_gpu = ""
self.add_task_prefix = False
self.add_lang_ids = False
self.num_train_epochs = 50
self.train_batch_size = 16
self.eval_batch_size = 16
self.gradient_accumulation_steps = 2
# other configs
self.load_model_path = ''
self.train_load_model_path = None
self.config_name = ""
self.tokenizer_name = ""
self.max_source_length = 512
self.max_target_length = 64
self.warm_up_ratio = 0.1
# controlling configs
self.do_train = True
self.do_eval = True
self.do_test = True
self.learning_rate = 5e-5
self.beam_size = 10
self.weight_decay = 0.0
self.adam_epsilon = 1e-8
self.max_grad_norm = 1.0
self.max_steps = -1
self.eval_steps = -1
self.train_steps = 2000
self.local_rank = -1
self.seed = 42
self.early_stop_threshold = 5
if __name__ == '__main__':
my_config = Config()
# begin time
begin_time = time.time()
# logger for record
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
# write to file
handler = logging.FileHandler(my_config.log_name)
handler.setLevel(logging.INFO)
logger.addHandler(handler)
# write to console
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logger.addHandler(console)
# print config
logger.info(my_config)
model = SOTitlePlus(my_config)
model.train()
for lan in my_config.langs:
logger.info(f'lan:{lan}')
model.test(lan, f'../data/{lan}/test.csv')
# model.predict()
logger.info("Finish training and take %s", get_elapse_time(begin_time))