Skip to content

Commit d2c8a24

Browse files
JackTemakivieting
andauthored
add LogMelFeatureExtraction (#26)
Adds PyTorch-based log-mel feature extraction that is compatible to the librosa-based feature extraction in RETURNN. Co-authored-by: vieting <[email protected]>
1 parent 18d3f1a commit d2c8a24

File tree

3 files changed

+179
-1
lines changed

3 files changed

+179
-1
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]
2+
3+
from dataclasses import dataclass
4+
from typing import Optional, Tuple
5+
6+
from librosa import filters
7+
import torch
8+
from torch import nn
9+
10+
from i6_models.config import ModelConfiguration
11+
12+
13+
@dataclass
14+
class LogMelFeatureExtractionV1Config(ModelConfiguration):
15+
"""
16+
Attributes:
17+
sample_rate: audio sample rate in Hz
18+
win_size: window size in seconds
19+
hop_size: window shift in seconds
20+
f_min: minimum filter frequency in Hz
21+
f_max: maximum filter frequency in Hz
22+
min_amp: minimum amplitude for safe log
23+
num_filters: number of mel windows
24+
center: centered STFT with automatic padding
25+
"""
26+
27+
sample_rate: int
28+
win_size: float
29+
hop_size: float
30+
f_min: int
31+
f_max: int
32+
min_amp: float
33+
num_filters: int
34+
center: bool
35+
n_fft: Optional[int] = None
36+
37+
def __post_init__(self) -> None:
38+
super().__post_init__()
39+
assert self.f_max <= self.sample_rate // 2, "f_max can not be larger than half of the sample rate"
40+
assert self.f_min > 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive"
41+
assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive"
42+
assert self.num_filters > 0, "number of filters needs to be positive"
43+
assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense"
44+
if self.n_fft is None:
45+
# if n_fft is not given, set n_fft to the window size (in samples)
46+
self.n_fft = int(self.win_size * self.sample_rate)
47+
else:
48+
assert self.n_fft >= self.win_size * self.sample_rate, "n_fft cannot to be smaller than the window size"
49+
50+
51+
class LogMelFeatureExtractionV1(nn.Module):
52+
"""
53+
Librosa-compatible log-mel feature extraction using log10. Does not use torchaudio.
54+
55+
Using it wrapped with torch.no_grad() is recommended if no gradient is needed
56+
"""
57+
58+
def __init__(self, cfg: LogMelFeatureExtractionV1Config):
59+
super().__init__()
60+
self.register_buffer("n_fft", torch.tensor(cfg.n_fft))
61+
self.register_buffer("window", torch.hann_window(int(cfg.win_size * cfg.sample_rate)))
62+
self.register_buffer("hop_length", torch.tensor(int(cfg.hop_size * cfg.sample_rate)))
63+
self.register_buffer("min_amp", torch.tensor(cfg.min_amp))
64+
self.center = cfg.center
65+
self.register_buffer(
66+
"mel_basis",
67+
torch.tensor(
68+
filters.mel(
69+
sr=cfg.sample_rate,
70+
n_fft=int(cfg.sample_rate * cfg.win_size),
71+
n_mels=cfg.num_filters,
72+
fmin=cfg.f_min,
73+
fmax=cfg.f_max,
74+
)
75+
),
76+
)
77+
78+
def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
79+
"""
80+
:param raw_audio: [B, T]
81+
:param length in samples: [B]
82+
:return features as [B,T,F] and length in frames [B]
83+
"""
84+
power_spectrum = (
85+
torch.abs(
86+
torch.stft(
87+
raw_audio,
88+
n_fft=self.n_fft,
89+
hop_length=self.hop_length,
90+
window=self.window,
91+
center=self.center,
92+
pad_mode="constant",
93+
return_complex=True,
94+
)
95+
)
96+
** 2
97+
)
98+
if len(power_spectrum.size()) == 2:
99+
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
100+
power_spectrum = torch.unsqueeze(power_spectrum, 0)
101+
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis)
102+
log_melspec = torch.log10(torch.max(self.min_amp, melspec))
103+
feature_data = torch.transpose(log_melspec, 1, 2)
104+
105+
if self.center:
106+
length = (length // self.hop_length) + 1
107+
else:
108+
length = ((length - self.n_fft) // self.hop_length) + 1
109+
110+
return feature_data, length.int()

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
typeguard
2-
torch
2+
torch
3+
librosa

tests/test_feature_extraction.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import copy
2+
import numpy
3+
import torch
4+
5+
from librosa.feature import melspectrogram
6+
7+
from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1, LogMelFeatureExtractionV1Config
8+
9+
10+
def test_logmel_librosa_compatibility():
11+
12+
audio = numpy.asarray(numpy.random.random((50000)), dtype=numpy.float32)
13+
librosa_mel = melspectrogram(
14+
y=audio,
15+
sr=16000,
16+
n_fft=int(0.05 * 16000),
17+
hop_length=int(0.0125 * 16000),
18+
win_length=int(0.05 * 16000),
19+
fmin=60,
20+
fmax=7600,
21+
n_mels=80,
22+
)
23+
librosa_log_mel = numpy.log10(numpy.maximum(librosa_mel, 1e-10))
24+
25+
fe_cfg = LogMelFeatureExtractionV1Config(
26+
sample_rate=16000,
27+
win_size=0.05,
28+
hop_size=0.0125,
29+
f_min=60,
30+
f_max=7600,
31+
min_amp=1e-10,
32+
num_filters=80,
33+
center=True,
34+
)
35+
fe = LogMelFeatureExtractionV1(cfg=fe_cfg)
36+
audio_tensor = torch.unsqueeze(torch.Tensor(audio), 0) # [B, T]
37+
audio_length = torch.tensor([50000]) # [B]
38+
pytorch_log_mel, frame_length = fe(audio_tensor, audio_length)
39+
librosa_log_mel = torch.tensor(librosa_log_mel).transpose(0, 1)
40+
assert torch.allclose(librosa_log_mel, pytorch_log_mel, atol=1e-06)
41+
42+
43+
def test_logmel_length():
44+
fe_center_cfg = LogMelFeatureExtractionV1Config(
45+
sample_rate=16000,
46+
win_size=0.05,
47+
hop_size=0.0125,
48+
f_min=60,
49+
f_max=7600,
50+
min_amp=1e-10,
51+
num_filters=80,
52+
center=True,
53+
)
54+
fe_center = LogMelFeatureExtractionV1(cfg=fe_center_cfg)
55+
fe_no_center_cfg = copy.deepcopy(fe_center_cfg)
56+
fe_no_center_cfg.center = False
57+
fe_no_center = LogMelFeatureExtractionV1(cfg=fe_no_center_cfg)
58+
for i in range(10):
59+
audio_length = int(numpy.random.randint(10000, 50000))
60+
audio = numpy.asarray(numpy.random.random(audio_length), dtype=numpy.float32)
61+
audio_length = torch.tensor(int(audio_length))
62+
audio_length = torch.unsqueeze(audio_length, 0)
63+
audio = torch.unsqueeze(torch.tensor(audio), 0)
64+
mel_center, length_center = fe_center(audio, audio_length)
65+
assert torch.all(mel_center.size()[1] == length_center)
66+
mel_no_center, length_no_center = fe_no_center(audio, audio_length)
67+
assert torch.all(mel_no_center.size()[1] == length_no_center)

0 commit comments

Comments
 (0)