Skip to content

Commit 933c6c1

Browse files
authored
Revert "Extend feature extraction module to allow for RASR compatible logmel features (#40)" (#41)
This reverts commit c363a01. Previous commit breakes hashes/setups.
1 parent c363a01 commit 933c6c1

File tree

1 file changed

+23
-59
lines changed

1 file changed

+23
-59
lines changed
Lines changed: 23 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
11
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]
22

33
from dataclasses import dataclass
4-
from typing import Optional, Tuple, Union, Literal
5-
from enum import Enum
4+
from typing import Optional, Tuple
65

76
from librosa import filters
87
import torch
98
from torch import nn
10-
import numpy as np
11-
from numpy.typing import DTypeLike
129

1310
from i6_models.config import ModelConfiguration
1411

1512

16-
class SpectrumType(Enum):
17-
STFT = 1
18-
RFFTN = 2
19-
20-
2113
@dataclass
2214
class LogMelFeatureExtractionV1Config(ModelConfiguration):
2315
"""
@@ -30,10 +22,6 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
3022
min_amp: minimum amplitude for safe log
3123
num_filters: number of mel windows
3224
center: centered STFT with automatic padding
33-
periodic: whether the window is assumed to be periodic
34-
htk: whether use HTK formula instead of Slaney
35-
norm: how to normalize the filters, cf. https://librosa.org/doc/main/generated/librosa.filters.mel.html
36-
spectrum_type: apply torch.stft on raw audio input (default) or torch.fft.rfftn on windowed audio to make features compatible to RASR's
3725
"""
3826

3927
sample_rate: int
@@ -45,11 +33,6 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
4533
num_filters: int
4634
center: bool
4735
n_fft: Optional[int] = None
48-
periodic: bool = True
49-
htk: bool = False
50-
norm: Optional[Union[Literal["slaney"], float]] = "slaney"
51-
dtype: DTypeLike = np.float32
52-
spectrum_type: SpectrumType = SpectrumType.STFT
5336

5437
def __post_init__(self) -> None:
5538
super().__post_init__()
@@ -79,7 +62,6 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
7962
self.min_amp = cfg.min_amp
8063
self.n_fft = cfg.n_fft
8164
self.win_length = int(cfg.win_size * cfg.sample_rate)
82-
self.spectrum_type = cfg.spectrum_type
8365

8466
self.register_buffer(
8567
"mel_basis",
@@ -90,60 +72,42 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
9072
n_mels=cfg.num_filters,
9173
fmin=cfg.f_min,
9274
fmax=cfg.f_max,
93-
htk=cfg.htk,
94-
norm=cfg.norm,
95-
dtype=cfg.dtype,
96-
),
75+
)
9776
),
9877
)
99-
self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic))
78+
self.register_buffer("window", torch.hann_window(self.win_length))
10079

10180
def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
10281
"""
10382
:param raw_audio: [B, T]
10483
:param length in samples: [B]
10584
:return features as [B,T,F] and length in frames [B]
10685
"""
107-
if self.spectrum_type == SpectrumType.STFT:
108-
power_spectrum = (
109-
torch.abs(
110-
torch.stft(
111-
raw_audio,
112-
n_fft=self.n_fft,
113-
hop_length=self.hop_length,
114-
win_length=self.win_length,
115-
window=self.window,
116-
center=self.center,
117-
pad_mode="constant",
118-
return_complex=True,
119-
)
86+
power_spectrum = (
87+
torch.abs(
88+
torch.stft(
89+
raw_audio,
90+
n_fft=self.n_fft,
91+
hop_length=self.hop_length,
92+
win_length=self.win_length,
93+
window=self.window,
94+
center=self.center,
95+
pad_mode="constant",
96+
return_complex=True,
12097
)
121-
** 2
12298
)
123-
elif self.spectrum_type == SpectrumType.RFFTN:
124-
windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
125-
smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]
126-
127-
# Compute power spectrum using torch.fft.rfftn
128-
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
129-
power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
130-
else:
131-
raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.")
132-
99+
** 2
100+
)
133101
if len(power_spectrum.size()) == 2:
134102
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
135-
power_spectrum = torch.unsqueeze(power_spectrum, 0) # [B, F, T']
136-
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T']
103+
power_spectrum = torch.unsqueeze(power_spectrum, 0)
104+
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis)
137105
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
138-
feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']
106+
feature_data = torch.transpose(log_melspec, 1, 2)
139107

140-
if self.spectrum_type == SpectrumType.STFT:
141-
if self.center:
142-
length = (length // self.hop_length) + 1
143-
else:
144-
length = ((length - self.n_fft) // self.hop_length) + 1
145-
elif self.spectrum_type == SpectrumType.RFFTN:
146-
length = ((length - self.win_length) // self.hop_length) + 1
108+
if self.center:
109+
length = (length // self.hop_length) + 1
147110
else:
148-
raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.")
111+
length = ((length - self.n_fft) // self.hop_length) + 1
112+
149113
return feature_data, length.int()

0 commit comments

Comments
 (0)