Skip to content

Commit fe1ff88

Browse files
Merge pull request #284 from mapmeld/mamba
Add Mamba to available LLMs
2 parents 6dfcbad + 6812284 commit fe1ff88

File tree

7 files changed

+46
-1
lines changed

7 files changed

+46
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ keywords = [
4343
dependencies = [
4444
"torch >= 1.9.0",
4545
"pytorch-lightning",
46-
"transformers==4.31.0",
46+
"transformers==4.39.3",
4747
"datasets==2.14.5",
4848
"evaluate==0.4.0",
4949
"bitsandbytes==0.41.1",

src/xturing/config/finetuning_config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ llama2_lora_kbit:
298298
num_train_epochs: 3
299299
optimizer_name: cpu_adam
300300

301+
mamba:
302+
learning_rate: 5e-5
303+
weight_decay: 0.01
304+
301305
opt:
302306
learning_rate: 5e-5
303307
weight_decay: 0.01

src/xturing/config/generation_config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,10 @@ llama2_lora_kbit:
252252
max_new_tokens: 256
253253
do_sample: false
254254

255+
# Greedy search
256+
mamba:
257+
do_sample: false
258+
255259
# Contrastive search
256260
opt:
257261
penalty_alpha: 0.6

src/xturing/engines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
LlamaLoraInt8Engine,
5959
LlamaLoraKbitEngine,
6060
)
61+
from xturing.engines.mamba_engine import MambaEngine
6162
from xturing.engines.opt_engine import (
6263
OPTEngine,
6364
OPTInt8Engine,
@@ -107,6 +108,7 @@
107108
BaseEngine.add_to_registry(LLama2LoraEngine.config_name, LLama2LoraEngine)
108109
BaseEngine.add_to_registry(LLama2LoraInt8Engine.config_name, LLama2LoraInt8Engine)
109110
BaseEngine.add_to_registry(LLama2LoraKbitEngine.config_name, LLama2LoraKbitEngine)
111+
BaseEngine.add_to_registry(MambaEngine.config_name, MambaEngine)
110112
BaseEngine.add_to_registry(OPTEngine.config_name, OPTEngine)
111113
BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine)
112114
BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os
2+
from pathlib import Path
3+
from typing import Optional, Union
4+
5+
from transformers import AutoTokenizer, MambaForCausalLM
6+
7+
from xturing.engines.causal import CausalEngine
8+
9+
class MambaEngine(CausalEngine):
10+
config_name: str = "mamba_engine"
11+
12+
def __init__(self, weights_path: Optional[Union[str, Path]] = None):
13+
model_name = "state-spaces/mamba-2.8b-hf"
14+
model = MambaForCausalLM.from_pretrained(model_name)
15+
tokenizer = AutoTokenizer.from_pretrained(model_name)
16+
17+
super().__init__(weights_path=weights_path, model=model, tokenizer=tokenizer)
18+
19+
20+
def save(self, saving_path: Union[str, Path]):
21+
self.model.save_pretrained(saving_path)
22+
self.tokenizer.save_pretrained(saving_path)

src/xturing/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
Llama2LoraInt8,
4444
Llama2LoraKbit,
4545
)
46+
from xturing.models.mamba import Mamba
4647
from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8
4748
from xturing.models.stable_diffusion import StableDiffusion
4849

@@ -88,6 +89,7 @@
8889
BaseModel.add_to_registry(Llama2Lora.config_name, Llama2Lora)
8990
BaseModel.add_to_registry(Llama2LoraInt8.config_name, Llama2LoraInt8)
9091
BaseModel.add_to_registry(Llama2LoraKbit.config_name, Llama2LoraKbit)
92+
BaseModel.add_to_registry(Mamba.config_name, Mamba)
9193
BaseModel.add_to_registry(OPT.config_name, OPT)
9294
BaseModel.add_to_registry(OPTInt8.config_name, OPTInt8)
9395
BaseModel.add_to_registry(OPTLora.config_name, OPTLora)

src/xturing/models/mamba.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from typing import Optional
2+
3+
from xturing.engines.mamba_engine import MambaEngine
4+
from xturing.models.causal import CausalModel
5+
6+
7+
class Mamba(CausalModel):
8+
config_name: str = "mamba"
9+
10+
def __init__(self, weights_path: Optional[str] = None):
11+
super().__init__(MambaEngine.config_name, weights_path)

0 commit comments

Comments
 (0)