20
20
__all__ = ["soft_round" , "soft_round_inverse" , "soft_round_conditional_mean" ]
21
21
22
22
23
- def soft_round (x , alpha , eps = 1e-12 ):
23
+ def soft_round (x , alpha , eps = 1e-3 ):
24
24
"""Differentiable approximation to round().
25
25
26
26
Larger alphas correspond to closer approximations of the round function.
@@ -39,28 +39,19 @@ def soft_round(x, alpha, eps=1e-12):
39
39
Returns:
40
40
tf.Tensor
41
41
"""
42
-
43
42
if isinstance (alpha , (float , int )) and alpha < eps :
44
43
return tf .identity (x , name = "soft_round" )
45
44
46
- m = tf .floor (x ) + 0 .5
45
+ m = tf .floor (x ) + .5
47
46
r = x - m
48
- z = tf .maximum ( tf . tanh (alpha / 2.0 ) * 2.0 , eps )
47
+ z = tf .tanh (alpha / 2. ) * 2.
49
48
y = m + tf .tanh (alpha * r ) / z
50
49
51
50
# For very low alphas, soft_round behaves like identity
52
51
return tf .where (alpha < eps , x , y , name = "soft_round" )
53
52
54
53
55
- @tf .custom_gradient
56
- def _clip_st (s ):
57
- """Clip s to [-1 + 1e-7, 1 - 1e-7] with straight-through gradients."""
58
- s = tf .clip_by_value (s , - 1 + 1e-7 , 1 - 1e-7 )
59
- grad = lambda x : x
60
- return s , grad
61
-
62
-
63
- def soft_round_inverse (y , alpha , eps = 1e-12 ):
54
+ def soft_round_inverse (y , alpha , eps = 1e-3 ):
64
55
"""Inverse of soft_round().
65
56
66
57
This is described in Sec. 4.1. in the paper
@@ -77,21 +68,19 @@ def soft_round_inverse(y, alpha, eps=1e-12):
77
68
Returns:
78
69
tf.Tensor
79
70
"""
80
-
81
71
if isinstance (alpha , (float , int )) and alpha < eps :
82
72
return tf .identity (y , name = "soft_round_inverse" )
83
73
84
- m = tf .floor (y ) + 0.5
85
- s = (y - m ) * (tf .tanh (alpha / 2.0 ) * 2.0 )
86
- # We have -0.5 <= (y-m) <= 0.5 and -1 < tanh < 1, so
87
- # -1 <= s <= 1. However tf.atanh is only stable for inputs
88
- # in the range [-1+1e-7, 1-1e-7], so we (safely) clip s to this range.
89
- # In the rare case where `1-|s| < 1e-7`, we use straight-through for the
90
- # gradient.
91
- s = _clip_st (s )
92
- r = tf .atanh (s ) / tf .maximum (alpha , eps )
74
+ m = tf .floor (y ) + .5
75
+ s = (y - m ) * (tf .tanh (alpha / 2. ) * 2. )
76
+ r = tf .atanh (s ) / alpha
77
+ # `r` must be between -.5 and .5 by definition. In case atanh becomes +-inf
78
+ # due to numerical instability, this prevents the forward pass from yielding
79
+ # infinite values. Note that it doesn't prevent the backward pass from
80
+ # returning non-finite values.
81
+ r = tf .clip_by_value (r , - .5 , .5 )
93
82
94
- # For very low alphas, soft_round behaves like identity
83
+ # For very low alphas, soft_round behaves like identity.
95
84
return tf .where (alpha < eps , y , m + r , name = "soft_round_inverse" )
96
85
97
86
@@ -107,12 +96,11 @@ def soft_round_conditional_mean(inputs, alpha):
107
96
> Eirikur Agustsson & Lucas Theis<br />
108
97
> https://arxiv.org/abs/2006.09952
109
98
110
-
111
99
Args:
112
100
inputs: The input tensor.
113
101
alpha: The softround alpha.
114
102
115
103
Returns:
116
104
The conditional mean, of same shape as `inputs`.
117
105
"""
118
- return soft_round_inverse (inputs - 0 .5 , alpha ) + 0 .5
106
+ return soft_round_inverse (inputs - .5 , alpha ) + .5
0 commit comments