@@ -587,15 +587,29 @@ def subsequent_chunk_mask(
587
587
num_lookhead : int = 0 ,
588
588
) -> torch .Tensor :
589
589
ret = torch .zeros (size , size , device = device , dtype = torch .bool )
590
- for i in range (size ):
591
- if num_left_chunks < 0 :
592
- start = 0
593
- else :
594
- start = max ((i // chunk_size - num_left_chunks ) * chunk_size ,
595
- 0 )
596
- ending = min ((i // chunk_size + 1 ) * chunk_size + num_lookhead ,
597
- size )
598
- ret [i , start :ending ] = True
590
+ # Vectorized computation of row indices and chunk boundaries
591
+ row_indices = torch .arange (size , device = device )
592
+ chunk_indices = row_indices // chunk_size
593
+ if num_left_chunks < 0 :
594
+ # If num_left_chunks < 0, start is always 0 for all rows
595
+ start_indices = torch .zeros_like (row_indices )
596
+ else :
597
+ # Compute start indices vectorially
598
+ start_chunk_indices = torch .clamp (chunk_indices - num_left_chunks ,
599
+ min = 0 )
600
+ start_indices = start_chunk_indices * chunk_size
601
+ # Compute ending indices vectorially
602
+ end_chunk_indices = chunk_indices + 1
603
+ end_indices = torch .clamp (end_chunk_indices * chunk_size +
604
+ num_lookhead ,
605
+ max = size )
606
+ # Create column indices for broadcasting
607
+ col_indices = torch .arange (size , device = device ).unsqueeze (0 )
608
+ row_indices = row_indices .unsqueeze (1 )
609
+ start_indices = start_indices .unsqueeze (1 )
610
+ end_indices = end_indices .unsqueeze (1 )
611
+ # Vectorized mask creation
612
+ ret = (col_indices >= start_indices ) & (col_indices < end_indices )
599
613
return ret
600
614
601
615
def _get_feat_extract_output_lengths (self ,
0 commit comments