Skip to content

Commit fb64bcb

Browse files
committed
implement zoneout lstm cell
1 parent b661cee commit fb64bcb

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

i6_models/decoder/zoneout_lstm.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import torch
2+
from torch import nn
3+
4+
from typing import Tuple
5+
6+
7+
class ZoneoutLSTMCell(nn.Module):
8+
"""
9+
Wrap an LSTM cell with Zoneout regularization (https://arxiv.org/abs/1606.01305)
10+
"""
11+
12+
def __init__(self, cell: nn.RNNCellBase, zoneout_h: float, zoneout_c: float):
13+
"""
14+
:param cell: LSTM cell
15+
:param zoneout_h: zoneout drop probability for hidden state
16+
:param zoneout_c: zoneout drop probability for cell state
17+
"""
18+
super().__init__()
19+
self.cell = cell
20+
assert 0.0 <= zoneout_h <= 1.0 and 0.0 <= zoneout_c <= 1.0, "Zoneout drop probability must be in [0, 1]"
21+
self.zoneout_h = zoneout_h
22+
self.zoneout_c = zoneout_c
23+
24+
def forward(
25+
self, inputs: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor]
26+
) -> Tuple[torch.Tensor, torch.Tensor]:
27+
h, c = self.cell(inputs)
28+
prev_h, prev_c = state
29+
h = self._zoneout(prev_h, h, self.zoneout_h)
30+
c = self._zoneout(prev_c, c, self.zoneout_c)
31+
return h, c
32+
33+
def _zoneout(self, prev_state: torch.Tensor, curr_state: torch.Tensor, factor: float):
34+
"""
35+
Apply Zoneout.
36+
37+
:param prev: previous state tensor
38+
:param curr: current state tensor
39+
:param factor: drop probability
40+
"""
41+
if factor == 0.0:
42+
return curr_state
43+
if self.training:
44+
mask = curr_state.new_empty(size=curr_state.size()).bernoulli_(factor)
45+
return mask * prev_state + (1 - mask) * curr_state
46+
else:
47+
return factor * prev_state + (1 - factor) * curr_state

0 commit comments

Comments
 (0)