Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions finetune/finetune_visualglm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ gpt_options=" \
--warmup .02 \
--checkpoint-activations \
--save-interval 300 \
--eval-interval 10000 \
--eval-interval 100 \
--save "./checkpoints" \
--split 1 \
--eval-iters 10 \
--eval-iters 1 \
--eval-batch-size 8 \
--zero-stage 1 \
--lr 0.0001 \
Expand Down
88 changes: 87 additions & 1 deletion finetune_visualglm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,95 @@
import os
import torch
import argparse
from functools import partial
import json
from tqdm import tqdm

from sat import mpu, get_args, get_tokenizer
from sat.training.deepspeed_training import training_main
from model import VisualGLMModel
from sat.model.finetune import PTuningV2Mixin
from sat.model.finetune.lora2 import LoraMixin
from sat.generation.autoregressive_sampling import filling_sequence, BaseStrategy
from sat.model.mixins import CachedAutoregressiveMixin

from model import chat, VisualGLMModel

from torch.nn import CrossEntropyLoss
import jieba
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu
import numpy as np


def forward_step_eval(data_iterator, model, args, timers):
# Metric
model.eval()

try:
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
except Exception as e:
# Avoid assertion errors caused by duplicate additions
pass
def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
# decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_preds = preds
if args.ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

score_dict = {
"rouge-1": [],
"rouge-2": [],
"rouge-l": [],
"bleu-4": []
}
for pred, label in zip(decoded_preds, decoded_labels):
hypothesis = list(jieba.cut(pred))
reference = list(jieba.cut(label))
rouge = Rouge()
scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
result = scores[0]

for k, v in result.items():
score_dict[k].append(round(v["f"] * 100, 4))
bleu_score = sentence_bleu([list(label)], list(pred))
score_dict["bleu-4"].append(round(bleu_score * 100, 4))

for k, v in score_dict.items():
score_dict[k] = float(np.mean(v))
return score_dict

# Get the batch.
tokens, labels, image, pre_image = get_batch(data_iterator, args, timers)

gen_kwargs = { "max_length": 1024, "top_p": 0.4, "top_k":100, "temperature": 0.8}

with open(args.valid_data[0], 'r', encoding='utf-8') as file:
data = json.load(file)
outputs = []
with torch.no_grad():
for item in tqdm(data):
response, _, _ = chat(
item['img'],
model,
tokenizer,
item['prompt'],
history=[],
max_length=gen_kwargs['max_length'],
top_p=gen_kwargs['top_p'],
temperature=gen_kwargs['temperature'],
top_k=gen_kwargs['top_k'],
)
sep = '答:'
outputs.append(response.split(sep)[-1].strip())
return torch.tensor(0, device=torch.device('cuda')), {k:torch.tensor(v, device=torch.device('cuda')) for k,v in compute_metrics((outputs, labels.cpu())).items()}



class FineTuneVisualGLMModel(VisualGLMModel):
def __init__(self, args, transformer=None, parallel_output=True, **kw_args):
Expand Down Expand Up @@ -37,6 +120,8 @@ def disable_untrainable_params(self):
enable.extend(['ptuning'])
if self.args.use_lora or self.args.use_qlora:
enable.extend(['matrix_A', 'matrix_B'])


for n, p in self.named_parameters():
flag = False
for e in enable:
Expand All @@ -60,6 +145,7 @@ def get_batch(data_iterator, args, timers):
data = next(data_iterator)
else:
data = None

timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype)
data_i = mpu.broadcast_data(['image'], data, torch.float32)
Expand Down Expand Up @@ -191,4 +277,4 @@ def data_collator(examples):
'pre_image': example['pre_image']
}
return ret
training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=create_dataset_function, collate_fn=data_collator)
training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=create_dataset_function, collate_fn=data_collator, forward_step_eval=forward_step_eval)