Skip to content

Commit fe2065b

Browse files
committed
chore: improve BandpassFilter
1 parent 0cae7c4 commit fe2065b

File tree

2 files changed

+269
-69
lines changed

2 files changed

+269
-69
lines changed

rust/src/preprocessing/filters.rs

Lines changed: 132 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,26 @@ pub struct BandpassFilter {
204204
center_freq: f32,
205205
bandwidth: f32,
206206
sample_rate: u32,
207-
order: usize, // Filter order (number of biquad sections = order/2)
208-
a_coeffs: Vec<f32>, // Feedback coefficients
209-
b_coeffs: Vec<f32>, // Feedforward coefficients
207+
order: usize, // Filter order (must be even)
208+
biquad_coeffs: Vec<BiquadCoeffs>, // Coefficients for each biquad section
209+
biquad_states: std::sync::RwLock<Vec<BiquadState>>, // State variables for each biquad section
210+
}
211+
212+
/// Coefficients for a single biquad section
213+
#[derive(Clone, Debug)]
214+
struct BiquadCoeffs {
215+
b0: f32,
216+
b1: f32,
217+
b2: f32, // Feedforward coefficients
218+
a1: f32,
219+
a2: f32, // Feedback coefficients (a0 normalized to 1)
220+
}
221+
222+
/// State variables for a single biquad section (Direct Form II Transposed)
223+
#[derive(Clone, Debug)]
224+
struct BiquadState {
225+
z1: f32, // First delay element
226+
z2: f32, // Second delay element
210227
}
211228

