Skip to content

Commit 37a83e5

Browse files
committed
feat: support elevenlabs tts
1 parent 5d8330c commit 37a83e5

File tree

6 files changed

+311
-0
lines changed

6 files changed

+311
-0
lines changed

src/ai/elevenlabs/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod tts;

src/ai/elevenlabs/tts.rs

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
use std::fmt::Display;
2+
3+
use base64::prelude::*;
4+
use futures_util::{SinkExt, StreamExt};
5+
use reqwest_websocket::{RequestBuilderExt, WebSocket};
6+
7+
#[derive(Debug, serde::Deserialize)]
8+
pub struct Alignment {
9+
pub chars: Vec<String>,
10+
}
11+
12+
#[derive(Debug, serde::Deserialize)]
13+
pub struct Response {
14+
#[serde(default)]
15+
pub alignment: Option<Alignment>,
16+
#[serde(default)]
17+
pub audio: Option<String>,
18+
#[serde(default, rename = "isFinal")]
19+
pub is_final: Option<bool>,
20+
#[serde(default)]
21+
pub error: String,
22+
#[serde(default)]
23+
pub message: String,
24+
}
25+
26+
impl Response {
27+
pub fn is_error(&self) -> bool {
28+
!self.error.is_empty()
29+
}
30+
31+
pub fn get_audio_bytes(&self) -> Option<Vec<u8>> {
32+
let _ = self.alignment.as_ref()?;
33+
self.audio
34+
.as_ref()
35+
.and_then(|audio_base64| BASE64_STANDARD.decode(audio_base64).ok())
36+
}
37+
38+
pub fn is_final(&self) -> bool {
39+
self.is_final.unwrap_or(false)
40+
}
41+
}
42+
43+
#[test]
44+
fn test_response_deserialize() {
45+
let json_data = r#"
46+
{
47+
"alignment": null,
48+
"audio": "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAIA+AAACABAAZGF0YRAAAAAA",
49+
"isFinal": null
50+
}
51+
"#;
52+
53+
let response: Response = serde_json::from_str(json_data).unwrap();
54+
println!("{:?}", response);
55+
assert!(!response.is_error());
56+
assert!(!response.is_final());
57+
assert!(response.get_audio_bytes().is_none());
58+
59+
let json_data_with_audio = r#"
60+
{
61+
"alignment": {},
62+
"audio": "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAIA+AAACABAAZGF0YRAAAAAA",
63+
"isFinal": true
64+
}
65+
"#;
66+
67+
let response_with_audio: Response = serde_json::from_str(json_data_with_audio).unwrap();
68+
println!("{:?}", response_with_audio);
69+
assert!(!response_with_audio.is_error());
70+
assert!(response_with_audio.is_final());
71+
assert!(response_with_audio.get_audio_bytes().is_some());
72+
}
73+
74+
pub struct ElevenlabsTTS {
75+
pub token: String,
76+
pub voice: String,
77+
websocket: WebSocket,
78+
}
79+
80+
const MODEL_ID: &str = "eleven_flash_v2_5";
81+
82+
pub enum OutputFormat {
83+
Pcm16000,
84+
Pcm24000,
85+
}
86+
87+
impl Display for OutputFormat {
88+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89+
match self {
90+
OutputFormat::Pcm16000 => write!(f, "pcm_16000"),
91+
OutputFormat::Pcm24000 => write!(f, "pcm_24000"),
92+
}
93+
}
94+
}
95+
96+
impl ElevenlabsTTS {
97+
pub async fn new(
98+
token: String,
99+
voice: String,
100+
output_format: OutputFormat,
101+
) -> anyhow::Result<Self> {
102+
let url = format!(
103+
"wss://api.elevenlabs.io/v1/text-to-speech/{voice}/stream-input?model_id={MODEL_ID}&output_format={output_format}",
104+
);
105+
106+
let client = reqwest::Client::new();
107+
108+
let response = client
109+
.get(url)
110+
.header("xi-api-key", &token)
111+
.upgrade()
112+
.send()
113+
.await?;
114+
115+
let websocket = response.into_websocket().await?;
116+
117+
Ok(Self {
118+
token,
119+
voice,
120+
websocket,
121+
})
122+
}
123+
124+
pub async fn initialize_connection(&mut self) -> anyhow::Result<()> {
125+
let init_message = serde_json::json!({
126+
"text": " ",
127+
});
128+
129+
let message_json = serde_json::to_string(&init_message)?;
130+
self.websocket
131+
.send(reqwest_websocket::Message::Text(message_json))
132+
.await?;
133+
134+
Ok(())
135+
}
136+
137+
pub async fn send_text(&mut self, text: &str, flush: bool) -> anyhow::Result<()> {
138+
let text_message = serde_json::json!({
139+
"text": text,
140+
"flush": flush,
141+
});
142+
143+
let message_json = serde_json::to_string(&text_message)?;
144+
self.websocket
145+
.send(reqwest_websocket::Message::Text(message_json))
146+
.await?;
147+
148+
Ok(())
149+
}
150+
151+
pub async fn close_connection(&mut self) -> anyhow::Result<()> {
152+
let close_message = serde_json::json!({
153+
"text": "",
154+
});
155+
self.websocket
156+
.send(reqwest_websocket::Message::Text(close_message.to_string()))
157+
.await?;
158+
Ok(())
159+
}
160+
161+
pub async fn next_audio_response(&mut self) -> anyhow::Result<Option<Response>> {
162+
while let Some(message) = self.websocket.next().await {
163+
match message.map_err(|e| anyhow::anyhow!("Elevenlabs TTS WebSocket error: {}", e))? {
164+
reqwest_websocket::Message::Text(text) => {
165+
let response: Response = serde_json::from_str(&text).map_err(|e| {
166+
anyhow::anyhow!(
167+
"Failed to parse Elevenlabs TTS response: {}, error: {}",
168+
text,
169+
e
170+
)
171+
})?;
172+
173+
if response.is_error() {
174+
return Err(anyhow::anyhow!(
175+
"Elevenlabs TTS error: {}",
176+
response.message
177+
));
178+
}
179+
180+
if response.alignment.is_some() && response.audio.is_some() {
181+
log::trace!(
182+
"Elevenlabs TTS audio chunk received, size: {}",
183+
response.audio.as_ref().unwrap().len()
184+
);
185+
return Ok(Some(response));
186+
}
187+
188+
if response.is_final() {
189+
log::trace!("TTS stream ended");
190+
return Ok(None);
191+
}
192+
}
193+
reqwest_websocket::Message::Binary(_) => {}
194+
msg => {
195+
if cfg!(debug_assertions) {
196+
log::debug!("Received non-text message: {:?}", msg);
197+
}
198+
}
199+
}
200+
}
201+
Ok(None)
202+
}
203+
}
204+
205+
#[tokio::test]
206+
async fn test_elevenlabs_tts() {
207+
env_logger::init();
208+
let token = std::env::var("ELEVENLABS_API_KEY").unwrap();
209+
let voice = std::env::var("ELEVENLABS_VOICE_ID").unwrap();
210+
211+
let mut tts = ElevenlabsTTS::new(token, voice, OutputFormat::Pcm16000)
212+
.await
213+
.expect("Failed to create ElevenlabsTTS");
214+
215+
tts.send_text("Hello, this is a test of Elevenlabs TTS.", true)
216+
.await
217+
.expect("Failed to send text");
218+
219+
tts.close_connection()
220+
.await
221+
.expect("Failed to close connection");
222+
223+
while let Ok(Some(resp)) = tts.next_audio_response().await {
224+
if let Some(audio) = resp.get_audio_bytes() {
225+
println!("Received audio chunk of size: {}", audio.len());
226+
}
227+
}
228+
229+
tts.close_connection()
230+
.await
231+
.expect("Failed to close connection");
232+
}

