Skip to content

Commit ce431b3

Browse files
committed
♻️ 减少检索,索引,压缩频率
1 parent d72bc53 commit ce431b3

File tree

4 files changed

+96
-90
lines changed

4 files changed

+96
-90
lines changed

src/nonebot_plugin_nyaturingtest/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ async def do_set_role(matcher: type[Matcher], group_id: int, name: str, role: st
305305
task.add_done_callback(_tasks.discard)
306306
async with group_states[group_id].lock:
307307
state = group_states[group_id]
308-
state.session.set_role(name=name, role=role)
308+
await state.session.set_role(name=name, role=role)
309309
await matcher.finish(f"角色已设为: {name}\n设定: {role}")
310310

311311

@@ -413,7 +413,7 @@ async def do_reset(matcher: type[Matcher], group_id: int):
413413
task.add_done_callback(_tasks.discard)
414414
async with group_states[group_id].lock:
415415
state = group_states[group_id]
416-
state.session.reset()
416+
await state.session.reset()
417417
await matcher.finish("已重置会话")
418418

419419

src/nonebot_plugin_nyaturingtest/hippo_mem.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
import os
33
import shutil
44

5-
from nonebot import logger
6-
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
7-
85
from hipporag import HippoRAG
6+
from nonebot import logger
7+
from transformers.models.auto.tokenization_auto import AutoTokenizer
98

109

1110
class HippoMemory:
@@ -128,7 +127,7 @@ def retrieve(self, queries: list[str], k: int = 5) -> list[str]:
128127

129128

130129
def _split_text_by_tokens(
131-
text: str, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, max_tokens=8192, overlap=100
130+
text: str, tokenizer, max_tokens=8192, overlap=100
132131
) -> list[str]:
133132
"""
134133
按照指定的最大 token 数量和重叠数量将文本分割成多个块

src/nonebot_plugin_nyaturingtest/mem.py

Lines changed: 44 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import asyncio
12
from collections import deque
3+
from collections.abc import Callable
24
from dataclasses import dataclass
35
from datetime import datetime
46

@@ -57,63 +59,40 @@ class MemoryRecord:
5759

5860

5961
class Memory:
60-
"""
61-
短时记忆
62-
"""
63-
6462
def __init__(
6563
self,
6664
llm_client: LLMClient,
6765
compressed_message: str | None = None,
6866
messages: list[Message] | None = None,
6967
length_limit: int = 10,
7068
):
71-
"""
72-
初始化记忆
73-
参数
74-
75-
- length_limit: 记忆消息长度限制
76-
"""
7769
self.__length_limit = length_limit
78-
79-
if compressed_message:
80-
self.__compressed_message = compressed_message
81-
"""
82-
压缩后的旧消息
83-
"""
84-
else:
85-
self.__compressed_message = ""
86-
87-
if messages:
88-
self.__messages = deque(messages, maxlen=length_limit * 5) # 有4/5空间不可直接访问,用于放置待压缩消息
89-
"""
90-
记忆消息列表
91-
"""
92-
else:
93-
self.__messages: deque[Message] = deque(
94-
maxlen=length_limit * 5
95-
) # 有4/5空间不可直接访问,用于放置待压缩消息
96-
70+
self.__compressed_message = compressed_message or ""
71+
self.__messages = deque(messages, maxlen=length_limit * 5) if messages else deque(maxlen=length_limit * 5)
9772
self.__llm_client = llm_client
98-
"""
99-
用于压缩消息的 LLM 客户端
100-
"""
73+
self.__compress_counter = 0
74+
self.__compress_task: asyncio.Task | None = None # 压缩任务句柄
10175

10276
def related_users(self) -> list[str]:
10377
"""
10478
获取相关用户列表
10579
"""
10680
return list({message.user_name for message in self.__messages})
10781

108-
async def compress_message(self):
82+
async def clear(self) -> None:
10983
"""
110-
压缩历史消息
84+
清除所有记忆
11185
"""
112-
history_messages = [f"{message.user_name}: {message.content}" for message in self.__messages][
113-
: self.__length_limit
114-
]
86+
self.__messages.clear()
87+
self.__compressed_message = ""
88+
self.__compress_counter = 0
89+
await self.__cancel_compress_task()
90+
logger.info("已清除所有记忆")
91+
92+
async def __compress_message(self, after_compress: Callable[[], None] | None = None):
93+
history_messages = [f"{msg.user_name}: {msg.content}" for msg in self.__messages]
11594
prompt = f"""
116-
请将以下消息分参与的话题压缩,保留
95+
请将以下消息分参与的话题压缩,提取
11796
11897
- 话题简要内容
11998
- 参与者和它们的发言总结
@@ -130,22 +109,19 @@ async def compress_message(self):
130109
"""
131110
try:
132111
response = await self.__llm_client.generate_response(prompt, model="Qwen/Qwen3-8B")
112+
if after_compress:
113+
after_compress()
133114
if response:
134115
self.__compressed_message = response
135116
logger.info(f"压缩消息成功: {response}")
136117
else:
137118
logger.warning("压缩消息失败,原因未知")
119+
except asyncio.CancelledError:
120+
logger.info("压缩任务被取消")
121+
raise
138122
except Exception as e:
139123
logger.error(f"压缩消息时发生错误: {e}")
140124

141-
def clear(self) -> None:
142-
"""
143-
清除所有记忆
144-
"""
145-
self.__messages.clear()
146-
self.__compressed_message = ""
147-
logger.info("已清除所有记忆")
148-
149125
def access(self) -> MemoryRecord:
150126
"""
151127
访问记忆
@@ -155,11 +131,28 @@ def access(self) -> MemoryRecord:
155131
compressed_history=self.__compressed_message,
156132
)
157133

158-
def update(self, message_chunk: list[Message]):
134+
async def __cancel_compress_task(self):
159135
"""
160-
更新记忆
161-
参数
162-
163-
- message_chunk: 消息块
136+
取消压缩任务
164137
"""
138+
if self.__compress_task and not self.__compress_task.done():
139+
self.__compress_task.cancel()
140+
try:
141+
await self.__compress_task
142+
except asyncio.CancelledError:
143+
pass
144+
145+
async def update(self, message_chunk: list[Message], after_compress: Callable[[], None] | None = None):
165146
self.__messages.extend(message_chunk)
147+
148+
# 每self.__length_limit条消息压缩一次
149+
if self.__compress_counter < self.__length_limit:
150+
self.__compress_counter += 1
151+
return
152+
153+
self.__compress_counter = 0
154+
# 如果有正在执行的压缩任务,先取消它
155+
await self.__cancel_compress_task()
156+
157+
# 开启新的压缩任务
158+
self.__compress_task = asyncio.create_task(self.__compress_message(after_compress=after_compress))

0 commit comments

Comments
 (0)