Skip to content

Commit d87e46c

Browse files
committed
feat: Update gemini to v2
1 parent 110ac15 commit d87e46c

File tree

6 files changed

+689
-23
lines changed

6 files changed

+689
-23
lines changed

src/ai/gemini/mod.rs

Lines changed: 82 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ mod test {
100100
use super::types;
101101
use super::*;
102102

103-
// cargo test --package esp_assistant --bin esp_assistant -- ai::gemini::test::test_live_client --exact --show-output
103+
// cargo test --package echokit_server --bin echokit_server -- ai::gemini::test::test_live_client --exact --show-output
104104
#[tokio::test]
105105
async fn test_live_client() -> anyhow::Result<()> {
106106
env_logger::init();
@@ -121,24 +121,96 @@ mod test {
121121
)],
122122
}),
123123
input_audio_transcription: Some(types::AudioTranscriptionConfig {}),
124+
output_audio_transcription: None,
125+
realtime_input_config: Some(types::RealtimeInputConfig {
126+
automatic_activity_detection: Some(types::AutomaticActivityDetectionConfig {
127+
disabled: true,
128+
}),
129+
}),
124130
};
125131
client.setup(setup).await?;
126132
log::info!("Setup completed");
127133

128134
// let submit_data = std::fs::read("sample.pcm").unwrap();
129-
let data = std::fs::read("asr.fc012ccfcd71.wav").unwrap();
135+
let data = std::fs::read("tmp.wav").unwrap();
130136
let mut reader = wav_io::reader::Reader::from_vec(data).unwrap();
131137
let header = reader.read_header().unwrap();
132138
log::info!("WAV Header: {:?}", header);
133-
let x = reader.get_samples_f32().unwrap();
139+
let x = crate::util::get_samples_f32(&mut reader).unwrap();
134140
let x = wav_io::resample::linear(x, 1, header.sample_rate, 16000);
135-
let data = wav_io::convert_samples_f32_to_i16(&x);
141+
let submit_data = crate::util::convert_samples_f32_to_i16_bytes(&x);
142+
143+
// let input = types::RealtimeInput {
144+
// audio: None,
145+
// text: Some("你是谁".to_string()),
146+
// };
147+
client
148+
.send_realtime_input(types::RealtimeInput::ActivityStart {})
149+
.await?;
136150

137-
let mut submit_data = Vec::with_capacity(data.len() * 2);
138-
for sample in data {
139-
submit_data.extend_from_slice(&sample.to_le_bytes());
151+
let input = types::RealtimeInput::Audio(types::RealtimeAudio {
152+
data: types::Blob::new(submit_data),
153+
mime_type: "audio/pcm;rate=16000".to_string(),
154+
});
155+
log::info!("Sending realtime input");
156+
client.send_realtime_input(input).await?;
157+
log::info!("Sent realtime input");
158+
// client
159+
// .send_realtime_input(types::RealtimeInput::AudioStreamEnd(true))
160+
// .await?;
161+
client
162+
.send_realtime_input(types::RealtimeInput::ActivityEnd {})
163+
.await?;
164+
165+
log::info!("Sent realtime input");
166+
loop {
167+
let content = client.receive().await?;
168+
log::info!("Received content: {:?}", content);
169+
if let types::ServerContent::TurnComplete(true) = content {
170+
log::info!("Generation complete");
171+
break;
172+
}
140173
}
141174

175+
Ok(())
176+
}
177+
178+
// cargo test --package echokit_server --bin echokit_server -- ai::gemini::test::test_live_client_audio --exact --show-output
179+
#[tokio::test]
180+
async fn test_live_client_audio() -> anyhow::Result<()> {
181+
env_logger::init();
182+
let api_key = std::env::var("GEMINI_API_KEY").unwrap();
183+
log::info!("api_key={api_key}");
184+
let mut client = LiveClient::connect(&api_key).await?;
185+
log::info!("Connected to Gemini Live Client");
186+
187+
let mut cfg = types::GenerationConfig::default();
188+
cfg.response_modalities = Some(vec![types::Modality::AUDIO]);
189+
190+
let setup = types::Setup {
191+
model: "models/gemini-2.0-flash-exp".to_string(),
192+
generation_config: Some(cfg),
193+
system_instruction: Some(types::Content {
194+
parts: vec![types::Parts::Text(
195+
"You are a helpful assistant and answer in a friendly tone.".to_string(),
196+
)],
197+
}),
198+
input_audio_transcription: Some(types::AudioTranscriptionConfig {}),
199+
output_audio_transcription: Some(types::AudioTranscriptionConfig {}),
200+
realtime_input_config: None,
201+
};
202+
client.setup(setup).await?;
203+
log::info!("Setup completed");
204+
205+
// let submit_data = std::fs::read("sample.pcm").unwrap();
206+
let data = std::fs::read("tmp.wav").unwrap();
207+
let mut reader = wav_io::reader::Reader::from_vec(data).unwrap();
208+
let header = reader.read_header().unwrap();
209+
log::info!("WAV Header: {:?}", header);
210+
let x = crate::util::get_samples_f32(&mut reader).unwrap();
211+
let x = wav_io::resample::linear(x, 1, header.sample_rate, 16000);
212+
let submit_data = crate::util::convert_samples_f32_to_i16_bytes(&x);
213+
142214
// let input = types::RealtimeInput {
143215
// audio: None,
144216
// text: Some("你是谁".to_string()),
@@ -147,12 +219,14 @@ mod test {
147219
data: types::Blob::new(submit_data),
148220
mime_type: "audio/pcm;rate=16000".to_string(),
149221
});
222+
log::info!("Sending realtime input");
150223
client.send_realtime_input(input).await?;
224+
log::info!("Sent realtime input");
151225
client
152226
.send_realtime_input(types::RealtimeInput::AudioStreamEnd(true))
153227
.await?;
154228

