|
| 1 | +import importlib.machinery |
| 2 | +import sys |
| 3 | +import types |
1 | 4 | from pathlib import Path |
2 | 5 |
|
| 6 | + |
| 7 | +def _make_module(name): |
| 8 | + module = types.ModuleType(name) |
| 9 | + module.__spec__ = importlib.machinery.ModuleSpec(name, loader=None) |
| 10 | + return module |
| 11 | + |
| 12 | + |
| 13 | +def _install_stub_modules(): |
| 14 | + if "ai21" not in sys.modules: |
| 15 | + ai21_module = _make_module("ai21") |
| 16 | + |
| 17 | + class _Completion: |
| 18 | + @staticmethod |
| 19 | + def execute(**_): |
| 20 | + return {"prompt": {"text": ""}} |
| 21 | + |
| 22 | + ai21_module.api_key = None |
| 23 | + ai21_module.Completion = _Completion |
| 24 | + sys.modules["ai21"] = ai21_module |
| 25 | + |
| 26 | + if "cohere" not in sys.modules: |
| 27 | + cohere_module = _make_module("cohere") |
| 28 | + |
| 29 | + class _CohereError(Exception): |
| 30 | + pass |
| 31 | + |
| 32 | + class _Client: |
| 33 | + def __init__(self, *_args, **_kwargs): |
| 34 | + self.generations = [types.SimpleNamespace(text="")] |
| 35 | + |
| 36 | + def generate(self, **_): |
| 37 | + return types.SimpleNamespace(generations=self.generations) |
| 38 | + |
| 39 | + cohere_module.CohereError = _CohereError |
| 40 | + cohere_module.Client = _Client |
| 41 | + sys.modules["cohere"] = cohere_module |
| 42 | + |
| 43 | + if "openai" not in sys.modules: |
| 44 | + openai_module = _make_module("openai") |
| 45 | + |
| 46 | + class _Completion: |
| 47 | + @staticmethod |
| 48 | + def create(n=1, **_): |
| 49 | + return {"choices": [types.SimpleNamespace(text="")] * n} |
| 50 | + |
| 51 | + class _ChatCompletion: |
| 52 | + @staticmethod |
| 53 | + def create(**_): |
| 54 | + return {"choices": [{"message": {"content": ""}}]} |
| 55 | + |
| 56 | + openai_module.api_key = None |
| 57 | + openai_module.organization = None |
| 58 | + openai_module.Completion = _Completion |
| 59 | + openai_module.ChatCompletion = _ChatCompletion |
| 60 | + openai_module.error = types.SimpleNamespace(OpenAIError=Exception) |
| 61 | + sys.modules["openai"] = openai_module |
| 62 | + |
| 63 | + if "xturing" not in sys.modules: |
| 64 | + xturing_module = _make_module("xturing") |
| 65 | + xturing_module.__path__ = [ |
| 66 | + str(Path(__file__).resolve().parents[3] / "src" / "xturing") |
| 67 | + ] |
| 68 | + sys.modules["xturing"] = xturing_module |
| 69 | + |
| 70 | + if "deepspeed" not in sys.modules: |
| 71 | + deepspeed_module = _make_module("deepspeed") |
| 72 | + ops_module = _make_module("deepspeed.ops") |
| 73 | + adam_module = _make_module("deepspeed.ops.adam") |
| 74 | + |
| 75 | + class _DeepSpeedCPUAdam: |
| 76 | + def __init__(self, *_, **__): |
| 77 | + pass |
| 78 | + |
| 79 | + adam_module.DeepSpeedCPUAdam = _DeepSpeedCPUAdam |
| 80 | + sys.modules["deepspeed"] = deepspeed_module |
| 81 | + sys.modules["deepspeed.ops"] = ops_module |
| 82 | + sys.modules["deepspeed.ops.adam"] = adam_module |
| 83 | + |
| 84 | + |
| 85 | +_install_stub_modules() |
| 86 | + |
3 | 87 | from xturing.config.read_config import read_yaml |
| 88 | +from xturing.engines.base import BaseEngine |
4 | 89 | from xturing.engines.qwen_engine import ( |
5 | 90 | Qwen3Engine, |
6 | 91 | Qwen3Int8Engine, |
|
16 | 101 | Qwen3LoraInt8, |
17 | 102 | Qwen3LoraKbit, |
18 | 103 | ) |
| 104 | +from xturing.preprocessors.base import BasePreprocessor |
| 105 | +from xturing.trainers.base import BaseTrainer |
| 106 | +from xturing.trainers.lightning_trainer import LightningTrainer |
19 | 107 |
|
20 | 108 |
|
21 | 109 | def test_qwen3_model_registry_entries_present(): |
@@ -66,3 +154,129 @@ def test_qwen3_config_entries_exist(): |
66 | 154 | assert "qwen3_0_6b_int8" in finetuning_config |
67 | 155 | assert "qwen3_0_6b_lora_int8" in finetuning_config |
68 | 156 | assert "qwen3_0_6b_lora_kbit" in finetuning_config |
| 157 | + |
| 158 | + |
| 159 | +def test_qwen3_lora_instruction_sft(monkeypatch): |
| 160 | + class DummyInstructionDataset: |
| 161 | + config_name = "instruction_dataset" |
| 162 | + |
| 163 | + def __init__(self, payload): |
| 164 | + self.payload = payload |
| 165 | + self._meta = type("Meta", (), {})() |
| 166 | + |
| 167 | + @property |
| 168 | + def meta(self): |
| 169 | + return self._meta |
| 170 | + |
| 171 | + def __len__(self): |
| 172 | + return len(self.payload["instruction"]) |
| 173 | + |
| 174 | + def __getitem__(self, idx): |
| 175 | + return {key: values[idx] for key, values in self.payload.items()} |
| 176 | + |
| 177 | + class DummyTokenizer: |
| 178 | + eos_token_id = 0 |
| 179 | + pad_token_id = 0 |
| 180 | + pad_token = "<pad>" |
| 181 | + |
| 182 | + def __call__(self, _): |
| 183 | + return {"input_ids": [0], "attention_mask": [1]} |
| 184 | + |
| 185 | + def pad(self, samples, padding=True, max_length=None, return_tensors=None): |
| 186 | + batch_size = len(samples) |
| 187 | + return { |
| 188 | + "input_ids": [[0] for _ in range(batch_size)], |
| 189 | + "attention_mask": [[1] for _ in range(batch_size)], |
| 190 | + } |
| 191 | + |
| 192 | + class DummyModel: |
| 193 | + def to(self, *_): |
| 194 | + return self |
| 195 | + |
| 196 | + def eval(self): |
| 197 | + return self |
| 198 | + |
| 199 | + def train(self): |
| 200 | + return self |
| 201 | + |
| 202 | + class DummyEngine: |
| 203 | + def __init__(self, *_, **__): |
| 204 | + self.model = DummyModel() |
| 205 | + self.tokenizer = DummyTokenizer() |
| 206 | + |
| 207 | + def save(self, *_): |
| 208 | + return None |
| 209 | + |
| 210 | + class DummyCollator: |
| 211 | + def __init__(self, *_, **__): |
| 212 | + self.calls = 0 |
| 213 | + |
| 214 | + def __call__(self, batches): |
| 215 | + self.calls += 1 |
| 216 | + batch_size = len(batches) |
| 217 | + return { |
| 218 | + "input_ids": [[0] for _ in range(batch_size)], |
| 219 | + "targets": [[0] for _ in range(batch_size)], |
| 220 | + } |
| 221 | + |
| 222 | + trainers = [] |
| 223 | + |
| 224 | + class DummyTrainer: |
| 225 | + def __init__( |
| 226 | + self, |
| 227 | + engine, |
| 228 | + dataset, |
| 229 | + collate_fn, |
| 230 | + num_epochs, |
| 231 | + batch_size, |
| 232 | + learning_rate, |
| 233 | + optimizer_name, |
| 234 | + use_lora=False, |
| 235 | + use_deepspeed=False, |
| 236 | + logger=True, |
| 237 | + ): |
| 238 | + self.engine = engine |
| 239 | + self.dataset = dataset |
| 240 | + self.collate_fn = collate_fn |
| 241 | + self.num_epochs = num_epochs |
| 242 | + self.batch_size = batch_size |
| 243 | + self.learning_rate = learning_rate |
| 244 | + self.optimizer_name = optimizer_name |
| 245 | + self.use_lora = use_lora |
| 246 | + self.use_deepspeed = use_deepspeed |
| 247 | + self.logger = logger |
| 248 | + self.fit_called = False |
| 249 | + trainers.append(self) |
| 250 | + |
| 251 | + def fit(self): |
| 252 | + self.fit_called = True |
| 253 | + batch = self.collate_fn([self.dataset[0]]) |
| 254 | + assert "input_ids" in batch |
| 255 | + assert len(batch["input_ids"]) == 1 |
| 256 | + |
| 257 | + monkeypatch.setitem(BaseEngine.registry, Qwen3LoraEngine.config_name, DummyEngine) |
| 258 | + monkeypatch.setitem( |
| 259 | + BasePreprocessor.registry, DummyInstructionDataset.config_name, DummyCollator |
| 260 | + ) |
| 261 | + monkeypatch.setitem( |
| 262 | + BaseTrainer.registry, LightningTrainer.config_name, DummyTrainer |
| 263 | + ) |
| 264 | + |
| 265 | + dataset = DummyInstructionDataset( |
| 266 | + { |
| 267 | + "instruction": [ |
| 268 | + "Rewrite the sentence in simple terms.", |
| 269 | + "Translate to English.", |
| 270 | + ], |
| 271 | + "text": [ |
| 272 | + "Quantum entanglement exhibits spooky action.", |
| 273 | + "Bonjour, comment ca va?", |
| 274 | + ], |
| 275 | + "target": ["Particles can stay linked.", "Hello, how are you?"], |
| 276 | + } |
| 277 | + ) |
| 278 | + |
| 279 | + model = BaseModel.create("qwen3_0_6b_lora") |
| 280 | + model.finetune(dataset=dataset) |
| 281 | + |
| 282 | + assert trainers and trainers[0].fit_called |
0 commit comments