Skip to content

Commit 73f8a09

Browse files
authored
Merge pull request #2948 from xinnan-tech/py_test_typing
Py test typing
2 parents 87b99e0 + 795dcec commit 73f8a09

36 files changed

+309
-151
lines changed

main/xiaozhi-server/core/connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(
7878

7979
self.read_config_from_api = self.config.get("read_config_from_api", False)
8080

81-
self.websocket = None
81+
self.websocket: websockets.ServerConnection | None = None
8282
self.headers = None
8383
self.device_id = None
8484
self.client_ip = None
@@ -169,7 +169,7 @@ def __init__(
169169
# 初始化提示词管理器
170170
self.prompt_manager = PromptManager(self.config, self.logger)
171171

172-
async def handle_connection(self, ws):
172+
async def handle_connection(self, ws: websockets.ServerConnection):
173173
try:
174174
# 获取运行中的事件循环(必须在异步上下文中)
175175
self.loop = asyncio.get_running_loop()

main/xiaozhi-server/core/handle/abortHandle.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import json
2+
from typing import TYPE_CHECKING
23

4+
if TYPE_CHECKING:
5+
from core.connection import ConnectionHandler
36
TAG = __name__
47

58

6-
async def handleAbortMessage(conn):
9+
async def handleAbortMessage(conn: "ConnectionHandler"):
710
conn.logger.bind(tag=TAG).info("Abort message received")
811
# 设置成打断状态,会自动打断llm、tts任务
912
conn.client_abort = True

main/xiaozhi-server/core/handle/helloHandle.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
import uuid
44
import random
55
import asyncio
6+
from typing import TYPE_CHECKING
7+
8+
if TYPE_CHECKING:
9+
from core.connection import ConnectionHandler
610
from core.utils.dialogue import Message
711
from core.utils.util import audio_to_data
812
from core.providers.tts.dto.dto import SentenceType
913
from core.utils.wakeup_word import WakeupWordsConfig
1014
from core.handle.sendAudioHandle import sendAudioMessage, send_tts_message
1115
from core.utils.util import remove_punctuation_and_length, opus_datas_to_wav_bytes
12-
from core.providers.tools.device_mcp import (
13-
MCPClient,
14-
send_mcp_initialize_message
15-
)
16+
from core.providers.tools.device_mcp import MCPClient, send_mcp_initialize_message
1617

1718
TAG = __name__
1819

@@ -38,7 +39,7 @@
3839
_wakeup_response_lock = asyncio.Lock()
3940

4041

41-
async def handleHelloMessage(conn, msg_json):
42+
async def handleHelloMessage(conn: "ConnectionHandler", msg_json):
4243
"""处理hello消息"""
4344
audio_params = msg_json.get("audio_params")
4445
if audio_params:
@@ -59,7 +60,7 @@ async def handleHelloMessage(conn, msg_json):
5960
await conn.websocket.send(json.dumps(conn.welcome_msg))
6061

6162

62-
async def checkWakeupWords(conn, text):
63+
async def checkWakeupWords(conn: "ConnectionHandler", text):
6364
enable_wakeup_words_response_cache = conn.config[
6465
"enable_wakeup_words_response_cache"
6566
]
@@ -120,7 +121,7 @@ async def checkWakeupWords(conn, text):
120121
return True
121122

122123

123-
async def wakeupWordsResponse(conn):
124+
async def wakeupWordsResponse(conn: "ConnectionHandler"):
124125
if not conn.tts:
125126
return
126127

main/xiaozhi-server/core/handle/intentHandler.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import json
22
import uuid
33
import asyncio
4+
from typing import TYPE_CHECKING
5+
6+
if TYPE_CHECKING:
7+
from core.connection import ConnectionHandler
48
from core.utils.dialogue import Message
59
from core.providers.tts.dto.dto import ContentType
610
from core.handle.helloHandle import checkWakeupWords
@@ -12,10 +16,10 @@
1216
TAG = __name__
1317

1418

15-
async def handle_user_intent(conn, text):
19+
async def handle_user_intent(conn: "ConnectionHandler", text):
1620
# 预处理输入文本,处理可能的JSON格式
1721
try:
18-
if text.strip().startswith('{') and text.strip().endswith('}'):
22+
if text.strip().startswith("{") and text.strip().endswith("}"):
1923
parsed_data = json.loads(text)
2024
if isinstance(parsed_data, dict) and "content" in parsed_data:
2125
text = parsed_data["content"] # 提取content用于意图分析
@@ -45,7 +49,7 @@ async def handle_user_intent(conn, text):
4549
return await process_intent_result(conn, intent_result, text)
4650

4751

48-
async def check_direct_exit(conn, text):
52+
async def check_direct_exit(conn: "ConnectionHandler", text):
4953
"""检查是否有明确的退出命令"""
5054
_, text = remove_punctuation_and_length(text)
5155
cmd_exit = conn.cmd_exit
@@ -58,7 +62,7 @@ async def check_direct_exit(conn, text):
5862
return False
5963

6064

61-
async def analyze_intent_with_llm(conn, text):
65+
async def analyze_intent_with_llm(conn: "ConnectionHandler", text):
6266
"""使用LLM分析用户意图"""
6367
if not hasattr(conn, "intent") or not conn.intent:
6468
conn.logger.bind(tag=TAG).warning("意图识别服务未初始化")
@@ -75,7 +79,9 @@ async def analyze_intent_with_llm(conn, text):
7579
return None
7680

7781

78-
async def process_intent_result(conn, intent_result, original_text):
82+
async def process_intent_result(
83+
conn: "ConnectionHandler", intent_result, original_text
84+
):
7985
"""处理意图识别结果"""
8086
try:
8187
# 尝试将结果解析为JSON
@@ -94,24 +100,26 @@ async def process_intent_result(conn, intent_result, original_text):
94100
if function_name == "result_for_context":
95101
await send_stt_message(conn, original_text)
96102
conn.client_abort = False
97-
103+
98104
def process_context_result():
99105
conn.dialogue.put(Message(role="user", content=original_text))
100-
106+
101107
from core.utils.current_time import get_current_time_info
102108

103-
current_time, today_date, today_weekday, lunar_date = get_current_time_info()
104-
109+
current_time, today_date, today_weekday, lunar_date = (
110+
get_current_time_info()
111+
)
112+
105113
# 构建带上下文的基础提示
106114
context_prompt = f"""当前时间:{current_time}
107115
今天日期:{today_date} ({today_weekday})
108116
今天农历:{lunar_date}
109117
110118
请根据以上信息回答用户的问题:{original_text}"""
111-
119+
112120
response = conn.intent.replyResult(context_prompt, original_text)
113121
speak_txt(conn, response)
114-
122+
115123
conn.executor.submit(process_context_result)
116124
return True
117125

@@ -188,7 +196,7 @@ def process_function_call():
188196
return False
189197

190198

191-
def speak_txt(conn, text):
199+
def speak_txt(conn: "ConnectionHandler", text):
192200
# 记录文本
193201
conn.tts_MessageText = text
194202

main/xiaozhi-server/core/handle/receiveAudioHandle.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import time
22
import json
33
import asyncio
4+
from typing import TYPE_CHECKING
5+
6+
if TYPE_CHECKING:
7+
from core.connection import ConnectionHandler
48
from core.utils.util import audio_to_data
59
from core.handle.abortHandle import handleAbortMessage
610
from core.handle.intentHandler import handle_user_intent
@@ -10,7 +14,7 @@
1014
TAG = __name__
1115

1216

13-
async def handleAudioMessage(conn, audio):
17+
async def handleAudioMessage(conn: "ConnectionHandler", audio):
1418
# 当前片段是否有人说话
1519
have_voice = conn.vad.is_vad(conn, audio)
1620
# 如果设备刚刚被唤醒,短暂忽略VAD检测
@@ -30,13 +34,13 @@ async def handleAudioMessage(conn, audio):
3034
await conn.asr.receive_audio(conn, audio, have_voice)
3135

3236

33-
async def resume_vad_detection(conn):
37+
async def resume_vad_detection(conn: "ConnectionHandler"):
3438
# 等待2秒后恢复VAD检测
3539
await asyncio.sleep(2)
3640
conn.just_woken_up = False
3741

3842

39-
async def startToChat(conn, text):
43+
async def startToChat(conn: "ConnectionHandler", text):
4044
# 检查输入是否是JSON格式(包含说话人信息)
4145
speaker_name = None
4246
language_tag = None
@@ -96,7 +100,7 @@ async def startToChat(conn, text):
96100
conn.executor.submit(conn.chat, actual_text)
97101

98102

99-
async def no_voice_close_connect(conn, have_voice):
103+
async def no_voice_close_connect(conn: "ConnectionHandler", have_voice):
100104
if have_voice:
101105
conn.last_activity_time = time.time() * 1000
102106
return
@@ -123,7 +127,7 @@ async def no_voice_close_connect(conn, have_voice):
123127
await startToChat(conn, prompt)
124128

125129

126-
async def max_out_size(conn):
130+
async def max_out_size(conn: "ConnectionHandler"):
127131
# 播放超出最大输出字数的提示
128132
conn.client_abort = False
129133
text = "不好意思,我现在有点事情要忙,明天这个时候我们再聊,约好了哦!明天不见不散,拜拜!"
@@ -134,7 +138,7 @@ async def max_out_size(conn):
134138
conn.close_after_chat = True
135139

136140

137-
async def check_bind_device(conn):
141+
async def check_bind_device(conn: "ConnectionHandler"):
138142
if conn.bind_code:
139143
# 确保bind_code是6位数字
140144
if len(conn.bind_code) != 6:

main/xiaozhi-server/core/handle/reportHandle.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111

1212
import time
1313
import opuslib_next
14+
from typing import TYPE_CHECKING
15+
16+
if TYPE_CHECKING:
17+
from core.connection import ConnectionHandler
1418

1519
from config.manage_api_client import report as manage_report
1620

1721
TAG = __name__
1822

1923

20-
async def report(conn, type, text, opus_data, report_time):
24+
async def report(conn: "ConnectionHandler", type, text, opus_data, report_time):
2125
"""执行聊天记录上报操作
2226
2327
Args:
@@ -45,7 +49,7 @@ async def report(conn, type, text, opus_data, report_time):
4549
conn.logger.bind(tag=TAG).error(f"聊天记录上报失败: {e}")
4650

4751

48-
def opus_to_wav(conn, opus_data):
52+
def opus_to_wav(conn: "ConnectionHandler", opus_data):
4953
"""将Opus数据转换为WAV格式的字节流
5054
5155
Args:
@@ -100,7 +104,7 @@ def opus_to_wav(conn, opus_data):
100104
conn.logger.bind(tag=TAG).debug(f"释放decoder资源时出错: {e}")
101105

102106

103-
def enqueue_tts_report(conn, text, opus_data):
107+
def enqueue_tts_report(conn: "ConnectionHandler", text, opus_data):
104108
if not conn.read_config_from_api or conn.need_bind or not conn.report_tts_enable:
105109
return
106110
if conn.chat_history_conf == 0:
@@ -128,7 +132,7 @@ def enqueue_tts_report(conn, text, opus_data):
128132
conn.logger.bind(tag=TAG).error(f"加入TTS上报队列失败: {text}, {e}")
129133

130134

131-
def enqueue_asr_report(conn, text, opus_data):
135+
def enqueue_asr_report(conn: "ConnectionHandler", text, opus_data):
132136
if not conn.read_config_from_api or conn.need_bind or not conn.report_asr_enable:
133137
return
134138
if conn.chat_history_conf == 0:

main/xiaozhi-server/core/handle/sendAudioHandle.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import json
22
import time
33
import asyncio
4+
from typing import TYPE_CHECKING
5+
6+
if TYPE_CHECKING:
7+
from core.connection import ConnectionHandler
48
from core.utils import textUtils
59
from core.utils.util import audio_to_data
610
from core.providers.tts.dto.dto import SentenceType
@@ -13,7 +17,7 @@
1317
PRE_BUFFER_COUNT = 5
1418

1519

16-
async def sendAudioMessage(conn, sentenceType, audios, text):
20+
async def sendAudioMessage(conn: "ConnectionHandler", sentenceType, audios, text):
1721
if conn.tts.tts_audio_first_sentence:
1822
conn.logger.bind(tag=TAG).info(f"发送第一段语音: {text}")
1923
conn.tts.tts_audio_first_sentence = False
@@ -47,7 +51,7 @@ async def sendAudioMessage(conn, sentenceType, audios, text):
4751
await conn.close()
4852

4953

50-
async def _wait_for_audio_completion(conn):
54+
async def _wait_for_audio_completion(conn: "ConnectionHandler"):
5155
"""
5256
等待音频队列清空并等待预缓冲包播放完成
5357
@@ -70,7 +74,9 @@ async def _wait_for_audio_completion(conn):
7074
conn.logger.bind(tag=TAG).debug("音频发送完成")
7175

7276

73-
async def _send_to_mqtt_gateway(conn, opus_packet, timestamp, sequence):
77+
async def _send_to_mqtt_gateway(
78+
conn: "ConnectionHandler", opus_packet, timestamp, sequence
79+
):
7480
"""
7581
发送带16字节头部的opus数据包给mqtt_gateway
7682
Args:
@@ -92,7 +98,9 @@ async def _send_to_mqtt_gateway(conn, opus_packet, timestamp, sequence):
9298
await conn.websocket.send(complete_packet)
9399

94100

95-
async def sendAudio(conn, audios, frame_duration=AUDIO_FRAME_DURATION):
101+
async def sendAudio(
102+
conn: "ConnectionHandler", audios, frame_duration=AUDIO_FRAME_DURATION
103+
):
96104
"""
97105
发送音频包,使用 AudioRateController 进行精确的流量控制
98106
@@ -121,7 +129,9 @@ async def sendAudio(conn, audios, frame_duration=AUDIO_FRAME_DURATION):
121129
)
122130

123131

124-
def _get_or_create_rate_controller(conn, frame_duration, is_single_packet):
132+
def _get_or_create_rate_controller(
133+
conn: "ConnectionHandler", frame_duration, is_single_packet
134+
):
125135
"""
126136
获取或创建 RateController 和 flow_control
127137
@@ -177,7 +187,7 @@ def _get_or_create_rate_controller(conn, frame_duration, is_single_packet):
177187
return conn.audio_rate_controller, conn.audio_flow_control
178188

179189

180-
def _start_background_sender(conn, rate_controller, flow_control):
190+
def _start_background_sender(conn: "ConnectionHandler", rate_controller, flow_control):
181191
"""
182192
启动后台发送循环任务
183193
@@ -201,7 +211,7 @@ async def send_callback(packet):
201211

202212

203213
async def _send_audio_with_rate_control(
204-
conn, audio_list, rate_controller, flow_control, send_delay
214+
conn: "ConnectionHandler", audio_list, rate_controller, flow_control, send_delay
205215
):
206216
"""
207217
使用 rate_controller 发送音频包
@@ -233,7 +243,7 @@ async def _send_audio_with_rate_control(
233243
rate_controller.add_audio(packet)
234244

235245

236-
async def _do_send_audio(conn, opus_packet, flow_control):
246+
async def _do_send_audio(conn: "ConnectionHandler", opus_packet, flow_control):
237247
"""
238248
执行实际的音频发送
239249
"""
@@ -254,7 +264,7 @@ async def _do_send_audio(conn, opus_packet, flow_control):
254264
flow_control["sequence"] = sequence + 1
255265

256266

257-
async def send_tts_message(conn, state, text=None):
267+
async def send_tts_message(conn: "ConnectionHandler", state, text=None):
258268
"""发送 TTS 状态消息"""
259269
if text is None and state == "sentence_start":
260270
return
@@ -281,7 +291,7 @@ async def send_tts_message(conn, state, text=None):
281291
await conn.websocket.send(json.dumps(message))
282292

283293

284-
async def send_stt_message(conn, text):
294+
async def send_stt_message(conn: "ConnectionHandler", text):
285295
"""发送 STT 状态消息"""
286296
end_prompt_str = conn.config.get("end_prompt", {}).get("prompt")
287297
if end_prompt_str and end_prompt_str == text:

0 commit comments

Comments
 (0)