44import torch
55from torch import nn
66
7+ from .zoneout_lstm import ZoneoutLSTMCell
8+
79
810@dataclass
911class AdditiveAttentionConfig :
@@ -46,7 +48,6 @@ def forward(
4648 :param enc_seq_len: encoder sequence lengths [B]
4749 :return: attention context [B,D_v], attention weights [B,T,1]
4850 """
49-
5051 # all inputs are already projected
5152 energies = self .linear (nn .functional .tanh (key + query .unsqueeze (1 ) + weight_feedback )) # [B,T,1]
5253 time_arange = torch .arange (energies .size (1 )) # [T]
@@ -60,14 +61,16 @@ def forward(
6061
6162
6263@dataclass
63- class AttentionLstmDecoderV1Config :
64+ class AttentionLSTMDecoderV1Config :
6465 """
6566 Attributes:
6667 encoder_dim: encoder dimension
6768 vocab_size: vocabulary size
6869 target_embed_dim: embedding dimension
6970 target_embed_dropout: embedding dropout
7071 lstm_hidden_size: LSTM hidden size
72+ zoneout_drop_h: zoneout drop probability for hidden state
73+ zoneout_drop_c: zoneout drop probability for cell state
7174 attention_cfg: attention config
7275 output_proj_dim: output projection dimension
7376 output_dropout: output dropout
@@ -78,26 +81,36 @@ class AttentionLstmDecoderV1Config:
7881 target_embed_dim : int
7982 target_embed_dropout : float
8083 lstm_hidden_size : int
84+ zoneout_drop_h : float
85+ zoneout_drop_c : float
8186 attention_cfg : AdditiveAttentionConfig
8287 output_proj_dim : int
8388 output_dropout : float
8489
8590
86- class AttentionLstmDecoderV1 (nn .Module ):
91+ class AttentionLSTMDecoderV1 (nn .Module ):
8792 """
8893 Single-headed Attention decoder with additive attention mechanism.
8994 """
9095
91- def __init__ (self , cfg : AttentionLstmDecoderV1Config ):
96+ def __init__ (self , cfg : AttentionLSTMDecoderV1Config ):
9297 super ().__init__ ()
9398
9499 self .target_embed = nn .Embedding (num_embeddings = cfg .vocab_size , embedding_dim = cfg .target_embed_dim )
95100 self .target_embed_dropout = nn .Dropout (cfg .target_embed_dropout )
96101
97- self . s = nn .LSTMCell (
102+ lstm_cell = nn .LSTMCell (
98103 input_size = cfg .target_embed_dim + cfg .encoder_dim ,
99104 hidden_size = cfg .lstm_hidden_size ,
100105 )
106+ self .lstm_hidden_size = cfg .lstm_hidden_size
107+ # if zoneout drop probs are 0, then it is equivalent to normal LSTMCell
108+ self .s = ZoneoutLSTMCell (
109+ cell = lstm_cell ,
110+ zoneout_h = cfg .zoneout_drop_h ,
111+ zoneout_c = cfg .zoneout_drop_c ,
112+ )
113+
101114 self .s_transformed = nn .Linear (cfg .lstm_hidden_size , cfg .attention_cfg .attention_dim , bias = False ) # query
102115
103116 # for attention
@@ -127,7 +140,8 @@ def forward(
127140 :param state: decoder state
128141 """
129142 if state is None :
130- lstm_state = None
143+ zeros = torch .zeros ((encoder_outputs .size (0 ), self .lstm_hidden_size ))
144+ lstm_state = (zeros , zeros )
131145 att_context = torch .zeros ((encoder_outputs .size (0 ), encoder_outputs .size (2 )))
132146 accum_att_weights = encoder_outputs .new_zeros ((encoder_outputs .size (0 ), encoder_outputs .size (1 ), 1 ))
133147 else :
0 commit comments