1
1
import os
2
+ import socket
2
3
import yaml
3
4
import numpy as np
4
5
import torch
6
+ import uuid
5
7
from fastapi import FastAPI , File , UploadFile , Form , Header , HTTPException
6
8
from fastapi .responses import JSONResponse
7
9
from modelscope .pipelines import pipeline
14
16
import tempfile
15
17
16
18
# 设置日志
17
- logging .basicConfig (
18
- level = logging .INFO ,
19
- format = "%(asctime)s %(levelname)s %(message)s"
20
- )
19
+ logging .basicConfig (level = logging .INFO , format = "%(asctime)s %(levelname)s %(message)s" )
21
20
logger = logging .getLogger (__name__ )
22
21
23
22
# 创建临时目录用于存放上传的音频文件
24
23
TMP_DIR = "tmp"
25
24
os .makedirs (TMP_DIR , exist_ok = True )
26
25
26
+
27
27
def load_config ():
28
28
"""
29
29
加载配置文件,优先读取环境变量(适合Docker部署),否则读取本地yaml。
30
+ 如果authorization不足32位或为空,自动生成UUID并更新配置文件。
30
31
"""
31
32
config_path = os .path .join ("data" , ".voiceprint.yaml" )
32
33
if not os .path .exists (config_path ):
33
34
logger .error ("配置文件 data/.voiceprint.yaml 未找到,请先配置。" )
34
35
raise RuntimeError ("请先配置 data/.voiceprint.yaml" )
36
+
35
37
with open (config_path , "r" , encoding = "utf-8" ) as f :
36
- return yaml .safe_load (f )
38
+ config = yaml .safe_load (f )
39
+
40
+ # 检查authorization字段
41
+ if "server" not in config :
42
+ config ["server" ] = {}
43
+
44
+ authorization = config ["server" ].get ("authorization" , "" )
45
+
46
+ # 如果authorization为空或长度不足32位,生成新的UUID
47
+ if not authorization or len (str (authorization )) < 32 :
48
+ new_authorization = str (uuid .uuid4 ())
49
+ config ["server" ]["authorization" ] = new_authorization
50
+
51
+ # 更新配置文件
52
+ with open (config_path , "w" , encoding = "utf-8" ) as f :
53
+ yaml .dump (config , f , default_flow_style = False , allow_unicode = True )
54
+
55
+ logger .info (f"已自动生成新的authorization密钥: { new_authorization } " )
56
+ logger .info ("配置文件已更新,请妥善保管此密钥" )
57
+
58
+ return config
59
+
37
60
38
61
try :
39
62
config = load_config ()
40
- API_TOKEN = config [' server' ][ ' authorization' ]
63
+ API_TOKEN = config [" server" ][ " authorization" ]
41
64
except Exception as e :
42
65
logger .error (f"配置加载失败: { e } " )
43
66
raise
44
67
45
68
# 初始化数据库连接
46
69
try :
47
- db = VoiceprintDB (config [' mysql' ])
70
+ db = VoiceprintDB (config [" mysql" ])
48
71
logger .info ("数据库连接成功。" )
49
72
except Exception as e :
50
73
logger .error (f"数据库连接失败: { e } " )
@@ -53,32 +76,48 @@ def load_config():
53
76
# 初始化声纹模型(线程安全,建议单进程部署,或用gunicorn单进程模式)
54
77
try :
55
78
sv_pipeline = pipeline (
56
- task = Tasks .speaker_verification , model = "iic/speech_campplus_sv_zh-cn_3dspeaker_16k"
79
+ task = Tasks .speaker_verification ,
80
+ model = "iic/speech_campplus_sv_zh-cn_3dspeaker_16k" ,
57
81
)
58
82
logger .info ("声纹模型加载成功。" )
59
83
except Exception as e :
60
84
logger .error (f"声纹模型加载失败: { e } " )
61
85
raise
62
86
87
+
63
88
def _to_numpy (x ):
64
89
"""
65
90
将torch tensor或其他类型转为numpy数组
66
91
"""
67
92
return x .cpu ().numpy () if torch .is_tensor (x ) else np .asarray (x )
68
93
94
+
69
95
app = FastAPI (
70
- title = "3D-Speaker 声纹识别API" ,
71
- description = "基于3D-Speaker的声纹注册与识别服务"
96
+ title = "3D-Speaker 声纹识别API" , description = "基于3D-Speaker的声纹注册与识别服务"
72
97
)
73
98
99
+
74
100
def check_token (token : str = Header (...)):
75
101
"""
76
102
校验接口令牌
77
103
"""
78
- if token != API_TOKEN :
104
+ if token != "Bearer " + API_TOKEN :
79
105
logger .warning ("无效的接口令牌。" )
80
106
raise HTTPException (status_code = 401 , detail = "无效的接口令牌" )
81
107
108
+
109
+ def get_local_ip ():
110
+ try :
111
+ s = socket .socket (socket .AF_INET , socket .SOCK_DGRAM )
112
+ # Connect to Google's DNS servers
113
+ s .connect (("8.8.8.8" , 80 ))
114
+ local_ip = s .getsockname ()[0 ]
115
+ s .close ()
116
+ return local_ip
117
+ except Exception as e :
118
+ return "127.0.0.1"
119
+
120
+
82
121
def ensure_16k_wav (audio_bytes ):
83
122
"""
84
123
将任意采样率的wav bytes转为16kHz wav临时文件,返回文件路径
@@ -93,15 +132,21 @@ def ensure_16k_wav(audio_bytes):
93
132
if data .ndim == 1 :
94
133
data_rs = librosa .resample (data , orig_sr = sr , target_sr = 16000 )
95
134
else :
96
- data_rs = np .vstack ([librosa .resample (data [:, ch ], orig_sr = sr , target_sr = 16000 ) for ch in range (data .shape [1 ])]).T
135
+ data_rs = np .vstack (
136
+ [
137
+ librosa .resample (data [:, ch ], orig_sr = sr , target_sr = 16000 )
138
+ for ch in range (data .shape [1 ])
139
+ ]
140
+ ).T
97
141
sf .write (tmp_path , data_rs , 16000 )
98
142
return tmp_path
99
143
144
+
100
145
@app .post ("/register" , summary = "声纹注册" )
101
146
async def register (
102
147
authorization : str = Header (..., description = "接口令牌" , alias = "authorization" ),
103
148
speaker_id : str = Form (..., description = "说话人ID" ),
104
- file : UploadFile = File (..., description = "WAV音频文件" )
149
+ file : UploadFile = File (..., description = "WAV音频文件" ),
105
150
):
106
151
"""
107
152
注册声纹接口
@@ -129,11 +174,12 @@ async def register(
129
174
if audio_path and os .path .exists (audio_path ):
130
175
os .remove (audio_path )
131
176
177
+
132
178
@app .post ("/identify" , summary = "声纹识别" )
133
179
async def identify (
134
180
authorization : str = Header (..., description = "接口令牌" , alias = "authorization" ),
135
181
speaker_ids : str = Form (..., description = "候选说话人ID,逗号分隔" ),
136
- file : UploadFile = File (..., description = "WAV音频文件" )
182
+ file : UploadFile = File (..., description = "WAV音频文件" ),
137
183
):
138
184
"""
139
185
声纹识别接口
@@ -160,14 +206,16 @@ async def identify(
160
206
logger .info ("未找到候选说话人声纹。" )
161
207
return {"speaker_id" : "" , "score" : 0.0 }
162
208
similarities = {
163
- name : float (np .dot (test_emb , emb ) / (np .linalg .norm (test_emb ) * np .linalg .norm (emb )))
209
+ name : float (
210
+ np .dot (test_emb , emb ) / (np .linalg .norm (test_emb ) * np .linalg .norm (emb ))
211
+ )
164
212
for name , emb in voiceprints .items ()
165
213
}
166
214
match_name = max (similarities , key = similarities .get )
167
215
match_score = similarities [match_name ]
168
216
if match_score < 0.2 :
169
217
logger .info (f"未识别到说话人,最高分: { match_score } " )
170
- return
218
+ return
171
219
logger .info (f"识别到说话人: { match_name } , 分数: { match_score } " )
172
220
return {"speaker_id" : match_name , "score" : match_score }
173
221
except Exception as e :
@@ -177,26 +225,31 @@ async def identify(
177
225
if audio_path and os .path .exists (audio_path ):
178
226
os .remove (audio_path )
179
227
228
+
180
229
@app .get ("/" , include_in_schema = False )
181
230
def root ():
182
231
"""
183
232
根路径,返回服务运行信息
184
233
"""
185
234
return JSONResponse ({"msg" : "3D-Speaker voiceprint API service running." })
186
235
236
+
187
237
if __name__ == "__main__" :
188
238
try :
189
239
logger .info (
190
240
f"服务启动中,监听地址: { config ['server' ]['host' ]} :{ config ['server' ]['port' ]} ,"
191
241
f"文档: http://{ config ['server' ]['host' ]} :{ config ['server' ]['port' ]} /docs"
192
242
)
193
- print ("=" * 60 )
194
- print (f"3D-Speaker 声纹API服务已启动,访问: http://{ config ['server' ]['host' ]} :{ config ['server' ]['port' ]} /docs" )
195
- print ("=" * 60 )
243
+ print ("=" * 60 )
244
+ local_ip = get_local_ip ()
245
+ print (
246
+ f"3D-Speaker 声纹API服务已启动,访问: http://{ local_ip } :{ config ['server' ]['port' ]} /docs"
247
+ )
248
+ print ("=" * 60 )
196
249
uvicorn .run (
197
250
"app:app" ,
198
- host = config [' server' ][ ' host' ],
199
- port = config [' server' ][ ' port' ],
251
+ host = config [" server" ][ " host" ],
252
+ port = config [" server" ][ " port" ],
200
253
)
201
254
except KeyboardInterrupt :
202
- logger .info ("收到中断信号,正在退出服务。" )
255
+ logger .info ("收到中断信号,正在退出服务。" )
0 commit comments