Skip to content

Commit ac3d642

Browse files
authored
Use torch.logsumexp in advanced_tutorial.py
`torch.logsumexp` is numerically stabilized: https://pytorch.org/docs/stable/generated/torch.logsumexp.html Found with TorchFix https://github.com/pytorch-labs/torchfix/
1 parent 46943d6 commit ac3d642

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

beginner_source/nlp/advanced_tutorial.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ def prepare_sequence(seq, to_ix):
142142
def log_sum_exp(vec):
143143
max_score = vec[0, argmax(vec)]
144144
max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
145-
return max_score + \
146-
torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))
145+
return max_score + torch.logsumexp(vec - max_score_broadcast)
147146

148147
#####################################################################
149148
# Create model

0 commit comments

Comments
 (0)