22import asyncio
33import os
44import traceback
5+ import textwrap
56from collections .abc import Mapping
67from urllib .parse import quote , unquote
78
@@ -235,18 +236,13 @@ async def run_ingest(context: TaskContext):
235236 progress = 5.0 + (idx / total ) * 90.0 # 5% ~ 95%
236237 await context .set_progress (progress , f"正在处理第 { idx } /{ total } 个文档" )
237238
238- # 处理单个文档
239239 try :
240240 result = await knowledge_base .add_content (db_id , [item ], params = params )
241241 processed_items .extend (result )
242242 except Exception as doc_error :
243- # 处理单个文档处理的所有异常(包括超时)
244243 logger .error (f"Document processing failed for { item } : { doc_error } " )
245-
246- # 判断是否是超时异常
247244 error_type = "timeout" if isinstance (doc_error , TimeoutError ) else "processing_error"
248245 error_msg = "处理超时" if isinstance (doc_error , TimeoutError ) else "处理失败"
249-
250246 processed_items .append (
251247 {
252248 "item" : item ,
@@ -815,6 +811,199 @@ async def get_knowledge_base_query_params(db_id: str, current_user: User = Depen
815811 return {"message" : f"获取知识库查询参数失败 { e } " , "params" : {}}
816812
817813
814+ # =============================================================================
815+ # === AI生成示例问题 ===
816+ # =============================================================================
817+
818+
819+ SAMPLE_QUESTIONS_SYSTEM_PROMPT = """你是一个专业的知识库问答测试专家。
820+
821+ 你的任务是根据知识库中的文件列表,生成有价值的测试问题。
822+
823+ 要求:
824+ 1. 问题要具体、有针对性,基于文件名称和类型推测可能的内容
825+ 2. 问题要涵盖不同方面和难度
826+ 3. 问题要简洁明了,适合用于检索测试
827+ 4. 问题要多样化,包括事实查询、概念解释、操作指导等
828+ 5. 问题长度控制在10-30字之间
829+ 6. 直接返回JSON数组格式,不要其他说明
830+
831+ 返回格式:
832+ ```json
833+ {
834+ "questions": [
835+ "问题1?",
836+ "问题2?",
837+ "问题3?"
838+ ]
839+ }
840+ ```
841+ """
842+
843+
844+ @knowledge .post ("/databases/{db_id}/sample-questions" )
845+ async def generate_sample_questions (
846+ db_id : str ,
847+ request_body : dict = Body (...),
848+ current_user : User = Depends (get_admin_user ),
849+ ):
850+ """
851+ AI生成针对知识库的测试问题
852+
853+ Args:
854+ db_id: 知识库ID
855+ request_body: 请求体,包含 count 字段
856+
857+ Returns:
858+ 生成的问题列表
859+ """
860+ try :
861+ from src .models import select_model
862+ import json
863+
864+ # 从请求体中提取参数
865+ count = request_body .get ("count" , 10 )
866+
867+ # 获取知识库信息
868+ db_info = knowledge_base .get_database_info (db_id )
869+ if not db_info :
870+ raise HTTPException (status_code = 404 , detail = f"知识库 { db_id } 不存在" )
871+
872+ db_name = db_info .get ("name" , "" )
873+ all_files = db_info .get ("files" , {})
874+
875+ if not all_files :
876+ raise HTTPException (status_code = 400 , detail = "知识库中没有文件" )
877+
878+ # 收集文件信息
879+ files_info = []
880+ for file_id , file_info in all_files .items ():
881+ files_info .append (
882+ {
883+ "filename" : file_info .get ("filename" , "" ),
884+ "type" : file_info .get ("type" , "" ),
885+ }
886+ )
887+
888+ # 构建AI提示词
889+ system_prompt = SAMPLE_QUESTIONS_SYSTEM_PROMPT
890+
891+ # 构建用户消息
892+ files_text = "\n " .join (
893+ [
894+ f"- { f ['filename' ]} ({ f ['type' ]} )"
895+ for f in files_info [:20 ] # 最多列举20个文件
896+ ]
897+ )
898+
899+ file_count_text = f"(共{ len (files_info )} 个文件)" if len (files_info ) > 20 else ""
900+
901+ user_message = textwrap .dedent (f"""请为知识库"{ db_name } "生成{ count } 个测试问题。
902+
903+ 知识库文件列表{ file_count_text } :
904+ { files_text }
905+
906+ 请根据这些文件的名称和类型,生成{ count } 个有价值的测试问题。""" )
907+
908+ # 调用AI生成
909+ logger .info (f"开始生成知识库问题,知识库: { db_name } , 文件数量: { len (files_info )} , 问题数量: { count } " )
910+
911+ # 选择模型并调用
912+ model = select_model ()
913+ messages = [{"role" : "system" , "content" : system_prompt }, {"role" : "user" , "content" : user_message }]
914+ response = model .call (messages , stream = False )
915+
916+ # 解析AI返回的JSON
917+ try :
918+ # 提取JSON内容
919+ content = response .content if hasattr (response , "content" ) else str (response )
920+
921+ # 尝试从markdown代码块中提取JSON
922+ if "```json" in content :
923+ json_start = content .find ("```json" ) + 7
924+ json_end = content .find ("```" , json_start )
925+ content = content [json_start :json_end ].strip ()
926+ elif "```" in content :
927+ json_start = content .find ("```" ) + 3
928+ json_end = content .find ("```" , json_start )
929+ content = content [json_start :json_end ].strip ()
930+
931+ questions_data = json .loads (content )
932+ questions = questions_data .get ("questions" , [])
933+
934+ if not questions or not isinstance (questions , list ):
935+ raise ValueError ("AI返回的问题格式不正确" )
936+
937+ logger .info (f"成功生成{ len (questions )} 个问题" )
938+
939+ # 保存问题到知识库元数据
940+ try :
941+ async with knowledge_base ._metadata_lock :
942+ # 确保知识库元数据存在
943+ if db_id not in knowledge_base .global_databases_meta :
944+ knowledge_base .global_databases_meta [db_id ] = {}
945+ # 保存问题到对应知识库
946+ knowledge_base .global_databases_meta [db_id ]["sample_questions" ] = questions
947+ knowledge_base ._save_global_metadata ()
948+ logger .info (f"成功保存 { len (questions )} 个问题到知识库 { db_id } " )
949+ except Exception as save_error :
950+ logger .error (f"保存问题失败: { save_error } " )
951+
952+ return {
953+ "message" : "success" ,
954+ "questions" : questions ,
955+ "count" : len (questions ),
956+ "db_id" : db_id ,
957+ "db_name" : db_name ,
958+ }
959+
960+ except json .JSONDecodeError as e :
961+ logger .error (f"AI返回的JSON解析失败: { e } , 原始内容: { content } " )
962+ raise HTTPException (status_code = 500 , detail = f"AI返回格式错误: { str (e )} " )
963+
964+ except HTTPException :
965+ raise
966+ except Exception as e :
967+ logger .error (f"生成知识库问题失败: { e } , { traceback .format_exc ()} " )
968+ raise HTTPException (status_code = 500 , detail = f"生成问题失败: { str (e )} " )
969+
970+
971+ @knowledge .get ("/databases/{db_id}/sample-questions" )
972+ async def get_sample_questions (db_id : str , current_user : User = Depends (get_admin_user )):
973+ """
974+ 获取知识库的测试问题
975+
976+ Args:
977+ db_id: 知识库ID
978+
979+ Returns:
980+ 问题列表
981+ """
982+ try :
983+ # 直接从全局元数据中读取
984+ if db_id not in knowledge_base .global_databases_meta :
985+ raise HTTPException (status_code = 404 , detail = f"知识库 { db_id } 不存在" )
986+
987+ db_meta = knowledge_base .global_databases_meta [db_id ]
988+ questions = db_meta .get ("sample_questions" , [])
989+
990+ if not questions :
991+ raise HTTPException (status_code = 404 , detail = "该知识库还没有生成测试问题" )
992+
993+ return {
994+ "message" : "success" ,
995+ "questions" : questions ,
996+ "count" : len (questions ),
997+ "db_id" : db_id ,
998+ }
999+
1000+ except HTTPException :
1001+ raise
1002+ except Exception as e :
1003+ logger .error (f"获取知识库问题失败: { e } , { traceback .format_exc ()} " )
1004+ raise HTTPException (status_code = 500 , detail = f"获取问题失败: { str (e )} " )
1005+
1006+
8181007# =============================================================================
8191008# === 文件管理分组 ===
8201009# =============================================================================
@@ -838,7 +1027,7 @@ async def upload_file(
8381027 if ext == ".jsonl" :
8391028 if allow_jsonl is not True or db_id is not None :
8401029 raise HTTPException (status_code = 400 , detail = f"Unsupported file type: { ext } " )
841- elif not is_supported_file_extension (file .filename ):
1030+ elif not ( is_supported_file_extension (file .filename ) or ext == ".zip" ):
8421031 raise HTTPException (status_code = 400 , detail = f"Unsupported file type: { ext } " )
8431032
8441033 # 根据db_id获取上传路径,如果db_id为None则使用默认路径
0 commit comments