Skip to content

Commit bc40736

Browse files
committed
implemented enc-dec-att model
1 parent 4ce5419 commit bc40736

File tree

2 files changed

+232
-0
lines changed

2 files changed

+232
-0
lines changed

i6_models/decoder/attention.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from dataclasses import dataclass
2+
from typing import Optional, Tuple
3+
4+
import torch
5+
from torch import nn
6+
7+
8+
@dataclass
9+
class AdditiveAttentionConfig:
10+
"""
11+
Attributes:
12+
attention_dim: attention dimension
13+
att_weights_dropout: attention weights dropout
14+
"""
15+
16+
attention_dim: int
17+
att_weights_dropout: float
18+
19+
20+
class AdditiveAttention(nn.Module):
21+
"""
22+
Additive attention mechanism. This is defined as:
23+
energies = v^T * tanh(h + s + beta) where beta is weight feedback information
24+
weights = softmax(energies)
25+
context = weights * h
26+
"""
27+
28+
def __init__(self, cfg: AdditiveAttentionConfig):
29+
super().__init__()
30+
self.linear = nn.Linear(cfg.attention_dim, 1, bias=False)
31+
self.att_weights_drop = nn.Dropout(cfg.att_weights_dropout)
32+
33+
def forward(
34+
self,
35+
key: torch.Tensor,
36+
value: torch.Tensor,
37+
query: torch.Tensor,
38+
weight_feedback: torch.Tensor,
39+
enc_seq_len: torch.Tensor,
40+
) -> Tuple[torch.Tensor, torch.Tensor]:
41+
"""
42+
:param key: encoder keys of shape [B,T,D_k]
43+
:param value: encoder values of shape [B,T,D_v]
44+
:param query: query of shape [B,D_k]
45+
:param weight_feedback: shape is [B,T,D_k]
46+
:param enc_seq_len: [B]
47+
:return: context [B,1,D_v], weights [B,T,1]
48+
"""
49+
50+
# all inputs are already projected
51+
energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1]
52+
time_arange = torch.arange(energies.size(1)) # [T]
53+
seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T]
54+
energies = torch.where(seq_len_mask.unsqueeze(2), energies, torch.tensor(-float("inf")))
55+
weights = nn.functional.softmax(energies, dim=1) # [B,T,1]
56+
weights = self.att_weights_drop(weights)
57+
context = torch.bmm(weights.transpose(1, 2), value) # [B,1,D_v]
58+
context = context.reshape(context.size(0), -1) # [B,D_v]
59+
return context, weights
60+
61+
62+
@dataclass
63+
class AttentionLstmDecoderV1Config:
64+
"""
65+
Attributes:
66+
encoder_dim: encoder dimension
67+
vocab_size: vocabulary size
68+
target_embed_dim: embedding dimension
69+
target_embed_dropout: embedding dropout
70+
lstm_hidden_size: LSTM hidden size
71+
attention_cfg: attention config
72+
output_proj_dim: output projection dimension
73+
output_dropout: output dropout
74+
"""
75+
76+
encoder_dim: int
77+
vocab_size: int
78+
target_embed_dim: int
79+
target_embed_dropout: float
80+
lstm_hidden_size: int
81+
attention_cfg: AdditiveAttentionConfig
82+
output_proj_dim: int
83+
output_dropout: float
84+
85+
86+
class AttentionLstmDecoderV1(nn.Module):
87+
"""
88+
Single-headed Attention decoder with additive attention mechanism.
89+
"""
90+
91+
def __init__(self, cfg: AttentionLstmDecoderV1Config):
92+
super().__init__()
93+
94+
self.target_embed = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.target_embed_dim)
95+
self.target_embed_dropout = nn.Dropout(cfg.target_embed_dropout)
96+
97+
self.s = nn.LSTMCell(
98+
input_size=cfg.target_embed_dim + cfg.encoder_dim,
99+
hidden_size=cfg.lstm_hidden_size,
100+
)
101+
self.s_transformed = nn.Linear(cfg.lstm_hidden_size, cfg.attention_cfg.attention_dim, bias=False) # query
102+
103+
# for attention
104+
self.enc_ctx = nn.Linear(cfg.encoder_dim, cfg.attention_cfg.attention_dim)
105+
self.attention = AdditiveAttention(cfg.attention_cfg)
106+
107+
# for weight feedback
108+
self.inv_fertility = nn.Linear(cfg.encoder_dim, 1, bias=False) # followed by sigmoid
109+
self.weight_feedback = nn.Linear(1, cfg.attention_cfg.attention_dim, bias=False)
110+
111+
self.readout_in = nn.Linear(cfg.lstm_hidden_size + cfg.target_embed_dim + cfg.encoder_dim, cfg.output_proj_dim)
112+
self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size)
113+
self.output_dropout = nn.Dropout(cfg.output_dropout)
114+
115+
def forward(
116+
self,
117+
encoder_outputs: torch.Tensor,
118+
labels: torch.Tensor,
119+
enc_seq_len: torch.Tensor,
120+
state: Optional[Tuple[torch.Tensor, ...]] = None,
121+
):
122+
"""
123+
:param encoder_outputs: encoder outputs of shape [B,T,D]
124+
:param labels: labels of shape [B,T]
125+
:param enc_seq_len: encoder sequence lengths of shape [B,T]
126+
:param state: decoder state
127+
"""
128+
if state is None:
129+
lstm_state = None
130+
att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)))
131+
accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1))
132+
else:
133+
lstm_state, att_context, accum_att_weights = state
134+
135+
target_embeddings = self.target_embed(labels) # [B,N,D]
136+
target_embeddings = self.target_embed_dropout(target_embeddings)
137+
# pad for BOS and remove last token as this represents history and last token is not used
138+
target_embeddings = nn.functional.pad(target_embeddings, (0, 0, 1, 0), value=0)[:, :-1, :] # [B,N,D]
139+
140+
enc_ctx = self.enc_ctx(encoder_outputs) # [B,T,D]
141+
enc_inv_fertility = nn.functional.sigmoid(self.inv_fertility(encoder_outputs)) # [B,T,1]
142+
143+
num_steps = labels.size(1) # N
144+
145+
# collect for computing later the decoder logits outside the loop
146+
s_list = []
147+
att_context_list = []
148+
149+
# decoder loop
150+
for step in range(num_steps):
151+
target_embed = target_embeddings[:, step, :] # [B,D]
152+
153+
lstm_state = self.s(torch.cat([target_embed, att_context], dim=-1), lstm_state)
154+
lstm_out = lstm_state[0]
155+
s_transformed = self.s_transformed(lstm_out) # project query
156+
s_list.append(lstm_out)
157+
158+
# attention mechanism
159+
weight_feedback = self.weight_feedback(accum_att_weights)
160+
att_context, att_weights = self.attention(
161+
key=enc_ctx,
162+
value=encoder_outputs,
163+
query=s_transformed,
164+
weight_feedback=weight_feedback,
165+
enc_seq_len=enc_seq_len,
166+
)
167+
att_context_list.append(att_context)
168+
accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5
169+
170+
# output layer
171+
s_stacked = torch.stack(s_list, dim=1) # [B,N,D]
172+
att_context_stacked = torch.stack(att_context_list, dim=1) # [B,N,D]
173+
readout_in = self.readout_in(torch.cat([s_stacked, target_embeddings, att_context_stacked], dim=-1)) # [B,N,D]
174+
175+
# maxout layer
176+
assert readout_in.size(-1) % 2 == 0
177+
readout_in = readout_in.view(readout_in.size(0), readout_in.size(1), -1, 2) # [B,N,D/2,2]
178+
readout, _ = torch.max(readout_in, dim=-1) # [B,N,D/2]
179+
180+
output = self.output(readout)
181+
decoder_logits = self.output_dropout(output)
182+
183+
state = lstm_state, att_context, accum_att_weights
184+
185+
return decoder_logits, state

