@@ -131,9 +131,9 @@ def _norm_list(v):
131131 return result
132132
133133
134- def _get_existing_message_ids (conv_mgr , thread_id ):
134+ async def _get_existing_message_ids (conv_mgr , thread_id ):
135135 """获取已保存的消息ID集合"""
136- existing_messages = conv_mgr .get_messages_by_thread_id (thread_id )
136+ existing_messages = await conv_mgr .get_messages_by_thread_id (thread_id )
137137 return {msg .extra_metadata ["id" ] for msg in existing_messages if msg .extra_metadata and "id" in msg .extra_metadata }
138138
139139
@@ -143,7 +143,7 @@ async def _save_ai_message(conv_mgr, thread_id, msg_dict):
143143 tool_calls_data = msg_dict .get ("tool_calls" , [])
144144
145145 # 保存AI消息
146- ai_msg = conv_mgr .add_message_by_thread_id (
146+ ai_msg = await conv_mgr .add_message_by_thread_id (
147147 thread_id = thread_id ,
148148 role = "assistant" ,
149149 content = content ,
@@ -155,7 +155,7 @@ async def _save_ai_message(conv_mgr, thread_id, msg_dict):
155155 if tool_calls_data :
156156 logger .debug (f"Saving { len (tool_calls_data )} tool calls from AI message" )
157157 for tc in tool_calls_data :
158- conv_mgr .add_tool_call (
158+ await conv_mgr .add_tool_call (
159159 message_id = ai_msg .id ,
160160 tool_name = tc .get ("name" , "unknown" ),
161161 tool_input = tc .get ("args" , {}),
@@ -166,7 +166,7 @@ async def _save_ai_message(conv_mgr, thread_id, msg_dict):
166166 logger .debug (f"Saved AI message { ai_msg .id } with { len (tool_calls_data )} tool calls" )
167167
168168
169- def _save_tool_message (conv_mgr , msg_dict ):
169+ async def _save_tool_message (conv_mgr , msg_dict ):
170170 """保存工具执行结果"""
171171 tool_call_id = msg_dict .get ("tool_call_id" )
172172 content = msg_dict .get ("content" , "" )
@@ -182,7 +182,7 @@ def _save_tool_message(conv_mgr, msg_dict):
182182 tool_output = str (content )
183183
184184 # 更新工具调用结果
185- updated_tc = conv_mgr .update_tool_call_output (
185+ updated_tc = await conv_mgr .update_tool_call_output (
186186 langgraph_tool_call_id = tool_call_id ,
187187 tool_output = tool_output ,
188188 status = "success" ,
@@ -238,7 +238,7 @@ async def save_partial_message(conv_mgr, thread_id, full_msg=None, error_message
238238 else :
239239 content = ""
240240
241- saved_msg = conv_mgr .add_message_by_thread_id (
241+ saved_msg = await conv_mgr .add_message_by_thread_id (
242242 thread_id = thread_id ,
243243 role = "assistant" ,
244244 content = content ,
@@ -271,7 +271,7 @@ async def save_messages_from_langgraph_state(
271271 return
272272
273273 logger .debug (f"Retrieved { len (messages )} messages from LangGraph state" )
274- existing_ids = _get_existing_message_ids (conv_mgr , thread_id )
274+ existing_ids = await _get_existing_message_ids (conv_mgr , thread_id )
275275
276276 for msg in messages :
277277 msg_dict = msg .model_dump () if hasattr (msg , "model_dump" ) else {}
@@ -283,7 +283,7 @@ async def save_messages_from_langgraph_state(
283283 if msg_type == "ai" :
284284 await _save_ai_message (conv_mgr , thread_id , msg_dict )
285285 elif msg_type == "tool" :
286- _save_tool_message (conv_mgr , msg_dict )
286+ await _save_tool_message (conv_mgr , msg_dict )
287287 else :
288288 logger .warning (f"Unknown message type: { msg_type } , skipping" )
289289 continue
@@ -808,14 +808,11 @@ def make_resume_chunk(content=None, **kwargs):
808808 logger .warning (f"Client disconnected during resume: { e } " )
809809
810810 # 保存中断消息到数据库
811- new_db = db_manager .get_session ()
812- try :
811+ async with db_manager .get_async_session_context () as new_db :
813812 new_conv_manager = ConversationManager (new_db )
814813 await save_partial_message (
815814 new_conv_manager , thread_id , error_message = "对话恢复已中断" , error_type = "resume_interrupted"
816815 )
817- finally :
818- new_db .close ()
819816
820817 yield make_resume_chunk (status = "interrupted" , message = "对话恢复已中断" , meta = meta )
821818
@@ -824,14 +821,11 @@ def make_resume_chunk(content=None, **kwargs):
824821 logger .error (f"Error during resume: { e } , { traceback .format_exc ()} " )
825822
826823 # 保存错误消息到数据库
827- new_db = db_manager .get_session ()
828- try :
824+ async with db_manager .get_async_session_context () as new_db :
829825 new_conv_manager = ConversationManager (new_db )
830826 await save_partial_message (
831827 new_conv_manager , thread_id , error_message = f"Error during resume: { e } " , error_type = "resume_error"
832828 )
833- finally :
834- new_db .close ()
835829
836830 yield make_resume_chunk (message = f"Error during resume: { e } " , status = "error" )
837831
@@ -1271,7 +1265,7 @@ async def submit_message_feedback(
12711265 raise
12721266 except Exception as e :
12731267 logger .error (f"Error submitting message feedback: { e } , { traceback .format_exc ()} " )
1274- db .rollback ()
1268+ await db .rollback ()
12751269 raise HTTPException (status_code = 500 , detail = f"Failed to submit feedback: { str (e )} " )
12761270
12771271
0 commit comments