Skip to content

Commit b7594cb

Browse files
author
Atanas Gruev
committed
torch blank collapse update
1 parent 8667edb commit b7594cb

File tree

2 files changed

+22
-46
lines changed

2 files changed

+22
-46
lines changed

users/gruev/implementations/pytorch/blank_collapse.py

Lines changed: 18 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -63,63 +63,36 @@ def blank_collapse_batched(logprobs, audio_features_len, blank_threshold, blank_
6363
blanks = logprobs[:, :, blank_idx] > blank_threshold # [B, T]
6464

6565
# For batches, adjust individual lengths by mapping paddings to True values in mask
66-
audio_lens_mask = (
67-
torch.arange(time_dim)[None, :] >= audio_features_len[:, None]
68-
) # [B, T]
66+
audio_lens_mask = torch.arange(time_dim)[None, :] >= audio_features_len[:, None] # [B, T]
6967
blanks = blanks | audio_lens_mask # [B, T]
7068

7169
# Obtain counts on initial and final blank frames
72-
sequence_mask, sequence_indices = (~blanks).nonzero(as_tuple=True) # tuple of [T',]
73-
_, sequence_bounds = torch.unique(sequence_mask, return_counts=True) # [B,]
70+
blanks_int = blanks.int() # torch does not support argmin/argmax on Bool
7471

75-
sequence_bounds = torch.cat(
76-
(torch.Tensor([0]).to(torch.int), torch.cumsum(sequence_bounds, dim=0))
77-
) # [B+1,]
72+
init_non_blank_idx = torch.argmin(blanks_int, dim=1) # [B,]
73+
final_non_blank_idx = torch.argmin(torch.fliplr(blanks_int), dim=1) # [B,]
74+
final_non_blank_idx = time_dim - final_non_blank_idx # [B,]
7875

79-
initial_blank_idx = sequence_indices[sequence_bounds[:-1]] # [B,]
80-
final_blank_idx = sequence_indices[(sequence_bounds - 1)[1:]] # [B,]
76+
# Logical-or between "(blanks & blanks_shift)" and "bounds_mask" to restore proper lengths
77+
bounds_range = torch.arange(time_dim).repeat(batch_dim, 1) # [B, T]
78+
bounds_mask = (bounds_range < init_non_blank_idx[:, None]) | (bounds_range >= final_non_blank_idx[:, None]) # [B, T]
8179

8280
# Logical-and between "blanks" and "blanks_shift" to account for label-blank-label case
8381
blanks_shift = torch.roll(blanks, shifts=-1, dims=1) # [B, T]
8482

85-
# Logical-or between "(blanks & blanks_shift)" and "bounds_mask" to restore proper lengths
86-
bounds_mask = torch.arange(time_dim).repeat(batch_dim, 1) # [B, T]
87-
bounds_mask_initial = bounds_mask < initial_blank_idx[:, None] # [B, T]
88-
bounds_mask_final = bounds_mask > final_blank_idx[:, None] # [B, T]
89-
bounds_mask = bounds_mask_initial | bounds_mask_final # [B, T]
90-
9183
# Logical-not to assign True to frames kept
9284
blanks = ~((blanks & blanks_shift) | bounds_mask) # [B, T]
9385

94-
# De-batchify and re-arrange based on changed lengths
95-
batch_mask, batch_indices = blanks.nonzero(as_tuple=True)
96-
_, collapsed_audio_features_len = torch.unique(
97-
batch_mask, return_counts=True
98-
) # [B,]
86+
# De-batchify and compute new time dimension to restore batching based on changed lengths
87+
_, batch_indices = blanks.nonzero(as_tuple=True)
88+
collapsed_audio_features_len = torch.sum(blanks, dim=1)
9989

100-
# Compute new time dimension to restore batching
101-
collapsed_time_dim = torch.max(collapsed_audio_features_len) # T''
102-
103-
# IMPORTANT: After blank collapse, permuting the batch might be necessary due to new audio lengths
104-
# batch_collapsed_order = torch.argsort(collapsed_audio_features_len, descending=True)
90+
# IMPORTANT: After blank collapse, the batch should not be permuted!
91+
# The padding pattern might change, but correspondence to target sequences is the same.
10592

10693
# Align mask and indices to match the collapsed audio lengths in sorted order
107-
batch_mask = torch.arange(batch_dim)[:, None].expand(
108-
batch_dim, collapsed_time_dim
109-
) # [B, T'']
110-
# batch_mask = batch_mask[batch_collapsed_order] # [B, T'']
111-
112-
batch_indices = torch.split(
113-
batch_indices, collapsed_audio_features_len.tolist()
114-
) # tuple (B,)
115-
batch_indices = torch.nn.utils.rnn.pad_sequence(
116-
batch_indices, batch_first=True
117-
) # [B, T'']
118-
# batch_indices = batch_indices[batch_collapsed_order] # [B, T'']
119-
120-
# Restore original order within the batch
121-
collapsed_logprobs = logprobs[batch_mask, batch_indices] # [B, T'', V+1]
122-
# collapsed_logprobs = permuted_logprobs[torch.argsort(batch_collapsed_order)] # [B, T'', V+1]
123-
# collapsed_audio_features_len = collapsed_audio_features_len[batch_collapsed_order] # [B, ]
124-
125-
return collapsed_logprobs, collapsed_audio_features_len
94+
batch_indices = torch.split(batch_indices, collapsed_audio_features_len.tolist()) # tuple (B,)
95+
batch_indices = torch.nn.utils.rnn.pad_sequence(batch_indices, batch_first=True) # [B, T'']
96+
97+
collapsed_logprobs = logprobs[torch.arange(batch_dim)[:, None], batch_indices] # [B, T'', V+1]
98+
return collapsed_logprobs, collapsed_audio_features_len

users/gruev/pytorch/models/i6_base_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def search_init_hook(run_ctx, **kwargs):
168168
import subprocess
169169

170170
arpa_lm = kwargs.get("arpa_lm", None)
171-
lm = subprocess.check_output(["cf", arpa_lm]).decode().strip() if arpa_lm is not None else None
171+
lm = subprocess.check_output(["cf", arpa_lm]).decode().strip() if arpa_lm else None
172172

173173
# Get labels directly, no need to load the vocab file
174174
labels = run_ctx.engine.forward_dataset.datasets["zip_dataset"].targets.labels
@@ -211,6 +211,9 @@ def search_step(*, model: ConformerCTCModel, data, run_ctx, **kwargs):
211211

212212
log_probs_list, audio_features_len = model(audio_features, audio_features_len)
213213

214+
from IPython import embed
215+
embed()
216+
214217
# see also model.forward()
215218
log_probs = log_probs_list[0]
216219

0 commit comments

Comments
 (0)