tests/test_enc_dec_att.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import torch
2+
from torch import nn
3+
4+
from i6_models.decoder.attention import AdditiveAttention, AdditiveAttentionConfig
5+
from i6_models.decoder.attention import AttentionLstmDecoderV1, AttentionLstmDecoderV1Config
6+
7+
8+
def test_additive_attention():
9+
cfg = AdditiveAttentionConfig(attention_dim=5, att_weights_dropout=0.1)
10+
att = AdditiveAttention(cfg)
11+
key = torch.rand((10, 20, 5))
12+
value = torch.rand((10, 20, 5))
13+
query = torch.rand((10, 5))
14+
15+
enc_seq_len = torch.arange(start=10, end=20) # [10, ..., 19]
16+
17+
# pass key as weight feedback for testing
18+
context, weights = att(key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len)
19+
assert context.shape == (10, 5)
20+
assert weights.shape == (10, 20, 1)
21+
22+
# Testing attention weights masking:
23+
# for first seq, the enc seq length is 10 so half the weights should be 0
24+
assert torch.eq(weights[0, 10:, 0], torch.tensor(0.0)).all()
25+
# test for other seqs
26+
assert torch.eq(weights[5, 15:, 0], torch.tensor(0.0)).all()
27+
28+
29+
def test_encoder_decoder_attention_model():
30+
encoder = torch.rand((10, 20, 5))
31+
encoder_seq_len = torch.arange(start=10, end=20) # [10, ..., 19]
32+
decoder_cfg = AttentionLstmDecoderV1Config(
33+
encoder_dim=5,
34+
vocab_size=15,
35+
target_embed_dim=3,
36+
target_embed_dropout=0.1,
37+
lstm_hidden_size=12,
38+
attention_cfg=AdditiveAttentionConfig(attention_dim=10, att_weights_dropout=0.1),
39+
output_proj_dim=12,
40+
output_dropout=0.1,
41+
)
42+
decoder = AttentionLstmDecoderV1(decoder_cfg)
43+
target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N]
44+
45+
decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len)
46+
47+
assert decoder_logits.shape == (10, 7, 15)

0 commit comments

Comments
 (0)