Skip to content

Commit 8030a30

Browse files
committed
make device configurable
1 parent cf8017e commit 8030a30

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

i6_models/decoder/attention.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,20 @@ def forward(
3939
query: torch.Tensor,
4040
weight_feedback: torch.Tensor,
4141
enc_seq_len: torch.Tensor,
42+
device: str,
4243
) -> Tuple[torch.Tensor, torch.Tensor]:
4344
"""
4445
:param key: encoder keys of shape [B,T,D_k]
4546
:param value: encoder values of shape [B,T,D_v]
4647
:param query: query of shape [B,D_k]
4748
:param weight_feedback: shape is [B,T,D_k]
4849
:param enc_seq_len: encoder sequence lengths [B]
50+
:param device: device where to run the model (cpu or cuda)
4951
:return: attention context [B,D_v], attention weights [B,T,1]
5052
"""
5153
# all inputs are already projected
5254
energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1]
53-
time_arange = torch.arange(energies.size(1), device="cuda") # [T]
55+
time_arange = torch.arange(energies.size(1), device=device) # [T]
5456
seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T]
5557
energies = torch.where(seq_len_mask.unsqueeze(2), energies, torch.tensor(-float("inf")))
5658
weights = nn.functional.softmax(energies, dim=1) # [B,T,1]
@@ -74,6 +76,7 @@ class AttentionLSTMDecoderV1Config:
7476
attention_cfg: attention config
7577
output_proj_dim: output projection dimension
7678
output_dropout: output dropout
79+
device: device where to run the model (cpu or cuda)
7780
"""
7881

7982
encoder_dim: int
@@ -86,6 +89,7 @@ class AttentionLSTMDecoderV1Config:
8689
attention_cfg: AdditiveAttentionConfig
8790
output_proj_dim: int
8891
output_dropout: float
92+
device: str
8993

9094

9195
class AttentionLSTMDecoderV1(nn.Module):
@@ -126,6 +130,8 @@ def __init__(self, cfg: AttentionLSTMDecoderV1Config):
126130
self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size)
127131
self.output_dropout = nn.Dropout(cfg.output_dropout)
128132

133+
self.device = cfg.device
134+
129135
def forward(
130136
self,
131137
encoder_outputs: torch.Tensor,
@@ -140,10 +146,10 @@ def forward(
140146
:param state: decoder state
141147
"""
142148
if state is None:
143-
zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size), device="cuda")
149+
zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size), device=self.device)
144150
lstm_state = (zeros, zeros)
145-
att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)), device="cuda")
146-
accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1), device="cuda")
151+
att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)), device=self.device)
152+
accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1), device=self.device)
147153
else:
148154
lstm_state, att_context, accum_att_weights = state
149155

@@ -179,6 +185,7 @@ def forward(
179185
query=s_transformed,
180186
weight_feedback=weight_feedback,
181187
enc_seq_len=enc_seq_len,
188+
device=self.device,
182189
)
183190
att_context_list.append(att_context)
184191
accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5

tests/test_enc_dec_att.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ def test_additive_attention():
1515
enc_seq_len = torch.arange(start=10, end=20) # [10, ..., 19]
1616

1717
# pass key as weight feedback just for testing
18-
context, weights = att(key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len)
18+
context, weights = att(
19+
key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len, device="cpu"
20+
)
1921
assert context.shape == (10, 5)
2022
assert weights.shape == (10, 20, 1)
2123

@@ -40,6 +42,7 @@ def test_encoder_decoder_attention_model():
4042
output_dropout=0.1,
4143
zoneout_drop_c=0.0,
4244
zoneout_drop_h=0.0,
45+
device="cpu",
4346
)
4447
decoder = AttentionLSTMDecoderV1(decoder_cfg)
4548
target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N]
@@ -66,6 +69,7 @@ def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float):
6669
output_dropout=0.1,
6770
zoneout_drop_c=zoneout_drop_c,
6871
zoneout_drop_h=zoneout_drop_h,
72+
device="cpu",
6973
)
7074
decoder = AttentionLSTMDecoderV1(decoder_cfg)
7175
decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len)

0 commit comments

Comments
 (0)