Skip to content

Commit 9da47ce

Browse files
committed
✨ feat(common): add Qwen model support and refactor react_agent code into common module
- Add `common` package with shared components for LangGraph agents - Implement Qwen model integration with DashScope API and region-based endpoints - Update context default model to Qwen with colon separator syntax - Refactor react_agent code: move context, tools, utils, models, prompts to common package - Modify load_chat_model to support colon-separated model spec and Qwen models - Update pyproject.toml dependencies to include langchain-qwq and common package - Enhance .env.example with new API keys and configuration for Qwen and tracing - Update tests with fixtures, centralized test data, and integration tests for Qwen and OpenAI - Add comprehensive e2e tests validating ReAct pattern with tool usage and streaming support - Improve unit tests for context configuration and error handling related to model loading
1 parent 2e15137 commit 9da47ce

File tree

17 files changed

+1157
-65
lines changed

17 files changed

+1157
-65
lines changed

.env.example

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1-
TAVILY_API_KEY=...
1+
# Region for model providers: prc, international
2+
REGION=
23

3-
# To separate your traces from other application
4-
LANGSMITH_PROJECT=react-agent
4+
# OpenAI Compatible models
5+
OPENAI_API_KEY=sk-...
6+
# OPENAI_API_BASE=
57

6-
# The following depend on your selected configuration
8+
# DashScope (Qwen models)
9+
DASHSCOPE_API_KEY=sk-...
710

8-
## LLM choice:
9-
ANTHROPIC_API_KEY=....
10-
FIREWORKS_API_KEY=...
11-
OPENAI_API_KEY=...
11+
# LangSmith (tracing)
12+
LANGCHAIN_TRACING_V2=true
13+
LANGCHAIN_PROJECT=langgraph-up-react
14+
LANGCHAIN_API_KEY=lsv2_sk_...
15+
16+
# Search Engines
17+
TAVILY_API_KEY=tvly-...

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ dependencies = [
1313
"langchain-openai>=0.1.22",
1414
"langchain-anthropic>=0.1.23",
1515
"langchain>=0.2.14",
16-
"langchain-fireworks>=0.1.7",
1716
"python-dotenv>=1.0.1",
1817
"langchain-tavily>=0.1",
18+
"langchain-qwq>=0.2.1",
1919
]
2020

2121

@@ -27,10 +27,11 @@ requires = ["setuptools>=73.0.0", "wheel"]
2727
build-backend = "setuptools.build_meta"
2828

2929
[tool.setuptools]
30-
packages = ["langgraph.templates.react_agent", "react_agent"]
30+
packages = ["langgraph.templates.react_agent", "react_agent", "common"]
3131
[tool.setuptools.package-dir]
3232
"langgraph.templates.react_agent" = "src/react_agent"
3333
"react_agent" = "src/react_agent"
34+
"common" = "src/common"
3435

3536

3637
[tool.setuptools.package-data]
@@ -71,4 +72,5 @@ dev = [
7172
"pytest-asyncio>=0.23.0",
7273
"langgraph-sdk>=0.1.0",
7374
"mypy>=1.17.1",
75+
"ruff>=0.9.10",
7476
]

