11__all__ = ["LogMelFeatureExtractionV1" , "LogMelFeatureExtractionV1Config" ]
22
33from dataclasses import dataclass
4- from typing import Optional , Tuple
4+ from typing import Optional , Tuple , Union , Literal
5+ from enum import Enum
56
67from librosa import filters
78import torch
89from torch import nn
10+ import numpy as np
11+ from numpy .typing import DTypeLike
912
1013from i6_models .config import ModelConfiguration
1114
1215
16+ class SpectrumType (Enum ):
17+ STFT = 1
18+ RFFTN = 2
19+
20+
1321@dataclass
1422class LogMelFeatureExtractionV1Config (ModelConfiguration ):
1523 """
@@ -22,6 +30,10 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
2230 min_amp: minimum amplitude for safe log
2331 num_filters: number of mel windows
2432 center: centered STFT with automatic padding
33+ periodic: whether the window is assumed to be periodic
34+ htk: whether use HTK formula instead of Slaney
35+ norm: how to normalize the filters, cf. https://librosa.org/doc/main/generated/librosa.filters.mel.html
36+ spectrum_type: apply torch.stft on raw audio input (default) or torch.fft.rfftn on windowed audio to make features compatible to RASR's
2537 """
2638
2739 sample_rate : int
@@ -33,6 +45,11 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
3345 num_filters : int
3446 center : bool
3547 n_fft : Optional [int ] = None
48+ periodic : bool = True
49+ htk : bool = False
50+ norm : Optional [Union [Literal ["slaney" ], float ]] = "slaney"
51+ dtype : DTypeLike = np .float32
52+ spectrum_type : SpectrumType = SpectrumType .STFT
3653
3754 def __post_init__ (self ) -> None :
3855 super ().__post_init__ ()
@@ -62,6 +79,7 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
6279 self .min_amp = cfg .min_amp
6380 self .n_fft = cfg .n_fft
6481 self .win_length = int (cfg .win_size * cfg .sample_rate )
82+ self .spectrum_type = cfg .spectrum_type
6583
6684 self .register_buffer (
6785 "mel_basis" ,
@@ -72,42 +90,60 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
7290 n_mels = cfg .num_filters ,
7391 fmin = cfg .f_min ,
7492 fmax = cfg .f_max ,
75- )
93+ htk = cfg .htk ,
94+ norm = cfg .norm ,
95+ dtype = cfg .dtype ,
96+ ),
7697 ),
7798 )
78- self .register_buffer ("window" , torch .hann_window (self .win_length ))
99+ self .register_buffer ("window" , torch .hann_window (self .win_length , periodic = cfg . periodic ))
79100
80101 def forward (self , raw_audio , length ) -> Tuple [torch .Tensor , torch .Tensor ]:
81102 """
82103 :param raw_audio: [B, T]
83104 :param length in samples: [B]
84105 :return features as [B,T,F] and length in frames [B]
85106 """
86- power_spectrum = (
87- torch .abs (
88- torch .stft (
89- raw_audio ,
90- n_fft = self .n_fft ,
91- hop_length = self .hop_length ,
92- win_length = self .win_length ,
93- window = self .window ,
94- center = self .center ,
95- pad_mode = "constant" ,
96- return_complex = True ,
107+ if self .spectrum_type == SpectrumType .STFT :
108+ power_spectrum = (
109+ torch .abs (
110+ torch .stft (
111+ raw_audio ,
112+ n_fft = self .n_fft ,
113+ hop_length = self .hop_length ,
114+ win_length = self .win_length ,
115+ window = self .window ,
116+ center = self .center ,
117+ pad_mode = "constant" ,
118+ return_complex = True ,
119+ )
97120 )
121+ ** 2
98122 )
99- ** 2
100- )
123+ elif self .spectrum_type == SpectrumType .RFFTN :
124+ windowed = raw_audio .unfold (1 , size = self .win_length , step = self .hop_length ) # [B, T', W=win_length]
125+ smoothed = windowed * self .window .unsqueeze (0 ) # [B, T', W]
126+
127+ # Compute power spectrum using torch.fft.rfftn
128+ power_spectrum = torch .abs (torch .fft .rfftn (smoothed , s = self .n_fft )) ** 2 # [B, T', F=n_fft//2+1]
129+ power_spectrum = power_spectrum .transpose (1 , 2 ) # [B, F, T']
130+ else :
131+ raise ValueError (f"Invalid spectrum type { self .spectrum_type !r} ." )
132+
101133 if len (power_spectrum .size ()) == 2 :
102134 # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
103- power_spectrum = torch .unsqueeze (power_spectrum , 0 )
104- melspec = torch .einsum ("...ft,mf->...mt" , power_spectrum , self .mel_basis )
135+ power_spectrum = torch .unsqueeze (power_spectrum , 0 ) # [B, F, T']
136+ melspec = torch .einsum ("...ft,mf->...mt" , power_spectrum , self .mel_basis ) # [B, F'=num_filters, T']
105137 log_melspec = torch .log10 (torch .clamp (melspec , min = self .min_amp ))
106- feature_data = torch .transpose (log_melspec , 1 , 2 )
138+ feature_data = torch .transpose (log_melspec , 1 , 2 ) # [B, T', F']
107139
108- if self .center :
109- length = (length // self .hop_length ) + 1
140+ if self .spectrum_type == SpectrumType .STFT :
141+ if self .center :
142+ length = (length // self .hop_length ) + 1
143+ else :
144+ length = ((length - self .n_fft ) // self .hop_length ) + 1
145+ elif self .spectrum_type == SpectrumType .RFFTN :
146+ length = ((length - self .win_length ) // self .hop_length ) + 1
110147 else :
111- length = ((length - self .n_fft ) // self .hop_length ) + 1
112-
148+ raise ValueError (f"Invalid spectrum type { self .spectrum_type !r} ." )
113149 return feature_data , length .int ()
0 commit comments