@@ -132,12 +132,20 @@ def forward(
132132 labels : torch .Tensor ,
133133 enc_seq_len : torch .Tensor ,
134134 state : Optional [Tuple [torch .Tensor , ...]] = None ,
135+ shift_embeddings = True ,
135136 ):
136137 """
137- :param encoder_outputs: encoder outputs of shape [B,T,D]
138- :param labels: labels of shape [B,T]
139- :param enc_seq_len: encoder sequence lengths of shape [B,T]
138+ :param encoder_outputs: encoder outputs of shape [B,T,D], same for training and search
139+ :param labels:
140+ training: labels of shape [B,N]
141+ (greedy-)search: hypotheses last label as [B,1]
142+ :param enc_seq_len: encoder sequence lengths of shape [B,T], same for training and search
140143 :param state: decoder state
144+ training: Usually None, unless decoding should be initialized with a certain state (e.g. for context init)
145+ search: current state of the active hypotheses
146+ :param shift_embeddings: shift the embeddings by one position along U, padding with zero in front and drop last
147+ training: this should be "True", in order to start with a zero target embedding
148+ search: use True for the first step in order to start with a zero embedding, False otherwise
141149 """
142150 if state is None :
143151 zeros = encoder_outputs .new_zeros ((encoder_outputs .size (0 ), self .lstm_hidden_size ))
@@ -150,8 +158,9 @@ def forward(
150158 target_embeddings = self .target_embed (labels ) # [B,N,D]
151159 target_embeddings = self .target_embed_dropout (target_embeddings )
152160
153- # pad for BOS and remove last token as this represents history and last token is not used
154- target_embeddings = nn .functional .pad (target_embeddings , (0 , 0 , 1 , 0 ), value = 0 )[:, :- 1 , :] # [B,N,D]
161+ if shift_embeddings :
162+ # pad for BOS and remove last token as this represents history and last token is not used
163+ target_embeddings = nn .functional .pad (target_embeddings , (0 , 0 , 1 , 0 ), value = 0 )[:, :- 1 , :] # [B,N,D]
155164
156165 enc_ctx = self .enc_ctx (encoder_outputs ) # [B,T,D]
157166 enc_inv_fertility = nn .functional .sigmoid (self .inv_fertility (encoder_outputs )) # [B,T,1]
0 commit comments