|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | + |
| 4 | + |
| 5 | +def _mask(tensor: torch.Tensor, batch_axis: int, axis: int, pos: torch.Tensor, max_len: int) -> torch.Tensor: |
| 6 | + """ |
| 7 | + :param tensor: e.g. [B, ..., A, ...] but arbitrary axis order |
| 8 | + :param batch_axis: index of the batch axis |
| 9 | + :param axis: which axis A to mask |
| 10 | + :param pos: at which positions along axis to start the mask (size [B]) |
| 11 | + :param max_len: mask length drawn uniformly from [0, max_len] |
| 12 | + """ |
| 13 | + batch_dim_size = tensor.shape[batch_axis] |
| 14 | + mask_dim_size = tensor.shape[axis] |
| 15 | + mask_len = torch.randint(low=1, high=max_len + 1, size=(batch_dim_size,), dtype=torch.int32, device=tensor.device) |
| 16 | + end_pos = torch.min(pos + mask_len, torch.tensor([mask_dim_size] * batch_dim_size, device=tensor.device)) |
| 17 | + idxs = torch.arange(0, mask_dim_size, device=tensor.device).unsqueeze(0) # [1,dim] |
| 18 | + pos_bc = pos.unsqueeze(1) # [B,1] |
| 19 | + end_pos_bc = end_pos.unsqueeze(1) # [B,1] |
| 20 | + mask = torch.logical_and(torch.greater_equal(idxs, pos_bc), torch.less(idxs, end_pos_bc)) # [B,dim] |
| 21 | + if batch_axis > axis: |
| 22 | + mask = mask.transpose(0, 1) # [dim,B] |
| 23 | + mask = torch.reshape(mask, shape=[tensor.shape[i] if i in (batch_axis, axis) else 1 for i in range(tensor.ndim)]) |
| 24 | + tensor = torch.where(mask, 0.0, tensor) |
| 25 | + return tensor |
| 26 | + |
| 27 | + |
| 28 | +def _random_mask(tensor: torch.Tensor, batch_axis: int, axis: int, min_num: int, max_num: int, max_len: int): |
| 29 | + """ |
| 30 | + Mask tensor along axis using N in [min_num, max_num] masks of length [0, max_len] |
| 31 | +
|
| 32 | + :param tensor: e.g. [B, ..., A, ...] but arbitrary axis order |
| 33 | + :param batch_axis: index of the batch axis |
| 34 | + :param axis: which axis to mask |
| 35 | + :param min_num: minimum number of masks |
| 36 | + :param max_num: maximum number of masks |
| 37 | + :param max_amount: mask length drawn uniformly from [0, max_amount] |
| 38 | + """ |
| 39 | + |
| 40 | + batch_dim_size = tensor.shape[batch_axis] |
| 41 | + if max_num < min_num: |
| 42 | + max_num = min_num |
| 43 | + num_masks = torch.randint(min_num, max_num + 1, size=(batch_dim_size,), device="cpu") # [B] |
| 44 | + |
| 45 | + max_num_masks = num_masks.max().item() |
| 46 | + |
| 47 | + z = torch.rand((batch_dim_size, tensor.shape[axis]), device=tensor.device) # [B,dim] |
| 48 | + _, indices = torch.topk(z, max_num_masks, dim=1) |
| 49 | + |
| 50 | + # Make num_masks broadcastable to shape of tensor for torch.where. |
| 51 | + num_masks = torch.reshape(num_masks, [1] * batch_axis + [batch_dim_size] + [1] * (tensor.dim() - batch_axis - 1)) |
| 52 | + |
| 53 | + num_masks = num_masks.to(device=tensor.device) |
| 54 | + |
| 55 | + for i in range(max_num_masks): |
| 56 | + tensor = torch.where(i < num_masks, _mask(tensor, batch_axis, axis, indices[:, i], max_len), tensor) |
| 57 | + |
| 58 | + return tensor |
| 59 | + |
| 60 | + |
| 61 | +def specaugment_v1( |
| 62 | + audio_features: torch.Tensor, |
| 63 | + *, |
| 64 | + time_min_num_masks: int, |
| 65 | + time_max_num_masks: int, |
| 66 | + time_mask_max_size: int, |
| 67 | + freq_min_num_masks: int, |
| 68 | + freq_max_num_masks: int, |
| 69 | + freq_mask_max_size: int, |
| 70 | +): |
| 71 | + """ |
| 72 | + Specaugment from legacy rossenbach/zeineldeen/zeyer attention setups e.g., |
| 73 | + https://github.com/rwth-i6/i6_experiments/blob/main/users/zeineldeen/data_aug/specaugment/specaug_tf2.py |
| 74 | + but without any step-based scheduling and without dependence on length. |
| 75 | + See `specaugment_v1_by_length` for a variant which is more close to the original. |
| 76 | +
|
| 77 | + Fills masks with zeros. |
| 78 | +
|
| 79 | + Basically just a convenience wrapper around _random_mask. |
| 80 | +
|
| 81 | + See also: https://arxiv.org/abs/1904.08779 |
| 82 | +
|
| 83 | + :param audio_features: e.g. log-mel features as [B, T, F] |
| 84 | + :param time_min_num_masks: minimum number of masks along T |
| 85 | + :param time_max_num_masks: maximum number of masks along T |
| 86 | + :param time_mask_max_size: maximum size of masks along T |
| 87 | + :param freq_min_num_masks: minimum number of masks along F |
| 88 | + :param freq_max_num_masks: maximum number of masks along F |
| 89 | + :param freq_mask_max_size: maximum size of masks along F |
| 90 | + :return: masked audio features |
| 91 | + """ |
| 92 | + assert len(audio_features.shape) == 3 |
| 93 | + assert time_min_num_masks <= time_max_num_masks |
| 94 | + assert freq_min_num_masks <= freq_max_num_masks |
| 95 | + masked_audio_features = _random_mask( |
| 96 | + audio_features, 0, 1, time_min_num_masks, time_max_num_masks, time_mask_max_size |
| 97 | + ) # time masking |
| 98 | + masked_audio_features = _random_mask( |
| 99 | + masked_audio_features, 0, 2, freq_min_num_masks, freq_max_num_masks, freq_mask_max_size |
| 100 | + ) # freq masking |
| 101 | + return masked_audio_features |
| 102 | + |
| 103 | + |
| 104 | +def specaugment_v1_by_length( |
| 105 | + audio_features: torch.Tensor, |
| 106 | + *, |
| 107 | + time_min_num_masks: int, |
| 108 | + time_max_mask_per_n_frames: int, |
| 109 | + time_mask_max_size: int, |
| 110 | + freq_min_num_masks: int, |
| 111 | + freq_max_num_masks: int, |
| 112 | + freq_mask_max_size: int, |
| 113 | +): |
| 114 | + """ |
| 115 | + Convenience wrapper around specaugment_v1 with time-length adaptive number of masks. |
| 116 | +
|
| 117 | + :param audio_features: e.g. log-mel features as [B, T, F] |
| 118 | + :param time_max_mask_per_n_frames: used for the maximum number time masks, |
| 119 | + max_num_masks = T / max_mask_per_n_frames for each batch. |
| 120 | + They are still drawn depending on the full batch length, so shorter sequences |
| 121 | + might get more masks than that by chance, or none at all when all masks |
| 122 | + fall into the padding space. |
| 123 | + :param time_min_num_masks: minimum number of masks along T |
| 124 | + :param time_mask_max_size: maximum size of masks along T |
| 125 | + :param freq_min_num_masks: minimum number of masks along F |
| 126 | + :param freq_max_num_masks: maximum number of masks along F |
| 127 | + :param freq_mask_max_size: maximum size of masks along F |
| 128 | + :return: masked audio features |
| 129 | + """ |
| 130 | + return specaugment_v1( |
| 131 | + audio_features, |
| 132 | + time_min_num_masks=time_min_num_masks, |
| 133 | + time_max_num_masks=np.maximum(audio_features.size(1) // time_max_mask_per_n_frames, time_min_num_masks), |
| 134 | + time_mask_max_size=time_mask_max_size, |
| 135 | + freq_min_num_masks=freq_min_num_masks, |
| 136 | + freq_max_num_masks=freq_max_num_masks, |
| 137 | + freq_mask_max_size=freq_mask_max_size, |
| 138 | + ) |
0 commit comments