|
| 1 | +import torch |
| 2 | +from torch import nn |
| 3 | + |
| 4 | +from typing import Tuple |
| 5 | + |
| 6 | + |
| 7 | +class ZoneoutLSTMCell(nn.Module): |
| 8 | + """ |
| 9 | + Wrap an LSTM cell with Zoneout regularization (https://arxiv.org/abs/1606.01305) |
| 10 | + """ |
| 11 | + |
| 12 | + def __init__(self, cell: nn.RNNCellBase, zoneout_h: float, zoneout_c: float): |
| 13 | + """ |
| 14 | + :param cell: LSTM cell |
| 15 | + :param zoneout_h: zoneout drop probability for hidden state |
| 16 | + :param zoneout_c: zoneout drop probability for cell state |
| 17 | + """ |
| 18 | + super().__init__() |
| 19 | + self.cell = cell |
| 20 | + assert 0.0 <= zoneout_h <= 1.0 and 0.0 <= zoneout_c <= 1.0, "Zoneout drop probability must be in [0, 1]" |
| 21 | + self.zoneout_h = zoneout_h |
| 22 | + self.zoneout_c = zoneout_c |
| 23 | + |
| 24 | + def forward( |
| 25 | + self, inputs: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor] |
| 26 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 27 | + h, c = self.cell(inputs) |
| 28 | + prev_h, prev_c = state |
| 29 | + h = self._zoneout(prev_h, h, self.zoneout_h) |
| 30 | + c = self._zoneout(prev_c, c, self.zoneout_c) |
| 31 | + return h, c |
| 32 | + |
| 33 | + def _zoneout(self, prev_state: torch.Tensor, curr_state: torch.Tensor, factor: float): |
| 34 | + """ |
| 35 | + Apply Zoneout. |
| 36 | +
|
| 37 | + :param prev: previous state tensor |
| 38 | + :param curr: current state tensor |
| 39 | + :param factor: drop probability |
| 40 | + """ |
| 41 | + if factor == 0.0: |
| 42 | + return curr_state |
| 43 | + if self.training: |
| 44 | + mask = curr_state.new_empty(size=curr_state.size()).bernoulli_(factor) |
| 45 | + return mask * prev_state + (1 - mask) * curr_state |
| 46 | + else: |
| 47 | + return factor * prev_state + (1 - factor) * curr_state |
0 commit comments