|
2 | 2 | import json |
3 | 3 | import traceback |
4 | 4 | import uuid |
5 | | -from pathlib import Path |
6 | 5 |
|
7 | | -from fastapi import APIRouter, Body, Depends, HTTPException |
| 6 | +from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, File |
8 | 7 | from fastapi.responses import StreamingResponse |
9 | 8 | from langchain.messages import AIMessageChunk, HumanMessage |
10 | 9 | from langgraph.types import Command |
|
22 | 21 | from src.agents.common.tools import gen_tool_info, get_buildin_tools |
23 | 22 | from src.models import select_model |
24 | 23 | from src.plugins.guard import content_guard |
| 24 | +from src.services.doc_converter import ( |
| 25 | + ATTACHMENT_ALLOWED_EXTENSIONS, |
| 26 | + MAX_ATTACHMENT_SIZE_BYTES, |
| 27 | + convert_upload_to_markdown, |
| 28 | +) |
| 29 | +from src.utils.datetime_utils import utc_isoformat |
25 | 30 | from src.utils.logging_config import logger |
26 | 31 |
|
27 | 32 | chat = APIRouter(prefix="/chat", tags=["chat"]) |
@@ -156,6 +161,25 @@ def _save_tool_message(conv_mgr, msg_dict): |
156 | 161 | logger.warning(f"Tool call {tool_call_id} not found for update") |
157 | 162 |
|
158 | 163 |
|
| 164 | +def _require_user_conversation(conv_mgr: ConversationManager, thread_id: str, user_id: str) -> Conversation: |
| 165 | + conversation = conv_mgr.get_conversation_by_thread_id(thread_id) |
| 166 | + if not conversation or conversation.user_id != str(user_id) or conversation.status == "deleted": |
| 167 | + raise HTTPException(status_code=404, detail="对话线程不存在") |
| 168 | + return conversation |
| 169 | + |
| 170 | + |
| 171 | +def _serialize_attachment(record: dict) -> dict: |
| 172 | + return { |
| 173 | + "file_id": record.get("file_id"), |
| 174 | + "file_name": record.get("file_name"), |
| 175 | + "file_type": record.get("file_type"), |
| 176 | + "file_size": record.get("file_size", 0), |
| 177 | + "status": record.get("status", "parsed"), |
| 178 | + "uploaded_at": record.get("uploaded_at"), |
| 179 | + "truncated": record.get("truncated", False), |
| 180 | + } |
| 181 | + |
| 182 | + |
159 | 183 | async def save_messages_from_langgraph_state( |
160 | 184 | agent_instance, |
161 | 185 | thread_id, |
@@ -313,7 +337,8 @@ async def get_agent(current_user: User = Depends(get_required_user)): |
313 | 337 | "description": agent_info.get("description", ""), |
314 | 338 | "examples": agent_info.get("examples", []), |
315 | 339 | "configurable_items": agent_info.get("configurable_items", []), |
316 | | - "has_checkpointer": agent_info.get("has_checkpointer", False) |
| 340 | + "has_checkpointer": agent_info.get("has_checkpointer", False), |
| 341 | + "capabilities": agent_info.get("capabilities", []) # 智能体能力列表 |
317 | 342 | } |
318 | 343 | for agent_info in agents_info |
319 | 344 | ] |
@@ -401,6 +426,15 @@ async def stream_messages(): |
401 | 426 | except Exception as e: |
402 | 427 | logger.error(f"Error saving user message: {e}") |
403 | 428 |
|
| 429 | + try: |
| 430 | + assert thread_id, "thread_id is required" |
| 431 | + attachments = conv_manager.get_attachments_by_thread_id(thread_id) |
| 432 | + input_context["attachments"] = attachments |
| 433 | + logger.debug(f"Loaded {len(attachments)} attachments for thread_id={thread_id}") |
| 434 | + except Exception as e: |
| 435 | + logger.error(f"Error loading attachments for thread_id={thread_id}: {e}") |
| 436 | + input_context["attachments"] = [] |
| 437 | + |
404 | 438 | try: |
405 | 439 | full_msg = None |
406 | 440 | async for msg, metadata in agent.stream_messages(messages, input_context=input_context): |
@@ -739,6 +773,26 @@ class ThreadResponse(BaseModel): |
739 | 773 | updated_at: str |
740 | 774 |
|
741 | 775 |
|
| 776 | +class AttachmentResponse(BaseModel): |
| 777 | + file_id: str |
| 778 | + file_name: str |
| 779 | + file_type: str | None = None |
| 780 | + file_size: int |
| 781 | + status: str |
| 782 | + uploaded_at: str |
| 783 | + truncated: bool | None = False |
| 784 | + |
| 785 | + |
| 786 | +class AttachmentLimits(BaseModel): |
| 787 | + allowed_extensions: list[str] |
| 788 | + max_size_bytes: int |
| 789 | + |
| 790 | + |
| 791 | +class AttachmentListResponse(BaseModel): |
| 792 | + attachments: list[AttachmentResponse] |
| 793 | + limits: AttachmentLimits |
| 794 | + |
| 795 | + |
742 | 796 | # ============================================================================= |
743 | 797 | # > === 会话管理分组 === |
744 | 798 | # ============================================================================= |
@@ -859,6 +913,75 @@ async def update_thread( |
859 | 913 | } |
860 | 914 |
|
861 | 915 |
|
| 916 | +@chat.post("/thread/{thread_id}/attachments", response_model=AttachmentResponse) |
| 917 | +async def upload_thread_attachment( |
| 918 | + thread_id: str, |
| 919 | + file: UploadFile = File(...), |
| 920 | + db: Session = Depends(get_db), |
| 921 | + current_user: User = Depends(get_required_user), |
| 922 | +): |
| 923 | + """上传并解析附件为 Markdown,附加到指定对话线程。""" |
| 924 | + conv_manager = ConversationManager(db) |
| 925 | + conversation = _require_user_conversation(conv_manager, thread_id, str(current_user.id)) |
| 926 | + |
| 927 | + try: |
| 928 | + conversion = await convert_upload_to_markdown(file) |
| 929 | + except ValueError as exc: |
| 930 | + raise HTTPException(status_code=400, detail=str(exc)) from exc |
| 931 | + except Exception as exc: # noqa: BLE001 |
| 932 | + logger.error(f"附件解析失败: {exc}") |
| 933 | + raise HTTPException(status_code=500, detail="附件解析失败,请稍后重试") from exc |
| 934 | + |
| 935 | + attachment_record = { |
| 936 | + "file_id": conversion.file_id, |
| 937 | + "file_name": conversion.file_name, |
| 938 | + "file_type": conversion.file_type, |
| 939 | + "file_size": conversion.file_size, |
| 940 | + "status": "parsed", |
| 941 | + "markdown": conversion.markdown, |
| 942 | + "uploaded_at": utc_isoformat(), |
| 943 | + "truncated": conversion.truncated, |
| 944 | + } |
| 945 | + conv_manager.add_attachment(conversation.id, attachment_record) |
| 946 | + |
| 947 | + return _serialize_attachment(attachment_record) |
| 948 | + |
| 949 | + |
| 950 | +@chat.get("/thread/{thread_id}/attachments", response_model=AttachmentListResponse) |
| 951 | +async def list_thread_attachments( |
| 952 | + thread_id: str, |
| 953 | + db: Session = Depends(get_db), |
| 954 | + current_user: User = Depends(get_required_user), |
| 955 | +): |
| 956 | + """列出当前对话线程的所有附件元信息。""" |
| 957 | + conv_manager = ConversationManager(db) |
| 958 | + conversation = _require_user_conversation(conv_manager, thread_id, str(current_user.id)) |
| 959 | + attachments = conv_manager.get_attachments(conversation.id) |
| 960 | + return { |
| 961 | + "attachments": [_serialize_attachment(item) for item in attachments], |
| 962 | + "limits": { |
| 963 | + "allowed_extensions": sorted(ATTACHMENT_ALLOWED_EXTENSIONS), |
| 964 | + "max_size_bytes": MAX_ATTACHMENT_SIZE_BYTES, |
| 965 | + }, |
| 966 | + } |
| 967 | + |
| 968 | + |
| 969 | +@chat.delete("/thread/{thread_id}/attachments/{file_id}") |
| 970 | +async def delete_thread_attachment( |
| 971 | + thread_id: str, |
| 972 | + file_id: str, |
| 973 | + db: Session = Depends(get_db), |
| 974 | + current_user: User = Depends(get_required_user), |
| 975 | +): |
| 976 | + """移除指定附件。""" |
| 977 | + conv_manager = ConversationManager(db) |
| 978 | + conversation = _require_user_conversation(conv_manager, thread_id, str(current_user.id)) |
| 979 | + removed = conv_manager.remove_attachment(conversation.id, file_id) |
| 980 | + if not removed: |
| 981 | + raise HTTPException(status_code=404, detail="附件不存在或已被删除") |
| 982 | + return {"message": "附件已删除"} |
| 983 | + |
| 984 | + |
862 | 985 | # ============================================================================= |
863 | 986 | # > === 消息反馈分组 === |
864 | 987 | # ============================================================================= |
|
0 commit comments