|
1 | | -__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"] |
| 1 | +__all__ = [ |
| 2 | + "LogMelFeatureExtractionV1", |
| 3 | + "LogMelFeatureExtractionV1Config", |
| 4 | + "RasrCompatibleLogMelFeatureExtractionV1", |
| 5 | + "RasrCompatibleLogMelFeatureExtractionV1Config", |
| 6 | +] |
2 | 7 |
|
| 8 | +import math |
3 | 9 | from dataclasses import dataclass |
4 | 10 | from typing import Optional, Tuple |
5 | 11 |
|
@@ -111,3 +117,112 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: |
111 | 117 | length = ((length - self.n_fft) // self.hop_length) + 1 |
112 | 118 |
|
113 | 119 | return feature_data, length.int() |
| 120 | + |
| 121 | + |
| 122 | +@dataclass |
| 123 | +class RasrCompatibleLogMelFeatureExtractionV1Config(ModelConfiguration): |
| 124 | + """ |
| 125 | + Attributes: |
| 126 | + sample_rate: audio sample rate in Hz |
| 127 | + win_size: window size in seconds |
| 128 | + hop_size: window shift in seconds |
| 129 | + min_amp: minimum amplitude for safe log |
| 130 | + num_filters: number of mel windows |
| 131 | + alpha: preemphasis weight |
| 132 | + """ |
| 133 | + |
| 134 | + sample_rate: int |
| 135 | + win_size: float |
| 136 | + hop_size: float |
| 137 | + min_amp: float |
| 138 | + num_filters: int |
| 139 | + alpha: float = 1.0 |
| 140 | + |
| 141 | + def __post_init__(self) -> None: |
| 142 | + super().__post_init__() |
| 143 | + assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive" |
| 144 | + assert self.num_filters > 0, "number of filters needs to be positive" |
| 145 | + assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense" |
| 146 | + |
| 147 | + |
| 148 | +class RasrCompatibleLogMelFeatureExtractionV1(nn.Module): |
| 149 | + """ |
| 150 | + Rasr-compatible log-mel feature extraction using log10. Does not use torchaudio. |
| 151 | + """ |
| 152 | + |
| 153 | + def __init__(self, cfg: RasrCompatibleLogMelFeatureExtractionV1Config): |
| 154 | + super().__init__() |
| 155 | + |
| 156 | + self.sample_rate = int(cfg.sample_rate) |
| 157 | + self.hop_length = int(cfg.hop_size * cfg.sample_rate) |
| 158 | + self.min_amp = cfg.min_amp |
| 159 | + self.win_length = int(cfg.win_size * cfg.sample_rate) |
| 160 | + self.n_fft = 2 ** math.ceil( |
| 161 | + math.log2(self.win_length) |
| 162 | + ) # smallest power if two which is greater than or equal to win_length |
| 163 | + self.alpha = cfg.alpha |
| 164 | + |
| 165 | + self.register_buffer( |
| 166 | + "mel_basis", |
| 167 | + torch.tensor( |
| 168 | + filters.mel( |
| 169 | + sr=cfg.sample_rate, |
| 170 | + n_fft=self.n_fft, |
| 171 | + n_mels=cfg.num_filters, |
| 172 | + fmin=0, |
| 173 | + fmax=cfg.sample_rate // 2, |
| 174 | + htk=True, |
| 175 | + norm=None, |
| 176 | + ), |
| 177 | + ), |
| 178 | + ) |
| 179 | + self.register_buffer( |
| 180 | + "window", torch.hann_window(self.win_length, periodic=False, dtype=torch.float64).to(torch.float32) |
| 181 | + ) |
| 182 | + |
| 183 | + def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: |
| 184 | + """ |
| 185 | + :param raw_audio: [B, T] |
| 186 | + :param length: in samples [B] |
| 187 | + :return features as [B,T,F] and length in frames [B] |
| 188 | + """ |
| 189 | + assert raw_audio.shape[1] > 0 # also same for length |
| 190 | + res_size = max(raw_audio.shape[1] - self.win_length + self.hop_length - 1, 0) // self.hop_length + 1 |
| 191 | + res_length = ( |
| 192 | + torch.maximum(length - self.win_length + self.hop_length - 1, torch.zeros_like(length)) // self.hop_length |
| 193 | + + 1 |
| 194 | + ) |
| 195 | + |
| 196 | + # preemphasize |
| 197 | + preemphasized = raw_audio.clone() |
| 198 | + preemphasized[..., 1:] -= self.alpha * preemphasized[..., :-1] |
| 199 | + preemphasized[..., 0] = 0.0 |
| 200 | + |
| 201 | + # zero pad for the last frame of each sequence in the batch |
| 202 | + last_win_size = length - (res_length - 1) * self.hop_length # [B] |
| 203 | + last_pad = self.win_length - last_win_size # [B] |
| 204 | + |
| 205 | + # zero pad for the whole batch |
| 206 | + last_pad_batch = self.win_length - (preemphasized.shape[1] - (res_size - 1) * self.hop_length) |
| 207 | + padded = torch.nn.functional.pad(preemphasized, (0, last_pad_batch)) |
| 208 | + |
| 209 | + windowed = padded.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=self.win_length] |
| 210 | + |
| 211 | + smoothed = windowed * self.window[None, None, :] # [B, T', W] |
| 212 | + |
| 213 | + # The last window might be shorter. Will use a shorter Hanning window then. Need to fix that. |
| 214 | + for i, (last_w_size, last_p, res_l) in enumerate(zip(last_win_size, last_pad, res_length)): |
| 215 | + last_win = torch.hann_window(last_w_size, periodic=False, dtype=torch.float64).to( |
| 216 | + self.window.device, torch.float32 |
| 217 | + ) |
| 218 | + last_win = torch.nn.functional.pad(last_win, (0, last_p)) # [W] |
| 219 | + smoothed[i, res_l - 1] = windowed[i, res_l - 1] * last_win[None, :] |
| 220 | + |
| 221 | + # compute amplitude spectrum using torch.fft.rfftn with Rasr specific scaling |
| 222 | + fft = torch.fft.rfftn(smoothed, s=self.n_fft) / self.sample_rate # [B, T', F=n_fft//2+1] |
| 223 | + amplitude_spectrum = torch.abs(fft) # [B, T', F=n_fft//2+1] |
| 224 | + |
| 225 | + melspec = torch.einsum("...tf,mf->...tm", amplitude_spectrum, self.mel_basis) # [B, T', F'=num_filters] |
| 226 | + log_melspec = torch.log10(melspec + self.min_amp) |
| 227 | + |
| 228 | + return log_melspec, res_length |
0 commit comments