Skip to content

Commit a15dad4

Browse files
kuacakuacaalbertz
andauthored
Add RASR compatible feature extraction (#44)
Co-authored-by: Albert Zeyer <[email protected]>
1 parent ea4354c commit a15dad4

File tree

2 files changed

+933
-2
lines changed

2 files changed

+933
-2
lines changed

i6_models/primitives/feature_extraction.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1-
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]
1+
__all__ = [
2+
"LogMelFeatureExtractionV1",
3+
"LogMelFeatureExtractionV1Config",
4+
"RasrCompatibleLogMelFeatureExtractionV1",
5+
"RasrCompatibleLogMelFeatureExtractionV1Config",
6+
]
27

8+
import math
39
from dataclasses import dataclass
410
from typing import Optional, Tuple
511

@@ -111,3 +117,112 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
111117
length = ((length - self.n_fft) // self.hop_length) + 1
112118

113119
return feature_data, length.int()
120+
121+
122+
@dataclass
123+
class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration):
124+
"""
125+
Attributes:
126+
sample_rate: audio sample rate in Hz
127+
win_size: window size in seconds
128+
hop_size: window shift in seconds
129+
min_amp: minimum amplitude for safe log
130+
num_filters: number of mel windows
131+
alpha: preemphasis weight
132+
"""
133+
134+
sample_rate: int
135+
win_size: float
136+
hop_size: float
137+
min_amp: float
138+
num_filters: int
139+
alpha: float = 1.0
140+
141+
def __post_init__(self) -> None:
142+
super().__post_init__()
143+
assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive"
144+
assert self.num_filters > 0, "number of filters needs to be positive"
145+
assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense"
146+
147+
148+
class RasrCompatibleLogMelFeatureExtractionV1(nn.Module):
149+
"""
150+
Rasr-compatible log-mel feature extraction using log10. Does not use torchaudio.
151+
"""
152+
153+
def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config):
154+
super().__init__()
155+
156+
self.sample_rate = int(cfg.sample_rate)
157+
self.hop_length = int(cfg.hop_size * cfg.sample_rate)
158+
self.min_amp = cfg.min_amp
159+
self.win_length = int(cfg.win_size * cfg.sample_rate)
160+
self.n_fft = 2 ** math.ceil(
161+
math.log2(self.win_length)
162+
) # smallest power if two which is greater than or equal to win_length
163+
self.alpha = cfg.alpha
164+
165+
self.register_buffer(
166+
"mel_basis",
167+
torch.tensor(
168+
filters.mel(
169+
sr=cfg.sample_rate,
170+
n_fft=self.n_fft,
171+
n_mels=cfg.num_filters,
172+
fmin=0,
173+
fmax=cfg.sample_rate // 2,
174+
htk=True,
175+
norm=None,
176+
),
177+
),
178+
)
179+
self.register_buffer(
180+
"window", torch.hann_window(self.win_length, periodic=False, dtype=torch.float64).to(torch.float32)
181+
)
182+
183+
def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
184+
"""
185+
:param raw_audio: [B, T]
186+
:param length: in samples [B]
187+
:return features as [B,T,F] and length in frames [B]
188+
"""
189+
assert raw_audio.shape[1] > 0 # also same for length
190+
res_size = max(raw_audio.shape[1] - self.win_length + self.hop_length - 1, 0) // self.hop_length + 1
191+
res_length = (
192+
torch.maximum(length - self.win_length + self.hop_length - 1, torch.zeros_like(length)) // self.hop_length
193+
+ 1
194+
)
195+
196+
# preemphasize
197+
preemphasized = raw_audio.clone()
198+
preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1]
199+
preemphasized[..., 0] = 0.0
200+
201+
# zero pad for the last frame of each sequence in the batch
202+
last_win_size = length - (res_length - 1) * self.hop_length # [B]
203+
last_pad = self.win_length - last_win_size # [B]
204+
205+
# zero pad for the whole batch
206+
last_pad_batch = self.win_length - (preemphasized.shape[1] - (res_size - 1) * self.hop_length)
207+
padded = torch.nn.functional.pad(preemphasized, (0, last_pad_batch))
208+
209+
windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=self.win_length]
210+
211+
smoothed = windowed * self.window[None, None, :] # [B, T', W]
212+
213+
# The last window might be shorter. Will use a shorter Hanning window then. Need to fix that.
214+
for i, (last_w_size, last_p, res_l) in enumerate(zip(last_win_size, last_pad, res_length)):
215+
last_win = torch.hann_window(last_w_size, periodic=False, dtype=torch.float64).to(
216+
self.window.device, torch.float32
217+
)
218+
last_win = torch.nn.functional.pad(last_win, (0, last_p)) # [W]
219+
smoothed[i, res_l - 1] = windowed[i, res_l - 1] * last_win[None, :]
220+
221+
# compute amplitude spectrum using torch.fft.rfftn with Rasr specific scaling
222+
fft = torch.fft.rfftn(smoothed, s=self.n_fft) / self.sample_rate # [B, T', F=n_fft//2+1]
223+
amplitude_spectrum = torch.abs(fft) # [B, T', F=n_fft//2+1]
224+
225+
melspec = torch.einsum("...tf,mf->...tm", amplitude_spectrum, self.mel_basis) # [B, T', F'=num_filters]
226+
log_melspec = torch.log10(melspec + self.min_amp)
227+
228+
return log_melspec, res_length

0 commit comments

Comments
 (0)