Skip to content

Commit 6dbeb34

Browse files
committed
better check
1 parent 6809695 commit 6dbeb34

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

i6_models/decoder/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(self, cfg: AttentionLstmDecoderV1Config):
109109
self.weight_feedback = nn.Linear(1, cfg.attention_cfg.attention_dim, bias=False)
110110

111111
self.readout_in = nn.Linear(cfg.lstm_hidden_size + cfg.target_embed_dim + cfg.encoder_dim, cfg.output_proj_dim)
112+
assert cfg.output_proj_dim % 2 == 0, "output projection dimension must be even for MaxOut"
112113
self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size)
113114
self.output_dropout = nn.Dropout(cfg.output_dropout)
114115

@@ -173,7 +174,6 @@ def forward(
173174
readout_in = self.readout_in(torch.cat([s_stacked, target_embeddings, att_context_stacked], dim=-1)) # [B,N,D]
174175

175176
# maxout layer
176-
assert readout_in.size(-1) % 2 == 0
177177
readout_in = readout_in.view(readout_in.size(0), readout_in.size(1), -1, 2) # [B,N,D/2,2]
178178
readout, _ = torch.max(readout_in, dim=-1) # [B,N,D/2]
179179

0 commit comments

Comments
 (0)