@@ -39,13 +39,14 @@ def soft_round(x, alpha, eps=1e-3):
39
39
Returns:
40
40
tf.Tensor
41
41
"""
42
- if isinstance (alpha , (float , int )) and alpha < eps :
43
- return tf .identity (x , name = "soft_round" )
42
+ # This guards the gradient of tf.where below against NaNs, while maintaining
43
+ # correctness, as for alpha < eps the result is ignored.
44
+ alpha_bounded = tf .maximum (alpha , eps )
44
45
45
46
m = tf .floor (x ) + .5
46
47
r = x - m
47
- z = tf .tanh (alpha / 2. ) * 2.
48
- y = m + tf .tanh (alpha * r ) / z
48
+ z = tf .tanh (alpha_bounded / 2. ) * 2.
49
+ y = m + tf .tanh (alpha_bounded * r ) / z
49
50
50
51
# For very low alphas, soft_round behaves like identity
51
52
return tf .where (alpha < eps , x , y , name = "soft_round" )
@@ -68,12 +69,12 @@ def soft_round_inverse(y, alpha, eps=1e-3):
68
69
Returns:
69
70
tf.Tensor
70
71
"""
71
- if isinstance ( alpha , ( float , int )) and alpha < eps :
72
- return tf . identity ( y , name = "soft_round_inverse" )
73
-
72
+ # This guards the gradient of tf.where below against NaNs, while maintaining
73
+ # correctness, as for alpha < eps the result is ignored.
74
+ alpha_bounded = tf . maximum ( alpha , eps )
74
75
m = tf .floor (y ) + .5
75
- s = (y - m ) * (tf .tanh (alpha / 2. ) * 2. )
76
- r = tf .atanh (s ) / alpha
76
+ s = (y - m ) * (tf .tanh (alpha_bounded / 2. ) * 2. )
77
+ r = tf .atanh (s ) / alpha_bounded
77
78
# `r` must be between -.5 and .5 by definition. In case atanh becomes +-inf
78
79
# due to numerical instability, this prevents the forward pass from yielding
79
80
# infinite values. Note that it doesn't prevent the backward pass from
0 commit comments