Skip to content

Commit ff955bd

Browse files
committed
feat: 移除 Async的Checkpointer,并默认配置了 InMemorySaver
1 parent c8ea1e7 commit ff955bd

File tree

3 files changed

+11
-30
lines changed

3 files changed

+11
-30
lines changed

src/agents/chatbot/graph.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import os
21
import uuid
32
from pathlib import Path
43
from typing import Any, cast
54

65
from langchain_core.messages import AIMessage, ToolMessage
76
from langgraph.checkpoint.memory import InMemorySaver
8-
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver, aiosqlite
97
from langgraph.graph import END, START, StateGraph
108
from langgraph.prebuilt import ToolNode, tools_condition
119
from langgraph.runtime import Runtime
@@ -28,6 +26,7 @@ class ChatbotAgent(BaseAgent):
2826
def __init__(self, **kwargs):
2927
super().__init__(**kwargs)
3028
self.graph = None
29+
self.checkpointer = InMemorySaver()
3130
self.context_schema = Context
3231
self.workdir = Path(sys_config.save_dir) / "agents" / self.module_name
3332
self.workdir.mkdir(parents=True, exist_ok=True)
@@ -90,8 +89,8 @@ async def dynamic_tools_node(self, state: State, runtime: Runtime[Context]) -> d
9089

9190
async def get_graph(self, **kwargs):
9291
"""构建图"""
93-
# if self.graph:
94-
# return self.graph
92+
if self.graph:
93+
return self.graph
9594

9695
builder = StateGraph(State, context_schema=self.context_schema)
9796
builder.add_node("chatbot", self.llm_call)
@@ -104,27 +103,9 @@ async def get_graph(self, **kwargs):
104103
builder.add_edge("tools", "chatbot")
105104
builder.add_edge("chatbot", END)
106105

107-
# 创建数据库连接并确保设置 checkpointer
108-
try:
109-
sqlite_checkpointer = AsyncSqliteSaver(await self.get_async_conn())
110-
graph = builder.compile(checkpointer=sqlite_checkpointer, name=self.name)
111-
self.graph = graph
112-
return graph
113-
except Exception as e:
114-
logger.error(f"构建 Graph 设置 checkpointer 时出错: {e}, 尝试使用内存存储")
115-
# 即使出错也返回一个可用的图实例,只是无法保存历史
116-
checkpointer = InMemorySaver()
117-
graph = builder.compile(checkpointer=checkpointer, name=self.name)
118-
self.graph = graph
119-
return graph
120-
121-
async def get_async_conn(self) -> aiosqlite.Connection:
122-
"""获取异步数据库连接"""
123-
return await aiosqlite.connect(os.path.join(self.workdir, "aio_history.db"))
124-
125-
async def get_aio_memory(self) -> AsyncSqliteSaver:
126-
"""获取异步存储实例"""
127-
return AsyncSqliteSaver(await self.get_async_conn())
106+
graph = builder.compile(checkpointer=self.checkpointer, name=self.name)
107+
self.graph = graph
108+
return graph
128109

129110

130111
def main():

src/agents/common/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from abc import abstractmethod
44

55
from langgraph.graph.state import CompiledStateGraph
6+
from langgraph.checkpoint.memory import InMemorySaver
67

78
from src.agents.common.context import BaseContext
89
from src.utils import logger
@@ -18,6 +19,7 @@ class BaseAgent:
1819

1920
def __init__(self, **kwargs):
2021
self.graph = None # will be covered by get_graph
22+
self.checkpointer = InMemorySaver()
2123
self.context_schema = BaseContext
2224

2325
@property
@@ -52,7 +54,7 @@ async def stream_messages(self, messages: list[str], input_context=None, **kwarg
5254
graph = await self.get_graph()
5355
context = self.context_schema.from_file(module_name=self.module_name, input_context=input_context)
5456
logger.debug(f"stream_messages: {context}")
55-
# TODO Checkpointer 似乎还没有适配最新的 Context API
57+
# TODO Checkpointer 似乎还没有适配最新的 1.0 Context API
5658
input_config = {"configurable": input_context, "recursion_limit": 100}
5759
async for msg, metadata in graph.astream(
5860
{"messages": messages}, stream_mode="messages", context=context, config=input_config

src/agents/react/graph.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from pathlib import Path
22

33
from langchain_core.messages import AnyMessage, SystemMessage
4-
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver, aiosqlite
54
from langgraph.prebuilt import create_react_agent
65
from langgraph.runtime import get_runtime
76

@@ -37,8 +36,7 @@ async def get_graph(self, **kwargs):
3736

3837
available_tools = get_buildin_tools()
3938

40-
sqlite_checkpointer = AsyncSqliteSaver(await aiosqlite.connect(self.workdir / "react_history.db"))
41-
graph = create_react_agent(model, tools=available_tools, checkpointer=sqlite_checkpointer, prompt=prompt)
39+
graph = create_react_agent(model, tools=available_tools, prompt=prompt, checkpointer=self.checkpointer)
4240
self.graph = graph
43-
logger.info("ReActAgent使用SQLite checkpointer构建成功")
41+
logger.info("ReActAgent 使用内存 checkpointer 构建成功")
4442
return graph

0 commit comments

Comments
 (0)