Skip to content

Commit 889690c

Browse files
committed
Optimize voice interruption
1 parent 254b5a9 commit 889690c

File tree

3 files changed

+91
-35
lines changed

3 files changed

+91
-35
lines changed

src/app.rs

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ pub enum Event {
1515
MicAudioChunk(Vec<i16>),
1616
MicAudioEnd,
1717
MicInterrupt(Vec<i16>),
18+
MicInterruptWaitTimeout,
1819
}
1920

2021
#[allow(dead_code)]
2122
impl Event {
23+
pub const IDLE: &'static str = "idle";
2224
pub const GAIA: &'static str = "gaia";
2325
pub const NO: &'static str = "no";
2426
pub const YES: &'static str = "yes";
@@ -41,6 +43,7 @@ async fn select_evt(
4143
server: &mut Server,
4244
notify: &tokio::sync::Notify,
4345
wait_notify: bool,
46+
timeout: std::time::Duration,
4447
) -> Option<Event> {
4548
let s_fut = async {
4649
if wait_notify {
@@ -50,38 +53,52 @@ async fn select_evt(
5053
server.recv().await
5154
}
5255
};
56+
let timeout_event = if timeout == INTERNAL_TIMEOUT {
57+
Some(Event::MicInterruptWaitTimeout)
58+
} else {
59+
Some(Event::Event(Event::IDLE))
60+
};
61+
62+
let timeout_f = tokio::time::sleep(timeout);
5363

5464
tokio::select! {
65+
_ = timeout_f => {
66+
log::info!("Event select timeout");
67+
timeout_event
68+
}
5569
Some(evt) = evt_rx.recv() => {
5670
match &evt {
5771
Event::Event(_) => {
58-
log::info!("Received event: {:?}", evt);
72+
log::info!("[Select] Received event: {:?}", evt);
5973
},
6074
Event::MicAudioEnd => {
61-
log::info!("Received MicAudioEnd");
75+
log::info!("[Select] Received MicAudioEnd");
6276
},
6377
Event::MicAudioChunk(data) => {
64-
log::debug!("Received MicAudioChunk with {} bytes", data.len());
78+
log::debug!("[Select] Received MicAudioChunk with {} bytes", data.len());
6579
},
6680
Event::ServerEvent(_) => {
67-
log::info!("Received ServerEvent: {:?}", evt);
81+
log::info!("[Select] Received ServerEvent: {:?}", evt);
6882
},
6983
Event::MicInterrupt(data) => {
70-
log::info!("Received MicInterrupt with {} samples", data.len());
84+
log::info!("[Select] Received MicInterrupt with {} samples", data.len());
85+
},
86+
Event::MicInterruptWaitTimeout => {
87+
log::info!("[Select] Received MicInterruptWaitTimeout");
7188
}
7289
}
7390
Some(evt)
7491
}
7592
Ok(msg) = s_fut => {
7693
match msg {
7794
Event::ServerEvent(ServerEvent::AudioChunk { .. })=>{
78-
log::debug!("Received AudioChunk");
95+
log::debug!("[Select] Received AudioChunk");
7996
}
8097
Event::ServerEvent(ServerEvent::HelloChunk { .. })=>{
81-
log::debug!("Received HelloChunk");
98+
log::debug!("[Select] Received HelloChunk");
8299
}
83100
_=> {
84-
log::debug!("Received message: {:?}", msg);
101+
log::debug!("[Select] Received message: {:?}", msg);
85102
}
86103
}
87104
Some(msg)
@@ -131,6 +148,8 @@ impl DownloadMetrics {
131148
}
132149

133150
const SPEED_LIMIT: f64 = 1.5;
151+
const INTERNAL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(1);
152+
const NORMAL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
134153

135154
pub async fn main_work<'d>(
136155
mut server: Server,
@@ -170,8 +189,10 @@ pub async fn main_work<'d>(
170189
let mut wait_notify = false;
171190
let mut init_hello = false;
172191
let mut allow_interrupt = false;
192+
let mut timeout = NORMAL_TIMEOUT;
173193

174-
while let Some(evt) = select_evt(&mut evt_rx, &mut server, &notify, wait_notify).await {
194+
while let Some(evt) = select_evt(&mut evt_rx, &mut server, &notify, wait_notify, timeout).await
195+
{
175196
match evt {
176197
Event::Event(Event::GAIA | Event::K0) => {
177198
log::info!("Received event: gaia");
@@ -235,6 +256,14 @@ pub async fn main_work<'d>(
235256
gui.display_flush().unwrap();
236257
}
237258
Event::Event(Event::YES | Event::K1) => {}
259+
Event::Event(Event::IDLE) => {
260+
if state == State::Listening {
261+
state = State::Idle;
262+
gui.state = "Idle".to_string();
263+
gui.display_flush().unwrap();
264+
server.close().await?;
265+
}
266+
}
238267
Event::Event(Event::NOTIFY) => {
239268
log::info!("Received notify event");
240269
wait_notify = false;
@@ -259,15 +288,7 @@ pub async fn main_work<'d>(
259288
log::info!("Submitted StartChat command");
260289
}
261290
start_submit = true;
262-
let audio_buffer_u8 = unsafe {
263-
std::slice::from_raw_parts(
264-
audio_buffer.as_ptr() as *const u8,
265-
audio_buffer.len() * 2,
266-
)
267-
};
268-
server
269-
.send_client_audio_chunk(bytes::Bytes::from(audio_buffer_u8))
270-
.await?;
291+
server.send_client_audio_chunk_i16(audio_buffer).await?;
271292
audio_buffer = Vec::with_capacity(8192);
272293
}
273294
}
@@ -279,15 +300,7 @@ pub async fn main_work<'d>(
279300
}
280301
if submit_audio > 0.5 {
281302
if !audio_buffer.is_empty() {
282-
let audio_buffer_u8 = unsafe {
283-
std::slice::from_raw_parts(
284-
audio_buffer.as_ptr() as *const u8,
285-
audio_buffer.len() * 2,
286-
)
287-
};
288-
server
289-
.send_client_audio_chunk(bytes::Bytes::from(audio_buffer_u8))
290-
.await?;
303+
server.send_client_audio_chunk_i16(audio_buffer).await?;
291304
audio_buffer = Vec::with_capacity(8192);
292305
}
293306
server
@@ -321,10 +334,11 @@ pub async fn main_work<'d>(
321334
continue;
322335
}
323336

324-
if (interrupt_data.len() as f32 / 16000.0) < 1.2 {
337+
let interrupt_audio_sec = interrupt_data.len() as f32 / 16000.0;
338+
if interrupt_audio_sec < 1.2 {
325339
log::info!(
326340
"Interrupt audio too short ({} s), ignoring",
327-
interrupt_data.len() as f32 / 16000.0
341+
interrupt_audio_sec
328342
);
329343
continue;
330344
}
@@ -340,12 +354,40 @@ pub async fn main_work<'d>(
340354
server.reconnect_with_retry(3).await?;
341355

342356
start_submit = false;
343-
submit_audio = 0.0;
344-
audio_buffer = Vec::with_capacity(8192);
357+
submit_audio = interrupt_audio_sec;
358+
audio_buffer = interrupt_data;
345359

346360
state = State::Listening;
347361
gui.state = "Listening...".to_string();
348362
gui.display_flush().unwrap();
363+
timeout = INTERNAL_TIMEOUT;
364+
}
365+
Event::MicInterruptWaitTimeout => {
366+
log::info!("Received MicInterruptWaitTimeout");
367+
timeout = NORMAL_TIMEOUT;
368+
if start_submit {
369+
log::info!("Already started submit, ignoring timeout");
370+
continue;
371+
}
372+
server
373+
.send_client_command(protocol::ClientCommand::StartChat)
374+
.await?;
375+
log::info!("Submitted StartChat command due to interrupt timeout");
376+
377+
server.send_client_audio_chunk_i16(audio_buffer).await?;
378+
server
379+
.send_client_command(protocol::ClientCommand::Submit)
380+
.await?;
381+
log::info!("Submitted audio");
382+
need_compute = metrics.is_timeout();
383+
384+
audio_buffer = Vec::with_capacity(8192);
385+
submit_audio = 0.0;
386+
start_submit = false;
387+
wait_notify = false;
388+
state = State::Waiting;
389+
gui.state = "Waiting...".to_string();
390+
gui.display_flush().unwrap();
349391
}
350392
Event::ServerEvent(ServerEvent::ASR { text }) => {
351393
log::info!("Received ASR: {:?}", text);

src/audio.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ fn afe_worker(afe_handle: Arc<AFE>, tx: MicTx, trigger_mean_value: f32) -> anyho
149149
crate::print_stack_high();
150150
let mut speech = false;
151151
let mut cache_buffer = Vec::with_capacity(16000);
152+
let mut vol = VOL_NUM.load(std::sync::atomic::Ordering::Relaxed) as f32 / 100.0;
153+
let mut trigger_mean_value_ = trigger_mean_value * vol;
152154
loop {
153155
let playing = PLAYING.load(std::sync::atomic::Ordering::Relaxed);
154156
let result = afe_handle.fetch();
@@ -163,6 +165,8 @@ fn afe_worker(afe_handle: Arc<AFE>, tx: MicTx, trigger_mean_value: f32) -> anyho
163165
if result.speech {
164166
if !speech {
165167
log::info!("Speech started");
168+
vol = VOL_NUM.load(std::sync::atomic::Ordering::Relaxed) as f32 / 100.0;
169+
trigger_mean_value_ = trigger_mean_value * vol;
166170
}
167171
speech = true;
168172
log::debug!("Speech detected, sending {} bytes", result.data.len());
@@ -184,7 +188,7 @@ fn afe_worker(afe_handle: Arc<AFE>, tx: MicTx, trigger_mean_value: f32) -> anyho
184188
.map(|x| x.abs() as f32 / len)
185189
.sum::<f32>();
186190

187-
if mean > trigger_mean_value || !playing {
191+
if mean > trigger_mean_value_ || !playing {
188192
log::info!("Sending cached {} s, mean:{}", len / 16000.0, mean);
189193
tx.blocking_send(crate::app::Event::MicInterrupt(cache_buffer))
190194
.map_err(|_| anyhow::anyhow!("Failed to send data"))?;
@@ -194,7 +198,7 @@ fn afe_worker(afe_handle: Arc<AFE>, tx: MicTx, trigger_mean_value: f32) -> anyho
194198
"Dropping cached {} s, mean:{} below trigger {}",
195199
len / 16000.0,
196200
mean,
197-
trigger_mean_value
201+
trigger_mean_value_
198202
);
199203
cache_buffer.clear();
200204
}
@@ -386,6 +390,7 @@ impl SendBuffer {
386390
}
387391

388392
static PLAYING: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
393+
static VOL_NUM: std::sync::atomic::AtomicU8 = std::sync::atomic::AtomicU8::new(50);
389394

390395
fn audio_task_run(
391396
rx: &mut tokio::sync::mpsc::UnboundedReceiver<AudioEvent>,
@@ -407,7 +412,7 @@ fn audio_task_run(
407412
let mut allow_speech = false;
408413
let mut speech = false;
409414

410-
send_buffer.volume = 0.2;
415+
send_buffer.volume = 0.5;
411416

412417
loop {
413418
if let Ok(event) = rx.try_recv() {
@@ -442,6 +447,7 @@ fn audio_task_run(
442447
}
443448
AudioEvent::VolSet(vol) => {
444449
send_buffer.volume = vol;
450+
VOL_NUM.store((vol * 100.0) as u8, std::sync::atomic::Ordering::Relaxed);
445451
}
446452
}
447453
}
@@ -657,7 +663,7 @@ impl BoardsAudioWorker {
657663

658664
let _afe_r = std::thread::Builder::new()
659665
.stack_size(8 * 1024)
660-
.spawn(|| afe_worker(afe_handle_, tx, 300.0))?;
666+
.spawn(|| afe_worker(afe_handle_, tx, 200.0))?;
661667

662668
audio_task_run(&mut rx, &mut fn_read, &mut fn_write, &afe_handle)
663669
}

src/ws.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ impl Server {
8888
self.send(msg).await
8989
}
9090

91+
pub async fn send_client_audio_chunk_i16(&mut self, chunk: Vec<i16>) -> anyhow::Result<()> {
92+
let audio_buffer_u8 =
93+
unsafe { std::slice::from_raw_parts(chunk.as_ptr() as *const u8, chunk.len() * 2) };
94+
95+
self.send_client_audio_chunk(bytes::Bytes::from(audio_buffer_u8))
96+
.await
97+
}
98+
9199
pub async fn recv(&mut self) -> anyhow::Result<Event> {
92100
let msg = self
93101
.ws

0 commit comments

Comments
 (0)