Skip to content

Commit 72ae8ac

Browse files
committed
Merge branch 'main' into zeineldeen_att_decoder
2 parents 1002af4 + 48ade42 commit 72ae8ac

File tree

3 files changed

+149
-8
lines changed

3 files changed

+149
-8
lines changed

i6_models/primitives/feature_extraction.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
3737
def __post_init__(self) -> None:
3838
super().__post_init__()
3939
assert self.f_max <= self.sample_rate // 2, "f_max can not be larger than half of the sample rate"
40-
assert self.f_min > 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive"
40+
assert self.f_min >= 0 and self.f_max > 0 and self.sample_rate > 0, "frequencies need to be positive"
4141
assert self.win_size > 0 and self.hop_size > 0, "window settings need to be positive"
4242
assert self.num_filters > 0, "number of filters needs to be positive"
4343
assert self.hop_size <= self.win_size, "using a larger hop size than window size does not make sense"
@@ -57,23 +57,25 @@ class LogMelFeatureExtractionV1(nn.Module):
5757

5858
def __init__(self, cfg: LogMelFeatureExtractionV1Config):
5959
super().__init__()
60-
self.register_buffer("n_fft", torch.tensor(cfg.n_fft))
61-
self.register_buffer("window", torch.hann_window(int(cfg.win_size * cfg.sample_rate)))
62-
self.register_buffer("hop_length", torch.tensor(int(cfg.hop_size * cfg.sample_rate)))
63-
self.register_buffer("min_amp", torch.tensor(cfg.min_amp))
6460
self.center = cfg.center
61+
self.hop_length = int(cfg.hop_size * cfg.sample_rate)
62+
self.min_amp = cfg.min_amp
63+
self.n_fft = cfg.n_fft
64+
self.win_length = int(cfg.win_size * cfg.sample_rate)
65+
6566
self.register_buffer(
6667
"mel_basis",
6768
torch.tensor(
6869
filters.mel(
6970
sr=cfg.sample_rate,
70-
n_fft=int(cfg.sample_rate * cfg.win_size),
71+
n_fft=cfg.n_fft,
7172
n_mels=cfg.num_filters,
7273
fmin=cfg.f_min,
7374
fmax=cfg.f_max,
7475
)
7576
),
7677
)
78+
self.register_buffer("window", torch.hann_window(self.win_length))
7779

