1
+ import os
2
+ import yaml
1
3
import numpy as np
2
4
import torch
5
+ from fastapi import FastAPI , File , UploadFile , Form , Header , HTTPException
6
+ from fastapi .responses import JSONResponse
3
7
from modelscope .pipelines import pipeline
4
8
from modelscope .utils .constant import Tasks
9
+ from db import VoiceprintDB
10
+ import uvicorn
11
+ import logging
12
+ import soundfile as sf
13
+ import librosa
14
+ import tempfile
5
15
6
- # 初始化
7
- sv_pipeline = pipeline (
8
- task = Tasks .speaker_verification , model = "iic/speech_campplus_sv_zh-cn_3dspeaker_16k"
16
+ # 设置日志
17
+ logging .basicConfig (
18
+ level = logging .INFO ,
19
+ format = "%(asctime)s %(levelname)s %(message)s"
9
20
)
21
+ logger = logging .getLogger (__name__ )
10
22
11
- voiceprints = {}
23
+ # 创建临时目录用于存放上传的音频文件
24
+ TMP_DIR = "tmp"
25
+ os .makedirs (TMP_DIR , exist_ok = True )
12
26
27
+ def load_config ():
28
+ """
29
+ 加载配置文件,优先读取环境变量(适合Docker部署),否则读取本地yaml。
30
+ """
31
+ config_path = os .path .join ("data" , ".voiceprint.yaml" )
32
+ if not os .path .exists (config_path ):
33
+ logger .error ("配置文件 data/.voiceprint.yaml 未找到,请先配置。" )
34
+ raise RuntimeError ("请先配置 data/.voiceprint.yaml" )
35
+ with open (config_path , "r" , encoding = "utf-8" ) as f :
36
+ return yaml .safe_load (f )
13
37
14
- def _to_numpy (x ):
15
- return x .cpu ().numpy () if torch .is_tensor (x ) else np .asarray (x )
38
+ try :
39
+ config = load_config ()
40
+ API_TOKEN = config ['server' ]['token' ]
41
+ except Exception as e :
42
+ logger .error (f"配置加载失败: { e } " )
43
+ raise
16
44
45
+ # 初始化数据库连接
46
+ try :
47
+ db = VoiceprintDB (config ['mysql' ])
48
+ logger .info ("数据库连接成功。" )
49
+ except Exception as e :
50
+ logger .error (f"数据库连接失败: { e } " )
51
+ raise
17
52
18
- def register_voiceprint (name , audio_path ):
19
- """登记声纹特征"""
20
- result = sv_pipeline ([audio_path ], output_emb = True )
21
- emb = _to_numpy (result ["embs" ][0 ]) # 1 条音频只取第 0 条
22
- voiceprints [name ] = emb
23
- print (f"已登记: { name } " )
53
+ # 初始化声纹模型(线程安全,建议单进程部署,或用gunicorn单进程模式)
54
+ try :
55
+ sv_pipeline = pipeline (
56
+ task = Tasks .speaker_verification , model = "iic/speech_campplus_sv_zh-cn_3dspeaker_16k"
57
+ )
58
+ logger .info ("声纹模型加载成功。" )
59
+ except Exception as e :
60
+ logger .error (f"声纹模型加载失败: { e } " )
61
+ raise
24
62
63
+ def _to_numpy (x ):
64
+ """
65
+ 将torch tensor或其他类型转为numpy数组
66
+ """
67
+ return x .cpu ().numpy () if torch .is_tensor (x ) else np .asarray (x )
25
68
26
- def identify_speaker ( audio_path ):
27
- """识别声纹所属"""
28
- test_result = sv_pipeline ([ audio_path ], output_emb = True )
29
- test_emb = _to_numpy ( test_result [ "embs" ][ 0 ] )
69
+ app = FastAPI (
70
+ title = "3D-Speaker 声纹识别API" ,
71
+ description = "基于3D-Speaker的声纹注册与识别服务"
72
+ )
30
73
31
- similarities = {}
32
- for name , emb in voiceprints .items ():
33
- cos_sim = np .dot (test_emb , emb ) / (
34
- np .linalg .norm (test_emb ) * np .linalg .norm (emb )
35
- )
36
- similarities [name ] = cos_sim
74
+ def check_token (token : str = Header (...)):
75
+ """
76
+ 校验接口令牌
77
+ """
78
+ if token != API_TOKEN :
79
+ logger .warning ("无效的接口令牌。" )
80
+ raise HTTPException (status_code = 401 , detail = "无效的接口令牌" )
37
81
38
- match_name = max (similarities , key = similarities .get )
39
- return match_name , similarities [match_name ], similarities
82
+ def ensure_16k_wav (audio_bytes ):
83
+ """
84
+ 将任意采样率的wav bytes转为16kHz wav临时文件,返回文件路径
85
+ """
86
+ with tempfile .NamedTemporaryFile (delete = False , suffix = ".wav" , dir = TMP_DIR ) as tmpf :
87
+ tmpf .write (audio_bytes )
88
+ tmp_path = tmpf .name
89
+ # 读取原采样率
90
+ data , sr = sf .read (tmp_path )
91
+ if sr != 16000 :
92
+ # librosa重采样,支持多通道
93
+ if data .ndim == 1 :
94
+ data_rs = librosa .resample (data , orig_sr = sr , target_sr = 16000 )
95
+ else :
96
+ data_rs = np .vstack ([librosa .resample (data [:, ch ], orig_sr = sr , target_sr = 16000 ) for ch in range (data .shape [1 ])]).T
97
+ sf .write (tmp_path , data_rs , 16000 )
98
+ return tmp_path
40
99
100
+ @app .post ("/register" , summary = "声纹注册" )
101
+ async def register (
102
+ token : str = Header (..., description = "接口令牌" ),
103
+ speaker_id : str = Form (..., description = "说话人ID" ),
104
+ file : UploadFile = File (..., description = "WAV音频文件" )
105
+ ):
106
+ """
107
+ 注册声纹接口
108
+ 参数:
109
+ token: 接口令牌(Header)
110
+ speaker_id: 说话人ID
111
+ file: 说话人音频文件(WAV)
112
+ 返回:
113
+ 注册结果
114
+ """
115
+ check_token (token )
116
+ audio_path = None
117
+ try :
118
+ audio_bytes = await file .read ()
119
+ audio_path = ensure_16k_wav (audio_bytes )
120
+ result = sv_pipeline ([audio_path ], output_emb = True )
121
+ emb = _to_numpy (result ["embs" ][0 ]).astype (np .float32 )
122
+ db .save_voiceprint (speaker_id , emb )
123
+ logger .info (f"声纹注册成功: { speaker_id } " )
124
+ return {"success" : True , "msg" : f"已登记: { speaker_id } " }
125
+ except Exception as e :
126
+ logger .error (f"声纹注册失败: { e } " )
127
+ raise HTTPException (status_code = 500 , detail = f"声纹注册失败: { e } " )
128
+ finally :
129
+ if audio_path and os .path .exists (audio_path ):
130
+ os .remove (audio_path )
41
131
42
- if __name__ == "__main__" :
43
- register_voiceprint ("max_output_size" , "test//test0.wav" )
44
- register_voiceprint ("tts1" , "test//test1.wav" )
132
+ @app .post ("/identify" , summary = "声纹识别" )
133
+ async def identify (
134
+ token : str = Header (..., description = "接口令牌" ),
135
+ speaker_ids : str = Form (..., description = "候选说话人ID,逗号分隔" ),
136
+ file : UploadFile = File (..., description = "WAV音频文件" )
137
+ ):
138
+ """
139
+ 声纹识别接口
140
+ 参数:
141
+ token: 接口令牌(Header)
142
+ speaker_ids: 候选说话人ID,逗号分隔
143
+ file: 待识别音频文件(WAV)
144
+ 返回:
145
+ 识别结果(说话人ID、相似度分数)
146
+ """
147
+ check_token (token )
148
+ candidate_ids = [x .strip () for x in speaker_ids .split ("," ) if x .strip ()]
149
+ if not candidate_ids :
150
+ logger .warning ("候选说话人ID不能为空。" )
151
+ raise HTTPException (status_code = 400 , detail = "候选说话人ID不能为空" )
152
+ audio_path = None
153
+ try :
154
+ audio_bytes = await file .read ()
155
+ audio_path = ensure_16k_wav (audio_bytes )
156
+ result = sv_pipeline ([audio_path ], output_emb = True )
157
+ test_emb = _to_numpy (result ["embs" ][0 ]).astype (np .float32 )
158
+ voiceprints = db .get_voiceprints (candidate_ids )
159
+ if not voiceprints :
160
+ logger .info ("未找到候选说话人声纹。" )
161
+ return {"speaker_id" : "" , "score" : 0.0 }
162
+ similarities = {
163
+ name : float (np .dot (test_emb , emb ) / (np .linalg .norm (test_emb ) * np .linalg .norm (emb )))
164
+ for name , emb in voiceprints .items ()
165
+ }
166
+ match_name = max (similarities , key = similarities .get )
167
+ match_score = similarities [match_name ]
168
+ if match_score < 0.2 :
169
+ logger .info (f"未识别到说话人,最高分: { match_score } " )
170
+ return
171
+ logger .info (f"识别到说话人: { match_name } , 分数: { match_score } " )
172
+ return {"speaker_id" : match_name , "score" : match_score }
173
+ except Exception as e :
174
+ logger .error (f"声纹识别失败: { e } " )
175
+ raise HTTPException (status_code = 500 , detail = f"声纹识别失败: { e } " )
176
+ finally :
177
+ if audio_path and os .path .exists (audio_path ):
178
+ os .remove (audio_path )
45
179
46
- test_file = "test//test2.wav"
47
- match_name , match_score , all_scores = identify_speaker (test_file )
180
+ @app .get ("/" , include_in_schema = False )
181
+ def root ():
182
+ """
183
+ 根路径,返回服务运行信息
184
+ """
185
+ return JSONResponse ({"msg" : "3D-Speaker voiceprint API service running." })
48
186
49
- print (f"\n 识别结果: { test_file } 属于 { match_name } " )
50
- print (f"匹配分数: { match_score :.4f} " )
51
- print ("\n 所有声纹对比分数:" )
52
- for name , score in all_scores .items ():
53
- print (f"{ name } : { score :.4f} " )
187
+ if __name__ == "__main__" :
188
+ try :
189
+ logger .info (
190
+ f"服务启动中,监听地址: { config ['server' ]['host' ]} :{ config ['server' ]['port' ]} ,"
191
+ f"文档: http://{ config ['server' ]['host' ]} :{ config ['server' ]['port' ]} /docs"
192
+ )
193
+ print ("=" * 60 )
194
+ print (f"3D-Speaker 声纹API服务已启动,访问: http://{ config ['server' ]['host' ]} :{ config ['server' ]['port' ]} /docs" )
195
+ print ("=" * 60 )
196
+ uvicorn .run (
197
+ "app:app" ,
198
+ host = config ['server' ]['host' ],
199
+ port = config ['server' ]['port' ],
200
+ )
201
+ except KeyboardInterrupt :
202
+ logger .info ("收到中断信号,正在退出服务。" )
0 commit comments