Skip to content

Commit 025a7ac

Browse files
Merge pull request #248 from tushar2407/tests
Tests
2 parents d113990 + 773a927 commit 025a7ac

File tree

8 files changed

+186
-39
lines changed

8 files changed

+186
-39
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ keywords = [
4343
dependencies = [
4444
"torch >= 1.9.0",
4545
"pytorch-lightning",
46-
"transformers==4.28.1",
46+
"transformers==4.31.0",
4747
"datasets",
4848
"evaluate",
49-
"bitsandbytes==0.37.2",
49+
"bitsandbytes==0.41.1",
5050
"sentencepiece",
5151
"deepspeed",
5252
"gradio",

src/xturing/config/finetuning_config.yaml

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ bloom_lora_int8:
3232
batch_size: 8
3333
max_length: 256
3434

35+
bloom_int8:
36+
learning_rate: 1e-4
37+
weight_decay: 0.01
38+
num_train_epochs: 3
39+
batch_size: 8
40+
max_length: 256
41+
3542
cerebras:
3643
learning_rate: 5e-5
3744
weight_decay: 0.01
@@ -50,6 +57,13 @@ cerebras_lora_int8:
5057
batch_size: 8
5158
max_length: 256
5259

60+
cerebras_int8:
61+
learning_rate: 1e-4
62+
weight_decay: 0.01
63+
num_train_epochs: 3
64+
batch_size: 8
65+
max_length: 256
66+
5367
distilgpt2:
5468
learning_rate: 1e-3
5569
weight_decay: 0.01
@@ -115,6 +129,13 @@ galactica_lora_int8:
115129
batch_size: 8
116130
max_length: 256
117131

132+
galactica_int8:
133+
learning_rate: 1e-4
134+
weight_decay: 0.01
135+
num_train_epochs: 3
136+
batch_size: 8
137+
max_length: 256
138+
118139
generic:
119140
learning_rate: 1e-4
120141
weight_decay: 0.01
@@ -169,6 +190,13 @@ gptj_lora_int8:
169190
batch_size: 8
170191
max_length: 256
171192

193+
gptj_int8:
194+
learning_rate: 1e-4
195+
weight_decay: 0.01
196+
num_train_epochs: 3
197+
batch_size: 8
198+
max_length: 256
199+
172200
gpt2:
173201
learning_rate: 1e-3
174202
weight_decay: 0.01
@@ -187,13 +215,18 @@ gpt2_lora_int8:
187215
num_train_epochs: 3
188216
batch_size: 16
189217

218+
gpt2_int8:
219+
learning_rate: 3e-3
220+
weight_decay: 0.01
221+
num_train_epochs: 3
222+
batch_size: 16
223+
190224
llama:
191225
learning_rate: 5e-5
192226
weight_decay: 0.01
193227
num_train_epochs: 3
194228
optimizer_name: cpu_adam
195229

196-
197230
llama_lora:
198231
learning_rate: 1e-4
199232
weight_decay: 0.01
@@ -207,6 +240,13 @@ llama_lora_int8:
207240
batch_size: 8
208241
max_length: 256
209242

243+
llama_int8:
244+
learning_rate: 1e-4
245+
weight_decay: 0.01
246+
num_train_epochs: 3
247+
batch_size: 8
248+
max_length: 256
249+
210250
llama_lora_kbit:
211251
learning_rate: 3e-4
212252
num_train_epochs: 3
@@ -275,3 +315,10 @@ opt_lora_int8:
275315
num_train_epochs: 3
276316
batch_size: 8
277317
max_length: 256
318+
319+
opt_int8:
320+
learning_rate: 1e-4
321+
weight_decay: 0.01
322+
num_train_epochs: 3
323+
batch_size: 8
324+
max_length: 256

src/xturing/config/generation_config.yaml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ bloom_lora_int8:
2525
max_new_tokens: 256
2626
do_sample: false
2727

28+
# Greedy search
29+
bloom_int8:
30+
max_new_tokens: 256
31+
do_sample: false
32+
2833
# Contrastive search
2934
cerebras:
3035
penalty_alpha: 0.6
@@ -44,6 +49,11 @@ cerebras_lora_int8:
4449
max_new_tokens: 256
4550
do_sample: false
4651

52+
# Greedy search
53+
cerebras_int8:
54+
max_new_tokens: 256
55+
do_sample: false
56+
4757
# Top-p sampling
4858
distilgpt2:
4959
do_sample: true
@@ -102,6 +112,11 @@ galactica_lora_int8:
102112
max_new_tokens: 256
103113
do_sample: false
104114

115+
# Greedy search
116+
galactica_int8:
117+
max_new_tokens: 256
118+
do_sample: false
119+
105120
# Greedy search
106121
generic:
107122
max_new_tokens: 256
@@ -146,6 +161,11 @@ gptj_lora_int8:
146161
max_new_tokens: 256
147162
do_sample: false
148163

164+
# Greedy search
165+
gptj_int8:
166+
max_new_tokens: 256
167+
do_sample: false
168+
149169
# Top-p sampling
150170
gpt2:
151171
do_sample: true
@@ -167,6 +187,13 @@ gpt2_lora_int8:
167187
top_p: 0.92
168188
max_new_tokens: 256
169189

190+
# Top-p sampling
191+
gpt2_int8:
192+
do_sample: true
193+
top_k: 0
194+
top_p: 0.92
195+
max_new_tokens: 256
196+
170197
# Contrastive search
171198
llama:
172199
penalty_alpha: 0.6
@@ -186,6 +213,11 @@ llama_lora_int8:
186213
max_new_tokens: 256
187214
do_sample: false
188215

216+
# Greedy search
217+
llama_int8:
218+
max_new_tokens: 256
219+
do_sample: false
220+
189221
# Greedy search
190222
llama_lora_kbit:
191223
max_new_tokens: 256
@@ -238,3 +270,8 @@ opt_lora:
238270
opt_lora_int8:
239271
max_new_tokens: 256
240272
do_sample: false
273+
274+
# Greedy search
275+
opt_int8:
276+
max_new_tokens: 256
277+
do_sample: false

src/xturing/engines/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@
4444
GPTJLoraEngine,
4545
GPTJLoraInt8Engine,
4646
)
47-
from xturing.engines.llama2_engine import LLama2Engine
47+
from xturing.engines.llama2_engine import (
48+
LLama2Engine,
49+
LLama2Int8Engine,
50+
LLama2LoraEngine,
51+
LLama2LoraInt8Engine,
52+
LLama2LoraKbitEngine,
53+
)
4854
from xturing.engines.llama_engine import (
4955
LLamaEngine,
5056
LLamaInt8Engine,
@@ -97,6 +103,10 @@
97103
BaseEngine.add_to_registry(LlamaLoraInt8Engine.config_name, LlamaLoraInt8Engine)
98104
BaseEngine.add_to_registry(LlamaLoraKbitEngine.config_name, LlamaLoraKbitEngine)
99105
BaseEngine.add_to_registry(LLama2Engine.config_name, LLama2Engine)
106+
BaseEngine.add_to_registry(LLama2Int8Engine.config_name, LLama2Int8Engine)
107+
BaseEngine.add_to_registry(LLama2LoraEngine.config_name, LLama2LoraEngine)
108+
BaseEngine.add_to_registry(LLama2LoraInt8Engine.config_name, LLama2LoraInt8Engine)
109+
BaseEngine.add_to_registry(LLama2LoraKbitEngine.config_name, LLama2LoraKbitEngine)
100110
BaseEngine.add_to_registry(OPTEngine.config_name, OPTEngine)
101111
BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine)
102112
BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine)

