Skip to content

Commit 100f473

Browse files
committed
feat: add Qwen3 SFT test and docs update
1 parent 7ea38e0 commit 100f473

File tree

2 files changed

+226
-1
lines changed

2 files changed

+226
-1
lines changed

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
___
2222

23-
`xTuring` makes it simple, fast, and cost‑efficient to fine‑tune open‑source LLMs (e.g., GPT‑OSS, LLaMA/LLaMA 2, Falcon, GPT‑J, GPT‑2, OPT, Bloom, Cerebras, Galactica) on your own data — locally or in your private cloud.
23+
`xTuring` makes it simple, fast, and cost‑efficient to fine‑tune open‑source LLMs (e.g., GPT‑OSS, LLaMA/LLaMA 2, Falcon, Qwen3, GPT‑J, GPT‑2, OPT, Bloom, Cerebras, Galactica) on your own data — locally or in your private cloud.
2424

2525
Why xTuring:
2626
- Simple API for data prep, training, and inference
@@ -162,6 +162,17 @@ outputs = model.generate(dataset = dataset, batch_size=10)
162162

163163
```
164164

165+
7. __Qwen3 0.6B supervised fine-tuning__ – The lightweight Qwen3 0.6B checkpoint now has first-class support (registry, configs, docs, and examples) so you can launch SFT/LoRA jobs immediately.
166+
```python
167+
from xturing.datasets import InstructionDataset
168+
from xturing.models import BaseModel
169+
170+
dataset = InstructionDataset("./examples/models/llama/alpaca_data")
171+
model = BaseModel.create("qwen3_0_6b_lora")
172+
model.finetune(dataset=dataset)
173+
```
174+
> See `examples/models/qwen3/qwen3_lora_finetune.py` for a runnable script.
175+
165176
An exploration of the [Llama LoRA INT4 working example](examples/features/int4_finetuning/LLaMA_lora_int4.ipynb) is recommended for an understanding of its application.
166177

167178
For an extended insight, consider examining the [GenericModel working example](examples/features/generic/generic_model.py) available in the repository.

tests/xturing/models/test_qwen_model.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,91 @@
1+
import importlib.machinery
2+
import sys
3+
import types
14
from pathlib import Path
25

6+
7+
def _make_module(name):
8+
module = types.ModuleType(name)
9+
module.__spec__ = importlib.machinery.ModuleSpec(name, loader=None)
10+
return module
11+
12+
13+
def _install_stub_modules():
14+
if "ai21" not in sys.modules:
15+
ai21_module = _make_module("ai21")
16+
17+
class _Completion:
18+
@staticmethod
19+
def execute(**_):
20+
return {"prompt": {"text": ""}}
21+
22+
ai21_module.api_key = None
23+
ai21_module.Completion = _Completion
24+
sys.modules["ai21"] = ai21_module
25+
26+
if "cohere" not in sys.modules:
27+
cohere_module = _make_module("cohere")
28+
29+
class _CohereError(Exception):
30+
pass
31+
32+
class _Client:
33+
def __init__(self, *_args, **_kwargs):
34+
self.generations = [types.SimpleNamespace(text="")]
35+
36+
def generate(self, **_):
37+
return types.SimpleNamespace(generations=self.generations)
38+
39+
cohere_module.CohereError = _CohereError
40+
cohere_module.Client = _Client
41+
sys.modules["cohere"] = cohere_module
42+
43+
if "openai" not in sys.modules:
44+
openai_module = _make_module("openai")
45+
46+
class _Completion:
47+
@staticmethod
48+
def create(n=1, **_):
49+
return {"choices": [types.SimpleNamespace(text="")] * n}
50+
51+
class _ChatCompletion:
52+
@staticmethod
53+
def create(**_):
54+
return {"choices": [{"message": {"content": ""}}]}
55+
56+
openai_module.api_key = None
57+
openai_module.organization = None
58+
openai_module.Completion = _Completion
59+
openai_module.ChatCompletion = _ChatCompletion
60+
openai_module.error = types.SimpleNamespace(OpenAIError=Exception)
61+
sys.modules["openai"] = openai_module
62+
63+
if "xturing" not in sys.modules:
64+
xturing_module = _make_module("xturing")
65+
xturing_module.__path__ = [
66+
str(Path(__file__).resolve().parents[3] / "src" / "xturing")
67+
]
68+
sys.modules["xturing"] = xturing_module
69+
70+
if "deepspeed" not in sys.modules:
71+
deepspeed_module = _make_module("deepspeed")
72+
ops_module = _make_module("deepspeed.ops")
73+
adam_module = _make_module("deepspeed.ops.adam")
74+
75+
class _DeepSpeedCPUAdam:
76+
def __init__(self, *_, **__):
77+
pass
78+
79+
adam_module.DeepSpeedCPUAdam = _DeepSpeedCPUAdam
80+
sys.modules["deepspeed"] = deepspeed_module
81+
sys.modules["deepspeed.ops"] = ops_module
82+
sys.modules["deepspeed.ops.adam"] = adam_module
83+
84+
85+
_install_stub_modules()
86+
387
from xturing.config.read_config import read_yaml
88+
from xturing.engines.base import BaseEngine
489
from xturing.engines.qwen_engine import (
590
Qwen3Engine,
691
Qwen3Int8Engine,
@@ -16,6 +101,9 @@
16101
Qwen3LoraInt8,
17102
Qwen3LoraKbit,
18103
)
104+
from xturing.preprocessors.base import BasePreprocessor
105+
from xturing.trainers.base import BaseTrainer
106+
from xturing.trainers.lightning_trainer import LightningTrainer
19107

20108

21109
def test_qwen3_model_registry_entries_present():
@@ -66,3 +154,129 @@ def test_qwen3_config_entries_exist():
66154
assert "qwen3_0_6b_int8" in finetuning_config
67155
assert "qwen3_0_6b_lora_int8" in finetuning_config
68156
assert "qwen3_0_6b_lora_kbit" in finetuning_config
157+
158+
159+
def test_qwen3_lora_instruction_sft(monkeypatch):
160+
class DummyInstructionDataset:
161+
config_name = "instruction_dataset"
162+
163+
def __init__(self, payload):
164+
self.payload = payload
165+
self._meta = type("Meta", (), {})()
166+
167+
@property
168+
def meta(self):
169+
return self._meta
170+
171+
def __len__(self):
172+
return len(self.payload["instruction"])
173+
174+
def __getitem__(self, idx):
175+
return {key: values[idx] for key, values in self.payload.items()}
176+
177+
class DummyTokenizer:
178+
eos_token_id = 0
179+
pad_token_id = 0
180+
pad_token = "<pad>"
181+
182+
def __call__(self, _):
183+
return {"input_ids": [0], "attention_mask": [1]}
184+
185+
def pad(self, samples, padding=True, max_length=None, return_tensors=None):
186+
batch_size = len(samples)
187+
return {
188+
"input_ids": [[0] for _ in range(batch_size)],
189+
"attention_mask": [[1] for _ in range(batch_size)],
190+
}
191+
192+
class DummyModel:
193+
def to(self, *_):
194+
return self
195+
196+
def eval(self):
197+
return self
198+
199+
def train(self):
200+
return self
201+
202+
class DummyEngine:
203+
def __init__(self, *_, **__):
204+
self.model = DummyModel()
205+
self.tokenizer = DummyTokenizer()
206+
207+
def save(self, *_):
208+
return None
209+
210+
class DummyCollator:
211+
def __init__(self, *_, **__):
212+
self.calls = 0
213+
214+
def __call__(self, batches):
215+
self.calls += 1
216+
batch_size = len(batches)
217+
return {
218+
"input_ids": [[0] for _ in range(batch_size)],
219+
"targets": [[0] for _ in range(batch_size)],
220+
}
221+
222+
trainers = []
223+
224+
class DummyTrainer:
225+
def __init__(
226+
self,
227+
engine,
228+
dataset,
229+
collate_fn,
230+
num_epochs,
231+
batch_size,
232+
learning_rate,
233+
optimizer_name,
234+
use_lora=False,
235+
use_deepspeed=False,
236+
logger=True,
237+
):
238+
self.engine = engine
239+
self.dataset = dataset
240+
self.collate_fn = collate_fn
241+
self.num_epochs = num_epochs
242+
self.batch_size = batch_size
243+
self.learning_rate = learning_rate
244+
self.optimizer_name = optimizer_name
245+
self.use_lora = use_lora
246+
self.use_deepspeed = use_deepspeed
247+
self.logger = logger
248+
self.fit_called = False
249+
trainers.append(self)
250+
251+
def fit(self):
252+
self.fit_called = True
253+
batch = self.collate_fn([self.dataset[0]])
254+
assert "input_ids" in batch
255+
assert len(batch["input_ids"]) == 1
256+
257+
monkeypatch.setitem(BaseEngine.registry, Qwen3LoraEngine.config_name, DummyEngine)
258+
monkeypatch.setitem(
259+
BasePreprocessor.registry, DummyInstructionDataset.config_name, DummyCollator
260+
)
261+
monkeypatch.setitem(
262+
BaseTrainer.registry, LightningTrainer.config_name, DummyTrainer
263+
)
264+
265+
dataset = DummyInstructionDataset(
266+
{
267+
"instruction": [
268+
"Rewrite the sentence in simple terms.",
269+
"Translate to English.",
270+
],
271+
"text": [
272+
"Quantum entanglement exhibits spooky action.",
273+
"Bonjour, comment ca va?",
274+
],
275+
"target": ["Particles can stay linked.", "Hello, how are you?"],
276+
}
277+
)
278+
279+
model = BaseModel.create("qwen3_0_6b_lora")
280+
model.finetune(dataset=dataset)
281+
282+
assert trainers and trainers[0].fit_called

0 commit comments

Comments
 (0)