|
| 1 | +from dataclasses import dataclass |
| 2 | +from typing import Optional, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch import nn |
| 6 | + |
| 7 | + |
| 8 | +@dataclass |
| 9 | +class AdditiveAttentionConfig: |
| 10 | + """ |
| 11 | + Attributes: |
| 12 | + attention_dim: attention dimension |
| 13 | + att_weights_dropout: attention weights dropout |
| 14 | + """ |
| 15 | + |
| 16 | + attention_dim: int |
| 17 | + att_weights_dropout: float |
| 18 | + |
| 19 | + |
| 20 | +class AdditiveAttention(nn.Module): |
| 21 | + """ |
| 22 | + Additive attention mechanism. This is defined as: |
| 23 | + energies = v^T * tanh(h + s + beta) where beta is weight feedback information |
| 24 | + weights = softmax(energies) |
| 25 | + context = weights * h |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__(self, cfg: AdditiveAttentionConfig): |
| 29 | + super().__init__() |
| 30 | + self.linear = nn.Linear(cfg.attention_dim, 1, bias=False) |
| 31 | + self.att_weights_drop = nn.Dropout(cfg.att_weights_dropout) |
| 32 | + |
| 33 | + def forward( |
| 34 | + self, |
| 35 | + key: torch.Tensor, |
| 36 | + value: torch.Tensor, |
| 37 | + query: torch.Tensor, |
| 38 | + weight_feedback: torch.Tensor, |
| 39 | + enc_seq_len: torch.Tensor, |
| 40 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 41 | + """ |
| 42 | + :param key: encoder keys of shape [B,T,D_k] |
| 43 | + :param value: encoder values of shape [B,T,D_v] |
| 44 | + :param query: query of shape [B,D_k] |
| 45 | + :param weight_feedback: shape is [B,T,D_k] |
| 46 | + :param enc_seq_len: [B] |
| 47 | + :return: context [B,1,D_v], weights [B,T,1] |
| 48 | + """ |
| 49 | + |
| 50 | + # all inputs are already projected |
| 51 | + energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1] |
| 52 | + time_arange = torch.arange(energies.size(1)) # [T] |
| 53 | + seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T] |
| 54 | + energies = torch.where(seq_len_mask.unsqueeze(2), energies, torch.tensor(-float("inf"))) |
| 55 | + weights = nn.functional.softmax(energies, dim=1) # [B,T,1] |
| 56 | + weights = self.att_weights_drop(weights) |
| 57 | + context = torch.bmm(weights.transpose(1, 2), value) # [B,1,D_v] |
| 58 | + context = context.reshape(context.size(0), -1) # [B,D_v] |
| 59 | + return context, weights |
| 60 | + |
| 61 | + |
| 62 | +@dataclass |
| 63 | +class AttentionLstmDecoderV1Config: |
| 64 | + """ |
| 65 | + Attributes: |
| 66 | + encoder_dim: encoder dimension |
| 67 | + vocab_size: vocabulary size |
| 68 | + target_embed_dim: embedding dimension |
| 69 | + target_embed_dropout: embedding dropout |
| 70 | + lstm_hidden_size: LSTM hidden size |
| 71 | + attention_cfg: attention config |
| 72 | + output_proj_dim: output projection dimension |
| 73 | + output_dropout: output dropout |
| 74 | + """ |
| 75 | + |
| 76 | + encoder_dim: int |
| 77 | + vocab_size: int |
| 78 | + target_embed_dim: int |
| 79 | + target_embed_dropout: float |
| 80 | + lstm_hidden_size: int |
| 81 | + attention_cfg: AdditiveAttentionConfig |
| 82 | + output_proj_dim: int |
| 83 | + output_dropout: float |
| 84 | + |
| 85 | + |
| 86 | +class AttentionLstmDecoderV1(nn.Module): |
| 87 | + """ |
| 88 | + Single-headed Attention decoder with additive attention mechanism. |
| 89 | + """ |
| 90 | + |
| 91 | + def __init__(self, cfg: AttentionLstmDecoderV1Config): |
| 92 | + super().__init__() |
| 93 | + |
| 94 | + self.target_embed = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.target_embed_dim) |
| 95 | + self.target_embed_dropout = nn.Dropout(cfg.target_embed_dropout) |
| 96 | + |
| 97 | + self.s = nn.LSTMCell( |
| 98 | + input_size=cfg.target_embed_dim + cfg.encoder_dim, |
| 99 | + hidden_size=cfg.lstm_hidden_size, |
| 100 | + ) |
| 101 | + self.s_transformed = nn.Linear(cfg.lstm_hidden_size, cfg.attention_cfg.attention_dim, bias=False) # query |
| 102 | + |
| 103 | + # for attention |
| 104 | + self.enc_ctx = nn.Linear(cfg.encoder_dim, cfg.attention_cfg.attention_dim) |
| 105 | + self.attention = AdditiveAttention(cfg.attention_cfg) |
| 106 | + |
| 107 | + # for weight feedback |
| 108 | + self.inv_fertility = nn.Linear(cfg.encoder_dim, 1, bias=False) # followed by sigmoid |
| 109 | + self.weight_feedback = nn.Linear(1, cfg.attention_cfg.attention_dim, bias=False) |
| 110 | + |
| 111 | + self.readout_in = nn.Linear(cfg.lstm_hidden_size + cfg.target_embed_dim + cfg.encoder_dim, cfg.output_proj_dim) |
| 112 | + self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size) |
| 113 | + self.output_dropout = nn.Dropout(cfg.output_dropout) |
| 114 | + |
| 115 | + def forward( |
| 116 | + self, |
| 117 | + encoder_outputs: torch.Tensor, |
| 118 | + labels: torch.Tensor, |
| 119 | + enc_seq_len: torch.Tensor, |
| 120 | + state: Optional[Tuple[torch.Tensor, ...]] = None, |
| 121 | + ): |
| 122 | + """ |
| 123 | + :param encoder_outputs: encoder outputs of shape [B,T,D] |
| 124 | + :param labels: labels of shape [B,T] |
| 125 | + :param enc_seq_len: encoder sequence lengths of shape [B,T] |
| 126 | + :param state: decoder state |
| 127 | + """ |
| 128 | + if state is None: |
| 129 | + lstm_state = None |
| 130 | + att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2))) |
| 131 | + accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1)) |
| 132 | + else: |
| 133 | + lstm_state, att_context, accum_att_weights = state |
| 134 | + |
| 135 | + target_embeddings = self.target_embed(labels) # [B,N,D] |
| 136 | + target_embeddings = self.target_embed_dropout(target_embeddings) |
| 137 | + # pad for BOS and remove last token as this represents history and last token is not used |
| 138 | + target_embeddings = nn.functional.pad(target_embeddings, (0, 0, 1, 0), value=0)[:, :-1, :] # [B,N,D] |
| 139 | + |
| 140 | + enc_ctx = self.enc_ctx(encoder_outputs) # [B,T,D] |
| 141 | + enc_inv_fertility = nn.functional.sigmoid(self.inv_fertility(encoder_outputs)) # [B,T,1] |
| 142 | + |
| 143 | + num_steps = labels.size(1) # N |
| 144 | + |
| 145 | + # collect for computing later the decoder logits outside the loop |
| 146 | + s_list = [] |
| 147 | + att_context_list = [] |
| 148 | + |
| 149 | + # decoder loop |
| 150 | + for step in range(num_steps): |
| 151 | + target_embed = target_embeddings[:, step, :] # [B,D] |
| 152 | + |
| 153 | + lstm_state = self.s(torch.cat([target_embed, att_context], dim=-1), lstm_state) |
| 154 | + lstm_out = lstm_state[0] |
| 155 | + s_transformed = self.s_transformed(lstm_out) # project query |
| 156 | + s_list.append(lstm_out) |
| 157 | + |
| 158 | + # attention mechanism |
| 159 | + weight_feedback = self.weight_feedback(accum_att_weights) |
| 160 | + att_context, att_weights = self.attention( |
| 161 | + key=enc_ctx, |
| 162 | + value=encoder_outputs, |
| 163 | + query=s_transformed, |
| 164 | + weight_feedback=weight_feedback, |
| 165 | + enc_seq_len=enc_seq_len, |
| 166 | + ) |
| 167 | + att_context_list.append(att_context) |
| 168 | + accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5 |
| 169 | + |
| 170 | + # output layer |
| 171 | + s_stacked = torch.stack(s_list, dim=1) # [B,N,D] |
| 172 | + att_context_stacked = torch.stack(att_context_list, dim=1) # [B,N,D] |
| 173 | + readout_in = self.readout_in(torch.cat([s_stacked, target_embeddings, att_context_stacked], dim=-1)) # [B,N,D] |
| 174 | + |
| 175 | + # maxout layer |
| 176 | + assert readout_in.size(-1) % 2 == 0 |
| 177 | + readout_in = readout_in.view(readout_in.size(0), readout_in.size(1), -1, 2) # [B,N,D/2,2] |
| 178 | + readout, _ = torch.max(readout_in, dim=-1) # [B,N,D/2] |
| 179 | + |
| 180 | + output = self.output(readout) |
| 181 | + decoder_logits = self.output_dropout(output) |
| 182 | + |
| 183 | + state = lstm_state, att_context, accum_att_weights |
| 184 | + |
| 185 | + return decoder_logits, state |
0 commit comments