Skip to content

Commit cee6a64

Browse files
committed
Merge branch 'main' of github.com:xerrors/Yuxi-Know
2 parents 27f71c0 + cb516b4 commit cee6a64

File tree

15 files changed

+492
-34
lines changed

15 files changed

+492
-34
lines changed

.gitignore

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,10 @@ cache
4040
.trae
4141
.pytest_cache
4242

43-
### (企业私有代码 - 仅忽略敏感配置,不忽略代码文件)
44-
# 移除了 *.private* 和 *_private 规则,允许 Git 本地管理
45-
# 通过 .git/info/exclude 或本地分支管理私有代码
4643
*.secret*
4744
*.nogit*
45+
*_private
46+
*.private
4847
# *.local* 保留用于本地配置文件
4948
*.local.py
5049
*.local.js

src/agents/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import inspect
44
from pathlib import Path
55

6+
from server.utils.singleton import SingletonMeta
67
from src.agents.common.base import BaseAgent
78
from src.utils import logger
89

910

10-
class AgentManager:
11+
class AgentManager(metaclass=SingletonMeta):
1112
def __init__(self):
1213
self._classes = {}
1314
self._instances = {} # 存储已创建的 agent 实例

src/agents/common/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ async def stream_messages(self, messages: list[str], input_context=None, **kwarg
6767
):
6868
yield msg, metadata
6969

70+
async def invoke_messages(self, messages: list[str], input_context=None, **kwargs):
71+
graph = await self.get_graph()
72+
context = self.context_schema.from_file(module_name=self.module_name, input_context=input_context)
73+
logger.debug(f"invoke_messages: {context}")
74+
input_config = {"configurable": input_context, "recursion_limit": 100}
75+
msg = await graph.ainvoke({"messages": messages}, context=context, config=input_config)
76+
return msg
77+
7078
async def check_checkpointer(self):
7179
app = await self.get_graph()
7280
if not hasattr(app, "checkpointer") or app.checkpointer is None:

src/agents/common/state.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Define the state structures for the agent."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Sequence
6+
from dataclasses import dataclass, field
7+
from typing import Annotated
8+
9+
from langchain.messages import AnyMessage
10+
from langgraph.graph import add_messages
11+
12+
13+
@dataclass
14+
class BaseState:
15+
"""Defines the input state for the agent, representing a narrower interface to the outside world.
16+
17+
This class is used to define the initial state and structure of incoming data.
18+
"""
19+
20+
messages: Annotated[Sequence[AnyMessage], add_messages] = field(default_factory=list)

