Skip to content

Commit a706b5b

Browse files
committed
feat(models): add qwen3 0.6b variants
1 parent dd8e0dc commit a706b5b

File tree

10 files changed

+355
-9
lines changed

10 files changed

+355
-9
lines changed

docs/docs/overview/quickstart/finetune_guide.md

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,32 @@ Start by loading the instruction dataset and initializing the model of your choi
4242

4343
<Test instruction={'Instruction'}/>
4444

45-
A list of all the supported models can be found [here](/overview/supported_models).
46-
47-
48-
49-
Next, we need to start the fine-tuning
50-
51-
```python
52-
model.finetune(dataset=instruction_dataset)
53-
```
45+
A list of all the supported models can be found [here](/overview/supported_models).
46+
47+
48+
49+
Next, we need to start the fine-tuning
50+
51+
```python
52+
model.finetune(dataset=instruction_dataset)
53+
```
54+
55+
### Example: Qwen3 0.6B with LoRA
56+
57+
```python
58+
from xturing.datasets import InstructionDataset
59+
from xturing.models import BaseModel
60+
61+
instruction_dataset = InstructionDataset("/path/to/your/dataset")
62+
model = BaseModel.create("qwen3_0_6b_lora")
63+
model.finetune(dataset=instruction_dataset)
64+
```
65+
66+
You can find a runnable script at `examples/models/qwen3/qwen3_lora_finetune.py`.
67+
68+
```bash
69+
xturing finetune --model qwen3_0_6b_lora --data-dir /path/to/your/dataset
70+
```
5471

5572
<!-- Finally, let us test how our fine-tuned model performs using the `.generate()` function.
5673

docs/docs/overview/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ description: Models Supported by xTuring
1717
| GPT-2 | gpt2 |||||
1818
| LLaMA 7B | llama |||||
1919
| LLaMA2 | llama2 |||||
20+
| Qwen3 0.6B | qwen3_0_6b |||||
2021
| OPT 1.3B | opt |||||
2122

