Skip to content

Commit cf8017e

Browse files
committed
put tensors on cuda
1 parent 5ce9542 commit cf8017e

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

i6_models/decoder/attention.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class AdditiveAttention(nn.Module):
2424
Additive attention mechanism. This is defined as:
2525
energies = v^T * tanh(h + s + beta) where beta is weight feedback information
2626
weights = softmax(energies)
27-
context = weights * h
27+
context = sum_t weights_t * h_t
2828
"""
2929

3030
def __init__(self, cfg: AdditiveAttentionConfig):
@@ -50,7 +50,7 @@ def forward(
5050
"""
5151
# all inputs are already projected
5252
energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1]
53-
time_arange = torch.arange(energies.size(1)) # [T]
53+
time_arange = torch.arange(energies.size(1), device="cuda") # [T]
5454
seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T]
5555
energies = torch.where(seq_len_mask.unsqueeze(2), energies, torch.tensor(-float("inf")))
5656
weights = nn.functional.softmax(energies, dim=1) # [B,T,1]
@@ -140,15 +140,16 @@ def forward(
140140
:param state: decoder state
141141
"""
142142
if state is None:
143-
zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size))
143+
zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size), device="cuda")
144144
lstm_state = (zeros, zeros)
145-
att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)))
146-
accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1))
145+
att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)), device="cuda")
146+
accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1), device="cuda")
147147
else:
148148
lstm_state, att_context, accum_att_weights = state
149149

150150
target_embeddings = self.target_embed(labels) # [B,N,D]
151151
target_embeddings = self.target_embed_dropout(target_embeddings)
152+
152153
# pad for BOS and remove last token as this represents history and last token is not used
153154
target_embeddings = nn.functional.pad(target_embeddings, (0, 0, 1, 0), value=0)[:, :-1, :] # [B,N,D]
154155

0 commit comments

Comments
 (0)