src/common/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""Shared components for LangGraph agents."""
2+
3+
from . import prompts
4+
from .context import Context
5+
from .models import create_qwen_model, get_supported_qwen_models
6+
from .tools import TOOLS, web_search
7+
from .utils import load_chat_model
8+
9+
__all__ = [
10+
"Context",
11+
"create_qwen_model",
12+
"get_supported_qwen_models",
13+
"TOOLS",
14+
"web_search",
15+
"load_chat_model",
16+
"prompts",
17+
]
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ class Context:
2222
)
2323

2424
model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
25-
default="anthropic/claude-3-5-sonnet-20240620",
25+
default="qwen:qwen-plus",
2626
metadata={
2727
"description": "The name of the language model to use for the agent's main interactions. "
28-
"Should be in the form: provider/model-name."
28+
"Should be in the form: provider:model-name."
2929
},
3030
)
3131

3232
max_search_results: int = field(
33-
default=10,
33+
default=5,
3434
metadata={
3535
"description": "The maximum number of search results to return for each search query."
3636
},

src/common/models.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""Custom model integrations for ReAct agent."""
2+
3+
import os
4+
from typing import Any, List, Optional, Union
5+
6+
from langchain_qwq import ChatQwen, ChatQwQ
7+
8+
9+
def create_qwen_model(
10+
model_name: str,
11+
api_key: Optional[str] = None,
12+
base_url: Optional[str] = None,
13+
region: Optional[str] = None,
14+
**kwargs: Any,
15+
) -> Union[ChatQwQ, ChatQwen]:
16+
"""Create a Qwen model with proper configuration.
17+
18+
Args:
19+
model_name: The model name (e.g., 'qwq-32b-preview', 'qwen-plus')
20+
api_key: DashScope API key (defaults to env var DASHSCOPE_API_KEY)
21+
base_url: Custom base URL for API (optional)
22+
region: Region setting ('prc' for China, 'international' for global)
23+
Defaults to env var REGION
24+
**kwargs: Additional model parameters
25+
26+
Returns:
27+
Configured ChatQwQ instance for QwQ/QvQ models or ChatQwen for other Qwen models
28+
"""
29+
# Get API key from env if not provided
30+
if api_key is None:
31+
api_key = os.getenv("DASHSCOPE_API_KEY")
32+
33+
# Get region from env if not provided
34+
if region is None:
35+
region = os.getenv("REGION")
36+
37+
# Set base URL based on region if not explicitly provided
38+
if base_url is None and region:
39+
if region.lower() == "prc":
40+
# China mainland endpoint
41+
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
42+
elif region.lower() == "international":
43+
# International endpoint
44+
base_url = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
45+
46+
# Create model configuration
47+
config = {"model": model_name, "api_key": api_key, **kwargs}
48+
49+
if base_url:
50+
config["base_url"] = base_url
51+
52+
# Select the appropriate chat model based on model name
53+
# Use ChatQwQ for QwQ and QvQ models, ChatQwen for other Qwen models
54+
if model_name.startswith(("qwq", "qvq")):
55+
return ChatQwQ(**config)
56+
else:
57+
return ChatQwen(**config)
58+
59+
60+
def get_supported_qwen_models() -> List[str]:
61+
"""Get list of supported Qwen models."""
62+
return [
63+
"qwen-plus",
64+
"qwen-turbo",
65+
"qwen-max",
66+
"qwq-32b-preview",
67+
"qvq-72b-preview",
68+
# Add more Qwen models as they become available
69+
]
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from langchain_tavily import TavilySearch
1212
from langgraph.runtime import get_runtime
1313

14-
from react_agent.context import Context
14+
from common.context import Context
1515

1616