2223
### Memory-efficient versions
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Minimal example showing how to fine-tune Qwen3-0.6B with LoRA using xTuring."""
2+
from pathlib import Path
3+
4+
from xturing.datasets.instruction_dataset import InstructionDataset
5+
from xturing.models import BaseModel
6+
7+
# Reuse the small Alpaca-style dataset that ships with the repo. Replace this path
8+
# with your own instruction dataset when running real experiments.
9+
DATASET_DIR = Path(__file__).parent.parent / "llama" / "alpaca_data"
10+
11+
# Location where the LoRA adapters will be stored once training finishes.
12+
OUTPUT_DIR = Path(__file__).parent / "qwen3_lora_weights"
13+
14+
15+
def main():
16+
instruction_dataset = InstructionDataset(str(DATASET_DIR))
17+
18+
# Initialize the Qwen3 0.6B model with a LoRA adapter head.
19+
model = BaseModel.create("qwen3_0_6b_lora")
20+
21+
# Launch fine-tuning with the default configuration (see
22+
# src/xturing/config/finetuning_config.yaml for the exact hyper-parameters).
23+
model.finetune(dataset=instruction_dataset)
24+
25+
# Run a quick generation to sanity-check the adapter before saving.
26+
output = model.generate(texts=["Why are smaller language models becoming popular?"])
27+
print(f"Generated output: {output}")
28+
29+
# Persist the adapter and tokenizer so the run can be resumed or deployed later.
30+
model.save(str(OUTPUT_DIR))
31+
print(f"Saved fine-tuned weights to {OUTPUT_DIR}")
32+
33+
if __name__ == "__main__":
34+
main()

src/xturing/config/finetuning_config.yaml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,38 @@ mamba:
302302
learning_rate: 5e-5
303303
weight_decay: 0.01
304304

305+
qwen3_0_6b:
306+
learning_rate: 5e-5
307+
weight_decay: 0.01
308+
num_train_epochs: 3
309+
310+
qwen3_0_6b_lora:
311+
learning_rate: 1e-4
312+
weight_decay: 0.01
313+
num_train_epochs: 3
314+
batch_size: 4
315+
316+
qwen3_0_6b_int8:
317+
learning_rate: 1e-4
318+
weight_decay: 0.01
319+
num_train_epochs: 3
320+
batch_size: 4
321+
max_length: 256
322+
323+
qwen3_0_6b_lora_int8:
324+
learning_rate: 1e-4
325+
weight_decay: 0.01
326+
num_train_epochs: 3
327+
batch_size: 8
328+
max_length: 256
329+
330+
qwen3_0_6b_lora_kbit:
331+
learning_rate: 1e-4
332+
weight_decay: 0.01
333+
num_train_epochs: 3
334+
batch_size: 4
335+
max_length: 512
336+
305337
opt:
306338
learning_rate: 5e-5
307339
weight_decay: 0.01

src/xturing/config/generation_config.yaml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,35 @@ llama2_lora_kbit:
316316
mamba:
317317
do_sample: false
318318

319+
# Contrastive search
320+
qwen3_0_6b:
321+
penalty_alpha: 0.6
322+
top_k: 4
323+
max_new_tokens: 256
324+
do_sample: false
325+
326+
# Contrastive search
327+
qwen3_0_6b_lora:
328+
penalty_alpha: 0.6
329+
top_k: 4
330+
max_new_tokens: 256
331+
do_sample: false
332+
333+
# Greedy search
334+
qwen3_0_6b_int8:
335+
max_new_tokens: 256
336+
do_sample: false
337+
338+
# Greedy search
339+
qwen3_0_6b_lora_int8:
340+
max_new_tokens: 256
341+
do_sample: false
342+
343+
# Greedy search
344+
qwen3_0_6b_lora_kbit:
345+
max_new_tokens: 256
346+
do_sample: false
347+
319348
# Contrastive search
320349
opt:
321350
penalty_alpha: 0.6

src/xturing/engines/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@
7171
LlamaLoraKbitEngine,
7272
)
7373
from xturing.engines.mamba_engine import MambaEngine
74+
from xturing.engines.qwen_engine import (
75+
Qwen3Engine,
76+
Qwen3Int8Engine,
77+
Qwen3LoraEngine,
78+
Qwen3LoraInt8Engine,
79+
Qwen3LoraKbitEngine,
80+
)
7481
from xturing.engines.opt_engine import (
7582
OPTEngine,
7683
OPTInt8Engine,
@@ -135,6 +142,11 @@
135142
BaseEngine.add_to_registry(LLama2LoraInt8Engine.config_name, LLama2LoraInt8Engine)
136143
BaseEngine.add_to_registry(LLama2LoraKbitEngine.config_name, LLama2LoraKbitEngine)
137144
BaseEngine.add_to_registry(MambaEngine.config_name, MambaEngine)
145+
BaseEngine.add_to_registry(Qwen3Engine.config_name, Qwen3Engine)
146+
BaseEngine.add_to_registry(Qwen3Int8Engine.config_name, Qwen3Int8Engine)
147+
BaseEngine.add_to_registry(Qwen3LoraEngine.config_name, Qwen3LoraEngine)
148+
BaseEngine.add_to_registry(Qwen3LoraInt8Engine.config_name, Qwen3LoraInt8Engine)
149+
BaseEngine.add_to_registry(Qwen3LoraKbitEngine.config_name, Qwen3LoraKbitEngine)
138150
BaseEngine.add_to_registry(OPTEngine.config_name, OPTEngine)
139151
BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine)
140152
BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine)

src/xturing/engines/qwen_engine.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from pathlib import Path
2+
from typing import Optional, Union
3+
4+
from xturing.engines.causal import CausalEngine, CausalLoraEngine, CausalLoraKbitEngine
5+
6+
_DEFAULT_MODEL_NAME = "Qwen/Qwen3-0.6B"
7+
_TARGET_MODULES = [
8+
"q_proj",
9+
"k_proj",
10+
"v_proj",
11+
"o_proj",
12+
"gate_proj",
13+
"up_proj",
14+
"down_proj",
15+
]
16+
17+
18+
class Qwen3Engine(CausalEngine):
19+
config_name: str = "qwen3_0_6b_engine"
20+
21+
def __init__(self, weights_path: Optional[Union[str, Path]] = None):
22+
super().__init__(
23+
model_name=_DEFAULT_MODEL_NAME,
24+
weights_path=weights_path,
25+
trust_remote_code=True,
26+
)
27+
28+
self.tokenizer.pad_token = self.tokenizer.eos_token
29+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
30+
31+
32+
class Qwen3LoraEngine(CausalLoraEngine):
33+
config_name: str = "qwen3_0_6b_lora_engine"
34+
35+
def __init__(self, weights_path: Optional[Union[str, Path]] = None):
36+
super().__init__(
37+
model_name=_DEFAULT_MODEL_NAME,
38+
weights_path=weights_path,
39+
target_modules=_TARGET_MODULES,
40+
trust_remote_code=True,
41+
)
42+
43+
self.tokenizer.pad_token = self.tokenizer.eos_token
44+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
45+
46+
47+
class Qwen3Int8Engine(CausalEngine):
48+
config_name: str = "qwen3_0_6b_int8_engine"
49+
50+
def __init__(self, weights_path: Optional[Union[str, Path]] = None):
51+
super().__init__(
52+
model_name=_DEFAULT_MODEL_NAME,
53+
weights_path=weights_path,
54+
load_8bit=True,
55+
trust_remote_code=True,
56+
)
57+
58+
self.tokenizer.pad_token = self.tokenizer.eos_token
59+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
60+
61+
62+
class Qwen3LoraInt8Engine(CausalLoraEngine):
63+
config_name: str = "qwen3_0_6b_lora_int8_engine"
64+
65+
def __init__(self, weights_path: Optional[Union[str, Path]] = None):
66+
super().__init__(
67+
model_name=_DEFAULT_MODEL_NAME,
68+
weights_path=weights_path,
69+
load_8bit=True,
70+
target_modules=_TARGET_MODULES,
71+
trust_remote_code=True,
72+
)
73+
74+
self.tokenizer.pad_token = self.tokenizer.eos_token
75+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
76+
77+
78+
class Qwen3LoraKbitEngine(CausalLoraKbitEngine):
79+
config_name: str = "qwen3_0_6b_lora_kbit_engine"
80+
81+
def __init__(self, weights_path: Optional[Union[str, Path]] = None):
82+
super().__init__(
83+
model_name=_DEFAULT_MODEL_NAME,
84+
weights_path=weights_path,
85+
target_modules=_TARGET_MODULES,
86+
trust_remote_code=True,
87+
)
88+
89+
self.tokenizer.pad_token = self.tokenizer.eos_token
90+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

src/xturing/models/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@
5656
Llama2LoraKbit,
5757
)
5858
from xturing.models.mamba import Mamba
59+
from xturing.models.qwen import (
60+
Qwen3,
61+
Qwen3Int8,
62+
Qwen3Lora,
63+
Qwen3LoraInt8,
64+
Qwen3LoraKbit,
65+
)
5966
from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8
6067
from xturing.models.stable_diffusion import StableDiffusion
6168

@@ -112,6 +119,11 @@
112119
BaseModel.add_to_registry(Llama2LoraInt8.config_name, Llama2LoraInt8)
113120
BaseModel.add_to_registry(Llama2LoraKbit.config_name, Llama2LoraKbit)
114121
BaseModel.add_to_registry(Mamba.config_name, Mamba)
122+
BaseModel.add_to_registry(Qwen3.config_name, Qwen3)
123+
BaseModel.add_to_registry(Qwen3Int8.config_name, Qwen3Int8)
124+
BaseModel.add_to_registry(Qwen3Lora.config_name, Qwen3Lora)
125+
BaseModel.add_to_registry(Qwen3LoraInt8.config_name, Qwen3LoraInt8)
126+
BaseModel.add_to_registry(Qwen3LoraKbit.config_name, Qwen3LoraKbit)
115127
BaseModel.add_to_registry(OPT.config_name, OPT)
116128
BaseModel.add_to_registry(OPTInt8.config_name, OPTInt8)
117129
BaseModel.add_to_registry(OPTLora.config_name, OPTLora)

src/xturing/models/qwen.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import Optional
2+
3+
from xturing.engines.qwen_engine import (
4+
Qwen3Engine,
5+
Qwen3Int8Engine,
6+
Qwen3LoraEngine,
7+
Qwen3LoraInt8Engine,
8+
Qwen3LoraKbitEngine,
9+
)
10+
from xturing.models.causal import (
11+
CausalInt8Model,
12+
CausalLoraInt8Model,
13+
CausalLoraKbitModel,
14+
CausalLoraModel,
15+
CausalModel,
16+
)
17+
18+
19+
class Qwen3(CausalModel):
20+
config_name: str = "qwen3_0_6b"
21+
22+
def __init__(self, weights_path: Optional[str] = None):
23+
super().__init__(Qwen3Engine.config_name, weights_path)
24+
25+
26+
class Qwen3Lora(CausalLoraModel):
27+
config_name: str = "qwen3_0_6b_lora"
28+
29+
def __init__(self, weights_path: Optional[str] = None):
30+
super().__init__(Qwen3LoraEngine.config_name, weights_path)
31+
32+
33+
class Qwen3Int8(CausalInt8Model):
34+
config_name: str = "qwen3_0_6b_int8"
35+
36+
def __init__(self, weights_path: Optional[str] = None):
37+
super().__init__(Qwen3Int8Engine.config_name, weights_path)
38+
39+
40+
class Qwen3LoraInt8(CausalLoraInt8Model):
41+
config_name: str = "qwen3_0_6b_lora_int8"
42+
43+
def __init__(self, weights_path: Optional[str] = None):
44+
super().__init__(Qwen3LoraInt8Engine.config_name, weights_path)
45+
46+
47+
class Qwen3LoraKbit(CausalLoraKbitModel):
48+
config_name: str = "qwen3_0_6b_lora_kbit"
49+
50+
def __init__(self, weights_path: Optional[str] = None):
51+
super().__init__(Qwen3LoraKbitEngine.config_name, weights_path)

0 commit comments

Comments
 (0)