11from returnn .tf .util .basic import get_shape_dim , check_input_dim
2+ from returnn .tf .util .data import Data
3+ from returnn .tf .layers .basic import LayerBase
4+ import tensorflow as tf
5+ from typing import List
26
37
4- def rna_alignment (source , ** kwargs ):
8+ def rna_alignment (source , ** kwargs ) -> tf .Tensor :
9+ """ Used only to create alignments according to RNA loss function.
10+ B: batch, T: time, U:target/labels, V: vocabulary
11+ Args:
12+ source: function (i: int, as_data: bool = False, ...) -> tf.Tensor|Data
13+ which returns one of:
14+ output_log_prob: [B, T, U+1, V] log-probabilities
15+ real_target: [B, U] -> [V] target labels
16+ base:encoder: [B, T, Feat] -> [V] encoder output
17+ Returns:
18+ alignment: [B, T] which holds a value from interval (0:blank_ix) for each alignment frame
519 """
6- Used only to create alignments according to RNA loss function.
7- :sources: [output_log_prob, real_target, "base:encoder"]
8- :return: alignments: [B, T] for each frame a value in [0:blank_ix]
9- """
10- # acts: (B, T, U, V)
11- # targets: (B, U-1)
12- # input_lengths (B,)
13- # label_lengths (B,)
1420 from .rna_align_sum_max_pure_tf import tf_forward_shifted_rna
1521
1622 log_probs = source (0 , as_data = True , auto_convert = False ).get_placeholder_as_batch_major ()
@@ -28,22 +34,23 @@ def rna_alignment(source, **kwargs):
2834 # "log-probs:", tf.shape(log_probs.get_placeholder_as_batch_major())], summarize=-1)
2935
3036 blank_idx = targets .dim # targets is without blank
31- costs , alignment = tf_forward_shifted_rna (log_probs , targets .get_placeholder_as_batch_major (), enc_lens , dec_lens ,
32- blank_index = blank_idx , debug = False , with_alignment = True )
33- return alignment # ( B, T)
37+ _ , alignment = tf_forward_shifted_rna (log_probs , targets .get_placeholder_as_batch_major (), enc_lens , dec_lens ,
38+ blank_index = blank_idx , debug = False , with_alignment = True )
39+ return alignment # [ B, T]
3440
3541
36- def rnnt_alignment (source , ** kwargs ):
37- """
38- Used only to create alignments according to RNNT loss function.
39- :sources: [output_log_prob, real_target, "base:encoder"]
40- :return: alignments: [B, T] for each frame a value in [0:blank_ix]
42+ def rnnt_alignment (source , ** kwargs ) -> tf .Tensor :
43+ """ Used only to create alignments according to RNNT loss function.
44+ B: batch, T: time, U:target/labels, V: vocabulary
45+ Args:
46+ source: function (i: int, as_data: bool = False, ...) -> tf.Tensor|Data
47+ which returns one of:
48+ output_log_prob: [B, T, U+1, V] log-probabilities
49+ real_target: [B, U] -> [V] target labels
50+ base:encoder: [B, T, Feat] -> [V] encoder output
51+ Returns:
52+ alignment: [B, T+U] which holds a value from interval (0:blank_ix) for each alignment frame
4153 """
42- # alignment-length (B,T+U+1)
43- # acts: (B, T, U+1, V)
44- # targets: (B, U)
45- # input_lengths (B,)
46- # label_lengths (B,)
4754 from .rnnt_align_sum_max_pure_tf import tf_forward_shifted_rnnt
4855
4956 log_probs = source (0 , as_data = True , auto_convert = False ).get_placeholder_as_batch_major ()
@@ -63,4 +70,34 @@ def rnnt_alignment(source, **kwargs):
6370 blank_idx = targets .dim
6471 _ , alignment = tf_forward_shifted_rnnt (log_probs , targets .get_placeholder_as_batch_major (), enc_lens , dec_lens ,
6572 blank_index = blank_idx , debug = False , with_alignment = True )
66- return alignment # (B, T)
73+ return alignment # [B, T+U]
74+
75+
76+ def rna_alignment_out_type (sources : List [LayerBase ], ** _kwargs ) -> Data :
77+ """ Computes the rna-alignment Data_out_type for RNA alignment
78+ B: batch, T: time, U:target/labels, V: vocabulary
79+ Args:
80+ sources:
81+ output_log_prob: [B, T, U+1, V] log-probabilities
82+ real_target: [B, U] -> [V] target labels
83+ base:encoder: [B, T, Feat] -> [V] encoder output
84+ Returns:
85+ alignment [B, T]
86+ """
87+ return Data (name = "rna_alignment_output" , sparse = True , dim = sources [0 ].output .dim ,
88+ size_placeholder = {0 : sources [2 ].output .size_placeholder [0 ]})
89+
90+
91+ def rnnt_alignment_out_type (sources : List [LayerBase ], ** _kwargs ) -> Data :
92+ """ Computes the rnnt-alignment Data_out_type for RNNT alignment
93+ B: batch, T: time, U:target/labels, V: vocabulary
94+ Args:
95+ sources:
96+ output_log_prob: [B, T, U+1, V] log-probabilities
97+ real_target: [B, U] -> [V] target labels
98+ base:encoder: [B, T, Feat] -> [V] encoder output
99+ Returns:
100+ alignment [B, T+U]
101+ """
102+ return Data (name = "rnnt_alignment_output" , sparse = True , dim = sources [0 ].output .dim ,
103+ size_placeholder = {0 : sources [1 ].output .size_placeholder [0 ] + sources [2 ].output .size_placeholder [0 ]})
0 commit comments