Skip to content

Commit b661cee

Browse files
committed
refactor + implement zoneout
1 parent 5770d0e commit b661cee

File tree

2 files changed

+54
-9
lines changed

2 files changed

+54
-9
lines changed

i6_models/decoder/attention.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
from torch import nn
66

7+
from .zoneout_lstm import ZoneoutLSTMCell
8+
79

810
@dataclass
911
class AdditiveAttentionConfig:
@@ -46,7 +48,6 @@ def forward(
4648
:param enc_seq_len: encoder sequence lengths [B]
4749
:return: attention context [B,D_v], attention weights [B,T,1]
4850
"""
49-
5051
# all inputs are already projected
5152
energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1]
5253
time_arange = torch.arange(energies.size(1)) # [T]
@@ -60,14 +61,16 @@ def forward(
6061

6162

6263
@dataclass
63-
class AttentionLstmDecoderV1Config:
64+
class AttentionLSTMDecoderV1Config:
6465
"""
6566
Attributes:
6667
encoder_dim: encoder dimension
6768
vocab_size: vocabulary size
6869
target_embed_dim: embedding dimension
6970
target_embed_dropout: embedding dropout
7071
lstm_hidden_size: LSTM hidden size
72+
zoneout_drop_h: zoneout drop probability for hidden state
73+
zoneout_drop_c: zoneout drop probability for cell state
7174
attention_cfg: attention config
7275
output_proj_dim: output projection dimension
7376
output_dropout: output dropout
@@ -78,26 +81,36 @@ class AttentionLstmDecoderV1Config:
7881
target_embed_dim: int
7982
target_embed_dropout: float
8083
lstm_hidden_size: int
84+
zoneout_drop_h: float
85+
zoneout_drop_c: float
8186
attention_cfg: AdditiveAttentionConfig
8287
output_proj_dim: int
8388
output_dropout: float
8489

8590

86-
class AttentionLstmDecoderV1(nn.Module):
91+
class AttentionLSTMDecoderV1(nn.Module):
8792
"""
8893
Single-headed Attention decoder with additive attention mechanism.
8994
"""
9095

91-
def __init__(self, cfg: AttentionLstmDecoderV1Config):
96+
def __init__(self, cfg: AttentionLSTMDecoderV1Config):
9297
super().__init__()
9398

9499
self.target_embed = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.target_embed_dim)
95100
self.target_embed_dropout = nn.Dropout(cfg.target_embed_dropout)
96101

97-
self.s = nn.LSTMCell(
102+
lstm_cell = nn.LSTMCell(
98103
input_size=cfg.target_embed_dim + cfg.encoder_dim,
99104
hidden_size=cfg.lstm_hidden_size,
100105
)
106+
self.lstm_hidden_size = cfg.lstm_hidden_size
107+
# if zoneout drop probs are 0, then it is equivalent to normal LSTMCell
108+
self.s = ZoneoutLSTMCell(
109+
cell=lstm_cell,
110+
zoneout_h=cfg.zoneout_drop_h,
111+
zoneout_c=cfg.zoneout_drop_c,
112+
)
113+
101114
self.s_transformed = nn.Linear(cfg.lstm_hidden_size, cfg.attention_cfg.attention_dim, bias=False) # query
102115

103116
# for attention
@@ -127,7 +140,8 @@ def forward(
127140
:param state: decoder state
128141
"""
129142
if state is None:
130-
lstm_state = None
143+
zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size))
144+
lstm_state = (zeros, zeros)
131145
att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)))
132146
accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1))
133147
else:

tests/test_enc_dec_att.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import nn
33

44
from i6_models.decoder.attention import AdditiveAttention, AdditiveAttentionConfig
5-
from i6_models.decoder.attention import AttentionLstmDecoderV1, AttentionLstmDecoderV1Config
5+
from i6_models.decoder.attention import AttentionLSTMDecoderV1, AttentionLSTMDecoderV1Config
66

77

88
def test_additive_attention():
@@ -29,7 +29,7 @@ def test_additive_attention():
2929
def test_encoder_decoder_attention_model():
3030
encoder = torch.rand((10, 20, 5))
3131
encoder_seq_len = torch.arange(start=10, end=20) # [10, ..., 19]
32-
decoder_cfg = AttentionLstmDecoderV1Config(
32+
decoder_cfg = AttentionLSTMDecoderV1Config(
3333
encoder_dim=5,
3434
vocab_size=15,
3535
target_embed_dim=3,
@@ -38,10 +38,41 @@ def test_encoder_decoder_attention_model():
3838
attention_cfg=AdditiveAttentionConfig(attention_dim=10, att_weights_dropout=0.1),
3939
output_proj_dim=12,
4040
output_dropout=0.1,
41+
zoneout_drop_c=0.0,
42+
zoneout_drop_h=0.0,
4143
)
42-
decoder = AttentionLstmDecoderV1(decoder_cfg)
44+
decoder = AttentionLSTMDecoderV1(decoder_cfg)
4345
target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N]
4446

4547
decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len)
4648

4749
assert decoder_logits.shape == (10, 7, 15)
50+
51+
52+
def test_zoneout_lstm_cell():
53+
encoder = torch.rand((10, 20, 5))
54+
encoder_seq_len = torch.arange(start=10, end=20) # [10, ..., 19]
55+
target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N]
56+
57+
def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float):
58+
decoder_cfg = AttentionLSTMDecoderV1Config(
59+
encoder_dim=5,
60+
vocab_size=15,
61+
target_embed_dim=3,
62+
target_embed_dropout=0.1,
63+
lstm_hidden_size=12,
64+
attention_cfg=AdditiveAttentionConfig(attention_dim=10, att_weights_dropout=0.1),
65+
output_proj_dim=12,
66+
output_dropout=0.1,
67+
zoneout_drop_c=zoneout_drop_c,
68+
zoneout_drop_h=zoneout_drop_h,
69+
)
70+
decoder = AttentionLSTMDecoderV1(decoder_cfg)
71+
decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len)
72+
return decoder_logits
73+
74+
decoder_logits = forward_decoder(zoneout_drop_c=0.15, zoneout_drop_h=0.05)
75+
assert decoder_logits.shape == (10, 7, 15)
76+
77+
decoder_logits = forward_decoder(zoneout_drop_c=0.0, zoneout_drop_h=0.0)
78+
assert decoder_logits.shape == (10, 7, 15)

0 commit comments

Comments
 (0)