11import json
22from pathlib import Path
3-
4- from typing import Iterable , List , Optional , Tuple , Type , Union
3+ from typing import Iterable , List , Optional , Tuple , Union
54
65import torch
7- import torch .nn .functional as F
86from pytorch_lightning .loggers import Logger
97from torch .utils .data import DataLoader
108from tqdm import tqdm
2119from xturing .trainers .base import BaseTrainer
2220from xturing .trainers .lightning_trainer import LightningTrainer
2321from xturing .utils .logging import configure_logger
24- from xturing .utils .metrics import get_accuracy
25- from xturing .utils .prompt import (
26- OpenAIChatMessage ,
27- OpenAICreateChatPrompt ,
28- OpenAICreatePrompt ,
29- Prompt ,
30- chat_prompt_to_text ,
31- is_chat_prompt ,
32- )
22+ from xturing .utils .prompt import OpenAICreateChatPrompt , OpenAICreatePrompt , Prompt
3323from xturing .utils .utils import _filter_args , _index_samples
3424
3525TokenSequence = Union [List [int ], torch .LongTensor , torch .Tensor , BatchEncoding ]
@@ -44,6 +34,7 @@ def __init__(
4434 weights_path : Optional [str ] = None ,
4535 model_name : Optional [str ] = None ,
4636 target_modules : Optional [List [str ]] = None ,
37+ transfer_to_device : Optional [bool ] = True ,
4738 ** kwargs ,
4839 ):
4940 arguments = dict (
@@ -82,6 +73,8 @@ def __init__(
8273 logger .debug (f"Finetuning parameters: { self .finetuning_args } " )
8374 logger .debug (f"Generation parameters: { self .generation_args } " )
8475
76+ self .transfer_to_device = transfer_to_device
77+
8578 def finetuning_config (self ):
8679 return self .finetuning_args
8780
@@ -163,7 +156,9 @@ def generate(
163156 batch_size : Optional [int ] = 1 ,
164157 ):
165158 self .engine .model .eval ()
166- self .engine .model = self .engine .model .to (DEFAULT_DEVICE )
159+
160+ if self .transfer_to_device :
161+ self .engine .model = self .engine .model .to (DEFAULT_DEVICE )
167162
168163 outputs = []
169164
@@ -239,18 +234,9 @@ def _model_call(
239234 def completion_query (
240235 self , prompt : Union [OpenAICreatePrompt , OpenAICreateChatPrompt , Prompt ]
241236 ):
242- # actual_prompt = chat_prompt_to_text(prompt)
243237 actual_prompt = prompt
244238 logger .info (prompt )
245239 text_out = self .generate (texts = [actual_prompt ])
246-
247- # parse results
248- # result = {
249- # "text": text_out,
250- # "tokens": None,
251- # "logprobs": None,
252- # }
253-
254240 return text_out , actual_prompt
255241
256242 def check_sampled_text (
@@ -314,8 +300,6 @@ def evaluate(
314300 dataset : Union [TextDataset , InstructionDataset ],
315301 batch_size : Optional [int ] = 1 ,
316302 ):
317- # outputs = self.eval_all_samples(dataset)
318- # return get_accuracy(outputs)
319303 collate_fn = self ._make_collate_fn (dataset )
320304 dataloader = DataLoader (
321305 dataset ,
@@ -338,7 +322,11 @@ def __init__(
338322 ):
339323 assert_not_cpu_int8 ()
340324 super ().__init__ (
341- engine , weights_path = weights_path , model_name = model_name , ** kwargs
325+ engine ,
326+ weights_path = weights_path ,
327+ model_name = model_name ,
328+ transfer_to_device = False ,
329+ ** kwargs ,
342330 )
343331
344332
@@ -400,18 +388,19 @@ def __init__(
400388
401389class CausalLoraKbitModel (CausalLoraModel ):
402390 def __init__ (
403- self ,
404- engine : str ,
405- weights_path : Optional [str ] = None ,
406- model_name : Optional [str ] = None ,
407- target_modules : Optional [List [str ]] = None ,
408- ** kwargs ,
409- ):
391+ self ,
392+ engine : str ,
393+ weights_path : Optional [str ] = None ,
394+ model_name : Optional [str ] = None ,
395+ target_modules : Optional [List [str ]] = None ,
396+ ** kwargs ,
397+ ):
410398 assert_not_cpu_int8 ()
411399 super ().__init__ (
412400 engine ,
413401 weights_path = weights_path ,
414402 model_name = model_name ,
415403 target_modules = target_modules ,
404+ transfer_to_device = False ,
416405 ** kwargs ,
417406 )
0 commit comments