src/agents/common/toolagent.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from abc import abstractmethod
2+
from typing import Any, cast
3+
4+
from langchain.messages import AIMessage, ToolMessage
5+
from langgraph.prebuilt import ToolNode
6+
from langgraph.runtime import Runtime
7+
8+
from src.agents.common.base import BaseAgent
9+
from src.agents.common.mcp import get_mcp_tools
10+
from src.agents.common.models import load_chat_model
11+
from src.utils import logger
12+
13+
from .state import BaseState
14+
from .context import BaseContext
15+
16+
class ToolAgent(BaseAgent):
17+
name = "ToolAgent"
18+
description = "具有工具调用能力的Agent"
19+
20+
def __init__(self, **kwargs):
21+
super().__init__(**kwargs)
22+
self.graph = None
23+
self.checkpointer = None
24+
self.context_schema = BaseContext
25+
self.agent_tools = None
26+
27+
28+
# TODO:[修改建议] _get_invoke_tools,llm_call,dynamic_tools_node这类针对工具调用的功能大多数Agent都能用得到
29+
# 可以通过一个ToolAgent类继承BaseAgent,通过重写抽象方法获取tools,通过继承BaseState和BaseContext获取配置
30+
# 必要时可通过重写以下方法实现其他逻辑
31+
@abstractmethod
32+
def get_tools(self):
33+
logger.error(f"get_tools() is not implemented in {self.__class__.__name__}")
34+
return []
35+
36+
37+
async def _get_invoke_tools(self, selected_tools: list[str], selected_mcps: list[str]):
38+
"""根据配置获取工具。
39+
默认不使用任何工具。
40+
如果配置为列表,则使用列表中的工具。
41+
"""
42+
enabled_tools = []
43+
self.agent_tools = self.agent_tools or self.get_tools()
44+
if selected_tools and isinstance(selected_tools, list) and len(selected_tools) > 0:
45+
# 使用配置中指定的工具
46+
enabled_tools = [tool for tool in self.agent_tools if tool.name in selected_tools]
47+
48+
if selected_mcps and isinstance(selected_mcps, list) and len(selected_mcps) > 0:
49+
for mcp in selected_mcps:
50+
enabled_tools.extend(await get_mcp_tools(mcp))
51+
52+
return enabled_tools
53+
54+
async def llm_call(self, state: BaseState, runtime: Runtime[BaseContext] = None) -> dict[str, Any]:
55+
"""调用 llm 模型 - 异步版本以支持异步工具"""
56+
model = load_chat_model(runtime.context.model)
57+
58+
# 这里要根据配置动态获取工具
59+
available_tools = await self._get_invoke_tools(runtime.context.tools, runtime.context.mcps)
60+
logger.info(f"LLM binded ({len(available_tools)}) available_tools: {[tool.name for tool in available_tools]}")
61+
62+
if available_tools:
63+
model = model.bind_tools(available_tools)
64+
65+
# 使用异步调用
66+
response = cast(
67+
AIMessage,
68+
await model.ainvoke([{"role": "system", "content": runtime.context.system_prompt}, *state.messages]),
69+
)
70+
return {"messages": [response]}
71+
72+
async def dynamic_tools_node(self, state: BaseState, runtime: Runtime[BaseContext]) -> dict[str, list[ToolMessage]]:
73+
"""Execute tools dynamically based on configuration.
74+
75+
This function gets the available tools based on the current configuration
76+
and executes the requested tool calls from the last message.
77+
"""
78+
# Get available tools based on configuration
79+
available_tools = await self._get_invoke_tools(runtime.context.tools, runtime.context.mcps)
80+
81+
# Create a ToolNode with the available tools
82+
tool_node = ToolNode(available_tools)
83+
84+
# Execute the tool node
85+
result = await tool_node.ainvoke(state)
86+
87+
return cast(dict[str, list[ToolMessage]], result)

src/agents/common/tools.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,54 @@
55
from langchain.tools import tool
66
from langchain_core.tools import StructuredTool
77
from langchain_tavily import TavilySearch
8+
from langgraph.types import interrupt
89
from pydantic import BaseModel, Field
910

1011
from src import config, graph_base, knowledge_base
1112
from src.utils import logger
1213

14+
# TODO[修改建议]:前端需要通过interrupt进行交互,点击是或否来批准执行
15+
# 返回中断点:
16+
# is_approved : bool = True 或者 False
17+
# resume_command = Command(resume=is_approved)
18+
# stream = graph.stream(resume_command, config=config, stream_mode="messages")
19+
# graph.invoke(resume_command, config=config)
20+
@tool(name_or_callable="人工审批工具", description="请求人工审批工具,用于在执行重要操作前获得人类确认。")
21+
def get_approved_user_goal(
22+
operation_description: str,
23+
)->dict:
24+
"""
25+
请求人工审批,在执行重要操作前获得人类确认。
26+
27+
Args:
28+
operation_description: 需要审批的操作描述,例如 "调用知识库工具"
29+
Returns:
30+
dict: 包含审批结果的字典,格式为 {"approved": bool, "message": str}
31+
"""
32+
# 构建详细的中断信息
33+
interrupt_info = {
34+
"question": f"是否批准以下操作?",
35+
"operation": operation_description,
36+
}
37+
38+
# 触发人工审批
39+
is_approved = interrupt(interrupt_info)
40+
41+
# 返回审批结果
42+
if is_approved:
43+
result = {
44+
"approved": True,
45+
"message": f"✅ 操作已批准:{operation_description}",
46+
}
47+
print(f"✅ 人工审批通过: {operation_description}")
48+
else:
49+
result = {
50+
"approved": False,
51+
"message": f"❌ 操作被拒绝:{operation_description}",
52+
}
53+
print(f"❌ 人工审批被拒绝: {operation_description}")
54+
55+
return result
1356

