We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fb64bcb commit 0ac8e69Copy full SHA for 0ac8e69
i6_models/decoder/attention.py
@@ -143,7 +143,7 @@ def forward(
143
zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size))
144
lstm_state = (zeros, zeros)
145
att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)))
146
- accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1))
+ accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1))
147
else:
148
lstm_state, att_context, accum_att_weights = state
149
0 commit comments