-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
` def forward_with_loss(self, rnn_output, target):
if self.training:
target_prob, sample_prob, sample = self(rnn_output, target)
loss = self.nce_loss(target_prob, sample_prob, target, sample)
return loss.mean()
else:
output = torch.addmm(
1, self.word_bias.weight.view(-1), 1, rnn_output, self.word_embeddings.weight.t()
)
return self.CE(output, target)`
And why addmm rnn output and embedding.weight
Metadata
Metadata
Assignees
Labels
No labels