1457
@tool(name_or_callable="查询知识图谱", description="使用这个工具可以查询知识图谱中包含的三元组信息。")
1558
def query_knowledge_graph(query: Annotated[str, "The keyword to query knowledge graph."]) -> Any:
@@ -31,6 +74,7 @@ def get_static_tools() -> list:
3174
"""注册静态工具"""
3275
static_tools = [
3376
query_knowledge_graph,
77+
get_approved_user_goal
3478
]
3579

3680
# 检查是否启用网页搜索

src/agents/multiAgent/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .graph import SampleMultiAgent
2+
3+
__all__ = ["SampleMultiAgent"]

src/agents/multiAgent/context.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from dataclasses import dataclass, field
2+
from typing import Annotated
3+
4+
from src.agents.common.context import BaseContext
5+
from src.agents.common.mcp import MCP_SERVERS
6+
from src.agents.common.tools import gen_tool_info
7+
8+
from .tools import get_tools
9+
10+
11+
@dataclass(kw_only=True)
12+
class Context(BaseContext):
13+
tools: Annotated[list[dict], {"__template_metadata__": {"kind": "tools"}}] = field(
14+
default_factory=list,
15+
metadata={
16+
"name": "工具",
17+
"options": gen_tool_info(get_tools()), # 这里的选择是所有的工具
18+
"description": "工具列表",
19+
},
20+
)
21+
22+
mcps: list[str] = field(
23+
default_factory=list,
24+
metadata={"name": "MCP服务器", "options": list(MCP_SERVERS.keys()), "description": "MCP服务器列表"},
25+
)

src/agents/multiAgent/graph.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from langgraph.graph import END, START, StateGraph
2+
from langgraph.prebuilt import tools_condition
3+
4+
from src.agents.common.toolagent import ToolAgent
5+
6+
from .context import Context
7+
from .state import State
8+
from .tools import get_tools
9+
10+
11+
class SampleMultiAgent(ToolAgent):
12+
name = "MultiAgent智能体"
13+
description = "Supervisor智能体,具有调用其他子智能体的能力(在工具中添加)"
14+
15+
# TODO[已完成]: 通过将其他agent封装为工具的方式添加了多智能体调度
16+
'''
17+
你是一个多智能体核心,通过多智能体调用的方式帮助用户完成一系列任务:
18+
19+
1.当你需要知识库问答功能时,请调用对话聊天智能体实现
20+
2.当你需要加密计算的时候,请调用加密计算智能体实现
21+
'''
22+
23+
def __init__(self, **kwargs):
24+
super().__init__(**kwargs)
25+
self.graph = None
26+
self.checkpointer = None
27+
self.context_schema = Context
28+
self.agent_tools = None
29+
30+
def get_tools(self):
31+
return get_tools()
32+
33+
async def get_graph(self, **kwargs):
34+
"""构建图"""
35+
if self.graph:
36+
return self.graph
37+
38+
builder = StateGraph(State, context_schema=self.context_schema)
39+
builder.add_node("chatbot", self.llm_call)
40+
builder.add_node("tools", self.dynamic_tools_node)
41+
builder.add_edge(START, "chatbot")
42+
builder.add_conditional_edges(
43+
"chatbot",
44+
tools_condition,
45+
)
46+
builder.add_edge("tools", "chatbot")
47+
builder.add_edge("chatbot", END)
48+
49+
self.checkpointer = await self._get_checkpointer()
50+
graph = builder.compile(checkpointer=self.checkpointer, name=self.name)
51+
self.graph = graph
52+
return graph
53+
54+
55+
def main():
56+
pass
57+
58+
59+
if __name__ == "__main__":
60+
main()
61+
# asyncio.run(main())

src/agents/multiAgent/state.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""Define the state structures for the agent."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Sequence
6+
from dataclasses import dataclass, field
7+
from typing import Annotated
8+
9+
from langchain.messages import AnyMessage
10+
from langgraph.graph import add_messages
11+
12+
from src.agents.common.state import BaseState
13+
14+
15+
@dataclass
16+
class State(BaseState):
17+
"""Defines the input state for the agent, representing a narrower interface to the outside world.
18+
19+
This class is used to define the initial state and structure of incoming data.
20+
"""
21+
22+
messages: Annotated[Sequence[AnyMessage], add_messages] = field(default_factory=list)

0 commit comments

Comments
 (0)