-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_sft.py
More file actions
122 lines (104 loc) · 4.59 KB
/
train_sft.py
File metadata and controls
122 lines (104 loc) · 4.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig
import os
from .data_loader import get_data
from .reward import compute_smatch_f1
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load dataset
train_dataset = get_data(args.dataset1_path, args.dataset2_path, type="sft")
# split_dataset = dataset.train_test_split(test_size=0.1, seed=42)
# train_dataset = split_dataset["train"]
# eval_dataset = split_dataset["test"]
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
trust_remote_code=True
).to(device)
if args.use_lora:
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear",
)
def compute_metrics(eval_preds):
preds, labels = eval_preds
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Compute SMATCH F1
f1 = compute_smatch_f1(decoded_labels, decoded_preds)
return {"smatch_f1": f1}
training_args = SFTConfig(
output_dir=args.output_dir,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
num_train_epochs=args.num_train_epochs,
learning_rate=args.learning_rate,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
adam_beta1=args.adam_beta1,
adam_beta2=args.adam_beta2,
weight_decay=args.weight_decay,
warmup_steps=args.warmup_steps,
lr_scheduler_type=args.lr_scheduler_type,
bf16=True,
logging_dir=os.path.join(args.output_dir, "logs"),
save_total_limit=2,
report_to="none",
completion_only_loss=False,
deepspeed=args.deepspeed_path,
max_length=args.max_input_length,
# eval_strategy="steps",
# eval_steps=args.eval_steps
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
# eval_dataset=eval_dataset,
args=training_args,
compute_metrics=compute_metrics,
peft_config=peft_config
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset1_path", type=str, required=True)
parser.add_argument("--dataset2_path", type=str, default=None, help="Optional second dataset path for concatenation")
parser.add_argument("--output_dir", type=str, default="./sft_lora_output")
parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct")
# Training parameters
parser.add_argument("--per_device_train_batch_size", type=int, default=4)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--num_train_epochs", type=int, default=3)
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--logging_steps", type=int, default=10)
parser.add_argument("--save_steps", type=int, default=1000)
parser.add_argument("--deepspeed_path", type=str, default=None)
parser.add_argument("--adam_beta1", type=float, default=0.9)
parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--warmup_steps", type=int, default=1000)
parser.add_argument("--lr_scheduler_type", type=str, default="linear")
parser.add_argument("--max_input_length", type=int, default=1024)
parser.add_argument("--eval_steps", type=int, default=500)
# LoRA parameters
parser.add_argument("--use_lora", type=int, default=0, help="Use LoRA for training")
parser.add_argument("--lora_r", type=int, default=8)
parser.add_argument("--lora_alpha", type=int, default=16)
parser.add_argument("--lora_dropout", type=float, default=0.1)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)