Skip to content

Commit 315b579

Browse files
committed
remove explicit device, fix return value
1 parent 4ddfae6 commit 315b579

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

i6_models/decoder/attention.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,20 @@ def forward(
3939
query: torch.Tensor,
4040
weight_feedback: torch.Tensor,
4141
enc_seq_len: torch.Tensor,
42-
device: str,
4342
) -> Tuple[torch.Tensor, torch.Tensor]:
4443
"""
4544
:param key: encoder keys of shape [B,T,D_k]
4645
:param value: encoder values of shape [B,T,D_v]
4746
:param query: query of shape [B,D_k]
4847
:param weight_feedback: shape is [B,T,D_k]
4948
:param enc_seq_len: encoder sequence lengths [B]
50-
:param device: device where to run the model (cpu or cuda)
5149
:return: attention context [B,D_v], attention weights [B,T,1]
5250
"""
5351
# all inputs are already projected
5452
energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1]
55-
time_arange = torch.arange(energies.size(1), device=device) # [T]
53+
time_arange = torch.arange(energies.size(1), device=energies.device) # [T]
5654
seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T]
57-
energies = torch.where(seq_len_mask.unsqueeze(2), energies, torch.tensor(-float("inf")))
55+
energies = torch.where(seq_len_mask.unsqueeze(2), energies, energies.new_tensor(-float("inf")))
5856
weights = nn.functional.softmax(energies, dim=1) # [B,T,1]
5957
weights = self.att_weights_drop(weights)
6058
context = torch.bmm(weights.transpose(1, 2), value) # [B,1,D_v]
@@ -76,7 +74,6 @@ class AttentionLSTMDecoderV1Config:
7674
attention_cfg: attention config
7775
output_proj_dim: output projection dimension
7876
output_dropout: output dropout
79-
device: device where to run the model (cpu or cuda)
8077
"""
8178

8279
encoder_dim: int
@@ -89,7 +86,6 @@ class AttentionLSTMDecoderV1Config:
8986
attention_cfg: AdditiveAttentionConfig
9087
output_proj_dim: int
9188
output_dropout: float
92-
device: str
9389

9490

9591
class AttentionLSTMDecoderV1(nn.Module):
@@ -100,6 +96,7 @@ class AttentionLSTMDecoderV1(nn.Module):
10096
def __init__(self, cfg: AttentionLSTMDecoderV1Config):
10197
super().__init__()
10298

99+
print(cfg.vocab_size)
103100
self.target_embed = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.target_embed_dim)
104101
self.target_embed_dropout = nn.Dropout(cfg.target_embed_dropout)
105102

@@ -130,10 +127,6 @@ def __init__(self, cfg: AttentionLSTMDecoderV1Config):
130127
self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size)
131128
self.output_dropout = nn.Dropout(cfg.output_dropout)
132129

133-
if "cuda" in cfg.device:
134-
assert torch.cuda.is_available(), "CUDA is not available"
135-
self.device = cfg.device
136-
137130
def forward(
138131
self,
139132
encoder_outputs: torch.Tensor,
@@ -148,10 +141,10 @@ def forward(
148141
:param state: decoder state
149142
"""
150143
if state is None:
151-
zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size), device=self.device)
144+
zeros = encoder_outputs.new_zeros((encoder_outputs.size(0), self.lstm_hidden_size))
152145
lstm_state = (zeros, zeros)
153-
att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)), device=self.device)
154-
accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1), device=self.device)
146+
att_context = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(2)))
147+
accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1))
155148
else:
156149
lstm_state, att_context, accum_att_weights = state
157150

@@ -187,7 +180,6 @@ def forward(
187180
query=s_transformed,
188181
weight_feedback=weight_feedback,
189182
enc_seq_len=enc_seq_len,
190-
device=self.device,
191183
)
192184
att_context_list.append(att_context)
193185
accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5

0 commit comments

Comments
 (0)