1- import os
21import uuid
32from pathlib import Path
43from typing import Any , cast
54
65from langchain_core .messages import AIMessage , ToolMessage
76from langgraph .checkpoint .memory import InMemorySaver
8- from langgraph .checkpoint .sqlite .aio import AsyncSqliteSaver , aiosqlite
97from langgraph .graph import END , START , StateGraph
108from langgraph .prebuilt import ToolNode , tools_condition
119from langgraph .runtime import Runtime
@@ -28,6 +26,7 @@ class ChatbotAgent(BaseAgent):
2826 def __init__ (self , ** kwargs ):
2927 super ().__init__ (** kwargs )
3028 self .graph = None
29+ self .checkpointer = InMemorySaver ()
3130 self .context_schema = Context
3231 self .workdir = Path (sys_config .save_dir ) / "agents" / self .module_name
3332 self .workdir .mkdir (parents = True , exist_ok = True )
@@ -90,8 +89,8 @@ async def dynamic_tools_node(self, state: State, runtime: Runtime[Context]) -> d
9089
9190 async def get_graph (self , ** kwargs ):
9291 """构建图"""
93- # if self.graph:
94- # return self.graph
92+ if self .graph :
93+ return self .graph
9594
9695 builder = StateGraph (State , context_schema = self .context_schema )
9796 builder .add_node ("chatbot" , self .llm_call )
@@ -104,27 +103,9 @@ async def get_graph(self, **kwargs):
104103 builder .add_edge ("tools" , "chatbot" )
105104 builder .add_edge ("chatbot" , END )
106105
107- # 创建数据库连接并确保设置 checkpointer
108- try :
109- sqlite_checkpointer = AsyncSqliteSaver (await self .get_async_conn ())
110- graph = builder .compile (checkpointer = sqlite_checkpointer , name = self .name )
111- self .graph = graph
112- return graph
113- except Exception as e :
114- logger .error (f"构建 Graph 设置 checkpointer 时出错: { e } , 尝试使用内存存储" )
115- # 即使出错也返回一个可用的图实例,只是无法保存历史
116- checkpointer = InMemorySaver ()
117- graph = builder .compile (checkpointer = checkpointer , name = self .name )
118- self .graph = graph
119- return graph
120-
121- async def get_async_conn (self ) -> aiosqlite .Connection :
122- """获取异步数据库连接"""
123- return await aiosqlite .connect (os .path .join (self .workdir , "aio_history.db" ))
124-
125- async def get_aio_memory (self ) -> AsyncSqliteSaver :
126- """获取异步存储实例"""
127- return AsyncSqliteSaver (await self .get_async_conn ())
106+ graph = builder .compile (checkpointer = self .checkpointer , name = self .name )
107+ self .graph = graph
108+ return graph
128109
129110
130111def main ():
0 commit comments