Skip to content

Commit e662466

Browse files
committed
style: ruff format
1 parent e1a0c6b commit e662466

File tree

2 files changed

+26
-37
lines changed

2 files changed

+26
-37
lines changed

server/routers/evaluation_router.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414

1515
@evaluation.get("/databases/{db_id}/benchmarks/{benchmark_id}")
1616
async def get_evaluation_benchmark_by_db(
17-
db_id: str,
18-
benchmark_id: str,
19-
page: int = 1,
20-
page_size: int = 10,
21-
current_user: User = Depends(get_admin_user)
17+
db_id: str, benchmark_id: str, page: int = 1, page_size: int = 10, current_user: User = Depends(get_admin_user)
2218
):
2319
"""根据 db_id 获取评估基准详情(支持分页)"""
2420
from src.services.evaluation_service import EvaluationService
@@ -54,15 +50,14 @@ async def delete_evaluation_benchmark(benchmark_id: str, current_user: User = De
5450
raise HTTPException(status_code=500, detail=f"删除评估基准失败: {str(e)}")
5551

5652

57-
5853
@evaluation.get("/databases/{db_id}/results/{task_id}")
5954
async def get_evaluation_results_by_db(
6055
db_id: str,
6156
task_id: str,
6257
page: int = 1,
6358
page_size: int = 20,
6459
error_only: bool = False,
65-
current_user: User = Depends(get_admin_user)
60+
current_user: User = Depends(get_admin_user),
6661
):
6762
"""获取评估结果(带 db_id,支持分页)"""
6863
from src.services.evaluation_service import EvaluationService

src/services/evaluation_service.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Any
88

99
from server.services.tasker import TaskContext, tasker
10-
from src import config
1110
from src.knowledge import knowledge_base
1211
from src.models import select_model
1312
from src.utils import logger
@@ -136,7 +135,9 @@ async def get_benchmark_detail(self, benchmark_id: str) -> dict[str, Any]:
136135
logger.error(f"获取评估基准详情失败: {e}")
137136
raise
138137

139-
async def get_benchmark_detail_by_db(self, db_id: str, benchmark_id: str, page: int = 1, page_size: int = 10) -> dict[str, Any]:
138+
async def get_benchmark_detail_by_db(
139+
self, db_id: str, benchmark_id: str, page: int = 1, page_size: int = 10
140+
) -> dict[str, Any]:
140141
"""根据 db_id 获取评估基准详情(支持分页)"""
141142
try:
142143
kb_instance = knowledge_base.get_kb(db_id)
@@ -174,17 +175,19 @@ async def get_benchmark_detail_by_db(self, db_id: str, benchmark_id: str, page:
174175
total_pages = (total_questions + page_size - 1) // page_size
175176

176177
meta_with_q = meta.copy()
177-
meta_with_q.update({
178-
"questions": questions,
179-
"pagination": {
180-
"current_page": page,
181-
"page_size": page_size,
182-
"total_questions": total_questions,
183-
"total_pages": total_pages,
184-
"has_next": page < total_pages,
185-
"has_prev": page > 1
178+
meta_with_q.update(
179+
{
180+
"questions": questions,
181+
"pagination": {
182+
"current_page": page,
183+
"page_size": page_size,
184+
"total_questions": total_questions,
185+
"total_pages": total_pages,
186+
"has_next": page < total_pages,
187+
"has_prev": page > 1,
188+
},
186189
}
187-
})
190+
)
188191
return meta_with_q
189192
except Exception as e:
190193
logger.error(f"获取评估基准详情失败: {e}")
@@ -226,8 +229,8 @@ async def generate_benchmark(self, db_id: str, params: dict[str, Any], created_b
226229
return {"task_id": task_id, "message": "基准生成任务已提交"}
227230

228231
async def _generate_benchmark_task(self, context: TaskContext):
229-
import random
230232
import math
233+
import random
231234

232235
await context.set_progress(0, "初始化")
233236

@@ -346,16 +349,16 @@ def cosine(a, b, na, nb):
346349
prompt = (
347350
"你将基于以下上下文生成一个可由上下文准确回答的问题与标准答案。"
348351
"仅返回一个JSON对象,不要包含其他文字。"
349-
"键为 query、gold_answer、gold_chunk_ids。gold_chunk_ids 必须是上述上下文片段的ID子集。\n\n上下文:\n"
350-
+ context_text
351-
+ "\n"
352+
"键为 query、gold_answer、gold_chunk_ids。gold_chunk_ids 必须是上述上下文片段的ID子集。\n\n"
353+
"上下文:\n" + context_text + "\n"
352354
)
353355

354356
try:
355357
resp = await asyncio.to_thread(llm.call, prompt, False)
356358
content = resp.content if resp else ""
357359

358360
import json_repair
361+
359362
obj = json_repair.loads(content)
360363
q = obj.get("query")
361364
a = obj.get("gold_answer")
@@ -587,7 +590,7 @@ def update_result_file(status="running", completed=0, metrics=None, interim=None
587590
prompt = (
588591
f"基于以下上下文信息,请回答用户的问题。\n\n"
589592
f"上下文信息:{context_text}\n\n"
590-
f"用户问题:{question_data["query"]}\n\n"
593+
f"用户问题:{question_data['query']}\n\n"
591594
"请根据上下文信息准确回答问题。\n\n"
592595
"如果上下文中缺少相关信息,请回答“信息不足,无法回答”。\n\n"
593596
)
@@ -762,12 +765,7 @@ async def get_evaluation_history(self, db_id: str) -> list[dict[str, Any]]:
762765
# 索引与回退逻辑已移除,统一通过 db_id 定位
763766

764767
async def get_evaluation_results_by_db(
765-
self,
766-
db_id: str,
767-
task_id: str,
768-
page: int = 1,
769-
page_size: int = 20,
770-
error_only: bool = False
768+
self, db_id: str, task_id: str, page: int = 1, page_size: int = 20, error_only: bool = False
771769
) -> dict[str, Any]:
772770
result_file_path = os.path.join(self._get_result_dir(db_id), f"{task_id}.json")
773771
if not os.path.exists(result_file_path):
@@ -800,11 +798,7 @@ async def get_evaluation_results_by_db(
800798

801799
# 检查检索指标是否明显偏低
802800
metrics = item.get("metrics", {})
803-
has_low_recall = any(
804-
metrics.get(k, 1.0) < 0.3
805-
for k in metrics
806-
if k.startswith("recall@")
807-
)
801+
has_low_recall = any(metrics.get(k, 1.0) < 0.3 for k in metrics if k.startswith("recall@"))
808802
if has_low_recall:
809803
filtered_results.append(item)
810804
all_results = filtered_results
@@ -831,8 +825,8 @@ async def get_evaluation_results_by_db(
831825
"page_size": page_size,
832826
"total": total,
833827
"total_pages": (total + page_size - 1) // page_size,
834-
"error_only": error_only
835-
}
828+
"error_only": error_only,
829+
},
836830
}
837831

838832
# 非分页请求,返回完整数据(保持向后兼容)

0 commit comments

Comments
 (0)