Skip to content

Commit 4e58ba4

Browse files
committed
update:密钥自动生成
1 parent 38beb9c commit 4e58ba4

File tree

4 files changed

+100
-28
lines changed

4 files changed

+100
-28
lines changed

README.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,23 @@ CREATE TABLE voiceprints (
5858
```
5959
- 复制 `voiceprint.yaml``data/.voiceprint.yaml`
6060

61-
4. 启动
61+
4. 修改配置
62+
修改`data/.voiceprint.yaml`连接数据库的IP、用户名和密码
63+
```
64+
mysql:
65+
database: voiceprint_db
66+
# 你的mysql所在的局域网ip
67+
host: "127.0.0.1"
68+
# 密码
69+
password: 123456
70+
# 端口
71+
port: 3306
72+
# 用户名
73+
user: root
74+
```
75+
76+
77+
5. 启动
6278
```
6379
python app.py
6480
```

app.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
2+
import socket
23
import yaml
34
import numpy as np
45
import torch
6+
import uuid
57
from fastapi import FastAPI, File, UploadFile, Form, Header, HTTPException
68
from fastapi.responses import JSONResponse
79
from modelscope.pipelines import pipeline
@@ -14,37 +16,58 @@
1416
import tempfile
1517

1618
# 设置日志
17-
logging.basicConfig(
18-
level=logging.INFO,
19-
format="%(asctime)s %(levelname)s %(message)s"
20-
)
19+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
2120
logger = logging.getLogger(__name__)
2221

2322
# 创建临时目录用于存放上传的音频文件
2423
TMP_DIR = "tmp"
2524
os.makedirs(TMP_DIR, exist_ok=True)
2625

26+
2727
def load_config():
2828
"""
2929
加载配置文件,优先读取环境变量(适合Docker部署),否则读取本地yaml。
30+
如果authorization不足32位或为空,自动生成UUID并更新配置文件。
3031
"""
3132
config_path = os.path.join("data", ".voiceprint.yaml")
3233
if not os.path.exists(config_path):
3334
logger.error("配置文件 data/.voiceprint.yaml 未找到,请先配置。")
3435
raise RuntimeError("请先配置 data/.voiceprint.yaml")
36+
3537
with open(config_path, "r", encoding="utf-8") as f:
36-
return yaml.safe_load(f)
38+
config = yaml.safe_load(f)
39+
40+
# 检查authorization字段
41+
if "server" not in config:
42+
config["server"] = {}
43+
44+
authorization = config["server"].get("authorization", "")
45+
46+
# 如果authorization为空或长度不足32位,生成新的UUID
47+
if not authorization or len(str(authorization)) < 32:
48+
new_authorization = str(uuid.uuid4())
49+
config["server"]["authorization"] = new_authorization
50+
51+
# 更新配置文件
52+
with open(config_path, "w", encoding="utf-8") as f:
53+
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
54+
55+
logger.info(f"已自动生成新的authorization密钥: {new_authorization}")
56+
logger.info("配置文件已更新,请妥善保管此密钥")
57+
58+
return config
59+
3760

3861
try:
3962
config = load_config()
40-
API_TOKEN = config['server']['authorization']
63+
API_TOKEN = config["server"]["authorization"]
4164
except Exception as e:
4265
logger.error(f"配置加载失败: {e}")
4366
raise
4467

4568
# 初始化数据库连接
4669
try:
47-
db = VoiceprintDB(config['mysql'])
70+
db = VoiceprintDB(config["mysql"])
4871
logger.info("数据库连接成功。")
4972
except Exception as e:
5073
logger.error(f"数据库连接失败: {e}")
@@ -53,32 +76,48 @@ def load_config():
5376
# 初始化声纹模型(线程安全,建议单进程部署,或用gunicorn单进程模式)
5477
try:
5578
sv_pipeline = pipeline(
56-
task=Tasks.speaker_verification, model="iic/speech_campplus_sv_zh-cn_3dspeaker_16k"
79+
task=Tasks.speaker_verification,
80+
model="iic/speech_campplus_sv_zh-cn_3dspeaker_16k",
5781
)
5882
logger.info("声纹模型加载成功。")
5983
except Exception as e:
6084
logger.error(f"声纹模型加载失败: {e}")
6185
raise
6286

87+
6388
def _to_numpy(x):
6489
"""
6590
将torch tensor或其他类型转为numpy数组
6691
"""
6792
return x.cpu().numpy() if torch.is_tensor(x) else np.asarray(x)
6893

94+
6995
app = FastAPI(
70-
title="3D-Speaker 声纹识别API",
71-
description="基于3D-Speaker的声纹注册与识别服务"
96+
title="3D-Speaker 声纹识别API", description="基于3D-Speaker的声纹注册与识别服务"
7297
)
7398

99+
74100
def check_token(token: str = Header(...)):
75101
"""
76102
校验接口令牌
77103
"""
78-
if token != API_TOKEN:
104+
if token != "Bearer " + API_TOKEN:
79105
logger.warning("无效的接口令牌。")
80106
raise HTTPException(status_code=401, detail="无效的接口令牌")
81107

108+
109+
def get_local_ip():
110+
try:
111+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
112+
# Connect to Google's DNS servers
113+
s.connect(("8.8.8.8", 80))
114+
local_ip = s.getsockname()[0]
115+
s.close()
116+
return local_ip
117+
except Exception as e:
118+
return "127.0.0.1"
119+
120+
82121
def ensure_16k_wav(audio_bytes):
83122
"""
84123
将任意采样率的wav bytes转为16kHz wav临时文件,返回文件路径
@@ -93,15 +132,21 @@ def ensure_16k_wav(audio_bytes):
93132
if data.ndim == 1:
94133
data_rs = librosa.resample(data, orig_sr=sr, target_sr=16000)
95134
else:
96-
data_rs = np.vstack([librosa.resample(data[:, ch], orig_sr=sr, target_sr=16000) for ch in range(data.shape[1])]).T
135+
data_rs = np.vstack(
136+
[
137+
librosa.resample(data[:, ch], orig_sr=sr, target_sr=16000)
138+
for ch in range(data.shape[1])
139+
]
140+
).T
97141
sf.write(tmp_path, data_rs, 16000)
98142
return tmp_path
99143

144+
100145
@app.post("/register", summary="声纹注册")
101146
async def register(
102147
authorization: str = Header(..., description="接口令牌", alias="authorization"),
103148
speaker_id: str = Form(..., description="说话人ID"),
104-
file: UploadFile = File(..., description="WAV音频文件")
149+
file: UploadFile = File(..., description="WAV音频文件"),
105150
):
106151
"""
107152
注册声纹接口
@@ -129,11 +174,12 @@ async def register(
129174
if audio_path and os.path.exists(audio_path):
130175
os.remove(audio_path)
131176

177+
132178
@app.post("/identify", summary="声纹识别")
133179
async def identify(
134180
authorization: str = Header(..., description="接口令牌", alias="authorization"),
135181
speaker_ids: str = Form(..., description="候选说话人ID,逗号分隔"),
136-
file: UploadFile = File(..., description="WAV音频文件")
182+
file: UploadFile = File(..., description="WAV音频文件"),
137183
):
138184
"""
139185
声纹识别接口
@@ -160,14 +206,16 @@ async def identify(
160206
logger.info("未找到候选说话人声纹。")
161207
return {"speaker_id": "", "score": 0.0}
162208
similarities = {
163-
name: float(np.dot(test_emb, emb) / (np.linalg.norm(test_emb) * np.linalg.norm(emb)))
209+
name: float(
210+
np.dot(test_emb, emb) / (np.linalg.norm(test_emb) * np.linalg.norm(emb))
211+
)
164212
for name, emb in voiceprints.items()
165213
}
166214
match_name = max(similarities, key=similarities.get)
167215
match_score = similarities[match_name]
168216
if match_score < 0.2:
169217
logger.info(f"未识别到说话人,最高分: {match_score}")
170-
return
218+
return
171219
logger.info(f"识别到说话人: {match_name}, 分数: {match_score}")
172220
return {"speaker_id": match_name, "score": match_score}
173221
except Exception as e:
@@ -177,26 +225,31 @@ async def identify(
177225
if audio_path and os.path.exists(audio_path):
178226
os.remove(audio_path)
179227

228+
180229
@app.get("/", include_in_schema=False)
181230
def root():
182231
"""
183232
根路径,返回服务运行信息
184233
"""
185234
return JSONResponse({"msg": "3D-Speaker voiceprint API service running."})
186235

236+
187237
if __name__ == "__main__":
188238
try:
189239
logger.info(
190240
f"服务启动中,监听地址: {config['server']['host']}:{config['server']['port']},"
191241
f"文档: http://{config['server']['host']}:{config['server']['port']}/docs"
192242
)
193-
print("="*60)
194-
print(f"3D-Speaker 声纹API服务已启动,访问: http://{config['server']['host']}:{config['server']['port']}/docs")
195-
print("="*60)
243+
print("=" * 60)
244+
local_ip = get_local_ip()
245+
print(
246+
f"3D-Speaker 声纹API服务已启动,访问: http://{local_ip}:{config['server']['port']}/docs"
247+
)
248+
print("=" * 60)
196249
uvicorn.run(
197250
"app:app",
198-
host=config['server']['host'],
199-
port=config['server']['port'],
251+
host=config["server"]["host"],
252+
port=config["server"]["port"],
200253
)
201254
except KeyboardInterrupt:
202-
logger.info("收到中断信号,正在退出服务。")
255+
logger.info("收到中断信号,正在退出服务。")

db.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@ def __init__(self, config):
1212
1313
:param config: dict,包含数据库连接信息(host, port, user, password, database)
1414
"""
15+
# 确保密码字段是字符串类型
16+
password = str(config['password']) if config['password'] is not None else ""
17+
1518
self.conn = pymysql.connect(
1619
host=config['host'],
1720
port=config['port'],
1821
user=config['user'],
19-
password=config['password'],
22+
password=password,
2023
database=config['database'],
2124
charset='utf8mb4',
2225
autocommit=True

voiceprint.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ server:
22
# 服务监听地址,0.0.0.0 表示所有网卡
33
host: 0.0.0.0
44
# 服务监听端口
5-
port: 8004
6-
# 接口访问令牌,调用API时需在header中携带
7-
authorization: "Bearer ac1ab7b959989135c030157ee5b73eb5"
5+
port: 8005
6+
# 接口访问令牌,会随机生成,如果为空,会自动生成
7+
authorization:
88

99
mysql:
1010
# MySQL数据库主机地址
11-
host: "localhost"
11+
host: "127.0.0.1"
1212
# 端口
1313
port: 3306
1414
# 用户名

0 commit comments

Comments
 (0)