Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/fft_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub use rustfft::num_complex::Complex;

// For microfft backend, we define our own Complex type
#[cfg(feature = "microfft-backend")]
#[derive(Debug, Clone, Copy, PartialEq)]
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub struct Complex<T> {
pub re: T,
pub im: T,
Expand Down Expand Up @@ -144,7 +144,7 @@ impl FftNum for f32 {}
impl FftNum for f64 {}

/// Trait abstracting FFT operations for both forward and inverse transforms
pub trait FftBackend<T: FftNum>: Send + Sync {
pub trait FftBackend<T: FftNum>: Send + Sync + core::fmt::Debug {
/// Process FFT in-place
fn process(&self, buffer: &mut [Complex<T>]);

Expand Down Expand Up @@ -183,6 +183,14 @@ mod rustfft_impl {
fft: Arc<dyn Fft<T>>,
}

impl<T: rustfft::FftNum> core::fmt::Debug for RustFftWrapper<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("RustFftWrapper")
.field("fft_size", &self.fft.len())
.finish()
}
}

impl<T: FftNum> FftBackend<T> for RustFftWrapper<T> {
fn process(&self, buffer: &mut [Complex<T>]) {
// Safety: rustfft::num_complex::Complex and our re-exported Complex
Expand Down Expand Up @@ -243,10 +251,12 @@ mod microfft_impl {

/// Wrapper around microfft that implements our FftBackend trait
/// Note: microfft only supports f32 and power-of-2 sizes up to 4096
#[derive(Debug, Clone)]
struct MicroFftForward {
size: usize,
}

#[derive(Debug, Clone)]
struct MicroFftInverse {
size: usize,
}
Expand Down
19 changes: 13 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub mod prelude {
};
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReconstructionMode {
/// Overlap-Add: normalize by sum(w), requires COLA condition
Ola,
Expand All @@ -75,7 +75,7 @@ pub enum ReconstructionMode {
Wola,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WindowType {
Hann,
Hamming,
Expand Down Expand Up @@ -122,14 +122,14 @@ impl<T: Float + fmt::Display + fmt::Debug> fmt::Display for ConfigError<T> {
#[cfg(feature = "std")]
impl<T: Float + fmt::Display + fmt::Debug> std::error::Error for ConfigError<T> {}

#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PadMode {
Reflect,
Zero,
Edge,
}

#[derive(Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct StftConfig<T: Float> {
pub fft_size: usize,
pub hop_size: usize,
Expand Down Expand Up @@ -286,6 +286,7 @@ impl<T: Float + FromPrimitive + fmt::Debug> StftConfig<T> {
}

/// Builder for StftConfig with fluent API
#[derive(Debug, Clone, PartialEq)]
pub struct StftConfigBuilder<T: Float> {
fft_size: Option<usize>,
hop_size: Option<usize>,
Expand Down Expand Up @@ -384,7 +385,7 @@ fn generate_window<T: Float + FromPrimitive>(window_type: WindowType, size: usiz
}
}

#[derive(Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct SpectrumFrame<T: Float> {
pub freq_bins: usize,
pub data: Vec<Complex<T>>,
Expand Down Expand Up @@ -474,7 +475,7 @@ impl<T: Float> SpectrumFrame<T> {
}
}

#[derive(Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct Spectrum<T: Float> {
pub num_frames: usize,
pub freq_bins: usize,
Expand Down Expand Up @@ -609,6 +610,7 @@ impl<T: Float> Spectrum<T> {
}
}

#[derive(Debug, Clone)]
pub struct BatchStft<T: Float + FftNum> {
config: StftConfig<T>,
window: Vec<T>,
Expand Down Expand Up @@ -821,6 +823,7 @@ impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchStft<T> {
}
}

#[derive(Debug, Clone)]
pub struct BatchIstft<T: Float + FftNum> {
config: StftConfig<T>,
window: Vec<T>,
Expand Down Expand Up @@ -1085,6 +1088,7 @@ impl<T: Float + FftNum + FromPrimitive + fmt::Debug> BatchIstft<T> {
}
}

#[derive(Debug, Clone)]
pub struct StreamingStft<T: Float + FftNum> {
config: StftConfig<T>,
window: Vec<T>,
Expand Down Expand Up @@ -1235,6 +1239,7 @@ impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingStft<T> {
}

/// Multi-channel streaming STFT processor with independent state per channel.
#[derive(Debug, Clone)]
pub struct MultiChannelStreamingStft<T: Float + FftNum> {
processors: Vec<StreamingStft<T>>,
}
Expand Down Expand Up @@ -1333,6 +1338,7 @@ where
}
}

#[derive(Debug, Clone)]
pub struct StreamingIstft<T: Float + FftNum> {
config: StftConfig<T>,
window: Vec<T>,
Expand Down Expand Up @@ -1568,6 +1574,7 @@ impl<T: Float + FftNum + FromPrimitive + fmt::Debug> StreamingIstft<T> {
}

/// Multi-channel streaming iSTFT processor with independent state per channel.
#[derive(Debug, Clone)]
pub struct MultiChannelStreamingIstft<T: Float + FftNum> {
processors: Vec<StreamingIstft<T>>,
}
Expand Down
12 changes: 7 additions & 5 deletions src/mel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use core::fmt;
use num_traits::Float;

/// Mel scale variant for frequency conversion.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum MelScale {
/// HTK mel scale formula: 2595 * log10(1 + hz/700)
Htk,
Expand All @@ -25,7 +25,7 @@ pub enum MelScale {
}

/// Normalization method for mel filterbank.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum MelNorm {
/// No normalization
None,
Expand All @@ -35,7 +35,7 @@ pub enum MelNorm {
}

/// Configuration for mel spectrogram computation.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct MelConfig<T: Float> {
/// Number of mel bands (default: 80 for speech/Whisper)
pub n_mels: usize,
Expand Down Expand Up @@ -138,7 +138,7 @@ pub fn mel_to_hz<T: Float>(mel: T, scale: MelScale) -> T {
///
/// Stores triangular filters as a sparse matrix where each mel bin
/// has weights for the relevant STFT frequency bins.
#[derive(Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct MelFilterbank<T: Float> {
/// Number of mel bands
pub n_mels: usize,
Expand Down Expand Up @@ -302,7 +302,7 @@ impl<T: Float + fmt::Debug> MelFilterbank<T> {
/// Mel-scale spectrum data structure.
///
/// Stores mel spectrogram as (num_frames x n_mels) in row-major order.
#[derive(Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct MelSpectrum<T: Float> {
/// Number of time frames
pub num_frames: usize,
Expand Down Expand Up @@ -498,6 +498,7 @@ impl<T: Float> MelSpectrum<T> {
/// Batch mel spectrogram processor.
///
/// Converts STFT Spectrum to mel-scale representation.
#[derive(Debug, Clone)]
pub struct BatchMelSpectrogram<T: Float> {
filterbank: MelFilterbank<T>,
use_power: bool,
Expand Down Expand Up @@ -596,6 +597,7 @@ impl<T: Float + fmt::Debug> BatchMelSpectrogram<T> {
/// Streaming mel spectrogram processor.
///
/// Processes individual STFT frames into mel-scale frames.
#[derive(Debug, Clone)]
pub struct StreamingMelSpectrogram<T: Float> {
filterbank: MelFilterbank<T>,
use_power: bool,
Expand Down