-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCragFlow.py
More file actions
369 lines (313 loc) · 14.7 KB
/
CragFlow.py
File metadata and controls
369 lines (313 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
import os
import asyncio
import aiohttp
import logging
import operator
from typing import Annotated, TypedDict, List, Dict, Any
from dotenv import load_dotenv
from tenacity import retry, stop_after_attempt, wait_exponential
import re
load_dotenv(override=True)
from langgraph.graph import StateGraph, END
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg_pool import AsyncConnectionPool
from pymilvus import Collection, AnnSearchRequest, RRFRanker
from services.embedding import BGEEmbeddingService
from services.constants import (
FIELD_CHUNK_ID,
FIELD_DOI,
FIELD_DENSE_VECTOR,
FIELD_SPARSE_VECTOR,
FIELD_YEAR,
FIELD_FIELD,
FIELD_TEXT,
get_collection_name,
)
logger = logging.getLogger(__name__)
llm=ChatOpenAI(
model=os.getenv('model'),
api_key=os.getenv('api_key'),
base_url=os.getenv("base_url"),
temperature=0.1,
streaming=True)
bge_service = BGEEmbeddingService()
# graphstate类其实还是个字典,所以调用的时候用[]来通过索引调用(typedict)
class GraphState(TypedDict):
messages:Annotated[list,operator.add]
question:str
documents:List[Dict]
rewritten_query:str
evaluation_result:str
meta_filters:str
generation:str
user_id:str # 用户ID,用于查询该用户的长期记忆
long_term_memories:str # 从PostgreSQL Memory表查出的用户历史记忆,注入到对话状态供LLM参考
async def memory_inject_node(state: GraphState) -> Dict:
"""
从PostgreSQL Memory表中加载当前用户的长期记忆。
这些记忆是LLM在每次对话结束后自动提取并存入的,
后续生成节点会把这些记忆拼进prompt,让LLM知道用户的历史偏好/背景。
"""
user_id = state.get('user_id', '')
if not user_id:
return {'long_term_memories': ''}
from services.database import AsyncSessionLocal, Memory
from sqlalchemy import select
async with AsyncSessionLocal() as db:
# 查询该用户的所有有效记忆,按时间倒序取最新的20条
result = await db.execute(
select(Memory)
.where(Memory.user_id == user_id, Memory.is_valid.is_(True))
.order_by(Memory.created_at.desc())
.limit(20)
)
memories = result.scalars().all()
# 把多条记忆拼成一段带 "-" 前缀的文本,方便放进prompt
if memories:
mem_text = "\n".join([f"- {m.content}" for m in memories])
return {'long_term_memories': mem_text}
return {'long_term_memories': ''}
async def retrieve_node(state: GraphState) -> Dict:
logger.info("Node [retrieve]: Hybrid vector search in Milvus")
question = state['question']
filters = state.get('meta_filters', '')
search_limit = int(os.getenv('RETRIEVAL_SEARCH_LIMIT', '50'))
rrk_k = int(os.getenv('RETRIEVAL_RRF_K', '60'))
rrk_limit = int(os.getenv('RETRIEVAL_RRF_LIMIT', '10'))
embedding = await asyncio.to_thread(bge_service.encode_text, question)
collection_name = get_collection_name()
collection = Collection(collection_name)
collection.load()
req_dense = AnnSearchRequest(
data=[embedding['dense']],
anns_field=FIELD_DENSE_VECTOR,
param={'metric_type': 'IP', 'params': {'ef': 200}},
limit=search_limit,
expr=filters if filters else None
)
req_sparse = AnnSearchRequest(
data=[embedding['sparse']],
anns_field=FIELD_SPARSE_VECTOR,
param={'metric_type': 'IP', 'params': {'drop_ratio_search': 0.2}},
limit=search_limit,
expr=filters if filters else None
)
res = await asyncio.to_thread(
collection.hybrid_search,
reqs=[req_dense, req_sparse],
rerank=RRFRanker(k=rrk_k),
limit=rrk_limit,
output_fields=[FIELD_CHUNK_ID, FIELD_DOI, FIELD_YEAR, FIELD_FIELD, FIELD_TEXT]
)
docs = []
if res:
for hit in res[0]:
docs.append({
'id': hit.entity.get(FIELD_CHUNK_ID),
'doi': hit.entity.get(FIELD_DOI),
'year': hit.entity.get(FIELD_YEAR),
'field': hit.entity.get(FIELD_FIELD),
'text': hit.entity.get(FIELD_TEXT),
'distance': hit.distance
})
logger.info(f"Node [retrieve]: recalled {len(docs)} chunks")
if docs:
logger.info(f"Node [retrieve]: top hit doi={docs[0].get('doi')} score={docs[0].get('distance'):.4f}")
return {'documents': docs}
async def evaluate_node(state: GraphState) -> Dict:
"""LLM judge: decides whether retrieved chunks sufficiently answer the question."""
logger.info("Node [evaluate]: Judging retrieval quality")
docs = state.get('documents', [])
if not docs:
logger.warning("Node [evaluate]: no chunks recalled, forcing web search fallback")
return {"evaluation_result": "fail"}
docs_text = '\n\n'.join([f"Snippet{i+1}:{d['text'][:1000]}" for i, d in enumerate(docs)])
prompt = ChatPromptTemplate.from_template(
"你是学术检索质量评估器。判断以下检索到的文献片段与用户问题是否相关。\n"
"只要片段中包含与问题主题相关的信息,即判定为 pass;\n"
"仅当片段内容与问题完全无关时,才判定为 fail。\n\n"
"用户问题: {question}\n\n文献片段:\n{docs}\n\n"
"仅回答 'pass' 或 'fail',禁止输出其他内容。"
)
chain = prompt | llm
# evaluate 节点只负责给图内路由做判断:
# 返回 pass / fail,决定后面是直接 generate,还是先 rewrite -> web_search。
# 它的产物是 graph state,不是单独给前端持久化的一条最终答案。
response = await chain.ainvoke({'question': state['question'], 'docs': docs_text})
# 拿到原始回复并转小写
raw_result = response.content.strip().lower()
logger.info(f"Node [evaluate]: LLM raw response = '{raw_result}'")
# 这里的清洗是“写回图状态前”的保险,确保 evaluation_result 不包含思考标签。
clean_result = re.sub(r'<think>.*?</think>', '', raw_result, flags=re.DOTALL).strip()
fail_signals = ['fail', '不相关', '无关', 'irrelevant', 'no relevant']
# clean_result 已经是剔除 <think> 后的结果,路由判断只针对可用文本进行。
if any(w in clean_result for w in fail_signals):
final_result = 'fail'
else:
final_result = 'pass'
logger.info(f"Node [evaluate]: clean result matched = {final_result.upper()}")
return {"evaluation_result": final_result}
async def rewritten_node(state: GraphState) -> Dict:
"""Rewrites the user query into concise English keywords for Semantic Scholar."""
logger.info("Node [rewrite]: Rewriting query for academic search")
question = state['question']
prompt = ChatPromptTemplate.from_template(
"将以下用户问题转换为适合在 Semantic Scholar 学术引擎中查询的极简英文关键词(不超过 5 个词)。\n"
"用户问题:{question}\n\n"
"仅输出关键词,不要输出任何解释性文字。" # 稍微强化一下 prompt
)
chain = prompt | llm
# rewrite 节点的输出不是给用户直接看的正文,
# 而是供 web_search 节点继续使用的精简检索 query。
response = await chain.ainvoke({'question': question})
# 在写回 rewritten_query 前同样做一次 <think> 清洗,保证查询词干净可用。
raw_content = response.content.strip()
clean_query = re.sub(r'<think>.*?</think>', '', raw_content, flags=re.DOTALL).strip()
# 去除可能包含的引号
rewritten = clean_query.replace('"', '').replace("'", "")
logger.info(f"Node [rewrite]: rewritten query = '{rewritten}'")
return {'rewritten_query': rewritten}
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=2, max=16),
reraise=True
)
async def _fetch_s2_data(query: str, api_key: str):
"""
底层请求函数,封装了与 Semantic Scholar API 的通信及自动重试逻辑。
"""
timeout = aiohttp.ClientTimeout(total=15)
async with aiohttp.ClientSession(timeout=timeout) as session:
headers = {}
if api_key and api_key != "your_api_key_here":
headers['x-api-key'] = api_key
url = "https://api.semanticscholar.org/graph/v1/paper/search"
# 优化:使用 params 字典,aiohttp 会自动处理空格和特殊符号的 URL 编码
params = {
"query": query,
"limit": 3,
"fields": "title,abstract,year,externalIds"
}
async with session.get(url, headers=headers, params=params) as resp:
if resp.status == 200:
return await resp.json()
elif resp.status == 429:
logger.warning("Node [web_search]: 触发 429 限流,tenacity 将接管并退避重试...")
raise Exception("Rate Limited 429")
else:
resp.raise_for_status()
# 兜底异常,彻底消除 IDE 的黄线警告
raise Exception(f"Unexpected HTTP status: {resp.status}")
async def web_search_node(state: GraphState) -> Dict:
"""Queries Semantic Scholar API as a fallback knowledge source."""
logger.info("Node [web_search]: Querying Semantic Scholar")
query = state.get('rewritten_query', state.get('question', ''))
api_key = os.getenv('SEMANTIC_SCHOLAR_API_KEY', '').strip()
try:
data = await _fetch_s2_data(query, api_key)
web_docs = []
for item in data.get('data', []):
doi = item.get('externalIds', {}).get('doi', 'Unknown')
text = (
f"Title: {item.get('title')}\n"
f"Abstract: {item.get('abstract')}\n"
f"Year: {item.get('year')}"
)
web_docs.append({
# 优化:增加防御性获取,防止缺失 paperId
'id': f"web_{item.get('paperId', 'unknown')}",
'doi': doi,
'text': text,
'source': "Semantic Scholar"
})
current_docs = state.get('documents', [])
current_docs.extend(web_docs)
logger.info(f"Node [web_search]: fetched {len(web_docs)} external abstracts")
return {'documents': current_docs}
except Exception as e:
logger.error(f"Node [web_search]: external search failed after retries: {e}")
return {'documents': state.get('documents', [])}
async def generate_node(state: GraphState) -> Dict:
"""Generates the final cited answer, streaming tokens through LangGraph."""
logger.info("Node [generate]: Generating final answer")
# generate 是整条 CRAG 链路里真正负责组织“最终回答”的节点。
# 前面的 memory_inject / retrieve / evaluate / rewrite / web_search
# 本质上都在为这里准备上下文和路由结果。
docs = state['documents']
memories = state.get('long_term_memories', '') # 从状态中获取之前注入的用户长期记忆
# 拼参考文档上下文
context_str = ""
for idx, doc in enumerate(docs):
doi_info = f"(DOI:{doc.get('doi')})" if doc.get('doi') != 'Unknown' else ""
context_str += f"[Source{idx+1}]{doi_info}:\n{doc['text']}\n\n"
# 如果有长期记忆,拼成专门的一段放进prompt,让LLM知道用户的历史背景
memory_section = ""
if memories:
memory_section = f"\n【你对这位用户的了解】\n{memories}\n\n"
prompt = ChatPromptTemplate.from_template(
"作为资深架构师和科研助手,请基于以下提供的学术文献资料回答用户问题。\n\n"
"{memory_section}"
"【绝对规则】\n"
"1. 所有的主张必须基于提供的文献,严禁捏造(幻觉)。\n"
"2. 必须在引用信息的句末严格标明引文来源,格式为 [Source X]。\n"
"3. 若文献中包含数学公式、变量,请使用标准的 LaTeX 格式。\n\n"
"【绝对禁忌(Critical Rules)】:\n"
"1. 绝对不允许说“我无法进行联网搜索”、“我作为一个AI模型没有网络权限”等废话。\n"
"2. 绝对不允许说“根据您提供的文献资料”。\n"
"3. 请表现得就像是你自己通过大脑直接查到了这些实时信息一样,语气要自然、自信。\n"
"文献资料:\n{context}\n\n"
"用户问题: {question}\n\n回答:"
)
chain = prompt | llm
response_content = ""
# 节点内部先把答案 chunk 累加成完整字符串,供 graph state 返回。
# 外层 main.py 的 chat_endpoint 会同时订阅图的 messages 流,
# 再把这些文本片段做 <think> 过滤后通过 SSE 推给前端。
async for chunk in chain.astream({
'context': context_str,
'question': state['question'],
'memory_section': memory_section,
}):
response_content += chunk.content
# 再做一次最终清洗,避免 generation 字段中残留 <think> 内容。
response_content = re.sub(r'<think>.*?</think>', '', response_content, flags=re.DOTALL).strip()
return {'generation': response_content}
def edge_evaluate_node(state: GraphState) -> str:
if state["evaluation_result"] == "pass":
return 'generate'
return 'rewrite'
async def build_production_crag():
workflow = StateGraph(GraphState)
workflow.add_node('memory_inject', memory_inject_node)
workflow.add_node('retrieve', retrieve_node)
workflow.add_node('evaluate', evaluate_node)
workflow.add_node('rewrite', rewritten_node)
workflow.add_node('web_search', web_search_node)
workflow.add_node('generate', generate_node)
# 整体流程:
# 1. memory_inject: 注入用户长期记忆
# 2. retrieve: 先走 Milvus 本地混合检索
# 3. evaluate: LLM 判断当前检索结果是否足够回答问题
# 4. pass -> generate: 直接生成最终答案
# 5. fail -> rewrite -> web_search -> generate:
# 先改写 query,再补外部资料,最后生成答案
workflow.set_entry_point('memory_inject')
workflow.add_edge('memory_inject', 'retrieve')
workflow.add_edge('retrieve', 'evaluate')
workflow.add_conditional_edges(
'evaluate',
edge_evaluate_node,
{'generate': 'generate', 'rewrite': 'rewrite'}
)
workflow.add_edge('rewrite', 'web_search')
workflow.add_edge('web_search', 'generate')
workflow.add_edge('generate', END)
conn_string = os.getenv("DATABASE_URL_SYNC", "")
conn_string = conn_string.replace("+psycopg", "")
pool = AsyncConnectionPool(conninfo=conn_string, open=False)
await pool.open()
checkpointer = AsyncPostgresSaver(pool)
await checkpointer.setup()
return workflow.compile(checkpointer=checkpointer)