Skip to content

Commit d865824

Browse files
authored
Merge pull request #22 from second-state/feat/responses-api
Feat/responses api
2 parents ec02a6f + 5da7f68 commit d865824

File tree

12 files changed

+2092
-303
lines changed

12 files changed

+2092
-303
lines changed

.github/workflows/rust_ci.yml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
name: Continuous Integration
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
paths-ignore:
8+
- "**/README.md"
9+
- ".github/workflows/**"
10+
- "setup/**"
11+
- "workshops/**"
12+
- "*.sh"
13+
pull_request:
14+
workflow_dispatch:
15+
16+
jobs:
17+
build-x86_64:
18+
name: Build (x86_64)
19+
runs-on: ubuntu-latest
20+
steps:
21+
- name: Install base dependencies
22+
shell: bash
23+
run: |
24+
set -euo pipefail
25+
export DEBIAN_FRONTEND=noninteractive
26+
sudo apt-get update
27+
sudo apt-get install -y --no-install-recommends \
28+
ca-certificates \
29+
pkg-config \
30+
libssl-dev \
31+
build-essential \
32+
git \
33+
curl \
34+
unzip
35+
36+
- name: Check out repository
37+
uses: actions/checkout@v4
38+
39+
- name: Set up Rust toolchain
40+
uses: dtolnay/rust-toolchain@stable
41+
42+
- name: Build binary
43+
run: cargo build --release

src/ai/gemini/mod.rs

Lines changed: 82 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ mod test {
100100
use super::types;
101101
use super::*;
102102

103-
// cargo test --package esp_assistant --bin esp_assistant -- ai::gemini::test::test_live_client --exact --show-output
103+
// cargo test --package echokit_server --bin echokit_server -- ai::gemini::test::test_live_client --exact --show-output
104104
#[tokio::test]
105105
async fn test_live_client() -> anyhow::Result<()> {
106106
env_logger::init();
@@ -121,24 +121,96 @@ mod test {
121121
)],
122122
}),
123123
input_audio_transcription: Some(types::AudioTranscriptionConfig {}),
124+
output_audio_transcription: None,
125+
realtime_input_config: Some(types::RealtimeInputConfig {
126+
automatic_activity_detection: Some(types::AutomaticActivityDetectionConfig {
127+
disabled: true,
128+
}),
129+
}),
124130
};
125131
client.setup(setup).await?;
126132
log::info!("Setup completed");
127133

128134
// let submit_data = std::fs::read("sample.pcm").unwrap();
129-
let data = std::fs::read("asr.fc012ccfcd71.wav").unwrap();
135+
let data = std::fs::read("tmp.wav").unwrap();
130136
let mut reader = wav_io::reader::Reader::from_vec(data).unwrap();
131137
let header = reader.read_header().unwrap();
132138
log::info!("WAV Header: {:?}", header);
133-
let x = reader.get_samples_f32().unwrap();
139+
let x = crate::util::get_samples_f32(&mut reader).unwrap();
134140
let x = wav_io::resample::linear(x, 1, header.sample_rate, 16000);
135-
let data = wav_io::convert_samples_f32_to_i16(&x);
141+
let submit_data = crate::util::convert_samples_f32_to_i16_bytes(&x);
142+
143+
// let input = types::RealtimeInput {
144+
// audio: None,
145+
// text: Some("你是谁".to_string()),
146+
// };
147+
client
148+
.send_realtime_input(types::RealtimeInput::ActivityStart {})
149+
.await?;
136150