7880
def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
7981
"""
@@ -87,6 +89,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
8789
raw_audio,
8890
n_fft=self.n_fft,
8991
hop_length=self.hop_length,
92+
win_length=self.win_length,
9093
window=self.window,
9194
center=self.center,
9295
pad_mode="constant",
@@ -99,7 +102,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
99102
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
100103
power_spectrum = torch.unsqueeze(power_spectrum, 0)
101104
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis)
102-
log_melspec = torch.log10(torch.max(self.min_amp, melspec))
105+
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
103106
feature_data = torch.transpose(log_melspec, 1, 2)
104107

105108
if self.center:
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import numpy as np
2+
import torch
3+
4+
5+
def _mask(tensor: torch.Tensor, batch_axis: int, axis: int, pos: torch.Tensor, max_len: int) -> torch.Tensor:
6+
"""
7+
:param tensor: e.g. [B, ..., A, ...] but arbitrary axis order
8+
:param batch_axis: index of the batch axis
9+
:param axis: which axis A to mask
10+
:param pos: at which positions along axis to start the mask (size [B])
11+
:param max_len: mask length drawn uniformly from [0, max_len]
12+
"""
13+
batch_dim_size = tensor.shape[batch_axis]
14+
mask_dim_size = tensor.shape[axis]
15+
mask_len = torch.randint(low=1, high=max_len + 1, size=(batch_dim_size,), dtype=torch.int32, device=tensor.device)
16+
end_pos = torch.min(pos + mask_len, torch.tensor([mask_dim_size] * batch_dim_size, device=tensor.device))
17+
idxs = torch.arange(0, mask_dim_size, device=tensor.device).unsqueeze(0) # [1,dim]
18+
pos_bc = pos.unsqueeze(1) # [B,1]
19+
end_pos_bc = end_pos.unsqueeze(1) # [B,1]
20+
mask = torch.logical_and(torch.greater_equal(idxs, pos_bc), torch.less(idxs, end_pos_bc)) # [B,dim]
21+
if batch_axis > axis:
22+
mask = mask.transpose(0, 1) # [dim,B]
23+
mask = torch.reshape(mask, shape=[tensor.shape[i] if i in (batch_axis, axis) else 1 for i in range(tensor.ndim)])
24+
tensor = torch.where(mask, 0.0, tensor)
25+
return tensor
26+
27+
28+
def _random_mask(tensor: torch.Tensor, batch_axis: int, axis: int, min_num: int, max_num: int, max_len: int):
29+
"""
30+
Mask tensor along axis using N in [min_num, max_num] masks of length [0, max_len]
31+
32+
:param tensor: e.g. [B, ..., A, ...] but arbitrary axis order
33+
:param batch_axis: index of the batch axis
34+
:param axis: which axis to mask
35+
:param min_num: minimum number of masks
36+
:param max_num: maximum number of masks
37+
:param max_amount: mask length drawn uniformly from [0, max_amount]
38+
"""
39+
40+
batch_dim_size = tensor.shape[batch_axis]
41+
if max_num < min_num:
42+
max_num = min_num
43+
num_masks = torch.randint(min_num, max_num + 1, size=(batch_dim_size,), device="cpu") # [B]
44+
45+
max_num_masks = num_masks.max().item()
46+
47+
z = torch.rand((batch_dim_size, tensor.shape[axis]), device=tensor.device) # [B,dim]
48+
_, indices = torch.topk(z, max_num_masks, dim=1)
49+
50+
# Make num_masks broadcastable to shape of tensor for torch.where.
51+
num_masks = torch.reshape(num_masks, [1] * batch_axis + [batch_dim_size] + [1] * (tensor.dim() - batch_axis - 1))
52+
53+
num_masks = num_masks.to(device=tensor.device)
54+
55+
for i in range(max_num_masks):
56+
tensor = torch.where(i < num_masks, _mask(tensor, batch_axis, axis, indices[:, i], max_len), tensor)
57+
58+
return tensor
59+
60+
61+
def specaugment_v1(
62+
audio_features: torch.Tensor,
63+
*,
64+
time_min_num_masks: int,
65+
time_max_num_masks: int,
66+
time_mask_max_size: int,
67+
freq_min_num_masks: int,
68+
freq_max_num_masks: int,
69+
freq_mask_max_size: int,
70+
):
71+
"""
72+
Specaugment from legacy rossenbach/zeineldeen/zeyer attention setups e.g.,
73+
https://github.com/rwth-i6/i6_experiments/blob/main/users/zeineldeen/data_aug/specaugment/specaug_tf2.py
74+
but without any step-based scheduling and without dependence on length.
75+
See `specaugment_v1_by_length` for a variant which is more close to the original.
76+
77+
Fills masks with zeros.
78+
79+
Basically just a convenience wrapper around _random_mask.
80+
81+
See also: https://arxiv.org/abs/1904.08779
82+
83+
:param audio_features: e.g. log-mel features as [B, T, F]
84+
:param time_min_num_masks: minimum number of masks along T
85+
:param time_max_num_masks: maximum number of masks along T
86+
:param time_mask_max_size: maximum size of masks along T
87+
:param freq_min_num_masks: minimum number of masks along F
88+
:param freq_max_num_masks: maximum number of masks along F
89+
:param freq_mask_max_size: maximum size of masks along F
90+
:return: masked audio features
91+
"""
92+
assert len(audio_features.shape) == 3
93+
assert time_min_num_masks <= time_max_num_masks
94+
assert freq_min_num_masks <= freq_max_num_masks
95+
masked_audio_features = _random_mask(
96+
audio_features, 0, 1, time_min_num_masks, time_max_num_masks, time_mask_max_size
97+
) # time masking
98+
masked_audio_features = _random_mask(
99+
masked_audio_features, 0, 2, freq_min_num_masks, freq_max_num_masks, freq_mask_max_size
100+
) # freq masking
101+
return masked_audio_features
102+
103+
104+
def specaugment_v1_by_length(
105+
audio_features: torch.Tensor,
106+
*,
107+
time_min_num_masks: int,
108+
time_max_mask_per_n_frames: int,
109+
time_mask_max_size: int,
110+
freq_min_num_masks: int,
111+
freq_max_num_masks: int,
112+
freq_mask_max_size: int,
113+
):
114+
"""
115+
Convenience wrapper around specaugment_v1 with time-length adaptive number of masks.
116+
117+
:param audio_features: e.g. log-mel features as [B, T, F]
118+
:param time_max_mask_per_n_frames: used for the maximum number time masks,
119+
max_num_masks = T / max_mask_per_n_frames for each batch.
120+
They are still drawn depending on the full batch length, so shorter sequences
121+
might get more masks than that by chance, or none at all when all masks
122+
fall into the padding space.
123+
:param time_min_num_masks: minimum number of masks along T
124+
:param time_mask_max_size: maximum size of masks along T
125+
:param freq_min_num_masks: minimum number of masks along F
126+
:param freq_max_num_masks: maximum number of masks along F
127+
:param freq_mask_max_size: maximum size of masks along F
128+
:return: masked audio features
129+
"""
130+
return specaugment_v1(
131+
audio_features,
132+
time_min_num_masks=time_min_num_masks,
133+
time_max_num_masks=np.maximum(audio_features.size(1) // time_max_mask_per_n_frames, time_min_num_masks),
134+
time_mask_max_size=time_mask_max_size,
135+
freq_min_num_masks=freq_min_num_masks,
136+
freq_max_num_masks=freq_max_num_masks,
137+
freq_mask_max_size=freq_mask_max_size,
138+
)

tests/test_blstm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_blstm_onnx_export():
3333
"classes": {0: "batch", 1: "time"},
3434
},
3535
)
36-
session = ort.InferenceSession(f.name)
36+
session = ort.InferenceSession(f.name, providers=["CPUExecutionProvider"])
3737
outputs_onnx = torch.FloatTensor(
3838
session.run(None, {"data": dummy_data.numpy(), "data_len": dummy_data_len.numpy()})[0]
3939
)

0 commit comments

Comments
 (0)