Skip to content

Commit f9703f0

Browse files
committed
♻️ 将OpenAI客户端更改为异步版本
1 parent a5ab993 commit f9703f0

File tree

4 files changed

+28
-28
lines changed

4 files changed

+28
-28
lines changed

src/nonebot_plugin_nyaturingtest/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from nonebot.params import CommandArg
2222
from nonebot.permission import SUPERUSER
2323
from nonebot.plugin import PluginMetadata
24-
from openai import OpenAI
24+
from openai import AsyncOpenAI
2525

2626
from .client import LLMClient
2727
from .config import Config, plugin_config
@@ -58,7 +58,7 @@ class GroupState:
5858
)
5959
messages_chunk: list[MMessage] = field(default_factory=list)
6060
client = LLMClient(
61-
client=OpenAI(
61+
client=AsyncOpenAI(
6262
api_key=plugin_config.nyaturingtest_chat_openai_api_key,
6363
base_url=plugin_config.nyaturingtest_chat_openai_base_url,
6464
)
@@ -84,7 +84,7 @@ async def spawn_state(state: GroupState):
8484
messages_chunk = state.messages_chunk.copy()
8585
state.messages_chunk.clear()
8686
try:
87-
responses = state.session.update(
87+
responses = await state.session.update(
8888
messages_chunk=messages_chunk, llm=lambda x: llm_response(state.client, x)
8989
)
9090
except Exception as e:
@@ -461,9 +461,9 @@ async def handle_list_groups_pm():
461461
await list_groups_pm.finish(msg)
462462

463463

464-
def llm_response(client: LLMClient, message: str) -> str:
464+
async def llm_response(client: LLMClient, message: str) -> str:
465465
try:
466-
result = client.generate_response(prompt=message, model=plugin_config.nyaturingtest_chat_openai_model)
466+
result = await client.generate_response(prompt=message, model=plugin_config.nyaturingtest_chat_openai_model)
467467
if result:
468468
return result
469469
else:

src/nonebot_plugin_nyaturingtest/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from openai import OpenAI
1+
from openai import AsyncOpenAI
22

33

44
class LLMClient:
5-
def __init__(self, client: OpenAI):
5+
def __init__(self, client: AsyncOpenAI):
66
self.client = client
77

8-
def generate_response(self, prompt: str, model: str) -> str | None:
9-
response = self.client.chat.completions.create(
8+
async def generate_response(self, prompt: str, model: str) -> str | None:
9+
response = await self.client.chat.completions.create(
1010
messages=[{"role": "user", "content": prompt}],
1111
model=model,
1212
temperature=0.5,

src/nonebot_plugin_nyaturingtest/session.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import deque
2-
from collections.abc import Callable
2+
from collections.abc import Awaitable, Callable
33
from dataclasses import dataclass
44
from datetime import datetime
55
from enum import Enum
@@ -11,7 +11,7 @@
1111
import traceback
1212

1313
from nonebot import logger
14-
from openai import OpenAI
14+
from openai import AsyncOpenAI
1515

1616
from .client import LLMClient
1717
from .config import plugin_config
@@ -71,7 +71,7 @@ def __init__(self, siliconflow_api_key: str, id: str = "global", name: str = "te
7171
"""
7272
self.global_memory: Memory = Memory(
7373
llm_client=LLMClient(
74-
client=OpenAI(
74+
client=AsyncOpenAI(
7575
api_key=plugin_config.nyaturingtest_siliconflow_api_key,
7676
base_url="https://api.siliconflow.cn/v1",
7777
)
@@ -257,7 +257,7 @@ def load_session(self):
257257
compressed_message=session_data["global_memory"].get("compressed_history", ""),
258258
messages=[Message.from_json(msg) for msg in session_data["global_memory"].get("messages", [])],
259259
llm_client=LLMClient(
260-
client=OpenAI(
260+
client=AsyncOpenAI(
261261
api_key=plugin_config.nyaturingtest_siliconflow_api_key,
262262
base_url="https://api.siliconflow.cn/v1",
263263
)
@@ -267,7 +267,7 @@ def load_session(self):
267267
logger.error(f"[Session {self.id}] 恢复全局短时记忆失败: {e}")
268268
self.global_memory = Memory(
269269
llm_client=LLMClient(
270-
client=OpenAI(
270+
client=AsyncOpenAI(
271271
api_key=plugin_config.nyaturingtest_siliconflow_api_key,
272272
base_url="https://api.siliconflow.cn/v1",
273273
)
@@ -406,8 +406,8 @@ def __search_stage(self, messages_chunk: list[Message]) -> _SearchResult:
406406
mem_history=long_term_memory,
407407
)
408408

409-
def __feedback_stage(
410-
self, messages_chunk: list[Message], search_stage_result: _SearchResult, llm: Callable[[str], str]
409+
async def __feedback_stage(
410+
self, messages_chunk: list[Message], search_stage_result: _SearchResult, llm: Callable[[str], Awaitable[str]]
411411
):
412412
"""
413413
反馈总结阶段
@@ -626,7 +626,7 @@ def __feedback_stage(
626626
}}
627627
```
628628
"""
629-
response = llm(prompt)
629+
response = await llm(prompt)
630630
response = re.sub(r"^```json\s*|\s*```$", "", response)
631631
logger.debug(f"反馈阶段llm返回:{response}")
632632
try:
@@ -734,11 +734,11 @@ def __feedback_stage(
734734
except Exception as e:
735735
raise ValueError(f"Feedback stage unexpected error: {e} in response: {response}")
736736

737-
def __chat_stage(
737+
async def __chat_stage(
738738
self,
739739
search_stage_result: _SearchResult,
740740
messages_chunk: list[Message],
741-
llm: Callable[[str], str],
741+
llm: Callable[[str], Awaitable[str]],
742742
) -> list[str]:
743743
"""
744744
对话阶段
@@ -867,7 +867,7 @@ def __chat_stage(
867867
]
868868
}}
869869
"""
870-
response = llm(prompt)
870+
response = await llm(prompt)
871871
response = re.sub(r"^```json\s*|\s*```$", "", response)
872872
logger.debug(f"对话阶段llm返回:{response}")
873873
try:
@@ -884,29 +884,29 @@ def __chat_stage(
884884
except json.JSONDecodeError:
885885
raise ValueError("LLM response is not valid JSON, response: " + response)
886886

887-
def update(self, messages_chunk: list[Message], llm: Callable[[str], str]) -> list[str] | None:
887+
async def update(self, messages_chunk: list[Message], llm: Callable[[str], Awaitable[str]]) -> list[str] | None:
888888
"""
889889
更新群聊消息
890890
"""
891891
# 检索阶段
892892
search_stage_result = self.__search_stage(messages_chunk=messages_chunk)
893893
# 反馈阶段
894-
self.__feedback_stage(messages_chunk=messages_chunk, search_stage_result=search_stage_result, llm=llm)
894+
await self.__feedback_stage(messages_chunk=messages_chunk, search_stage_result=search_stage_result, llm=llm)
895895
# 对话阶段
896896
match self.__chatting_state:
897897
case _ChattingState.ILDE:
898898
logger.debug("nyabot潜水中...")
899899
reply_messages = None
900900
case _ChattingState.BUBBLE:
901901
logger.debug("nyabot冒泡中...")
902-
reply_messages = self.__chat_stage(
902+
reply_messages = await self.__chat_stage(
903903
search_stage_result=search_stage_result,
904904
messages_chunk=messages_chunk,
905905
llm=llm,
906906
)
907907
case _ChattingState.ACTIVE:
908908
logger.debug("nyabot对话中...")
909-
reply_messages = self.__chat_stage(
909+
reply_messages = await self.__chat_stage(
910910
search_stage_result=search_stage_result,
911911
messages_chunk=messages_chunk,
912912
llm=llm,

src/nonebot_plugin_nyaturingtest/vlm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from openai import OpenAI
1+
from openai import AsyncOpenAI
22

33

44
class SiliconFlowVLM:
@@ -31,7 +31,7 @@ def __init__(
3131
max_retries: 最大重试次数
3232
retry_delay: 重试延迟(秒)
3333
"""
34-
self.client = OpenAI(
34+
self.client = AsyncOpenAI(
3535
api_key=api_key,
3636
base_url=endpoint,
3737
)
@@ -40,7 +40,7 @@ def __init__(
4040
self.max_retries = max_retries
4141
self.retry_delay = retry_delay
4242

43-
def request(
43+
async def request(
4444
self,
4545
prompt: str,
4646
image_base64: str,
@@ -49,7 +49,7 @@ def request(
4949
"""
5050
让vlm根据图片和文本提示词生成描述
5151
"""
52-
responese = self.client.chat.completions.create(
52+
responese = await self.client.chat.completions.create(
5353
model=self.model,
5454
messages=[
5555
{

0 commit comments

Comments
 (0)