11__all__ = ["LogMelFeatureExtractionV1" , "LogMelFeatureExtractionV1Config" ]
22
33from dataclasses import dataclass
4- from typing import Optional , Tuple , Union , Literal
5- from enum import Enum
4+ from typing import Optional , Tuple
65
76from librosa import filters
87import torch
98from torch import nn
10- import numpy as np
11- from numpy .typing import DTypeLike
129
1310from i6_models .config import ModelConfiguration
1411
1512
16- class SpectrumType (Enum ):
17- STFT = 1
18- RFFTN = 2
19-
20-
2113@dataclass
2214class LogMelFeatureExtractionV1Config (ModelConfiguration ):
2315 """
@@ -30,10 +22,6 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
3022 min_amp: minimum amplitude for safe log
3123 num_filters: number of mel windows
3224 center: centered STFT with automatic padding
33- periodic: whether the window is assumed to be periodic
34- htk: whether use HTK formula instead of Slaney
35- norm: how to normalize the filters, cf. https://librosa.org/doc/main/generated/librosa.filters.mel.html
36- spectrum_type: apply torch.stft on raw audio input (default) or torch.fft.rfftn on windowed audio to make features compatible to RASR's
3725 """
3826
3927 sample_rate : int
@@ -45,11 +33,6 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
4533 num_filters : int
4634 center : bool
4735 n_fft : Optional [int ] = None
48- periodic : bool = True
49- htk : bool = False
50- norm : Optional [Union [Literal ["slaney" ], float ]] = "slaney"
51- dtype : DTypeLike = np .float32
52- spectrum_type : SpectrumType = SpectrumType .STFT
5336
5437 def __post_init__ (self ) -> None :
5538 super ().__post_init__ ()
@@ -79,7 +62,6 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
7962 self .min_amp = cfg .min_amp
8063 self .n_fft = cfg .n_fft
8164 self .win_length = int (cfg .win_size * cfg .sample_rate )
82- self .spectrum_type = cfg .spectrum_type
8365
8466 self .register_buffer (
8567 "mel_basis" ,
@@ -90,60 +72,42 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
9072 n_mels = cfg .num_filters ,
9173 fmin = cfg .f_min ,
9274 fmax = cfg .f_max ,
93- htk = cfg .htk ,
94- norm = cfg .norm ,
95- dtype = cfg .dtype ,
96- ),
75+ )
9776 ),
9877 )
99- self .register_buffer ("window" , torch .hann_window (self .win_length , periodic = cfg . periodic ))
78+ self .register_buffer ("window" , torch .hann_window (self .win_length ))
10079
10180 def forward (self , raw_audio , length ) -> Tuple [torch .Tensor , torch .Tensor ]:
10281 """
10382 :param raw_audio: [B, T]
10483 :param length in samples: [B]
10584 :return features as [B,T,F] and length in frames [B]
10685 """
107- if self .spectrum_type == SpectrumType .STFT :
108- power_spectrum = (
109- torch .abs (
110- torch .stft (
111- raw_audio ,
112- n_fft = self .n_fft ,
113- hop_length = self .hop_length ,
114- win_length = self .win_length ,
115- window = self .window ,
116- center = self .center ,
117- pad_mode = "constant" ,
118- return_complex = True ,
119- )
86+ power_spectrum = (
87+ torch .abs (
88+ torch .stft (
89+ raw_audio ,
90+ n_fft = self .n_fft ,
91+ hop_length = self .hop_length ,
92+ win_length = self .win_length ,
93+ window = self .window ,
94+ center = self .center ,
95+ pad_mode = "constant" ,
96+ return_complex = True ,
12097 )
121- ** 2
12298 )
123- elif self .spectrum_type == SpectrumType .RFFTN :
124- windowed = raw_audio .unfold (1 , size = self .win_length , step = self .hop_length ) # [B, T', W=win_length]
125- smoothed = windowed * self .window .unsqueeze (0 ) # [B, T', W]
126-
127- # Compute power spectrum using torch.fft.rfftn
128- power_spectrum = torch .abs (torch .fft .rfftn (smoothed , s = self .n_fft )) ** 2 # [B, T', F=n_fft//2+1]
129- power_spectrum = power_spectrum .transpose (1 , 2 ) # [B, F, T']
130- else :
131- raise ValueError (f"Invalid spectrum type { self .spectrum_type !r} ." )
132-
99+ ** 2
100+ )
133101 if len (power_spectrum .size ()) == 2 :
134102 # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
135- power_spectrum = torch .unsqueeze (power_spectrum , 0 ) # [B, F, T']
136- melspec = torch .einsum ("...ft,mf->...mt" , power_spectrum , self .mel_basis ) # [B, F'=num_filters, T']
103+ power_spectrum = torch .unsqueeze (power_spectrum , 0 )
104+ melspec = torch .einsum ("...ft,mf->...mt" , power_spectrum , self .mel_basis )
137105 log_melspec = torch .log10 (torch .clamp (melspec , min = self .min_amp ))
138- feature_data = torch .transpose (log_melspec , 1 , 2 ) # [B, T', F']
106+ feature_data = torch .transpose (log_melspec , 1 , 2 )
139107
140- if self .spectrum_type == SpectrumType .STFT :
141- if self .center :
142- length = (length // self .hop_length ) + 1
143- else :
144- length = ((length - self .n_fft ) // self .hop_length ) + 1
145- elif self .spectrum_type == SpectrumType .RFFTN :
146- length = ((length - self .win_length ) // self .hop_length ) + 1
108+ if self .center :
109+ length = (length // self .hop_length ) + 1
147110 else :
148- raise ValueError (f"Invalid spectrum type { self .spectrum_type !r} ." )
111+ length = ((length - self .n_fft ) // self .hop_length ) + 1
112+
149113 return feature_data , length .int ()
0 commit comments