Skip to content

Commit af48c96

Browse files
authored
add topology and document alignment and loss (#62)
1 parent 9c057aa commit af48c96

File tree

3 files changed

+153
-59
lines changed

3 files changed

+153
-59
lines changed
Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from returnn.tf.util.basic import get_shape_dim, check_input_dim
2+
from returnn.tf.util.data import Data
3+
from returnn.tf.layers.basic import LayerBase
4+
import tensorflow as tf
5+
from typing import List
26

37

4-
def rna_alignment(source, **kwargs):
8+
def rna_alignment(source, **kwargs) -> tf.Tensor:
9+
""" Used only to create alignments according to RNA loss function.
10+
B: batch, T: time, U:target/labels, V: vocabulary
11+
Args:
12+
source: function (i: int, as_data: bool = False, ...) -> tf.Tensor|Data
13+
which returns one of:
14+
output_log_prob: [B, T, U+1, V] log-probabilities
15+
real_target: [B, U] -> [V] target labels
16+
base:encoder: [B, T, Feat] -> [V] encoder output
17+
Returns:
18+
alignment: [B, T] which holds a value from interval (0:blank_ix) for each alignment frame
519
"""
6-
Used only to create alignments according to RNA loss function.
7-
:sources: [output_log_prob, real_target, "base:encoder"]
8-
:return: alignments: [B, T] for each frame a value in [0:blank_ix]
9-
"""
10-
# acts: (B, T, U, V)
11-
# targets: (B, U-1)
12-
# input_lengths (B,)
13-
# label_lengths (B,)
1420
from .rna_align_sum_max_pure_tf import tf_forward_shifted_rna
1521

1622
log_probs = source(0, as_data=True, auto_convert=False).get_placeholder_as_batch_major()
@@ -28,22 +34,23 @@ def rna_alignment(source, **kwargs):
2834
# "log-probs:", tf.shape(log_probs.get_placeholder_as_batch_major())], summarize=-1)
2935

3036
blank_idx = targets.dim # targets is without blank
31-
costs, alignment = tf_forward_shifted_rna(log_probs, targets.get_placeholder_as_batch_major(), enc_lens, dec_lens,
32-
blank_index=blank_idx, debug=False, with_alignment=True)
33-
return alignment # (B, T)
37+
_, alignment = tf_forward_shifted_rna(log_probs, targets.get_placeholder_as_batch_major(), enc_lens, dec_lens,
38+
blank_index=blank_idx, debug=False, with_alignment=True)
39+
return alignment # [B, T]
3440

3541

36-
def rnnt_alignment(source, **kwargs):
37-
"""
38-
Used only to create alignments according to RNNT loss function.
39-
:sources: [output_log_prob, real_target, "base:encoder"]
40-
:return: alignments: [B, T] for each frame a value in [0:blank_ix]
42+
def rnnt_alignment(source, **kwargs) -> tf.Tensor:
43+
""" Used only to create alignments according to RNNT loss function.
44+
B: batch, T: time, U:target/labels, V: vocabulary
45+
Args:
46+
source: function (i: int, as_data: bool = False, ...) -> tf.Tensor|Data
47+
which returns one of:
48+
output_log_prob: [B, T, U+1, V] log-probabilities
49+
real_target: [B, U] -> [V] target labels
50+
base:encoder: [B, T, Feat] -> [V] encoder output
51+
Returns:
52+
alignment: [B, T+U] which holds a value from interval (0:blank_ix) for each alignment frame
4153
"""
42-
# alignment-length (B,T+U+1)
43-
# acts: (B, T, U+1, V)
44-
# targets: (B, U)
45-
# input_lengths (B,)
46-
# label_lengths (B,)
4754
from .rnnt_align_sum_max_pure_tf import tf_forward_shifted_rnnt
4855

4956
log_probs = source(0, as_data=True, auto_convert=False).get_placeholder_as_batch_major()
@@ -63,4 +70,34 @@ def rnnt_alignment(source, **kwargs):
6370
blank_idx = targets.dim
6471
_, alignment = tf_forward_shifted_rnnt(log_probs, targets.get_placeholder_as_batch_major(), enc_lens, dec_lens,
6572
blank_index=blank_idx, debug=False, with_alignment=True)
66-
return alignment # (B, T)
73+
return alignment # [B, T+U]
74+
75+
76+
def rna_alignment_out_type(sources: List[LayerBase], **_kwargs) -> Data:
77+
""" Computes the rna-alignment Data_out_type for RNA alignment
78+
B: batch, T: time, U:target/labels, V: vocabulary
79+
Args:
80+
sources:
81+
output_log_prob: [B, T, U+1, V] log-probabilities
82+
real_target: [B, U] -> [V] target labels
83+
base:encoder: [B, T, Feat] -> [V] encoder output
84+
Returns:
85+
alignment [B, T]
86+
"""
87+
return Data(name="rna_alignment_output", sparse=True, dim=sources[0].output.dim,
88+
size_placeholder={0: sources[2].output.size_placeholder[0]})
89+
90+
91+
def rnnt_alignment_out_type(sources: List[LayerBase], **_kwargs) -> Data:
92+
""" Computes the rnnt-alignment Data_out_type for RNNT alignment
93+
B: batch, T: time, U:target/labels, V: vocabulary
94+
Args:
95+
sources:
96+
output_log_prob: [B, T, U+1, V] log-probabilities
97+
real_target: [B, U] -> [V] target labels
98+
base:encoder: [B, T, Feat] -> [V] encoder output
99+
Returns:
100+
alignment [B, T+U]
101+
"""
102+
return Data(name="rnnt_alignment_output", sparse=True, dim=sources[0].output.dim,
103+
size_placeholder={0: sources[1].output.size_placeholder[0] + sources[2].output.size_placeholder[0]})

common/models/transducer/loss.py

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
21
from returnn.tf.util.data import Data
3-
4-
5-
def rnnt_loss(source, **_kwargs):
6-
"""
7-
Computes the RNN-T loss function. Native TF kernel implementation.
8-
9-
:param log_prob:
10-
:return:
2+
import tensorflow as tf
3+
4+
5+
def rnnt_loss(source, **_kwargs) -> tf.Tensor:
6+
""" Computes the RNN-T loss function. Native TF kernel implementation.
7+
B: batch, T: time, U:target/labels, V: vocabulary
8+
Args:
9+
source: function (i: int, as_data: bool = False, ...) -> tf.Tensor|Data
10+
which returns one of:
11+
output_log_prob: [B, T, U+1, V] log-probabilities
12+
target: [B, U] -> [V] target labels
13+
base:encoder: [B, T, Feat] -> [V] encoder output
14+
Returns:
15+
costs: [B]
1116
"""
12-
# acts: (B, T, U + 1, V)
13-
# targets: (B, T)
14-
# input_lengths (B,)
15-
# label_lengths (B,)
1617
from returnn.extern.HawkAaronWarpTransducer import rnnt_loss
1718

1819
log_probs = source(0, as_data=True, auto_convert=False)
@@ -31,17 +32,18 @@ def rnnt_loss(source, **_kwargs):
3132
return costs
3233

3334

34-
def rnnt_tf_loss(source, **kwargs):
35-
"""
36-
Computes the RNN-T loss function. Pure TF.
37-
38-
:param log_prob:
39-
:return:
35+
def rnnt_tf_loss(source, **kwargs) -> tf.Tensor:
36+
""" Computes the RNN-T loss function. Pure TF.
37+
B: batch, T: time, U:target/labels, V: vocabulary
38+
Args:
39+
source: function (i: int, as_data: bool = False, ...) -> tf.Tensor|Data
40+
which returns one of:
41+
output_log_prob: [B, T, U+1, V] log-probabilities
42+
target: [B, U] -> [V] target labels
43+
base:encoder: [B, T, Feat] -> [V] encoder output
44+
Returns:
45+
costs: [B]
4046
"""
41-
# acts: (B, T, U + 1, V)
42-
# targets: (B, T)
43-
# input_lengths (B,)
44-
# label_lengths (B,)
4547
from .rnnt_align_sum_max_pure_tf import tf_forward_shifted_rnnt
4648

4749
log_probs = source(0, as_data=True, auto_convert=False)
@@ -60,20 +62,18 @@ def rnnt_tf_loss(source, **kwargs):
6062
return costs
6163

6264

63-
def rnnt_loss_out_type(**_kwargs) -> Data:
64-
return Data(name="rnnt_loss", shape=())
65-
66-
67-
def rna_tf_loss(source, **kwargs):
68-
"""
69-
Computes the RNA loss. Pure TF.
70-
:param log_prob:
71-
:return:
65+
def rna_tf_loss(source, **kwargs) -> tf.Tensor:
66+
""" Computes the RNA loss. Pure TF.
67+
B: batch, T: time, U:target/labels, V: vocabulary
68+
Args:
69+
source: function (i: int, as_data: bool = False, ...) -> tf.Tensor|Data
70+
which returns one of:
71+
output_log_prob: [B, T, U+1, V] log-probabilities
72+
target: [B, U] -> [V] target labels
73+
base:encoder: [B, T, Feat] -> [V] encoder output
74+
Returns:
75+
costs: [B]
7276
"""
73-
# acts: (B, T, U, V)
74-
# targets: (B, U-1)
75-
# input_lengths (B,)
76-
# label_lengths (B,)
7777
from .rna_align_sum_max_pure_tf import tf_forward_shifted_rna
7878
from returnn.tf.compat import v1 as tf
7979

@@ -92,5 +92,9 @@ def rna_tf_loss(source, **kwargs):
9292
return costs
9393

9494

95+
def rnnt_loss_out_type(**_kwargs) -> Data:
96+
return Data(name="rnnt_loss", shape=())
97+
98+
9599
def rna_loss_out_type(**_kwargs) -> Data:
96100
return Data(name="rnna_loss", shape=())
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from .loss import (rnnt_loss, rnnt_loss_out_type,
2+
rnnt_tf_loss, rna_tf_loss, rna_loss_out_type)
3+
from .alignment import (rnnt_alignment, rnnt_alignment_out_type,
4+
rna_alignment, rna_alignment_out_type)
5+
6+
7+
class Topology:
8+
"""
9+
Hold informations about different label topologies such as loss-, alignment-funcion and their out_types.
10+
loss and alignment functions are to be used in eval like layers that return a source function.
11+
12+
The loss, alignment_out_type and alignment function all are to be used in EvalLayers.
13+
taking from layers that output the followings:
14+
output_log_prob: [B, T, U+1, V] log-probabilities
15+
target: [B, U] -> [V] target labels
16+
base:encoder: [B, T, Feat] -> [V] encoder output
17+
where
18+
B: batch, T: time, U:target/labels, V: vocabulary, U': seq_len of alignment
19+
EvalLayer offers a source() callback, which has to be used to get the mentioned data.
20+
"""
21+
def __init__(self,
22+
name: str,
23+
loss,
24+
loss_out_type,
25+
alignment,
26+
alignment_out_type):
27+
""" Label Topology such as rnnt, rna, ctc.
28+
Args:
29+
loss: function (source: (i: int, as_data: bool = False, ...) -> tf.Tensor|Data, ...) -> tf.Tensor[B]
30+
loss_out_type: function (...) -> Data[B]
31+
alignment: function (source: (i: int, as_data: bool = False, ...) -> tf.Tensor|Data, ...) -> tf.Tensor[B,U']
32+
alignment_out_type: function (sources: list[LayerBase], ...) -> Data[B,U']
33+
"""
34+
self.name = name
35+
self.loss = loss
36+
self.loss_out_type = loss_out_type
37+
self.alignment = alignment
38+
self.alignment_out_type = alignment_out_type
39+
40+
41+
rna_topology = Topology(
42+
name="rna",
43+
loss=rna_tf_loss,
44+
loss_out_type=rna_loss_out_type,
45+
alignment=rna_alignment,
46+
alignment_out_type=rna_alignment_out_type)
47+
48+
rnnt_topology = Topology(
49+
name="rnnt",
50+
loss=rnnt_tf_loss,
51+
loss_out_type=rnnt_loss_out_type,
52+
alignment=rnnt_alignment,
53+
alignment_out_type=rnnt_alignment_out_type,)

0 commit comments

Comments
 (0)