Skip to content

Commit 1002af4

Browse files
committed
Add shift_embeddings flag for Attention Decoder
Allows to pass the label unshifted for step-wise search without needing a separate function besides "forward".
1 parent 36f4de5 commit 1002af4

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

i6_models/decoder/__init__.py

Whitespace-only changes.

i6_models/decoder/attention.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,20 @@ def forward(
132132
labels: torch.Tensor,
133133
enc_seq_len: torch.Tensor,
134134
state: Optional[Tuple[torch.Tensor, ...]] = None,
135+
shift_embeddings=True,
135136
):
136137
"""
137-
:param encoder_outputs: encoder outputs of shape [B,T,D]
138-
:param labels: labels of shape [B,T]
139-
:param enc_seq_len: encoder sequence lengths of shape [B,T]
138+
:param encoder_outputs: encoder outputs of shape [B,T,D], same for training and search
139+
:param labels:
140+
training: labels of shape [B,N]
141+
(greedy-)search: hypotheses last label as [B,1]
142+
:param enc_seq_len: encoder sequence lengths of shape [B,T], same for training and search
140143
:param state: decoder state
144+
training: Usually None, unless decoding should be initialized with a certain state (e.g. for context init)
145+
search: current state of the active hypotheses
146+
:param shift_embeddings: shift the embeddings by one position along U, padding with zero in front and drop last
147+
training: this should be "True", in order to start with a zero target embedding
148+
search: use True for the first step in order to start with a zero embedding, False otherwise
141149
"""
142150
if state is None:
143151
zeros = encoder_outputs.new_zeros((encoder_outputs.size(0), self.lstm_hidden_size))
@@ -150,8 +158,9 @@ def forward(
150158
target_embeddings = self.target_embed(labels) # [B,N,D]
151159
target_embeddings = self.target_embed_dropout(target_embeddings)
152160

153-
# pad for BOS and remove last token as this represents history and last token is not used
154-
target_embeddings = nn.functional.pad(target_embeddings, (0, 0, 1, 0), value=0)[:, :-1, :] # [B,N,D]
161+
if shift_embeddings:
162+
# pad for BOS and remove last token as this represents history and last token is not used
163+
target_embeddings = nn.functional.pad(target_embeddings, (0, 0, 1, 0), value=0)[:, :-1, :] # [B,N,D]
155164

156165
enc_ctx = self.enc_ctx(encoder_outputs) # [B,T,D]
157166
enc_inv_fertility = nn.functional.sigmoid(self.inv_fertility(encoder_outputs)) # [B,T,1]

0 commit comments

Comments
 (0)