src/ai/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use rmcp::{
1010

1111
/// 阿里百炼
1212
pub mod bailian;
13+
pub mod elevenlabs;
1314
pub mod gemini;
1415
pub mod openai;
1516
pub mod store;

src/config.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ pub struct CosyVoiceTTS {
9191
pub version: CosyVoiceVersion,
9292
}
9393

94+
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
95+
pub struct ElevenlabsTTS {
96+
pub token: String,
97+
pub voice: String,
98+
}
99+
94100
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
95101
#[serde(tag = "platform")]
96102
pub enum TTSConfig {
@@ -99,6 +105,7 @@ pub enum TTSConfig {
99105
Groq(GroqTTS),
100106
StreamGSV(StreamGSV),
101107
CosyVoice(CosyVoiceTTS),
108+
Elevenlabs(ElevenlabsTTS),
102109
}
103110

104111
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]

src/services/realtime_ws.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use uuid::Uuid;
1515
use crate::{
1616
ai::{
1717
bailian::cosyvoice,
18+
elevenlabs,
1819
openai::realtime::*,
1920
vad::{VadRealtimeClient, VadRealtimeEvent},
2021
ChatSession,
@@ -133,6 +134,7 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) {
133134
TTSConfig::CosyVoice(cosyvoice) => {
134135
cosyvoice.speaker.clone().unwrap_or("default".to_string())
135136
}
137+
TTSConfig::Elevenlabs(elevenlabs_tts) => elevenlabs_tts.voice.clone(),
136138
};
137139

138140
session.config.turn_detection = Some(turn_detection.clone());
@@ -1225,5 +1227,40 @@ async fn tts_and_send(
12251227
}
12261228
Ok(())
12271229
}
1230+
crate::config::TTSConfig::Elevenlabs(elevenlabs_tts) => {
1231+
let mut tts = elevenlabs::tts::ElevenlabsTTS::new(
1232+
elevenlabs_tts.token.clone(),
1233+
elevenlabs_tts.voice.clone(),
1234+
elevenlabs::tts::OutputFormat::Pcm24000,
1235+
)
1236+
.await
1237+
.map_err(|e| anyhow::anyhow!("Elevenlabs TTS init error: {e}"))?;
1238+
1239+
tts.initialize_connection()
1240+
.await
1241+
.map_err(|e| anyhow::anyhow!("Elevenlabs TTS connection error: {e}"))?;
1242+
1243+
tts.send_text(&text, true)
1244+
.await
1245+
.map_err(|e| anyhow::anyhow!("Elevenlabs TTS send text error: {e}"))?;
1246+
1247+
tts.close_connection()
1248+
.await
1249+
.map_err(|e| anyhow::anyhow!("Elevenlabs TTS close connection error: {e}"))?;
1250+
1251+
while let Ok(Some(resp)) = tts.next_audio_response().await {
1252+
tx.send(ServerEvent::ResponseAudioDelta {
1253+
event_id: Uuid::new_v4().to_string(),
1254+
response_id: response_id.clone(),
1255+
item_id: item_id.clone().unwrap_or_default(),
1256+
output_index: 0,
1257+
content_index: 1,
1258+
delta: resp.audio.unwrap(),
1259+
})
1260+
.await
1261+
.map_err(|e| anyhow::anyhow!("send audio error: {e}"))?;
1262+
}
1263+
Ok(())
1264+
}
12281265
}
12291266
}