src/xturing/engines/generic_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464

6565

6666
class GenericLoraKbitEngine(CausalLoraKbitEngine):
67-
config_name: str = "generic+lora_kbit_engine"
67+
config_name: str = "generic_lora_kbit_engine"
6868

6969
def __init__(
7070
self,
@@ -75,7 +75,6 @@ def __init__(
7575
super().__init__(
7676
model_name=model_name,
7777
weights_path=weights_path,
78-
load_4bit=True,
7978
target_modules=target_modules,
8079
)
8180

src/xturing/models/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636
LlamaLoraInt8,
3737
LlamaLoraKbit,
3838
)
39-
from xturing.models.llama2 import Llama2
39+
from xturing.models.llama2 import (
40+
Llama2,
41+
Llama2Int8,
42+
Llama2Lora,
43+
Llama2LoraInt8,
44+
Llama2LoraKbit,
45+
)
4046
from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8
4147
from xturing.models.stable_diffusion import StableDiffusion
4248

@@ -78,6 +84,10 @@
7884
BaseModel.add_to_registry(LlamaLoraInt8.config_name, LlamaLoraInt8)
7985
BaseModel.add_to_registry(LlamaLoraKbit.config_name, LlamaLoraKbit)
8086
BaseModel.add_to_registry(Llama2.config_name, Llama2)
87+
BaseModel.add_to_registry(Llama2Int8.config_name, Llama2Int8)
88+
BaseModel.add_to_registry(Llama2Lora.config_name, Llama2Lora)
89+
BaseModel.add_to_registry(Llama2LoraInt8.config_name, Llama2LoraInt8)
90+
BaseModel.add_to_registry(Llama2LoraKbit.config_name, Llama2LoraKbit)
8191
BaseModel.add_to_registry(OPT.config_name, OPT)
8292
BaseModel.add_to_registry(OPTInt8.config_name, OPTInt8)
8393
BaseModel.add_to_registry(OPTLora.config_name, OPTLora)

src/xturing/models/causal.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import json
22
from pathlib import Path
3-
4-
from typing import Iterable, List, Optional, Tuple, Type, Union
3+
from typing import Iterable, List, Optional, Tuple, Union
54

65
import torch
7-
import torch.nn.functional as F
86
from pytorch_lightning.loggers import Logger
97
from torch.utils.data import DataLoader
108
from tqdm import tqdm
@@ -21,15 +19,7 @@
2119
from xturing.trainers.base import BaseTrainer
2220
from xturing.trainers.lightning_trainer import LightningTrainer
2321
from xturing.utils.logging import configure_logger
24-
from xturing.utils.metrics import get_accuracy
25-
from xturing.utils.prompt import (
26-
OpenAIChatMessage,
27-
OpenAICreateChatPrompt,
28-
OpenAICreatePrompt,
29-
Prompt,
30-
chat_prompt_to_text,
31-
is_chat_prompt,
32-
)
22+
from xturing.utils.prompt import OpenAICreateChatPrompt, OpenAICreatePrompt, Prompt
3323
from xturing.utils.utils import _filter_args, _index_samples
3424

3525
TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]
@@ -44,6 +34,7 @@ def __init__(
4434
weights_path: Optional[str] = None,
4535
model_name: Optional[str] = None,
4636
target_modules: Optional[List[str]] = None,
37+
transfer_to_device: Optional[bool] = True,
4738
**kwargs,
4839
):
4940
arguments = dict(
@@ -82,6 +73,8 @@ def __init__(
8273
logger.debug(f"Finetuning parameters: {self.finetuning_args}")
8374
logger.debug(f"Generation parameters: {self.generation_args}")
8475

76+
self.transfer_to_device = transfer_to_device
77+
8578
def finetuning_config(self):
8679
return self.finetuning_args
8780

@@ -163,7 +156,9 @@ def generate(
163156
batch_size: Optional[int] = 1,
164157
):
165158
self.engine.model.eval()
166-
self.engine.model = self.engine.model.to(DEFAULT_DEVICE)
159+
160+
if self.transfer_to_device:
161+
self.engine.model = self.engine.model.to(DEFAULT_DEVICE)
167162

168163
outputs = []
169164

@@ -239,18 +234,9 @@ def _model_call(
239234
def completion_query(
240235
self, prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt]
241236
):
242-
# actual_prompt = chat_prompt_to_text(prompt)
243237
actual_prompt = prompt
244238
logger.info(prompt)
245239
text_out = self.generate(texts=[actual_prompt])
246-
247-
# parse results
248-
# result = {
249-
# "text": text_out,
250-
# "tokens": None,
251-
# "logprobs": None,
252-
# }
253-
254240
return text_out, actual_prompt
255241

256242
def check_sampled_text(
@@ -314,8 +300,6 @@ def evaluate(
314300
dataset: Union[TextDataset, InstructionDataset],
315301
batch_size: Optional[int] = 1,
316302
):
317-
# outputs = self.eval_all_samples(dataset)
318-
# return get_accuracy(outputs)
319303
collate_fn = self._make_collate_fn(dataset)
320304
dataloader = DataLoader(
321305
dataset,
@@ -338,7 +322,11 @@ def __init__(
338322
):
339323
assert_not_cpu_int8()
340324
super().__init__(
341-
engine, weights_path=weights_path, model_name=model_name, **kwargs
325+
engine,
326+
weights_path=weights_path,
327+
model_name=model_name,
328+
transfer_to_device=False,
329+
**kwargs,
342330
)
343331

344332

@@ -400,18 +388,19 @@ def __init__(
400388

401389
class CausalLoraKbitModel(CausalLoraModel):
402390
def __init__(
403-
self,
404-
engine: str,
405-
weights_path: Optional[str] = None,
406-
model_name: Optional[str] = None,
407-
target_modules: Optional[List[str]] = None,
408-
**kwargs,
409-
):
391+
self,
392+
engine: str,
393+
weights_path: Optional[str] = None,
394+
model_name: Optional[str] = None,
395+
target_modules: Optional[List[str]] = None,
396+
**kwargs,
397+
):
410398
assert_not_cpu_int8()
411399
super().__init__(
412400
engine,
413401
weights_path=weights_path,
414402
model_name=model_name,
415403
target_modules=target_modules,
404+
transfer_to_device=False,
416405
**kwargs,
417406
)

0 commit comments

Comments
 (0)