Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions examples/cpal_mycroft_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use cpal::SampleFormat;
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use log::{debug, info, warn};
use oww_rs::config::SpeechUnlockType::OpenWakeWordHeyMycroft;
use oww_rs::mic::converters::i16_to_f32;
use oww_rs::mic::mic_config::find_best_config;
use oww_rs::mic::process_audio::resample_into_chunks;
use oww_rs::mic::resampler::make_resampler;
use oww_rs::oww::{OWW_MODEL_CHUNK_SIZE, OwwModel};
use std::process::exit;
use std::sync::{Arc, Mutex};

/// Example using the Hey Mycroft model with cpal microphone input
fn main() -> Result<(), anyhow::Error> {
env_logger::Builder::new()
.filter_level(log::LevelFilter::Info)
.init();
// Initialize CPAL
let host = cpal::default_host();
let device = host
.default_input_device()
.expect("No input device available");
match device.name() {
Ok(name) => {
debug!("Input device: {}", name);
}
Err(e) => {
warn!("Couldn't get mic device: {:?}", e);
exit(1);
}
}

let (mut config, sample_format) = find_best_config(&device).unwrap();
// Prefer 48000 Hz for best real-time performance if available
config.sample_rate = cpal::SampleRate(48000);
info!("Selected input config (forced 48kHz): {:?}", config);

// Create a buffer to store audio data
let buffer: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(vec![]));
let buffer_clone = buffer.clone();

// Store the original sample rate and channels
let original_sample_rate = config.sample_rate.0 as f32;
println!("{:?}", original_sample_rate);
let channels = 1;

// Create the input stream
let err_fn = |err| warn!("An error occurred on the input stream: {}", err);

let mut model = OwwModel::new(OpenWakeWordHeyMycroft, 0.5).unwrap();

let mut resampler = make_resampler(
original_sample_rate as _,
OWW_MODEL_CHUNK_SIZE as _,
channels,
)
.unwrap();

let stream = match sample_format {
SampleFormat::F32 => device.build_input_stream(
&config,
move |data: &[f32], _: &_| {
let chunks = resample_into_chunks(data, &buffer_clone, channels, &mut resampler);
for chunk in chunks {
let d = model.detection(chunk.data_f32.first().clone());
if d.detected {
println!("Result f32 {:?}", d);
} else {
println!("Anything else1 {:?}", d);
}
}
},
err_fn,
None,
)?,
SampleFormat::I16 => device.build_input_stream(
&config,
move |data: &[i16], _: &_| {
// Convert i16 to f32
let samples: Vec<f32> = data.iter().map(i16_to_f32).collect();
let chunks =
resample_into_chunks(&samples, &buffer_clone, channels, &mut resampler);
for chunk in chunks {
let d = model.detection(chunk.data_f32.first().clone());
if d.detected {
println!("Result i16 {:?}", d);
} else {
println!("Anything else2");
}
}
},
err_fn,
None,
)?,
SampleFormat::U16 => device.build_input_stream(
&config,
move |_data: &[u16], _: &_| {
panic!("U16 format is not supported");
},
err_fn,
None,
)?,
_ => return Err(anyhow::anyhow!("Unsupported sample format")),
};

stream.play()?;

println!("Recording and resampling to 16000 Hz... Press Enter to stop.");
let mut input = String::new();
std::io::stdin().read_line(&mut input)?;

Ok(())
}
Binary file added speech_models/hey_mycroft_v0.1.onnx
Binary file not shown.
4 changes: 2 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

#[derive(Debug, Clone, Eq, PartialEq)]
pub enum SpeechUnlockType {
OpenWakeWordAlexa,
OpenWakeWordHeyMycroft,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone)]
pub struct UnlockConfig {
pub unlock_type: SpeechUnlockType,
pub detection_threshold: f32,
Expand Down
9 changes: 6 additions & 3 deletions src/info.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::config::SpeechUnlockType::{OpenWakeWordAlexa};
use crate::config::SpeechUnlockType;
use crate::config::{UnlockConfig};
use log::{debug};

pub const OWW_CZ_NAME_AHOJ_HUGO: &str = "Český - Ahoj Hugo";
pub const OWW_CZ_NAME_ALEXA: &str = "Český - Alexa";
pub const OWW_CZ_NAME_HEY_MYCROFT: &str = "Český - Hey Mycroft";

#[derive(Debug)]
pub struct LanguageModel {
Expand All @@ -19,15 +20,17 @@ impl LanguageModel {

pub fn get_trigger_phases(unlock_config: &UnlockConfig) -> Vec<String> {
match unlock_config.unlock_type {
OpenWakeWordAlexa => vec!["Alexa".to_string()],
SpeechUnlockType::OpenWakeWordAlexa => vec!["Alexa".to_string()],
SpeechUnlockType::OpenWakeWordHeyMycroft => vec!["Hey Mycroft".to_string()],
}
}

pub fn set_unlock_model(language_model: &LanguageModel) -> Option<UnlockConfig> {
let unlock_config = UnlockConfig::default(); // load_config(unlock_config_file.clone());

let model_type = match language_model.name.as_str() {
OWW_CZ_NAME_ALEXA => OpenWakeWordAlexa,
OWW_CZ_NAME_ALEXA => SpeechUnlockType::OpenWakeWordAlexa,
OWW_CZ_NAME_HEY_MYCROFT => SpeechUnlockType::OpenWakeWordHeyMycroft,
_ => {
panic!("Unexpected language model {:?}", language_model);
}
Expand Down
8 changes: 6 additions & 2 deletions src/model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::config::SpeechUnlockType::{OpenWakeWordAlexa};
use crate::config::{SpeechUnlockType, UnlockConfig};
use crate::oww::OwwModel;

Expand All @@ -10,7 +9,12 @@ pub(crate) trait Model: Send + Sync {

pub fn new_model(config: UnlockConfig) -> Result<Box<dyn Model>, String> {
match config.unlock_type {
SpeechUnlockType::OpenWakeWordAlexa => new_oww_model(config, OpenWakeWordAlexa),
SpeechUnlockType::OpenWakeWordAlexa => {
new_oww_model(config, SpeechUnlockType::OpenWakeWordAlexa)
}
SpeechUnlockType::OpenWakeWordHeyMycroft => {
new_oww_model(config, SpeechUnlockType::OpenWakeWordHeyMycroft)
}
}
}

Expand Down
12 changes: 11 additions & 1 deletion src/oww/oww_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,21 @@ impl OwwModel {

pub fn new(model_type: SpeechUnlockType, threshold: f32) -> Result<OwwModel, String> {
let model_data = match model_type {
SpeechUnlockType::OpenWakeWordAlexa => &crate::oww::oww_model::SpeechModels::get("alexa.onnx").unwrap().data,
SpeechUnlockType::OpenWakeWordAlexa => {
&crate::oww::oww_model::SpeechModels::get("alexa.onnx")
.unwrap()
.data
}
SpeechUnlockType::OpenWakeWordHeyMycroft => {
&crate::oww::oww_model::SpeechModels::get("hey_mycroft_v0.1.onnx")
.unwrap()
.data
}
};

let model_unlock_word = match model_type {
SpeechUnlockType::OpenWakeWordAlexa => "Alexa".to_string(),
SpeechUnlockType::OpenWakeWordHeyMycroft => "Hey Mycroft".to_string(),
};
let detections_buffer = CircularBuffer::<DETECTION_BUFFER_SIZE, f32>::new();

Expand Down