11from collections import deque
2- from collections .abc import Callable
2+ from collections .abc import Awaitable , Callable
33from dataclasses import dataclass
44from datetime import datetime
55from enum import Enum
1111import traceback
1212
1313from nonebot import logger
14- from openai import OpenAI
14+ from openai import AsyncOpenAI
1515
1616from .client import LLMClient
1717from .config import plugin_config
@@ -71,7 +71,7 @@ def __init__(self, siliconflow_api_key: str, id: str = "global", name: str = "te
7171 """
7272 self .global_memory : Memory = Memory (
7373 llm_client = LLMClient (
74- client = OpenAI (
74+ client = AsyncOpenAI (
7575 api_key = plugin_config .nyaturingtest_siliconflow_api_key ,
7676 base_url = "https://api.siliconflow.cn/v1" ,
7777 )
@@ -257,7 +257,7 @@ def load_session(self):
257257 compressed_message = session_data ["global_memory" ].get ("compressed_history" , "" ),
258258 messages = [Message .from_json (msg ) for msg in session_data ["global_memory" ].get ("messages" , [])],
259259 llm_client = LLMClient (
260- client = OpenAI (
260+ client = AsyncOpenAI (
261261 api_key = plugin_config .nyaturingtest_siliconflow_api_key ,
262262 base_url = "https://api.siliconflow.cn/v1" ,
263263 )
@@ -267,7 +267,7 @@ def load_session(self):
267267 logger .error (f"[Session { self .id } ] 恢复全局短时记忆失败: { e } " )
268268 self .global_memory = Memory (
269269 llm_client = LLMClient (
270- client = OpenAI (
270+ client = AsyncOpenAI (
271271 api_key = plugin_config .nyaturingtest_siliconflow_api_key ,
272272 base_url = "https://api.siliconflow.cn/v1" ,
273273 )
@@ -406,8 +406,8 @@ def __search_stage(self, messages_chunk: list[Message]) -> _SearchResult:
406406 mem_history = long_term_memory ,
407407 )
408408
409- def __feedback_stage (
410- self , messages_chunk : list [Message ], search_stage_result : _SearchResult , llm : Callable [[str ], str ]
409+ async def __feedback_stage (
410+ self , messages_chunk : list [Message ], search_stage_result : _SearchResult , llm : Callable [[str ], Awaitable [ str ] ]
411411 ):
412412 """
413413 反馈总结阶段
@@ -626,7 +626,7 @@ def __feedback_stage(
626626}}
627627```
628628"""
629- response = llm (prompt )
629+ response = await llm (prompt )
630630 response = re .sub (r"^```json\s*|\s*```$" , "" , response )
631631 logger .debug (f"反馈阶段llm返回:{ response } " )
632632 try :
@@ -734,11 +734,11 @@ def __feedback_stage(
734734 except Exception as e :
735735 raise ValueError (f"Feedback stage unexpected error: { e } in response: { response } " )
736736
737- def __chat_stage (
737+ async def __chat_stage (
738738 self ,
739739 search_stage_result : _SearchResult ,
740740 messages_chunk : list [Message ],
741- llm : Callable [[str ], str ],
741+ llm : Callable [[str ], Awaitable [ str ] ],
742742 ) -> list [str ]:
743743 """
744744 对话阶段
@@ -867,7 +867,7 @@ def __chat_stage(
867867 ]
868868}}
869869"""
870- response = llm (prompt )
870+ response = await llm (prompt )
871871 response = re .sub (r"^```json\s*|\s*```$" , "" , response )
872872 logger .debug (f"对话阶段llm返回:{ response } " )
873873 try :
@@ -884,29 +884,29 @@ def __chat_stage(
884884 except json .JSONDecodeError :
885885 raise ValueError ("LLM response is not valid JSON, response: " + response )
886886
887- def update (self , messages_chunk : list [Message ], llm : Callable [[str ], str ]) -> list [str ] | None :
887+ async def update (self , messages_chunk : list [Message ], llm : Callable [[str ], Awaitable [ str ] ]) -> list [str ] | None :
888888 """
889889 更新群聊消息
890890 """
891891 # 检索阶段
892892 search_stage_result = self .__search_stage (messages_chunk = messages_chunk )
893893 # 反馈阶段
894- self .__feedback_stage (messages_chunk = messages_chunk , search_stage_result = search_stage_result , llm = llm )
894+ await self .__feedback_stage (messages_chunk = messages_chunk , search_stage_result = search_stage_result , llm = llm )
895895 # 对话阶段
896896 match self .__chatting_state :
897897 case _ChattingState .ILDE :
898898 logger .debug ("nyabot潜水中..." )
899899 reply_messages = None
900900 case _ChattingState .BUBBLE :
901901 logger .debug ("nyabot冒泡中..." )
902- reply_messages = self .__chat_stage (
902+ reply_messages = await self .__chat_stage (
903903 search_stage_result = search_stage_result ,
904904 messages_chunk = messages_chunk ,
905905 llm = llm ,
906906 )
907907 case _ChattingState .ACTIVE :
908908 logger .debug ("nyabot对话中..." )
909- reply_messages = self .__chat_stage (
909+ reply_messages = await self .__chat_stage (
910910 search_stage_result = search_stage_result ,
911911 messages_chunk = messages_chunk ,
912912 llm = llm ,
0 commit comments