@@ -24,7 +24,7 @@ class AdditiveAttention(nn.Module):
2424 Additive attention mechanism. This is defined as:
2525 energies = v^T * tanh(h + s + beta) where beta is weight feedback information
2626 weights = softmax(energies)
27- context = weights * h
27+ context = sum_t weights_t * h_t
2828 """
2929
3030 def __init__ (self , cfg : AdditiveAttentionConfig ):
@@ -50,7 +50,7 @@ def forward(
5050 """
5151 # all inputs are already projected
5252 energies = self .linear (nn .functional .tanh (key + query .unsqueeze (1 ) + weight_feedback )) # [B,T,1]
53- time_arange = torch .arange (energies .size (1 )) # [T]
53+ time_arange = torch .arange (energies .size (1 ), device = "cuda" ) # [T]
5454 seq_len_mask = torch .less (time_arange [None , :], enc_seq_len [:, None ]) # [B,T]
5555 energies = torch .where (seq_len_mask .unsqueeze (2 ), energies , torch .tensor (- float ("inf" )))
5656 weights = nn .functional .softmax (energies , dim = 1 ) # [B,T,1]
@@ -140,15 +140,16 @@ def forward(
140140 :param state: decoder state
141141 """
142142 if state is None :
143- zeros = torch .zeros ((encoder_outputs .size (0 ), self .lstm_hidden_size ))
143+ zeros = torch .zeros ((encoder_outputs .size (0 ), self .lstm_hidden_size ), device = "cuda" )
144144 lstm_state = (zeros , zeros )
145- att_context = torch .zeros ((encoder_outputs .size (0 ), encoder_outputs .size (2 )))
146- accum_att_weights = torch .zeros ((encoder_outputs .size (0 ), encoder_outputs .size (1 ), 1 ))
145+ att_context = torch .zeros ((encoder_outputs .size (0 ), encoder_outputs .size (2 )), device = "cuda" )
146+ accum_att_weights = torch .zeros ((encoder_outputs .size (0 ), encoder_outputs .size (1 ), 1 ), device = "cuda" )
147147 else :
148148 lstm_state , att_context , accum_att_weights = state
149149
150150 target_embeddings = self .target_embed (labels ) # [B,N,D]
151151 target_embeddings = self .target_embed_dropout (target_embeddings )
152+
152153 # pad for BOS and remove last token as this represents history and last token is not used
153154 target_embeddings = nn .functional .pad (target_embeddings , (0 , 0 , 1 , 0 ), value = 0 )[:, :- 1 , :] # [B,N,D]
154155
0 commit comments