17-
async def search(query: str) -> Optional[dict[str, Any]]:
17+
async def web_search(query: str) -> Optional[dict[str, Any]]:
1818
"""Search for general web results.
1919
2020
This function performs a search using the Tavily search engine, which is designed
@@ -26,4 +26,4 @@ async def search(query: str) -> Optional[dict[str, Any]]:
2626
return cast(dict[str, Any], await wrapped.ainvoke({"query": query}))
2727

2828

29-
TOOLS: List[Callable[..., Any]] = [search]
29+
TOOLS: List[Callable[..., Any]] = [web_search]
Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
"""Utility & helper functions."""
22

3+
from typing import Union
4+
35
from langchain.chat_models import init_chat_model
46
from langchain_core.language_models import BaseChatModel
57
from langchain_core.messages import BaseMessage
8+
from langchain_qwq import ChatQwen, ChatQwQ
9+
10+
from .models import create_qwen_model
611

712

813
def get_message_text(msg: BaseMessage) -> str:
@@ -17,11 +22,19 @@ def get_message_text(msg: BaseMessage) -> str:
1722
return "".join(txts).strip()
1823

1924

20-
def load_chat_model(fully_specified_name: str) -> BaseChatModel:
25+
def load_chat_model(
26+
fully_specified_name: str,
27+
) -> Union[BaseChatModel, ChatQwQ, ChatQwen]:
2128
"""Load a chat model from a fully specified name.
2229
2330
Args:
24-
fully_specified_name (str): String in the format 'provider/model'.
31+
fully_specified_name (str): String in the format 'provider:model'.
2532
"""
26-
provider, model = fully_specified_name.split("/", maxsplit=1)
33+
provider, model = fully_specified_name.split(":", maxsplit=1)
34+
35+
# Handle Qwen models specially with dashscope integration
36+
if provider.lower() == "qwen":
37+
return create_qwen_model(model)
38+
39+
# Use standard langchain initialization for other providers
2740
return init_chat_model(model, model_provider=provider)

src/react_agent/graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from langgraph.prebuilt import ToolNode
1212
from langgraph.runtime import Runtime
1313

14-
from react_agent.context import Context
14+
from common.context import Context
15+
from common.tools import TOOLS
16+
from common.utils import load_chat_model
1517
from react_agent.state import InputState, State
16-
from react_agent.tools import TOOLS
17-
from react_agent.utils import load_chat_model
1818

1919
# Define the function that calls the model
2020

tests/conftest.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22

33
import os
44
from pathlib import Path
5+
from unittest.mock import AsyncMock, Mock
56

67
import pytest
78
from dotenv import load_dotenv
9+
from langchain_core.messages import AIMessage, HumanMessage
10+
from langgraph_sdk import get_client
11+
12+
from tests.test_data import TestModels, TestUrls
813

914

1015
@pytest.fixture(scope="session", autouse=True)
@@ -24,3 +29,124 @@ def load_env():
2429

2530
if missing_keys:
2631
pytest.skip(f"Missing required environment variables: {missing_keys}")
32+
33+
34+
@pytest.fixture
35+
def mock_model():
36+
"""Create a mock model for testing."""
37+
mock = Mock()
38+
mock.bind_tools.return_value = mock
39+
mock.ainvoke = AsyncMock()
40+
return mock
41+
42+
43+
@pytest.fixture
44+
def sample_ai_response():
45+
"""Create a sample AI response for testing."""
46+
return AIMessage(
47+
content="This is a test response from the AI model.", tool_calls=[]
48+
)
49+
50+
51+
@pytest.fixture
52+
def sample_human_message():
53+
"""Create a sample human message for testing."""
54+
return HumanMessage(content="This is a test message from human.")
55+
56+
57+
@pytest.fixture
58+
def test_context():
59+
"""Create a test context with common configuration."""
60+
from common.context import Context
61+
62+
return Context(
63+
model=TestModels.QWEN_PLUS,
64+
system_prompt="You are a helpful AI assistant for testing.",
65+
)
66+
67+
68+
@pytest.fixture
69+
async def langgraph_client():
70+
"""Create a LangGraph client for e2e testing."""
71+
return get_client(url=TestUrls.LANGGRAPH_LOCAL)
72+
73+
74+
@pytest.fixture
75+
async def assistant_id(langgraph_client):
76+
"""Get the first available assistant ID for testing."""
77+
assistants = await langgraph_client.assistants.search()
78+
if not assistants:
79+
pytest.skip("No assistants found for e2e testing")
80+
return assistants[0]["assistant_id"]
81+
82+
83+
@pytest.fixture
84+
async def test_thread(langgraph_client):
85+
"""Create a test thread for e2e testing."""
86+
thread = await langgraph_client.threads.create()
87+
yield thread
88+
# Cleanup could be added here if needed
89+
90+
91+
class TestHelpers:
92+
"""Helper methods for common test operations."""
93+
94+
@staticmethod
95+
def assert_valid_response(
96+
messages: list, expected_content: str | None = None, min_messages: int = 2
97+
):
98+
"""Assert that a response has valid structure and content."""
99+
assert isinstance(messages, list), "Messages should be a list"
100+
assert len(messages) >= min_messages, (
101+
f"Should have at least {min_messages} messages"
102+
)
103+
104+
# Check final message structure
105+
final_message = messages[-1]
106+
assert isinstance(final_message, dict) or hasattr(final_message, "content"), (
107+
"Final message should have content attribute"
108+
)
109+
110+
if expected_content:
111+
content = str(
112+
getattr(final_message, "content", final_message.get("content", ""))
113+
).lower()
114+
assert expected_content.lower() in content, (
115+
f"Expected '{expected_content}' in response content: {content[:200]}..."
116+
)
117+
118+
@staticmethod
119+
def assert_tool_usage(messages: list, tool_name: str = "web_search"):
120+
"""Assert that a specific tool was used in the conversation."""
121+
tool_found = False
122+
for msg in messages:
123+
msg_dict = msg if isinstance(msg, dict) else msg.__dict__
124+
125+
# Check for tool calls
126+
if msg_dict.get("tool_calls"):
127+
for call in msg_dict["tool_calls"]:
128+
if isinstance(call, dict) and call.get("name") == tool_name:
129+
tool_found = True
130+
break
131+
132+
# Check for tool messages
133+
if msg_dict.get("type") == "tool" and tool_name in str(
134+
msg_dict.get("name", "")
135+
):
136+
tool_found = True
137+
break
138+
139+
assert tool_found, f"Tool '{tool_name}' should have been used in conversation"
140+
141+
@staticmethod
142+
def create_input_state(content: str):
143+
"""Create an InputState for testing."""
144+
from react_agent.state import InputState
145+
146+
return InputState(messages=[HumanMessage(content=content)])
147+
148+
149+
@pytest.fixture
150+
def test_helpers():
151+
"""Provide test helper methods."""
152+
return TestHelpers

0 commit comments

Comments
 (0)