155-
log::info!("Sent realtime input");
229+
log::info!("Sent realtime AudioStreamEnd");
156230
loop {
157231
let content = client.receive().await?;
158232
log::info!("Received content: {:?}", content);

src/ai/gemini/types.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ pub enum Parts {
2323
#[serde(rename_all = "camelCase")]
2424
pub struct AudioTranscriptionConfig {}
2525

26+
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
27+
#[serde(rename_all = "camelCase")]
28+
pub struct AutomaticActivityDetectionConfig {
29+
pub disabled: bool,
30+
}
31+
32+
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
33+
#[serde(rename_all = "camelCase")]
34+
pub struct RealtimeInputConfig {
35+
#[serde(skip_serializing_if = "Option::is_none")]
36+
pub automatic_activity_detection: Option<AutomaticActivityDetectionConfig>,
37+
}
38+
2639
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
2740
#[serde(rename_all = "camelCase")]
2841
pub struct Setup {
@@ -33,6 +46,10 @@ pub struct Setup {
3346
pub system_instruction: Option<Content>,
3447
#[serde(skip_serializing_if = "Option::is_none")]
3548
pub input_audio_transcription: Option<AudioTranscriptionConfig>,
49+
#[serde(skip_serializing_if = "Option::is_none")]
50+
pub output_audio_transcription: Option<AudioTranscriptionConfig>,
51+
#[serde(skip_serializing_if = "Option::is_none")]
52+
pub realtime_input_config: Option<RealtimeInputConfig>,
3653
}
3754

3855
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
@@ -70,7 +87,7 @@ impl Default for Modality {
7087
}
7188
}
7289

73-
#[derive(Debug, Clone)]
90+
#[derive(Clone)]
7491
pub struct Blob(Vec<u8>);
7592
impl Blob {
7693
pub fn new(data: Vec<u8>) -> Self {
@@ -81,6 +98,13 @@ impl Blob {
8198
self.0
8299
}
83100
}
101+
102+
impl std::fmt::Debug for Blob {
103+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104+
write!(f, "Blob(len={})", self.0.len())
105+
}
106+
}
107+
84108
impl serde::Serialize for Blob {
85109
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
86110
where
@@ -111,6 +135,10 @@ pub enum RealtimeInput {
111135
Text(String),
112136
#[serde(rename = "audioStreamEnd")]
113137
AudioStreamEnd(bool),
138+
#[serde(rename = "activityEnd")]
139+
ActivityEnd {},
140+
#[serde(rename = "activityStart")]
141+
ActivityStart {},
114142
}
115143

116144
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
@@ -124,6 +152,8 @@ pub struct RealtimeAudio {
124152
pub enum ServerContent {
125153
#[serde(rename = "inputTranscription")]
126154
InputTranscription { text: String },
155+
#[serde(rename = "outputTranscription")]
156+
OutputTranscription { text: String },
127157
#[serde(rename = "modelTurn")]
128158
ModelTurn(Content),
129159
#[serde(rename = "generationComplete")]
@@ -178,6 +208,8 @@ mod test {
178208
)],
179209
}),
180210
input_audio_transcription: None,
211+
output_audio_transcription: None,
212+
realtime_input_config: None,
181213
};
182214
let serialized = serde_json::to_string(&setup).unwrap();
183215
assert!(serialized.contains("gemini-2.0-flash-live-001"));

src/main.rs

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,10 @@ async fn routes(
118118
.layer(axum::Extension(ws_setting.clone()))
119119
.layer(axum::Extension(record_config.clone()));
120120

121+
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
122+
121123
match config.config {
122124
config::AIConfig::Stable { llm, tts, asr } => {
123-
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
124125
// let tool_set = tool_set;
125126
tokio::spawn(async move {
126127
if let Err(e) = crate::services::ws::stable::run_session_manager(
@@ -131,21 +132,44 @@ async fn routes(
131132
log::error!("Stable session manager exited with error: {}", e);
132133
}
133134
});
134-
135-
router = router
136-
.route("/ws/{id}", any(services::v2_mixed_handler))
137-
.route("/v2/stable_ws/{id}", any(services::ws::stable::ws_handler))
138-
.layer(axum::Extension(Arc::new(
139-
services::ws::stable::StableWsSetting {
140-
sessions: tx,
141-
hello_wav,
142-
},
143-
)))
144-
.layer(axum::Extension(record_config.clone()));
145135
}
146-
_ => {}
136+
config::AIConfig::GeminiAndTTS { gemini, tts } => {
137+
tokio::spawn(async move {
138+
if let Err(e) = crate::services::ws::stable::gemini::run_session_manager(
139+
&gemini,
140+
Some(&tts),
141+
rx,
142+
)
143+
.await
144+
{
145+
log::error!("Gemini session manager exited with error: {}", e);
146+
}
147+
});
148+
}
149+
config::AIConfig::Gemini { gemini } => {
150+
// let tool_set = tool_set;
151+
tokio::spawn(async move {
152+
if let Err(e) =
153+
crate::services::ws::stable::gemini::run_session_manager(&gemini, None, rx)
154+
.await
155+
{
156+
log::error!("Gemini session manager exited with error: {}", e);
157+
}
158+
});
159+
}
147160
}
148161

162+
router = router
163+
.route("/ws/{id}", any(services::v2_mixed_handler))
164+
.route("/v2/stable_ws/{id}", any(services::ws::stable::ws_handler))
165+
.layer(axum::Extension(Arc::new(
166+
services::ws::stable::StableWsSetting {
167+
sessions: tx,
168+
hello_wav,
169+
},
170+
)))
171+
.layer(axum::Extension(record_config.clone()));
172+
149173
if let Some(real_config) = real_config {
150174
log::info!(
151175
"Adding realtime WebSocket handler with config: {:?}",

src/services/ws.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,9 @@ async fn submit_to_gemini_and_tts(
756756
tx.send(WsCommand::AsrResult(vec![asr_text.clone()]))?;
757757
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
758758
}
759+
gemini::types::ServerContent::OutputTranscription { text } => {
760+
todo!("Handle output transcription: {text}");
761+
}
759762
gemini::types::ServerContent::Timeout => {}
760763
gemini::types::ServerContent::GoAway {} => {
761764
log::warn!("gemini GoAway");
@@ -886,6 +889,9 @@ async fn submit_to_gemini(
886889
// If the input transcription is not empty, we can use it as the ASR result
887890
tx.send(WsCommand::AsrResult(vec![message]))?;
888891
}
892+
gemini::types::ServerContent::OutputTranscription { text } => {
893+
todo!()
894+
}
889895
gemini::types::ServerContent::Timeout => {
890896
log::warn!("gemini timeout");
891897
tx.send(WsCommand::AsrResult(vec![]))?;
@@ -1027,6 +1033,8 @@ async fn handle_audio(
10271033
generation_config: Some(generation_config),
10281034
system_instruction,
10291035
input_audio_transcription: Some(gemini::types::AudioTranscriptionConfig {}),
1036+
output_audio_transcription: None,
1037+
realtime_input_config: None,
10301038
};
10311039

10321040
submit_to_gemini_and_tts(&pool, &mut client, &mut ws_tx, setup, rx).await?;
@@ -1055,6 +1063,8 @@ async fn handle_audio(
10551063
generation_config: Some(generation_config),
10561064
system_instruction,
10571065
input_audio_transcription: Some(gemini::types::AudioTranscriptionConfig {}),
1066+
output_audio_transcription: None,
1067+
realtime_input_config: None,
10581068
};
10591069

10601070
client.setup(setup).await?;

0 commit comments

Comments
 (0)