@@ -63,63 +63,36 @@ def blank_collapse_batched(logprobs, audio_features_len, blank_threshold, blank_
6363 blanks = logprobs [:, :, blank_idx ] > blank_threshold # [B, T]
6464
6565 # For batches, adjust individual lengths by mapping paddings to True values in mask
66- audio_lens_mask = (
67- torch .arange (time_dim )[None , :] >= audio_features_len [:, None ]
68- ) # [B, T]
66+ audio_lens_mask = torch .arange (time_dim )[None , :] >= audio_features_len [:, None ] # [B, T]
6967 blanks = blanks | audio_lens_mask # [B, T]
7068
7169 # Obtain counts on initial and final blank frames
72- sequence_mask , sequence_indices = (~ blanks ).nonzero (as_tuple = True ) # tuple of [T',]
73- _ , sequence_bounds = torch .unique (sequence_mask , return_counts = True ) # [B,]
70+ blanks_int = blanks .int () # torch does not support argmin/argmax on Bool
7471
75- sequence_bounds = torch .cat (
76- ( torch .Tensor ([ 0 ]). to ( torch .int ), torch . cumsum ( sequence_bounds , dim = 0 ))
77- ) # [B+1 ,]
72+ init_non_blank_idx = torch .argmin ( blanks_int , dim = 1 ) # [B,]
73+ final_non_blank_idx = torch .argmin ( torch .fliplr ( blanks_int ), dim = 1 ) # [B,]
74+ final_non_blank_idx = time_dim - final_non_blank_idx # [B,]
7875
79- initial_blank_idx = sequence_indices [sequence_bounds [:- 1 ]] # [B,]
80- final_blank_idx = sequence_indices [(sequence_bounds - 1 )[1 :]] # [B,]
76+ # Logical-or between "(blanks & blanks_shift)" and "bounds_mask" to restore proper lengths
77+ bounds_range = torch .arange (time_dim ).repeat (batch_dim , 1 ) # [B, T]
78+ bounds_mask = (bounds_range < init_non_blank_idx [:, None ]) | (bounds_range >= final_non_blank_idx [:, None ]) # [B, T]
8179
8280 # Logical-and between "blanks" and "blanks_shift" to account for label-blank-label case
8381 blanks_shift = torch .roll (blanks , shifts = - 1 , dims = 1 ) # [B, T]
8482
85- # Logical-or between "(blanks & blanks_shift)" and "bounds_mask" to restore proper lengths
86- bounds_mask = torch .arange (time_dim ).repeat (batch_dim , 1 ) # [B, T]
87- bounds_mask_initial = bounds_mask < initial_blank_idx [:, None ] # [B, T]
88- bounds_mask_final = bounds_mask > final_blank_idx [:, None ] # [B, T]
89- bounds_mask = bounds_mask_initial | bounds_mask_final # [B, T]
90-
9183 # Logical-not to assign True to frames kept
9284 blanks = ~ ((blanks & blanks_shift ) | bounds_mask ) # [B, T]
9385
94- # De-batchify and re-arrange based on changed lengths
95- batch_mask , batch_indices = blanks .nonzero (as_tuple = True )
96- _ , collapsed_audio_features_len = torch .unique (
97- batch_mask , return_counts = True
98- ) # [B,]
86+ # De-batchify and compute new time dimension to restore batching based on changed lengths
87+ _ , batch_indices = blanks .nonzero (as_tuple = True )
88+ collapsed_audio_features_len = torch .sum (blanks , dim = 1 )
9989
100- # Compute new time dimension to restore batching
101- collapsed_time_dim = torch .max (collapsed_audio_features_len ) # T''
102-
103- # IMPORTANT: After blank collapse, permuting the batch might be necessary due to new audio lengths
104- # batch_collapsed_order = torch.argsort(collapsed_audio_features_len, descending=True)
90+ # IMPORTANT: After blank collapse, the batch should not be permuted!
91+ # The padding pattern might change, but correspondence to target sequences is the same.
10592
10693 # Align mask and indices to match the collapsed audio lengths in sorted order
107- batch_mask = torch .arange (batch_dim )[:, None ].expand (
108- batch_dim , collapsed_time_dim
109- ) # [B, T'']
110- # batch_mask = batch_mask[batch_collapsed_order] # [B, T'']
111-
112- batch_indices = torch .split (
113- batch_indices , collapsed_audio_features_len .tolist ()
114- ) # tuple (B,)
115- batch_indices = torch .nn .utils .rnn .pad_sequence (
116- batch_indices , batch_first = True
117- ) # [B, T'']
118- # batch_indices = batch_indices[batch_collapsed_order] # [B, T'']
119-
120- # Restore original order within the batch
121- collapsed_logprobs = logprobs [batch_mask , batch_indices ] # [B, T'', V+1]
122- # collapsed_logprobs = permuted_logprobs[torch.argsort(batch_collapsed_order)] # [B, T'', V+1]
123- # collapsed_audio_features_len = collapsed_audio_features_len[batch_collapsed_order] # [B, ]
124-
125- return collapsed_logprobs , collapsed_audio_features_len
94+ batch_indices = torch .split (batch_indices , collapsed_audio_features_len .tolist ()) # tuple (B,)
95+ batch_indices = torch .nn .utils .rnn .pad_sequence (batch_indices , batch_first = True ) # [B, T'']
96+
97+ collapsed_logprobs = logprobs [torch .arange (batch_dim )[:, None ], batch_indices ] # [B, T'', V+1]
98+ return collapsed_logprobs , collapsed_audio_features_len
0 commit comments