Skip to content

Commit 1969da7

Browse files
committed
Make TanhNormal.log_prob(1.0) != -inf
1 parent c3bc852 commit 1969da7

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/garage/torch/distributions/tanh_normal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def log_prob(self, value, pre_tanh_value=None, epsilon=1e-6):
4949
"""
5050
# pylint: disable=arguments-differ
5151
if pre_tanh_value is None:
52-
pre_tanh_value = torch.log((1 + value) / (1 - value)) / 2
52+
pre_tanh_value = torch.log((1 + epsilon + value) / (1 + epsilon - value)) / 2
5353
norm_lp = self._normal.log_prob(pre_tanh_value)
5454
ret = (norm_lp - torch.sum(
5555
torch.log(self._clip_but_pass_gradient((1. - value**2)) + epsilon),

0 commit comments

Comments
 (0)