diff --git a/README.md b/README.md index 67d05f29..13c72fc5 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ ___ -`xTuring` makes it simple, fast, and cost‑efficient to fine‑tune open‑source LLMs (e.g., GPT‑OSS, LLaMA/LLaMA 2, Falcon, GPT‑J, GPT‑2, OPT, Bloom, Cerebras, Galactica) on your own data — locally or in your private cloud. +`xTuring` makes it simple, fast, and cost‑efficient to fine‑tune open‑source LLMs (e.g., GPT‑OSS, LLaMA/LLaMA 2, Falcon, Qwen3, GPT‑J, GPT‑2, OPT, Bloom, Cerebras, Galactica) on your own data — locally or in your private cloud. Why xTuring: - Simple API for data prep, training, and inference @@ -162,6 +162,17 @@ outputs = model.generate(dataset = dataset, batch_size=10) ``` +7. __Qwen3 0.6B supervised fine-tuning__ – The lightweight Qwen3 0.6B checkpoint now has first-class support (registry, configs, docs, and examples) so you can launch SFT/LoRA jobs immediately. +```python +from xturing.datasets import InstructionDataset +from xturing.models import BaseModel + +dataset = InstructionDataset("./examples/models/llama/alpaca_data") +model = BaseModel.create("qwen3_0_6b_lora") +model.finetune(dataset=dataset) +``` +> See `examples/models/qwen3/qwen3_lora_finetune.py` for a runnable script. + An exploration of the [Llama LoRA INT4 working example](examples/features/int4_finetuning/LLaMA_lora_int4.ipynb) is recommended for an understanding of its application. For an extended insight, consider examining the [GenericModel working example](examples/features/generic/generic_model.py) available in the repository. @@ -290,7 +301,7 @@ Replace `` with a local directory or a Hugging Face model like `face - [x] Dataset generation using self-instruction - [x] Low-precision LoRA fine-tuning and unsupervised fine-tuning - [x] INT8 low-precision fine-tuning support -- [x] OpenAI, Cohere and AI21 Studio model APIs for dataset generation +- [x] OpenAI, Cohere, and Claude model APIs for dataset generation - [x] Added fine-tuned checkpoints for some models to the hub - [x] INT4 LLaMA LoRA fine-tuning demo - [x] INT4 LLaMA LoRA fine-tuning with INT4 generation diff --git a/docs/docs/advanced/generate.md b/docs/docs/advanced/generate.md index 38d3edb9..ac27ac86 100644 --- a/docs/docs/advanced/generate.md +++ b/docs/docs/advanced/generate.md @@ -26,23 +26,23 @@ engine = Davinci("your-api-key") - - - ```python - from xturing.model_apis.cohere import Medium - engine = Medium("your-api-key") - ``` - - - - - ```python - from xturing.model_apis.ai21 import J2Grande - engine = J2Grande("your-api-key") - ``` - - - + + + ```python + from xturing.model_apis.cohere import Medium + engine = Medium("your-api-key") + ``` + + + + + ```python + from xturing.model_apis.claude import ClaudeSonnet + engine = ClaudeSonnet("your-api-key") + ``` + + + ## From no data diff --git a/examples/datasets/create_alpaca_dataset.ipynb b/examples/datasets/create_alpaca_dataset.ipynb index 81914f4c..668896a9 100644 --- a/examples/datasets/create_alpaca_dataset.ipynb +++ b/examples/datasets/create_alpaca_dataset.ipynb @@ -42,9 +42,9 @@ "#\n", "# engine = Medium(\"your-api-key\")\n", "\n", - "# Alternatively, you can use AI21 to generate dataset\n", + "# Alternatively, you can use Claude to generate dataset\n", "\n", - "# from xturing.model_apis.ai21 import J2Grande\n", + "# from xturing.model_apis.claude import ClaudeSonnet\n", "#\n", "# engine = J2Grande(\"your-api-key\")" ], @@ -100,4 +100,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/examples/datasets/create_instruction_dataset_from_files.ipynb b/examples/datasets/create_instruction_dataset_from_files.ipynb index 15ecc291..62ee00f3 100644 --- a/examples/datasets/create_instruction_dataset_from_files.ipynb +++ b/examples/datasets/create_instruction_dataset_from_files.ipynb @@ -46,9 +46,9 @@ "#\n", "# engine = Medium(\"your-api-key\")\n", "\n", - "# Alternatively, you can use AI21 to generate dataset\n", + "# Alternatively, you can use Claude to generate dataset\n", "\n", - "# from xturing.model_apis.ai21 import J2Grande\n", + "# from xturing.model_apis.claude import ClaudeSonnet\n", "#\n", "# engine = J2Grande(\"your-api-key\")" ] @@ -124,4 +124,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/examples/features/dataset_generation/create_alpaca_dataset.ipynb b/examples/features/dataset_generation/create_alpaca_dataset.ipynb index 81914f4c..668896a9 100644 --- a/examples/features/dataset_generation/create_alpaca_dataset.ipynb +++ b/examples/features/dataset_generation/create_alpaca_dataset.ipynb @@ -42,9 +42,9 @@ "#\n", "# engine = Medium(\"your-api-key\")\n", "\n", - "# Alternatively, you can use AI21 to generate dataset\n", + "# Alternatively, you can use Claude to generate dataset\n", "\n", - "# from xturing.model_apis.ai21 import J2Grande\n", + "# from xturing.model_apis.claude import ClaudeSonnet\n", "#\n", "# engine = J2Grande(\"your-api-key\")" ], @@ -100,4 +100,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 9d0dd62a..537672fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ keywords = [ dependencies = [ "torch >= 1.9.0", "pytorch-lightning", - "transformers>=4.53.0", + "transformers>=4.36.0", "datasets==2.14.5", "pyarrow >= 8.0.0, < 21.0.0", "scipy >= 1.0.0", @@ -54,8 +54,8 @@ dependencies = [ "gradio>=5.31.0", "click", "wget", - "ai21", "cohere", + "anthropic", "ipywidgets", "openai >= 0.27.0", "pydantic >= 1.10.0", diff --git a/src/xturing/model_apis/__init__.py b/src/xturing/model_apis/__init__.py index eccca014..4fced5ed 100644 --- a/src/xturing/model_apis/__init__.py +++ b/src/xturing/model_apis/__init__.py @@ -1,6 +1,5 @@ -from xturing.model_apis.ai21 import AI21TextGenerationAPI -from xturing.model_apis.ai21 import J2Grande as AI21J2Grande from xturing.model_apis.base import BaseApi, TextGenerationAPI +from xturing.model_apis.claude import ClaudeSonnet, ClaudeTextGenerationAPI from xturing.model_apis.cohere import CohereTextGenerationAPI from xturing.model_apis.cohere import Medium as CohereMedium from xturing.model_apis.openai import ChatGPT as OpenAIChatGPT @@ -9,8 +8,8 @@ BaseApi.add_to_registry(OpenAITextGenerationAPI.config_name, OpenAITextGenerationAPI) BaseApi.add_to_registry(CohereTextGenerationAPI.config_name, CohereTextGenerationAPI) -BaseApi.add_to_registry(AI21TextGenerationAPI.config_name, AI21TextGenerationAPI) +BaseApi.add_to_registry(ClaudeTextGenerationAPI.config_name, ClaudeTextGenerationAPI) BaseApi.add_to_registry(OpenAIDavinci.config_name, OpenAIDavinci) BaseApi.add_to_registry(OpenAIChatGPT.config_name, OpenAIChatGPT) BaseApi.add_to_registry(CohereMedium.config_name, CohereMedium) -BaseApi.add_to_registry(AI21J2Grande.config_name, AI21J2Grande) +BaseApi.add_to_registry(ClaudeSonnet.config_name, ClaudeSonnet) diff --git a/src/xturing/model_apis/ai21.py b/src/xturing/model_apis/ai21.py deleted file mode 100644 index 04c10adf..00000000 --- a/src/xturing/model_apis/ai21.py +++ /dev/null @@ -1,70 +0,0 @@ -import time -from datetime import datetime - -import ai21 - -from xturing.model_apis.base import TextGenerationAPI - - -class AI21TextGenerationAPI(TextGenerationAPI): - config_name = "ai21" - - def __init__(self, engine, api_key): - super().__init__(engine, api_key=api_key, request_batch_size=1) - ai21.api_key = api_key - - def generate_text( - self, - prompts, - max_tokens, - temperature, - top_p, - stop_sequences, - retries=3, - **kwargs, - ): - response = None - retry_cnt = 0 - backoff_time = 30 - while retry_cnt <= retries: - try: - response = ai21.Completion.execute( - model=self.engine, - prompt=prompts[0], - numResults=1, - maxTokens=max_tokens, - temperature=temperature, - topKReturn=0, - topP=top_p, - stopSequences=stop_sequences, - ) - break - except Exception as e: - print(f"AI21Error: {e}.") - print(f"Retrying in {backoff_time} seconds...") - time.sleep(backoff_time) - backoff_time *= 1.5 - retry_cnt += 1 - - predicts = { - "choices": [ - { - "text": response["prompt"]["text"], - "finish_reason": "eos", - } - ] - } - - data = { - "prompt": prompts, - "response": predicts, - "created_at": str(datetime.now()), - } - return [data] - - -class J2Grande(AI21TextGenerationAPI): - config_name = "ai21_j2_grande" - - def __init__(self, api_key): - super().__init__(engine="j2-grande", api_key=api_key) diff --git a/src/xturing/model_apis/claude.py b/src/xturing/model_apis/claude.py new file mode 100644 index 00000000..eea76822 --- /dev/null +++ b/src/xturing/model_apis/claude.py @@ -0,0 +1,131 @@ +import time +from datetime import datetime + +try: + from anthropic import ( + APIConnectionError as AnthropicAPIConnectionError, + APIError as AnthropicAPIError, + Anthropic, + RateLimitError as AnthropicRateLimitError, + ) +except ModuleNotFoundError as import_err: # pragma: no cover - optional dependency + Anthropic = None + AnthropicAPIError = AnthropicAPIConnectionError = AnthropicRateLimitError = Exception + _ANTHROPIC_IMPORT_ERROR = import_err +else: # pragma: no cover - dependency import paths exercised in runtime envs + _ANTHROPIC_IMPORT_ERROR = None + +from xturing.model_apis.base import TextGenerationAPI + + +class ClaudeTextGenerationAPI(TextGenerationAPI): + config_name = "claude" + + def __init__(self, model, api_key, request_batch_size=1): + self._ensure_dependency() + super().__init__(engine=model, api_key=api_key, request_batch_size=request_batch_size) + self._client = Anthropic(api_key=api_key) + + @staticmethod + def _ensure_dependency(): + if Anthropic is None: + message = ( + "The anthropic SDK is required for ClaudeTextGenerationAPI. " + "Install it with `pip install anthropic`." + ) + raise ModuleNotFoundError(message) from _ANTHROPIC_IMPORT_ERROR + + def _make_request(self, prompt, max_tokens, temperature, top_p, stop_sequences): + params = { + "model": self.engine, + "max_tokens": max_tokens, + "temperature": temperature, + "messages": [{"role": "user", "content": prompt}], + } + if top_p is not None: + params["top_p"] = top_p + if stop_sequences: + params["stop_sequences"] = stop_sequences + return self._client.messages.create(**params) + + @staticmethod + def _render_response(response): + if response is None: + return None + text_chunks = [] + for block in getattr(response, "content", []): + if getattr(block, "type", None) == "text": + text_chunks.append(getattr(block, "text", "")) + predicts = { + "choices": [ + { + "text": "".join(text_chunks), + "finish_reason": getattr(response, "stop_reason", "eos"), + } + ] + } + return predicts + + def generate_text( + self, + prompts, + max_tokens, + temperature, + top_p=None, + frequency_penalty=None, + presence_penalty=None, + stop_sequences=None, + logprobs=None, + n=1, + best_of=1, + retries=3, + **kwargs, + ): + if not isinstance(prompts, list): + prompts = [prompts] + + results = [] + for prompt in prompts: + response = None + retry_cnt = 0 + backoff_time = 30 + while retry_cnt <= retries: + try: + response = self._make_request( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop_sequences=stop_sequences, + ) + break + except ( + AnthropicAPIError, + AnthropicAPIConnectionError, + AnthropicRateLimitError, + ) as e: + print(f"ClaudeError: {e}.") + print(f"Retrying in {backoff_time} seconds...") + time.sleep(backoff_time) + backoff_time *= 1.5 + retry_cnt += 1 + + data = { + "prompt": prompt, + "response": self._render_response(response), + "created_at": str(datetime.now()), + } + results.append(data) + + return results + + +class ClaudeSonnet(ClaudeTextGenerationAPI): + config_name = "claude_3_sonnet" + + def __init__(self, api_key, request_batch_size=1): + super().__init__( + model="claude-3-sonnet-20240229", + api_key=api_key, + request_batch_size=request_batch_size, + ) diff --git a/tests/xturing/model_apis/__init__.py b/tests/xturing/model_apis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/xturing/model_apis/test_claude_api.py b/tests/xturing/model_apis/test_claude_api.py new file mode 100644 index 00000000..0353d5e7 --- /dev/null +++ b/tests/xturing/model_apis/test_claude_api.py @@ -0,0 +1,323 @@ +import unittest +from unittest.mock import MagicMock, patch + +import pytest + + +class TestClaudeTextGenerationAPI: + """Test suite for ClaudeTextGenerationAPI""" + + def test_missing_anthropic_dependency(self): + """Test that missing anthropic package raises ModuleNotFoundError""" + with patch.dict("sys.modules", {"anthropic": None}): + # Force reimport to trigger the import error path + import importlib + import sys + + # Remove from cache if present + if "xturing.model_apis.claude" in sys.modules: + del sys.modules["xturing.model_apis.claude"] + + # This should work but _ensure_dependency should fail + from xturing.model_apis.claude import ClaudeTextGenerationAPI + + with pytest.raises(ModuleNotFoundError, match="anthropic SDK is required"): + ClaudeTextGenerationAPI( + model="claude-3-sonnet-20240229", + api_key="test-key", + ) + + def test_initialization(self): + """Test ClaudeTextGenerationAPI initialization""" + from xturing.model_apis.claude import ClaudeTextGenerationAPI + + with patch("xturing.model_apis.claude.Anthropic") as mock_anthropic: + api = ClaudeTextGenerationAPI( + model="claude-3-sonnet-20240229", + api_key="test-key", + request_batch_size=5, + ) + + assert api.engine == "claude-3-sonnet-20240229" + assert api.api_key == "test-key" + assert api.request_batch_size == 5 + mock_anthropic.assert_called_once_with(api_key="test-key") + + def test_claude_sonnet_initialization(self): + """Test ClaudeSonnet convenience class initialization""" + from xturing.model_apis.claude import ClaudeSonnet + + with patch("xturing.model_apis.claude.Anthropic"): + api = ClaudeSonnet(api_key="test-key", request_batch_size=3) + + assert api.engine == "claude-3-sonnet-20240229" + assert api.api_key == "test-key" + assert api.request_batch_size == 3 + assert api.config_name == "claude_3_sonnet" + + def test_make_request_basic(self): + """Test _make_request with basic parameters""" + from xturing.model_apis.claude import ClaudeTextGenerationAPI + + with patch("xturing.model_apis.claude.Anthropic") as mock_anthropic: + mock_client = MagicMock() + mock_anthropic.return_value = mock_client + + api = ClaudeTextGenerationAPI( + model="claude-3-sonnet-20240229", + api_key="test-key", + ) + + api._make_request( + prompt="Hello, world!", + max_tokens=100, + temperature=0.7, + top_p=None, + stop_sequences=None, + ) + + mock_client.messages.create.assert_called_once_with( + model="claude-3-sonnet-20240229", + max_tokens=100, + temperature=0.7, + messages=[{"role": "user", "content": "Hello, world!"}], + ) + + def test_make_request_with_optional_params(self): + """Test _make_request with optional parameters""" + from xturing.model_apis.claude import ClaudeTextGenerationAPI + + with patch("xturing.model_apis.claude.Anthropic") as mock_anthropic: + mock_client = MagicMock() + mock_anthropic.return_value = mock_client + + api = ClaudeTextGenerationAPI( + model="claude-3-sonnet-20240229", + api_key="test-key", + ) + + api._make_request( + prompt="Hello, world!", + max_tokens=100, + temperature=0.7, + top_p=0.9, + stop_sequences=["STOP", "END"], + ) + + mock_client.messages.create.assert_called_once_with( + model="claude-3-sonnet-20240229", + max_tokens=100, + temperature=0.7, + top_p=0.9, + stop_sequences=["STOP", "END"], + messages=[{"role": "user", "content": "Hello, world!"}], + ) + + def test_render_response_success(self): + """Test _render_response with successful response""" + from xturing.model_apis.claude import ClaudeTextGenerationAPI + + # Mock response object + mock_response = MagicMock() + mock_text_block = MagicMock() + mock_text_block.type = "text" + mock_text_block.text = "This is a response" + mock_response.content = [mock_text_block] + mock_response.stop_reason = "end_turn" + + result = ClaudeTextGenerationAPI._render_response(mock_response) + + assert result == { + "choices": [ + { + "text": "This is a response", + "finish_reason": "end_turn", + } + ] + } + + def test_render_response_multiple_blocks(self): + """Test _render_response with multiple text blocks""" + from xturing.model_apis.claude import ClaudeTextGenerationAPI + + # Mock response with multiple text blocks + mock_response = MagicMock() + mock_block1 = MagicMock() + mock_block1.type = "text" + mock_block1.text = "Part 1 " + + mock_block2 = MagicMock() + mock_block2.type = "text" + mock_block2.text = "Part 2" + + mock_response.content = [mock_block1, mock_block2] + mock_response.stop_reason = "max_tokens" + + result = ClaudeTextGenerationAPI._render_response(mock_response) + + assert result == { + "choices": [ + { + "text": "Part 1 Part 2", + "finish_reason": "max_tokens", + } + ] + } + + def test_render_response_none(self): + """Test _render_response with None response""" + from xturing.model_apis.claude import ClaudeTextGenerationAPI + + result = ClaudeTextGenerationAPI._render_response(None) + assert result is None + + def test_generate_text_single_prompt(self): + """Test generate_text with single prompt""" + from xturing.model_apis.claude import ClaudeTextGenerationAPI + + with patch("xturing.model_apis.claude.Anthropic") as mock_anthropic: + mock_client = MagicMock() + mock_anthropic.return_value = mock_client + + # Mock response + mock_response = MagicMock() + mock_text_block = MagicMock() + mock_text_block.type = "text" + mock_text_block.text = "Generated text" + mock_response.content = [mock_text_block] + mock_response.stop_reason = "end_turn" + + mock_client.messages.create.return_value = mock_response + + api = ClaudeTextGenerationAPI( + model="claude-3-sonnet-20240229", + api_key="test-key", + ) + + results = api.generate_text( + prompts="Test prompt", + max_tokens=100, + temperature=0.7, + ) + + assert len(results) == 1 + assert results[0]["prompt"] == "Test prompt" + assert results[0]["response"]["choices"][0]["text"] == "Generated text" + assert "created_at" in results[0] + + def test_generate_text_multiple_prompts(self): + """Test generate_text with multiple prompts""" + from xturing.model_apis.claude import ClaudeTextGenerationAPI + + with patch("xturing.model_apis.claude.Anthropic") as mock_anthropic: + mock_client = MagicMock() + mock_anthropic.return_value = mock_client + + # Mock response + mock_response = MagicMock() + mock_text_block = MagicMock() + mock_text_block.type = "text" + mock_text_block.text = "Generated text" + mock_response.content = [mock_text_block] + mock_response.stop_reason = "end_turn" + + mock_client.messages.create.return_value = mock_response + + api = ClaudeTextGenerationAPI( + model="claude-3-sonnet-20240229", + api_key="test-key", + ) + + results = api.generate_text( + prompts=["Prompt 1", "Prompt 2", "Prompt 3"], + max_tokens=100, + temperature=0.7, + ) + + assert len(results) == 3 + assert results[0]["prompt"] == "Prompt 1" + assert results[1]["prompt"] == "Prompt 2" + assert results[2]["prompt"] == "Prompt 3" + + def test_generate_text_with_retry(self): + """Test generate_text retry logic on API errors""" + from xturing.model_apis.claude import ClaudeTextGenerationAPI + + with patch("xturing.model_apis.claude.Anthropic") as mock_anthropic: + with patch("time.sleep"): # Mock sleep to speed up test + mock_client = MagicMock() + mock_anthropic.return_value = mock_client + + # Mock successful response + mock_response = MagicMock() + mock_text_block = MagicMock() + mock_text_block.type = "text" + mock_text_block.text = "Generated text" + mock_response.content = [mock_text_block] + mock_response.stop_reason = "end_turn" + + # First call fails, second succeeds + from anthropic import RateLimitError + + mock_client.messages.create.side_effect = [ + RateLimitError("Rate limit exceeded", response=None, body=None), + mock_response, + ] + + api = ClaudeTextGenerationAPI( + model="claude-3-sonnet-20240229", + api_key="test-key", + ) + + results = api.generate_text( + prompts="Test prompt", + max_tokens=100, + temperature=0.7, + retries=3, + ) + + assert len(results) == 1 + assert results[0]["response"]["choices"][0]["text"] == "Generated text" + # Should have been called twice (1 failure + 1 success) + assert mock_client.messages.create.call_count == 2 + + def test_generate_text_max_retries_exceeded(self): + """Test generate_text when max retries exceeded""" + from xturing.model_apis.claude import ClaudeTextGenerationAPI + + with patch("xturing.model_apis.claude.Anthropic") as mock_anthropic: + with patch("time.sleep"): # Mock sleep to speed up test + mock_client = MagicMock() + mock_anthropic.return_value = mock_client + + # Always fail + from anthropic import APIError + + mock_client.messages.create.side_effect = APIError( + "API Error", response=None, body=None + ) + + api = ClaudeTextGenerationAPI( + model="claude-3-sonnet-20240229", + api_key="test-key", + ) + + results = api.generate_text( + prompts="Test prompt", + max_tokens=100, + temperature=0.7, + retries=2, + ) + + assert len(results) == 1 + assert results[0]["prompt"] == "Test prompt" + assert results[0]["response"] is None + # Should have been called 3 times (initial + 2 retries) + assert mock_client.messages.create.call_count == 3 + + def test_config_names(self): + """Test that config names are set correctly""" + from xturing.model_apis.claude import ClaudeSonnet, ClaudeTextGenerationAPI + + assert ClaudeTextGenerationAPI.config_name == "claude" + assert ClaudeSonnet.config_name == "claude_3_sonnet" diff --git a/tests/xturing/models/test_qwen_model.py b/tests/xturing/models/test_qwen_model.py index 82112c0c..7d16ad30 100644 --- a/tests/xturing/models/test_qwen_model.py +++ b/tests/xturing/models/test_qwen_model.py @@ -1,6 +1,97 @@ +import importlib.machinery +import sys +import types from pathlib import Path + +def _make_module(name): + module = types.ModuleType(name) + module.__spec__ = importlib.machinery.ModuleSpec(name, loader=None) + return module + + +def _install_stub_modules(): + if "cohere" not in sys.modules: + cohere_module = _make_module("cohere") + + class _CohereError(Exception): + pass + + class _Client: + def __init__(self, *_args, **_kwargs): + self.generations = [types.SimpleNamespace(text="")] + + def generate(self, **_): + return types.SimpleNamespace(generations=self.generations) + + cohere_module.CohereError = _CohereError + cohere_module.Client = _Client + sys.modules["cohere"] = cohere_module + + if "openai" not in sys.modules: + openai_module = _make_module("openai") + + class _Completion: + @staticmethod + def create(n=1, **_): + return {"choices": [types.SimpleNamespace(text="")] * n} + + class _ChatCompletion: + @staticmethod + def create(**_): + return {"choices": [{"message": {"content": ""}}]} + + openai_module.api_key = None + openai_module.organization = None + openai_module.Completion = _Completion + openai_module.ChatCompletion = _ChatCompletion + openai_module.error = types.SimpleNamespace(OpenAIError=Exception) + sys.modules["openai"] = openai_module + + if "anthropic" not in sys.modules: + anthropic_module = _make_module("anthropic") + + class _Messages: + def create(self, **_): + content_block = types.SimpleNamespace(type="text", text="") + return types.SimpleNamespace(content=[content_block], stop_reason="stop") + + class _Anthropic: + def __init__(self, *_args, **_kwargs): + self.messages = _Messages() + + anthropic_module.Anthropic = _Anthropic + anthropic_module.APIError = Exception + anthropic_module.APIConnectionError = Exception + anthropic_module.RateLimitError = Exception + sys.modules["anthropic"] = anthropic_module + + if "xturing" not in sys.modules: + xturing_module = _make_module("xturing") + xturing_module.__path__ = [ + str(Path(__file__).resolve().parents[3] / "src" / "xturing") + ] + sys.modules["xturing"] = xturing_module + + if "deepspeed" not in sys.modules: + deepspeed_module = _make_module("deepspeed") + ops_module = _make_module("deepspeed.ops") + adam_module = _make_module("deepspeed.ops.adam") + + class _DeepSpeedCPUAdam: + def __init__(self, *_, **__): + pass + + adam_module.DeepSpeedCPUAdam = _DeepSpeedCPUAdam + sys.modules["deepspeed"] = deepspeed_module + sys.modules["deepspeed.ops"] = ops_module + sys.modules["deepspeed.ops.adam"] = adam_module + + +_install_stub_modules() + from xturing.config.read_config import read_yaml +from xturing.engines.base import BaseEngine from xturing.engines.qwen_engine import ( Qwen3Engine, Qwen3Int8Engine, @@ -16,6 +107,9 @@ Qwen3LoraInt8, Qwen3LoraKbit, ) +from xturing.preprocessors.base import BasePreprocessor +from xturing.trainers.base import BaseTrainer +from xturing.trainers.lightning_trainer import LightningTrainer def test_qwen3_model_registry_entries_present(): @@ -66,3 +160,129 @@ def test_qwen3_config_entries_exist(): assert "qwen3_0_6b_int8" in finetuning_config assert "qwen3_0_6b_lora_int8" in finetuning_config assert "qwen3_0_6b_lora_kbit" in finetuning_config + + +def test_qwen3_lora_instruction_sft(monkeypatch): + class DummyInstructionDataset: + config_name = "instruction_dataset" + + def __init__(self, payload): + self.payload = payload + self._meta = type("Meta", (), {})() + + @property + def meta(self): + return self._meta + + def __len__(self): + return len(self.payload["instruction"]) + + def __getitem__(self, idx): + return {key: values[idx] for key, values in self.payload.items()} + + class DummyTokenizer: + eos_token_id = 0 + pad_token_id = 0 + pad_token = "" + + def __call__(self, _): + return {"input_ids": [0], "attention_mask": [1]} + + def pad(self, samples, padding=True, max_length=None, return_tensors=None): + batch_size = len(samples) + return { + "input_ids": [[0] for _ in range(batch_size)], + "attention_mask": [[1] for _ in range(batch_size)], + } + + class DummyModel: + def to(self, *_): + return self + + def eval(self): + return self + + def train(self): + return self + + class DummyEngine: + def __init__(self, *_, **__): + self.model = DummyModel() + self.tokenizer = DummyTokenizer() + + def save(self, *_): + return None + + class DummyCollator: + def __init__(self, *_, **__): + self.calls = 0 + + def __call__(self, batches): + self.calls += 1 + batch_size = len(batches) + return { + "input_ids": [[0] for _ in range(batch_size)], + "targets": [[0] for _ in range(batch_size)], + } + + trainers = [] + + class DummyTrainer: + def __init__( + self, + engine, + dataset, + collate_fn, + num_epochs, + batch_size, + learning_rate, + optimizer_name, + use_lora=False, + use_deepspeed=False, + logger=True, + ): + self.engine = engine + self.dataset = dataset + self.collate_fn = collate_fn + self.num_epochs = num_epochs + self.batch_size = batch_size + self.learning_rate = learning_rate + self.optimizer_name = optimizer_name + self.use_lora = use_lora + self.use_deepspeed = use_deepspeed + self.logger = logger + self.fit_called = False + trainers.append(self) + + def fit(self): + self.fit_called = True + batch = self.collate_fn([self.dataset[0]]) + assert "input_ids" in batch + assert len(batch["input_ids"]) == 1 + + monkeypatch.setitem(BaseEngine.registry, Qwen3LoraEngine.config_name, DummyEngine) + monkeypatch.setitem( + BasePreprocessor.registry, DummyInstructionDataset.config_name, DummyCollator + ) + monkeypatch.setitem( + BaseTrainer.registry, LightningTrainer.config_name, DummyTrainer + ) + + dataset = DummyInstructionDataset( + { + "instruction": [ + "Rewrite the sentence in simple terms.", + "Translate to English.", + ], + "text": [ + "Quantum entanglement exhibits spooky action.", + "Bonjour, comment ca va?", + ], + "target": ["Particles can stay linked.", "Hello, how are you?"], + } + ) + + model = BaseModel.create("qwen3_0_6b_lora") + model.finetune(dataset=dataset) + + assert trainers and trainers[0].fit_called