src/services/ws.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use futures_util::StreamExt;
1717
use crate::{
1818
ai::{
1919
bailian::cosyvoice,
20+
elevenlabs,
2021
gemini::{
2122
self,
2223
types::{Blob, GenerationConfig, RealtimeAudio},
@@ -328,6 +329,35 @@ async fn tts_and_send(pool: &WsSetting, tx: &mut WsTx, text: String) -> anyhow::
328329
}
329330
Ok(())
330331
}
332+
crate::config::TTSConfig::Elevenlabs(elevenlabs_tts) => {
333+
let mut tts = elevenlabs::tts::ElevenlabsTTS::new(
334+
elevenlabs_tts.token.clone(),
335+
elevenlabs_tts.voice.clone(),
336+
elevenlabs::tts::OutputFormat::Pcm16000,
337+
)
338+
.await
339+
.map_err(|e| anyhow::anyhow!("Elevenlabs TTS init error: {e}"))?;
340+
341+
tts.initialize_connection()
342+
.await
343+
.map_err(|e| anyhow::anyhow!("Elevenlabs TTS initialize connection error: {e}"))?;
344+
345+
tts.send_text(&text, true)
346+
.await
347+
.map_err(|e| anyhow::anyhow!("Elevenlabs TTS send text error: {e}"))?;
348+
349+
tts.close_connection()
350+
.await
351+
.map_err(|e| anyhow::anyhow!("Elevenlabs TTS close connection error: {e}"))?;
352+
353+
while let Ok(Some(resp)) = tts.next_audio_response().await {
354+
if let Some(audio) = resp.get_audio_bytes() {
355+
tx.send(WsCommand::Audio(audio))
356+
.map_err(|e| anyhow::anyhow!("send audio error: {e}"))?;
357+
}
358+
}
359+
Ok(())
360+
}
331361
}
332362
}
333363

@@ -1111,16 +1141,19 @@ async fn process_command(ws: &mut WebSocket, cmd: WsCommand) -> anyhow::Result<(
11111141
ws.send(Message::binary(action)).await?;
11121142
}
11131143
WsCommand::StartAudio(text) => {
1144+
log::trace!("StartAudio: {text:?}");
11141145
let start_audio = rmp_serde::to_vec(&crate::protocol::ServerEvent::StartAudio { text })
11151146
.expect("Failed to serialize StartAudio ServerEvent");
11161147
ws.send(Message::binary(start_audio)).await?;
11171148
}
11181149
WsCommand::Audio(data) => {
1150+
log::trace!("Audio chunk size: {}", data.len());
11191151
let start_audio = rmp_serde::to_vec(&crate::protocol::ServerEvent::AudioChunk { data })
11201152
.expect("Failed to serialize StartAudio ServerEvent");
11211153
ws.send(Message::binary(start_audio)).await?;
11221154
}
11231155
WsCommand::EndAudio => {
1156+
log::trace!("EndAudio");
11241157
let end_audio = rmp_serde::to_vec(&crate::protocol::ServerEvent::EndAudio)
11251158
.expect("Failed to serialize EndAudio ServerEvent");
11261159
ws.send(Message::binary(end_audio)).await?;

0 commit comments

Comments
 (0)