212229
impl BandpassFilter {
@@ -238,21 +255,49 @@ impl BandpassFilter {
238255
/// ```
239256
pub fn new(center_freq: f32, bandwidth: f32) -> Self {
240257
let sample_rate = 48000; // Default sample rate
241-
let order = 2; // Default 4th order filter (2 biquad sections)
258+
let order = 2; // Default 2nd order filter (1 biquad section)
242259

243260
let mut filter = Self {
244261
center_freq,
245262
bandwidth,
246263
sample_rate,
247264
order,
248-
a_coeffs: Vec::new(),
249-
b_coeffs: Vec::new(),
265+
biquad_coeffs: Vec::new(),
266+
biquad_states: std::sync::RwLock::new(Vec::new()),
250267
};
251268

252269
filter.compute_coefficients();
253270
filter
254271
}
255272

273+
/// Reset the filter's internal state
274+
///
275+
/// Clears all delay elements and state variables, allowing the filter
276+
/// to start processing from a clean state. This is useful when processing
277+
/// discontinuous signals or when you want to avoid transients from previous processing.
278+
///
279+
/// ### Examples
280+
///
281+
/// ```
282+
/// use rust_photoacoustic::preprocessing::filters::{Filter, BandpassFilter};
283+
///
284+
/// let filter = BandpassFilter::new(1000.0, 200.0);
285+
/// let signal1 = vec![1.0, 0.5, -0.3];
286+
/// let _output1 = filter.apply(&signal1);
287+
///
288+
/// // Reset state before processing new signal
289+
/// filter.reset_state();
290+
/// let signal2 = vec![0.8, -0.2, 0.4];
291+
/// let _output2 = filter.apply(&signal2);
292+
/// ```
293+
pub fn reset_state(&self) {
294+
let mut states = self.biquad_states.write().unwrap();
295+
for state in states.iter_mut() {
296+
state.z1 = 0.0;
297+
state.z2 = 0.0;
298+
}
299+
}
300+
256301
/// Set the sample rate for the filter
257302
///
258303
/// Updates the sample rate and recomputes the filter coefficients accordingly.
@@ -442,47 +487,71 @@ impl BandpassFilter {
442487
/// - Multiple sections are cascaded to achieve higher orders
443488
/// - Coefficients are normalized for optimal numerical precision
444489
fn compute_coefficients(&mut self) {
445-
// Clear existing coefficients
446-
self.a_coeffs.clear();
447-
self.b_coeffs.clear();
490+
// Clear existing coefficients and states
491+
self.biquad_coeffs.clear();
492+
self.biquad_states.write().unwrap().clear();
448493

449-
// Convert to angular frequency
450494
let fs = self.sample_rate as f32;
451-
let w0 = 2.0 * std::f32::consts::PI * self.center_freq / fs;
452-
// Q factor calculation (relates to bandwidth)
453-
let q = self.center_freq / self.bandwidth;
454-
let alpha = w0.sin() / (2.0 * q);
455-
456-
// Calculate biquad coefficients for a single second-order section
457-
// For a bandpass filter, we have:
458-
let b0 = alpha;
459-
let b1 = 0.0;
460-
let b2 = -alpha;
461-
let a0 = 1.0 + alpha;
462-
let a1 = -2.0 * w0.cos();
463-
let a2 = 1.0 - alpha;
464-
465-
// Normalize by a0
466-
let b0_norm = b0 / a0;
467-
let b1_norm = b1 / a0;
468-
let b2_norm = b2 / a0;
469-
let a1_norm = a1 / a0;
470-
let a2_norm = a2 / a0;
471-
472-
// For higher order filters, we'd cascade multiple biquad sections
473-
// For simplicity, we're implementing just one second-order section
474-
// In a real implementation, we'd calculate multiple sections based on the order
475-
476-
// For now, we'll just duplicate the same coefficients for each section
477-
for _ in 0..(self.order / 2) {
478-
// Each biquad section has 3 b coeffs and 3 a coeffs (with a0 normalized to 1)
479-
self.b_coeffs.push(b0_norm);
480-
self.b_coeffs.push(b1_norm);
481-
self.b_coeffs.push(b2_norm);
482-
483-
// a0 is always normalized to 1.0, so we don't store it
484-
self.a_coeffs.push(a1_norm);
485-
self.a_coeffs.push(a2_norm);
495+
let fc = self.center_freq;
496+
let bw = self.bandwidth;
497+
498+
// Number of biquad sections
499+
let n_sections = self.order / 2;
500+
501+
// For Butterworth bandpass filter, we'll create cascaded sections
502+
// Each section is a 2nd-order bandpass with slightly different Q factors
503+
// to achieve the overall Butterworth response
504+
505+
for k in 0..n_sections {
506+
// Calculate Q factor for this section to achieve Butterworth response
507+
// For higher order filters, distribute Q values appropriately
508+
let section_q = if n_sections == 1 {
509+
fc / bw // Standard Q for single section
510+
} else {
511+
// For multiple sections, use modified Q to maintain overall response
512+
let butterworth_q_factor = 1.0
513+
/ (2.0
514+
* (std::f32::consts::PI * (2.0 * k as f32 + 1.0)
515+
/ (4.0 * n_sections as f32))
516+
.sin());
517+
(fc / bw) * butterworth_q_factor
518+
};
519+
520+
// Calculate biquad coefficients using the standard bandpass formula
521+
let w0 = 2.0 * std::f32::consts::PI * fc / fs;
522+
let alpha = w0.sin() / (2.0 * section_q);
523+
524+
// Bandpass filter coefficients
525+
let b0 = alpha;
526+
let b1 = 0.0;
527+
let b2 = -alpha;
528+
let a0 = 1.0 + alpha;
529+
let a1 = -2.0 * w0.cos();
530+
let a2 = 1.0 - alpha;
531+
532+
// Normalize by a0 and store
533+
self.biquad_coeffs.push(BiquadCoeffs {
534+
b0: b0 / a0,
535+
b1: b1 / a0,
536+
b2: b2 / a0,
537+
a1: a1 / a0,
538+
a2: a2 / a0,
539+
});
540+
541+
// Initialize state variables for this section
542+
self.biquad_states
543+
.write()
544+
.unwrap()
545+
.push(BiquadState { z1: 0.0, z2: 0.0 });
546+
}
547+
548+
// Apply gain correction for multiple sections
549+
if n_sections > 1 {
550+
let gain_correction = (n_sections as f32).sqrt();
551+
for coeffs in &mut self.biquad_coeffs {
552+
coeffs.b0 *= gain_correction;
553+
coeffs.b2 *= gain_correction;
554+
}
486555
}
487556
}
488557
}
@@ -524,38 +593,32 @@ impl Filter for BandpassFilter {
524593
let mut filtered = Vec::with_capacity(signal.len());
525594

526595
// Ensure we have calculated coefficients
527-
if self.a_coeffs.is_empty() || self.b_coeffs.is_empty() {
596+
if self.biquad_coeffs.is_empty() {
528597
// Return the original signal if no coefficients are available
529598
return signal.to_vec();
530599
}
531600

532-
// Number of biquad sections
533-
let n_sections = self.order / 2;
534-
535-
// Initialize state variables for Direct Form II Transposed structure
536-
let mut z1 = vec![0.0f32; n_sections]; // z^-1 state for each section
537-
let mut z2 = vec![0.0f32; n_sections]; // z^-2 state for each section
601+
// Acquire write lock on states
602+
let mut states = self.biquad_states.write().unwrap();
538603

539604
// Process each sample through the cascade of biquad sections
540605
for &x in signal {
541606
let mut y = x;
542607

543608
// Apply each biquad section in cascade
544-
for section in 0..n_sections {
545-
// Get coefficients for this section
546-
let b0 = self.b_coeffs[section * 3];
547-
let b1 = self.b_coeffs[section * 3 + 1];
548-
let b2 = self.b_coeffs[section * 3 + 2];
549-
let a1 = self.a_coeffs[section * 2];
550-
let a2 = self.a_coeffs[section * 2 + 1];
551-
609+
for (section, coeffs) in self.biquad_coeffs.iter().enumerate() {
552610
// Direct Form II Transposed biquad implementation
553-
let y_section = b0 * y + z1[section];
554-
z1[section] = b1 * y - a1 * y_section + z2[section];
555-
z2[section] = b2 * y - a2 * y_section;
611+
let state = &mut states[section];
612+
613+
// Calculate output
614+
let y_out = coeffs.b0 * y + state.z1;
615+
616+
// Update state variables
617+
state.z1 = coeffs.b1 * y - coeffs.a1 * y_out + state.z2;
618+
state.z2 = coeffs.b2 * y - coeffs.a2 * y_out;
556619

557620
// Output of this section becomes input to the next section
558-
y = y_section;
621+
y = y_out;
559622
}
560623

561624
filtered.push(y);
@@ -1307,8 +1370,8 @@ mod tests {
13071370
assert_eq!(filter.bandwidth, 200.0);
13081371
assert_eq!(filter.sample_rate, 48000);
13091372
assert_eq!(filter.order, 2);
1310-
assert!(!filter.a_coeffs.is_empty());
1311-
assert!(!filter.b_coeffs.is_empty());
1373+
assert!(!filter.biquad_coeffs.is_empty());
1374+
assert!(!filter.biquad_states.read().unwrap().is_empty());
13121375
}
13131376

13141377
#[test]
@@ -1321,9 +1384,9 @@ mod tests {
13211384
fn test_bandpass_filter_with_order() {
13221385
let filter = BandpassFilter::new(1000.0, 200.0).with_order(6);
13231386
assert_eq!(filter.order, 6);
1324-
// Should have 3 sections (6/2), each with 3 b coeffs and 2 a coeffs
1325-
assert_eq!(filter.b_coeffs.len(), 9); // 3 sections * 3 coeffs
1326-
assert_eq!(filter.a_coeffs.len(), 6); // 3 sections * 2 coeffs
1387+
// Should have 3 sections (6/2)
1388+
assert_eq!(filter.biquad_coeffs.len(), 3); // 3 biquad sections
1389+
assert_eq!(filter.biquad_states.read().unwrap().len(), 3); // 3 state variables
13271390
}
13281391

13291392
#[test]

0 commit comments

Comments
 (0)