137-
let mut submit_data = Vec::with_capacity(data.len() * 2);
138-
for sample in data {
139-
submit_data.extend_from_slice(&sample.to_le_bytes());
151+
let input = types::RealtimeInput::Audio(types::RealtimeAudio {
152+
data: types::Blob::new(submit_data),
153+
mime_type: "audio/pcm;rate=16000".to_string(),
154+
});
155+
log::info!("Sending realtime input");
156+
client.send_realtime_input(input).await?;
157+
log::info!("Sent realtime input");
158+
// client
159+
// .send_realtime_input(types::RealtimeInput::AudioStreamEnd(true))
160+
// .await?;
161+
client
162+
.send_realtime_input(types::RealtimeInput::ActivityEnd {})
163+
.await?;
164+
165+
log::info!("Sent realtime input");
166+
loop {
167+
let content = client.receive().await?;
168+
log::info!("Received content: {:?}", content);
169+
if let types::ServerContent::TurnComplete(true) = content {
170+
log::info!("Generation complete");
171+
break;
172+
}
140173
}
141174

175+
Ok(())
176+
}
177+
178+
// cargo test --package echokit_server --bin echokit_server -- ai::gemini::test::test_live_client_audio --exact --show-output
179+
#[tokio::test]
180+
async fn test_live_client_audio() -> anyhow::Result<()> {
181+
env_logger::init();
182+
let api_key = std::env::var("GEMINI_API_KEY").unwrap();
183+
log::info!("api_key={api_key}");
184+
let mut client = LiveClient::connect(&api_key).await?;
185+
log::info!("Connected to Gemini Live Client");
186+
187+
let mut cfg = types::GenerationConfig::default();
188+
cfg.response_modalities = Some(vec![types::Modality::AUDIO]);
189+
190+
let setup = types::Setup {
191+
model: "models/gemini-2.0-flash-exp".to_string(),
192+
generation_config: Some(cfg),
193+
system_instruction: Some(types::Content {
194+
parts: vec![types::Parts::Text(
195+
"You are a helpful assistant and answer in a friendly tone.".to_string(),
196+
)],
197+
}),
198+
input_audio_transcription: Some(types::AudioTranscriptionConfig {}),
199+
output_audio_transcription: Some(types::AudioTranscriptionConfig {}),
200+
realtime_input_config: None,
201+
};
202+
client.setup(setup).await?;
203+
log::info!("Setup completed");
204+
205+
// let submit_data = std::fs::read("sample.pcm").unwrap();
206+
let data = std::fs::read("tmp.wav").unwrap();
207+
let mut reader = wav_io::reader::Reader::from_vec(data).unwrap();
208+
let header = reader.read_header().unwrap();
209+
log::info!("WAV Header: {:?}", header);
210+
let x = crate::util::get_samples_f32(&mut reader).unwrap();
211+
let x = wav_io::resample::linear(x, 1, header.sample_rate, 16000);
212+
let submit_data = crate::util::convert_samples_f32_to_i16_bytes(&x);
213+
142214
// let input = types::RealtimeInput {
143215
// audio: None,
144216
// text: Some("你是谁".to_string()),
@@ -147,12 +219,14 @@ mod test {
147219
data: types::Blob::new(submit_data),
148220
mime_type: "audio/pcm;rate=16000".to_string(),
149221
});
222+
log::info!("Sending realtime input");
150223
client.send_realtime_input(input).await?;
224+
log::info!("Sent realtime input");
151225
client
152226
.send_realtime_input(types::RealtimeInput::AudioStreamEnd(true))
153227
.await?;
154228

155-
log::info!("Sent realtime input");
229+
log::info!("Sent realtime AudioStreamEnd");
156230
loop {
157231
let content = client.receive().await?;
158232
log::info!("Received content: {:?}", content);

src/ai/gemini/types.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ pub enum Parts {
2323
#[serde(rename_all = "camelCase")]
2424
pub struct AudioTranscriptionConfig {}
2525

26+
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
27+
#[serde(rename_all = "camelCase")]
28+
pub struct AutomaticActivityDetectionConfig {
29+
pub disabled: bool,
30+
}
31+
32+
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
33+
#[serde(rename_all = "camelCase")]
34+
pub struct RealtimeInputConfig {
35+
#[serde(skip_serializing_if = "Option::is_none")]
36+
pub automatic_activity_detection: Option<AutomaticActivityDetectionConfig>,
37+
}
38+
2639
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
2740
#[serde(rename_all = "camelCase")]
2841
pub struct Setup {
@@ -33,6 +46,10 @@ pub struct Setup {
3346
pub system_instruction: Option<Content>,
3447
#[serde(skip_serializing_if = "Option::is_none")]
3548
pub input_audio_transcription: Option<AudioTranscriptionConfig>,
49+
#[serde(skip_serializing_if = "Option::is_none")]
50+
pub output_audio_transcription: Option<AudioTranscriptionConfig>,
51+
#[serde(skip_serializing_if = "Option::is_none")]
52+
pub realtime_input_config: Option<RealtimeInputConfig>,
3653
}
3754

3855
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
@@ -70,7 +87,7 @@ impl Default for Modality {
7087
}
7188
}
7289

73-
#[derive(Debug, Clone)]
90+
#[derive(Clone)]
7491
pub struct Blob(Vec<u8>);
7592
impl Blob {
7693
pub fn new(data: Vec<u8>) -> Self {
@@ -81,6 +98,13 @@ impl Blob {
8198
self.0
8299
}
83100
}
101+
102+
impl std::fmt::Debug for Blob {
103+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104+
write!(f, "Blob(len={})", self.0.len())
105+
}
106+
}
107+
84108
impl serde::Serialize for Blob {
85109
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
86110
where
@@ -111,6 +135,10 @@ pub enum RealtimeInput {
111135
Text(String),
112136
#[serde(rename = "audioStreamEnd")]
113137
AudioStreamEnd(bool),
138+
#[serde(rename = "activityEnd")]
139+
ActivityEnd {},
140+
#[serde(rename = "activityStart")]
141+
ActivityStart {},
114142
}
115143

116144
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
@@ -124,6 +152,8 @@ pub struct RealtimeAudio {
124152
pub enum ServerContent {
125153
#[serde(rename = "inputTranscription")]
126154
InputTranscription { text: String },
155+
#[serde(rename = "outputTranscription")]
156+
OutputTranscription { text: String },
127157
#[serde(rename = "modelTurn")]
128158
ModelTurn(Content),
129159
#[serde(rename = "generationComplete")]
@@ -178,6 +208,8 @@ mod test {
178208
)],
179209
}),
180210
input_audio_transcription: None,
211+
output_audio_transcription: None,
212+
realtime_input_config: None,
181213
};
182214
let serialized = serde_json::to_string(&setup).unwrap();
183215
assert!(serialized.contains("gemini-2.0-flash-live-001"));

0 commit comments

Comments
 (0)