Skip to content

Commit c363a01

Browse files
authored
Extend feature extraction module to allow for RASR compatible logmel features (#40)
1 parent 8f4c364 commit c363a01

File tree

1 file changed

+59
-23
lines changed

1 file changed

+59
-23
lines changed

i6_models/primitives/feature_extraction.py

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]
22

33
from dataclasses import dataclass
4-
from typing import Optional, Tuple
4+
from typing import Optional, Tuple, Union, Literal
5+
from enum import Enum
56

67
from librosa import filters
78
import torch
89
from torch import nn
10+
import numpy as np
11+
from numpy.typing import DTypeLike
912

1013
from i6_models.config import ModelConfiguration
1114

1215

16+
class SpectrumType(Enum):
17+
STFT = 1
18+
RFFTN = 2
19+
20+
1321
@dataclass
1422
class LogMelFeatureExtractionV1Config(ModelConfiguration):
1523
"""
@@ -22,6 +30,10 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
2230
min_amp: minimum amplitude for safe log
2331
num_filters: number of mel windows
2432
center: centered STFT with automatic padding
33+
periodic: whether the window is assumed to be periodic
34+
htk: whether use HTK formula instead of Slaney
35+
norm: how to normalize the filters, cf. https://librosa.org/doc/main/generated/librosa.filters.mel.html
36+
spectrum_type: apply torch.stft on raw audio input (default) or torch.fft.rfftn on windowed audio to make features compatible to RASR's
2537
"""
2638

2739
sample_rate: int
@@ -33,6 +45,11 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
3345
num_filters: int
3446
center: bool
3547
n_fft: Optional[int] = None
48+
periodic: bool = True
49+
htk: bool = False
50+
norm: Optional[Union[Literal["slaney"], float]] = "slaney"
51+
dtype: DTypeLike = np.float32
52+
spectrum_type: SpectrumType = SpectrumType.STFT
3653

3754
def __post_init__(self) -> None:
3855
super().__post_init__()
@@ -62,6 +79,7 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
6279
self.min_amp = cfg.min_amp
6380
self.n_fft = cfg.n_fft
6481
self.win_length = int(cfg.win_size * cfg.sample_rate)
82+
self.spectrum_type = cfg.spectrum_type
6583

6684
self.register_buffer(
6785
"mel_basis",
@@ -72,42 +90,60 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
7290
n_mels=cfg.num_filters,
7391
fmin=cfg.f_min,
7492
fmax=cfg.f_max,
75-
)
93+
htk=cfg.htk,
94+
norm=cfg.norm,
95+
dtype=cfg.dtype,
96+
),
7697
),
7798
)
78-
self.register_buffer("window", torch.hann_window(self.win_length))
99+
self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic))
79100

80101
def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
81102
"""
82103
:param raw_audio: [B, T]
83104
:param length in samples: [B]
84105
:return features as [B,T,F] and length in frames [B]
85106
"""
86-
power_spectrum = (
87-
torch.abs(
88-
torch.stft(
89-
raw_audio,
90-
n_fft=self.n_fft,
91-
hop_length=self.hop_length,
92-
win_length=self.win_length,
93-
window=self.window,
94-
center=self.center,
95-
pad_mode="constant",
96-
return_complex=True,
107+
if self.spectrum_type == SpectrumType.STFT:
108+
power_spectrum = (
109+
torch.abs(
110+
torch.stft(
111+
raw_audio,
112+
n_fft=self.n_fft,
113+
hop_length=self.hop_length,
114+
win_length=self.win_length,
115+
window=self.window,
116+
center=self.center,
117+
pad_mode="constant",
118+
return_complex=True,
119+
)
97120
)
121+
** 2
98122
)
99-
** 2
100-
)
123+
elif self.spectrum_type == SpectrumType.RFFTN:
124+
windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
125+
smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]
126+
127+
# Compute power spectrum using torch.fft.rfftn
128+
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
129+
power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
130+
else:
131+
raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.")
132+
101133
if len(power_spectrum.size()) == 2:
102134
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
103-
power_spectrum = torch.unsqueeze(power_spectrum, 0)
104-
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis)
135+
power_spectrum = torch.unsqueeze(power_spectrum, 0) # [B, F, T']
136+
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T']
105137
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
106-
feature_data = torch.transpose(log_melspec, 1, 2)
138+
feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']
107139

108-
if self.center:
109-
length = (length // self.hop_length) + 1
140+
if self.spectrum_type == SpectrumType.STFT:
141+
if self.center:
142+
length = (length // self.hop_length) + 1
143+
else:
144+
length = ((length - self.n_fft) // self.hop_length) + 1
145+
elif self.spectrum_type == SpectrumType.RFFTN:
146+
length = ((length - self.win_length) // self.hop_length) + 1
110147
else:
111-
length = ((length - self.n_fft) // self.hop_length) + 1
112-
148+
raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.")
113149
return feature_data, length.int()

0 commit comments

Comments
 (0)