Skip to content

Commit e75243c

Browse files
committed
fix(chat_router): 修复遗漏了异步操作的地方 Fixes: #369 #368
1 parent 213fbe5 commit e75243c

File tree

2 files changed

+13
-18
lines changed

2 files changed

+13
-18
lines changed

docker/pull_image.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ else
2525
fi
2626

2727
# Pull image from mirror
28+
echo "Pulling image from mirror: $MIRROR_URL/$IMAGE_TAG"
2829
docker pull $MIRROR_URL/$IMAGE_TAG
2930

3031
# Tag image with original name

server/routers/chat_router.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ def _norm_list(v):
131131
return result
132132

133133

134-
def _get_existing_message_ids(conv_mgr, thread_id):
134+
async def _get_existing_message_ids(conv_mgr, thread_id):
135135
"""获取已保存的消息ID集合"""
136-
existing_messages = conv_mgr.get_messages_by_thread_id(thread_id)
136+
existing_messages = await conv_mgr.get_messages_by_thread_id(thread_id)
137137
return {msg.extra_metadata["id"] for msg in existing_messages if msg.extra_metadata and "id" in msg.extra_metadata}
138138

139139

@@ -143,7 +143,7 @@ async def _save_ai_message(conv_mgr, thread_id, msg_dict):
143143
tool_calls_data = msg_dict.get("tool_calls", [])
144144

145145
# 保存AI消息
146-
ai_msg = conv_mgr.add_message_by_thread_id(
146+
ai_msg = await conv_mgr.add_message_by_thread_id(
147147
thread_id=thread_id,
148148
role="assistant",
149149
content=content,
@@ -155,7 +155,7 @@ async def _save_ai_message(conv_mgr, thread_id, msg_dict):
155155
if tool_calls_data:
156156
logger.debug(f"Saving {len(tool_calls_data)} tool calls from AI message")
157157
for tc in tool_calls_data:
158-
conv_mgr.add_tool_call(
158+
await conv_mgr.add_tool_call(
159159
message_id=ai_msg.id,
160160
tool_name=tc.get("name", "unknown"),
161161
tool_input=tc.get("args", {}),
@@ -166,7 +166,7 @@ async def _save_ai_message(conv_mgr, thread_id, msg_dict):
166166
logger.debug(f"Saved AI message {ai_msg.id} with {len(tool_calls_data)} tool calls")
167167

168168

169-
def _save_tool_message(conv_mgr, msg_dict):
169+
async def _save_tool_message(conv_mgr, msg_dict):
170170
"""保存工具执行结果"""
171171
tool_call_id = msg_dict.get("tool_call_id")
172172
content = msg_dict.get("content", "")
@@ -182,7 +182,7 @@ def _save_tool_message(conv_mgr, msg_dict):
182182
tool_output = str(content)
183183

184184
# 更新工具调用结果
185-
updated_tc = conv_mgr.update_tool_call_output(
185+
updated_tc = await conv_mgr.update_tool_call_output(
186186
langgraph_tool_call_id=tool_call_id,
187187
tool_output=tool_output,
188188
status="success",
@@ -238,7 +238,7 @@ async def save_partial_message(conv_mgr, thread_id, full_msg=None, error_message
238238
else:
239239
content = ""
240240

241-
saved_msg = conv_mgr.add_message_by_thread_id(
241+
saved_msg = await conv_mgr.add_message_by_thread_id(
242242
thread_id=thread_id,
243243
role="assistant",
244244
content=content,
@@ -271,7 +271,7 @@ async def save_messages_from_langgraph_state(
271271
return
272272

273273
logger.debug(f"Retrieved {len(messages)} messages from LangGraph state")
274-
existing_ids = _get_existing_message_ids(conv_mgr, thread_id)
274+
existing_ids = await _get_existing_message_ids(conv_mgr, thread_id)
275275

276276
for msg in messages:
277277
msg_dict = msg.model_dump() if hasattr(msg, "model_dump") else {}
@@ -283,7 +283,7 @@ async def save_messages_from_langgraph_state(
283283
if msg_type == "ai":
284284
await _save_ai_message(conv_mgr, thread_id, msg_dict)
285285
elif msg_type == "tool":
286-
_save_tool_message(conv_mgr, msg_dict)
286+
await _save_tool_message(conv_mgr, msg_dict)
287287
else:
288288
logger.warning(f"Unknown message type: {msg_type}, skipping")
289289
continue
@@ -808,14 +808,11 @@ def make_resume_chunk(content=None, **kwargs):
808808
logger.warning(f"Client disconnected during resume: {e}")
809809

810810
# 保存中断消息到数据库
811-
new_db = db_manager.get_session()
812-
try:
811+
async with db_manager.get_async_session_context() as new_db:
813812
new_conv_manager = ConversationManager(new_db)
814813
await save_partial_message(
815814
new_conv_manager, thread_id, error_message="对话恢复已中断", error_type="resume_interrupted"
816815
)
817-
finally:
818-
new_db.close()
819816

820817
yield make_resume_chunk(status="interrupted", message="对话恢复已中断", meta=meta)
821818

@@ -824,14 +821,11 @@ def make_resume_chunk(content=None, **kwargs):
824821
logger.error(f"Error during resume: {e}, {traceback.format_exc()}")
825822

826823
# 保存错误消息到数据库
827-
new_db = db_manager.get_session()
828-
try:
824+
async with db_manager.get_async_session_context() as new_db:
829825
new_conv_manager = ConversationManager(new_db)
830826
await save_partial_message(
831827
new_conv_manager, thread_id, error_message=f"Error during resume: {e}", error_type="resume_error"
832828
)
833-
finally:
834-
new_db.close()
835829

836830
yield make_resume_chunk(message=f"Error during resume: {e}", status="error")
837831

@@ -1271,7 +1265,7 @@ async def submit_message_feedback(
12711265
raise
12721266
except Exception as e:
12731267
logger.error(f"Error submitting message feedback: {e}, {traceback.format_exc()}")
1274-
db.rollback()
1268+
await db.rollback()
12751269
raise HTTPException(status_code=500, detail=f"Failed to submit feedback: {str(e)}")
12761270

12771271

0 commit comments

Comments
 (0)