Skip to content

Commit c660015

Browse files
windows和linux下测试无碍
1 parent 685a957 commit c660015

File tree

6 files changed

+290
-36
lines changed

6 files changed

+290
-36
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
__pycache__
22
logs
3-
.voiceprint.yaml
3+
/data

app.py

Lines changed: 183 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,202 @@
1+
import os
2+
import yaml
13
import numpy as np
24
import torch
5+
from fastapi import FastAPI, File, UploadFile, Form, Header, HTTPException
6+
from fastapi.responses import JSONResponse
37
from modelscope.pipelines import pipeline
48
from modelscope.utils.constant import Tasks
9+
from db import VoiceprintDB
10+
import uvicorn
11+
import logging
12+
import soundfile as sf
13+
import librosa
14+
import tempfile
515

6-
# 初始化
7-
sv_pipeline = pipeline(
8-
task=Tasks.speaker_verification, model="iic/speech_campplus_sv_zh-cn_3dspeaker_16k"
16+
# 设置日志
17+
logging.basicConfig(
18+
level=logging.INFO,
19+
format="%(asctime)s %(levelname)s %(message)s"
920
)
21+
logger = logging.getLogger(__name__)
1022

11-
voiceprints = {}
23+
# 创建临时目录用于存放上传的音频文件
24+
TMP_DIR = "tmp"
25+
os.makedirs(TMP_DIR, exist_ok=True)
1226

27+
def load_config():
28+
"""
29+
加载配置文件,优先读取环境变量(适合Docker部署),否则读取本地yaml。
30+
"""
31+
config_path = os.path.join("data", ".voiceprint.yaml")
32+
if not os.path.exists(config_path):
33+
logger.error("配置文件 data/.voiceprint.yaml 未找到,请先配置。")
34+
raise RuntimeError("请先配置 data/.voiceprint.yaml")
35+
with open(config_path, "r", encoding="utf-8") as f:
36+
return yaml.safe_load(f)
1337

14-
def _to_numpy(x):
15-
return x.cpu().numpy() if torch.is_tensor(x) else np.asarray(x)
38+
try:
39+
config = load_config()
40+
API_TOKEN = config['server']['token']
41+
except Exception as e:
42+
logger.error(f"配置加载失败: {e}")
43+
raise
1644

45+
# 初始化数据库连接
46+
try:
47+
db = VoiceprintDB(config['mysql'])
48+
logger.info("数据库连接成功。")
49+
except Exception as e:
50+
logger.error(f"数据库连接失败: {e}")
51+
raise
1752

18-
def register_voiceprint(name, audio_path):
19-
"""登记声纹特征"""
20-
result = sv_pipeline([audio_path], output_emb=True)
21-
emb = _to_numpy(result["embs"][0]) # 1 条音频只取第 0 条
22-
voiceprints[name] = emb
23-
print(f"已登记: {name}")
53+
# 初始化声纹模型(线程安全,建议单进程部署,或用gunicorn单进程模式)
54+
try:
55+
sv_pipeline = pipeline(
56+
task=Tasks.speaker_verification, model="iic/speech_campplus_sv_zh-cn_3dspeaker_16k"
57+
)
58+
logger.info("声纹模型加载成功。")
59+
except Exception as e:
60+
logger.error(f"声纹模型加载失败: {e}")
61+
raise
2462

63+
def _to_numpy(x):
64+
"""
65+
将torch tensor或其他类型转为numpy数组
66+
"""
67+
return x.cpu().numpy() if torch.is_tensor(x) else np.asarray(x)
2568

26-
def identify_speaker(audio_path):
27-
"""识别声纹所属"""
28-
test_result = sv_pipeline([audio_path], output_emb=True)
29-
test_emb = _to_numpy(test_result["embs"][0])
69+
app = FastAPI(
70+
title="3D-Speaker 声纹识别API",
71+
description="基于3D-Speaker的声纹注册与识别服务"
72+
)
3073

31-
similarities = {}
32-
for name, emb in voiceprints.items():
33-
cos_sim = np.dot(test_emb, emb) / (
34-
np.linalg.norm(test_emb) * np.linalg.norm(emb)
35-
)
36-
similarities[name] = cos_sim
74+
def check_token(token: str = Header(...)):
75+
"""
76+
校验接口令牌
77+
"""
78+
if token != API_TOKEN:
79+
logger.warning("无效的接口令牌。")
80+
raise HTTPException(status_code=401, detail="无效的接口令牌")
3781

38-
match_name = max(similarities, key=similarities.get)
39-
return match_name, similarities[match_name], similarities
82+
def ensure_16k_wav(audio_bytes):
83+
"""
84+
将任意采样率的wav bytes转为16kHz wav临时文件,返回文件路径
85+
"""
86+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav", dir=TMP_DIR) as tmpf:
87+
tmpf.write(audio_bytes)
88+
tmp_path = tmpf.name
89+
# 读取原采样率
90+
data, sr = sf.read(tmp_path)
91+
if sr != 16000:
92+
# librosa重采样,支持多通道
93+
if data.ndim == 1:
94+
data_rs = librosa.resample(data, orig_sr=sr, target_sr=16000)
95+
else:
96+
data_rs = np.vstack([librosa.resample(data[:, ch], orig_sr=sr, target_sr=16000) for ch in range(data.shape[1])]).T
97+
sf.write(tmp_path, data_rs, 16000)
98+
return tmp_path
4099

100+
@app.post("/register", summary="声纹注册")
101+
async def register(
102+
token: str = Header(..., description="接口令牌"),
103+
speaker_id: str = Form(..., description="说话人ID"),
104+
file: UploadFile = File(..., description="WAV音频文件")
105+
):
106+
"""
107+
注册声纹接口
108+
参数:
109+
token: 接口令牌(Header)
110+
speaker_id: 说话人ID
111+
file: 说话人音频文件(WAV)
112+
返回:
113+
注册结果
114+
"""
115+
check_token(token)
116+
audio_path = None
117+
try:
118+
audio_bytes = await file.read()
119+
audio_path = ensure_16k_wav(audio_bytes)
120+
result = sv_pipeline([audio_path], output_emb=True)
121+
emb = _to_numpy(result["embs"][0]).astype(np.float32)
122+
db.save_voiceprint(speaker_id, emb)
123+
logger.info(f"声纹注册成功: {speaker_id}")
124+
return {"success": True, "msg": f"已登记: {speaker_id}"}
125+
except Exception as e:
126+
logger.error(f"声纹注册失败: {e}")
127+
raise HTTPException(status_code=500, detail=f"声纹注册失败: {e}")
128+
finally:
129+
if audio_path and os.path.exists(audio_path):
130+
os.remove(audio_path)
41131

42-
if __name__ == "__main__":
43-
register_voiceprint("max_output_size", "test//test0.wav")
44-
register_voiceprint("tts1", "test//test1.wav")
132+
@app.post("/identify", summary="声纹识别")
133+
async def identify(
134+
token: str = Header(..., description="接口令牌"),
135+
speaker_ids: str = Form(..., description="候选说话人ID,逗号分隔"),
136+
file: UploadFile = File(..., description="WAV音频文件")
137+
):
138+
"""
139+
声纹识别接口
140+
参数:
141+
token: 接口令牌(Header)
142+
speaker_ids: 候选说话人ID,逗号分隔
143+
file: 待识别音频文件(WAV)
144+
返回:
145+
识别结果(说话人ID、相似度分数)
146+
"""
147+
check_token(token)
148+
candidate_ids = [x.strip() for x in speaker_ids.split(",") if x.strip()]
149+
if not candidate_ids:
150+
logger.warning("候选说话人ID不能为空。")
151+
raise HTTPException(status_code=400, detail="候选说话人ID不能为空")
152+
audio_path = None
153+
try:
154+
audio_bytes = await file.read()
155+
audio_path = ensure_16k_wav(audio_bytes)
156+
result = sv_pipeline([audio_path], output_emb=True)
157+
test_emb = _to_numpy(result["embs"][0]).astype(np.float32)
158+
voiceprints = db.get_voiceprints(candidate_ids)
159+
if not voiceprints:
160+
logger.info("未找到候选说话人声纹。")
161+
return {"speaker_id": "", "score": 0.0}
162+
similarities = {
163+
name: float(np.dot(test_emb, emb) / (np.linalg.norm(test_emb) * np.linalg.norm(emb)))
164+
for name, emb in voiceprints.items()
165+
}
166+
match_name = max(similarities, key=similarities.get)
167+
match_score = similarities[match_name]
168+
if match_score < 0.2:
169+
logger.info(f"未识别到说话人,最高分: {match_score}")
170+
return
171+
logger.info(f"识别到说话人: {match_name}, 分数: {match_score}")
172+
return {"speaker_id": match_name, "score": match_score}
173+
except Exception as e:
174+
logger.error(f"声纹识别失败: {e}")
175+
raise HTTPException(status_code=500, detail=f"声纹识别失败: {e}")
176+
finally:
177+
if audio_path and os.path.exists(audio_path):
178+
os.remove(audio_path)
45179

46-
test_file = "test//test2.wav"
47-
match_name, match_score, all_scores = identify_speaker(test_file)
180+
@app.get("/", include_in_schema=False)
181+
def root():
182+
"""
183+
根路径,返回服务运行信息
184+
"""
185+
return JSONResponse({"msg": "3D-Speaker voiceprint API service running."})
48186

49-
print(f"\n识别结果: {test_file} 属于 {match_name}")
50-
print(f"匹配分数: {match_score:.4f}")
51-
print("\n所有声纹对比分数:")
52-
for name, score in all_scores.items():
53-
print(f"{name}: {score:.4f}")
187+
if __name__ == "__main__":
188+
try:
189+
logger.info(
190+
f"服务启动中,监听地址: {config['server']['host']}:{config['server']['port']},"
191+
f"文档: http://{config['server']['host']}:{config['server']['port']}/docs"
192+
)
193+
print("="*60)
194+
print(f"3D-Speaker 声纹API服务已启动,访问: http://{config['server']['host']}:{config['server']['port']}/docs")
195+
print("="*60)
196+
uvicorn.run(
197+
"app:app",
198+
host=config['server']['host'],
199+
port=config['server']['port'],
200+
)
201+
except KeyboardInterrupt:
202+
logger.info("收到中断信号,正在退出服务。")

db.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pymysql
2+
import numpy as np
3+
4+
class VoiceprintDB:
5+
"""
6+
声纹数据库操作类,负责声纹特征的存储与读取。
7+
"""
8+
9+
def __init__(self, config):
10+
"""
11+
初始化数据库连接。
12+
13+
:param config: dict,包含数据库连接信息(host, port, user, password, database)
14+
"""
15+
self.conn = pymysql.connect(
16+
host=config['host'],
17+
port=config['port'],
18+
user=config['user'],
19+
password=config['password'],
20+
database=config['database'],
21+
charset='utf8mb4',
22+
autocommit=True
23+
)
24+
25+
def save_voiceprint(self, speaker_id, emb):
26+
"""
27+
保存或更新声纹特征。
28+
29+
:param speaker_id: str,说话人ID
30+
:param emb: np.ndarray,声纹特征向量
31+
"""
32+
with self.conn.cursor() as cursor:
33+
sql = """
34+
INSERT INTO voiceprints (speaker_id, feature_vector)
35+
VALUES (%s, %s)
36+
ON DUPLICATE KEY UPDATE feature_vector=VALUES(feature_vector)
37+
"""
38+
cursor.execute(sql, (speaker_id, emb.tobytes()))
39+
40+
def get_voiceprints(self, speaker_ids=None):
41+
"""
42+
获取指定说话人ID的声纹特征(如未指定则获取全部)。
43+
44+
:param speaker_ids: list[str],说话人ID列表
45+
:return: dict,{speaker_id: np.ndarray}
46+
"""
47+
with self.conn.cursor() as cursor:
48+
if speaker_ids:
49+
format_strings = ','.join(['%s'] * len(speaker_ids))
50+
sql = f"SELECT speaker_id, feature_vector FROM voiceprints WHERE speaker_id IN ({format_strings})"
51+
cursor.execute(sql, tuple(speaker_ids))
52+
else:
53+
sql = "SELECT speaker_id, feature_vector FROM voiceprints"
54+
cursor.execute(sql)
55+
results = cursor.fetchall()
56+
# 将数据库中的二进制特征转为numpy数组
57+
return {row[0]: np.frombuffer(row[1], dtype=np.float32) for row in results}

requirements.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,10 @@ transformers==4.52.4
77
torch==2.2.2
88
sentencepiece==0.2.0
99
soundfile==0.13.1
10-
torchaudio==2.2.2
10+
torchaudio==2.2.2
11+
pyyaml==6.0.1
12+
fastapi==0.110.2
13+
uvicorn==0.29.0
14+
PyMySQL==1.1.0
15+
python-multipart==0.0.9
16+
librosa==0.10.1

test_user.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import requests
2+
3+
token = "123456" # 替换为你的真实token
4+
base_url = "http://192.168.4.82:8000"
5+
6+
# # 注册三个说话人
7+
# for i in range(3):
8+
# wav_path = f"test/test{i}.wav"
9+
# speaker_id = f"user_{i}"
10+
# files = {'file': open(wav_path, 'rb')}
11+
# data = {'speaker_id': speaker_id}
12+
# headers = {'token': token}
13+
# resp = requests.post(f"{base_url}/register", files=files, data=data, headers=headers)
14+
# print(f"注册 {speaker_id}:", resp.json())
15+
16+
# 声纹识别
17+
wav_path = "test/test2.wav"
18+
candidate_ids = "user_0,user_1,user_2"
19+
files = {'file': open(wav_path, 'rb')}
20+
data = {'speaker_ids': candidate_ids}
21+
headers = {'token': token}
22+
resp = requests.post(f"{base_url}/identify", files=files, data=data, headers=headers)
23+
print("识别结果:", resp.json())

voiceprint.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
server:
2+
# 服务监听地址,0.0.0.0 表示所有网卡
3+
host: 0.0.0.0
4+
# 服务监听端口
5+
port: 8004
6+
# 接口访问令牌,调用API时需在header中携带
7+
token: "your_api_token"
8+
9+
mysql:
10+
# MySQL数据库主机地址
11+
host: "localhost"
12+
# 端口
13+
port: 3306
14+
# 用户名
15+
user: "root"
16+
# 用户密码
17+
password: "your_password"
18+
# 数据库名
19+
database: "voiceprint_db"

0 commit comments

Comments
 (0)