diff --git a/finetune/finetune_visualglm.sh b/finetune/finetune_visualglm.sh index 56270a6..f973883 100755 --- a/finetune/finetune_visualglm.sh +++ b/finetune/finetune_visualglm.sh @@ -36,10 +36,10 @@ gpt_options=" \ --warmup .02 \ --checkpoint-activations \ --save-interval 300 \ - --eval-interval 10000 \ + --eval-interval 100 \ --save "./checkpoints" \ --split 1 \ - --eval-iters 10 \ + --eval-iters 1 \ --eval-batch-size 8 \ --zero-stage 1 \ --lr 0.0001 \ diff --git a/finetune_visualglm.py b/finetune_visualglm.py index 325493a..7a82b56 100644 --- a/finetune_visualglm.py +++ b/finetune_visualglm.py @@ -1,12 +1,95 @@ import os import torch import argparse +from functools import partial +import json +from tqdm import tqdm from sat import mpu, get_args, get_tokenizer from sat.training.deepspeed_training import training_main from model import VisualGLMModel from sat.model.finetune import PTuningV2Mixin from sat.model.finetune.lora2 import LoraMixin +from sat.generation.autoregressive_sampling import filling_sequence, BaseStrategy +from sat.model.mixins import CachedAutoregressiveMixin + +from model import chat, VisualGLMModel + +from torch.nn import CrossEntropyLoss +import jieba +from rouge_chinese import Rouge +from nltk.translate.bleu_score import sentence_bleu +import numpy as np + + +def forward_step_eval(data_iterator, model, args, timers): + # Metric + model.eval() + + try: + model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) + except Exception as e: + # Avoid assertion errors caused by duplicate additions + pass + def compute_metrics(eval_preds): + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + # decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_preds = preds + if args.ignore_pad_token_for_loss: + # Replace -100 in the labels as we can't decode them. + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + score_dict = { + "rouge-1": [], + "rouge-2": [], + "rouge-l": [], + "bleu-4": [] + } + for pred, label in zip(decoded_preds, decoded_labels): + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) + rouge = Rouge() + scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v["f"] * 100, 4)) + bleu_score = sentence_bleu([list(label)], list(pred)) + score_dict["bleu-4"].append(round(bleu_score * 100, 4)) + + for k, v in score_dict.items(): + score_dict[k] = float(np.mean(v)) + return score_dict + + # Get the batch. + tokens, labels, image, pre_image = get_batch(data_iterator, args, timers) + + gen_kwargs = { "max_length": 1024, "top_p": 0.4, "top_k":100, "temperature": 0.8} + + with open(args.valid_data[0], 'r', encoding='utf-8') as file: + data = json.load(file) + outputs = [] + with torch.no_grad(): + for item in tqdm(data): + response, _, _ = chat( + item['img'], + model, + tokenizer, + item['prompt'], + history=[], + max_length=gen_kwargs['max_length'], + top_p=gen_kwargs['top_p'], + temperature=gen_kwargs['temperature'], + top_k=gen_kwargs['top_k'], + ) + sep = '答:' + outputs.append(response.split(sep)[-1].strip()) + return torch.tensor(0, device=torch.device('cuda')), {k:torch.tensor(v, device=torch.device('cuda')) for k,v in compute_metrics((outputs, labels.cpu())).items()} + + class FineTuneVisualGLMModel(VisualGLMModel): def __init__(self, args, transformer=None, parallel_output=True, **kw_args): @@ -37,6 +120,8 @@ def disable_untrainable_params(self): enable.extend(['ptuning']) if self.args.use_lora or self.args.use_qlora: enable.extend(['matrix_A', 'matrix_B']) + + for n, p in self.named_parameters(): flag = False for e in enable: @@ -60,6 +145,7 @@ def get_batch(data_iterator, args, timers): data = next(data_iterator) else: data = None + timers('data loader').stop() data_b = mpu.broadcast_data(keys, data, datatype) data_i = mpu.broadcast_data(['image'], data, torch.float32) @@ -191,4 +277,4 @@ def data_collator(examples): 'pre_image': example['pre_image'] } return ret - training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=create_dataset_function, collate_fn=data_collator) \ No newline at end of file + training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=create_dataset_function, collate_fn=data_collator, forward_step_eval=forward_step_eval)