28
28
from tensorflow_probability .python .internal import auto_composite_tensor
29
29
from tensorflow_probability .python .internal import dtype_util
30
30
from tensorflow_probability .python .internal import parameter_properties
31
+ from tensorflow_probability .python .internal import samplers
31
32
from tensorflow_probability .python .internal import tensor_util
32
33
from tensorflow_probability .python .util .deferred_tensor import DeferredTensor
33
34
34
35
__all__ = ['Inflated' , 'inflated_factory' , 'ZeroInflatedNegativeBinomial' ]
35
36
36
37
38
+ def _safe_value_for_distribution (dist ):
39
+ """Returns an x for which it is safe to differentiate dist.logprob(x)."""
40
+ return dist .sample (seed = samplers .zeros_seed ())
41
+
42
+
37
43
class _Inflated (mixture .Mixture ):
38
44
"""A mixture of a point-mass and another distribution.
39
45
@@ -53,6 +59,8 @@ def __init__(self,
53
59
inflated_loc_logits = None ,
54
60
inflated_loc_probs = None ,
55
61
inflated_loc = 0.0 ,
62
+ inflated_loc_atol = None ,
63
+ inflated_loc_rtol = None ,
56
64
validate_args = False ,
57
65
allow_nan_stats = True ,
58
66
name = 'Inflated' ):
@@ -71,6 +79,12 @@ def __init__(self,
71
79
`inflated_loc_logits` should be passed in.
72
80
inflated_loc: A scalar or tensor containing the locations of the point
73
81
mass component of the mixture.
82
+ inflated_loc_atol: Non-negative `Tensor` of same `dtype` as
83
+ `inflated_loc` and broadcastable shape. The absolute tolerance for
84
+ comparing closeness to `inflated_loc`. Default is `0`.
85
+ inflated_loc_rtol: Non-negative `Tensor` of same `dtype` as
86
+ `inflated_loc` and broadcastable shape. The relative tolerance for
87
+ comparing closeness to `inflated_loc`. Default is `0`.
74
88
validate_args: If true, inconsistent batch or event sizes raise a runtime
75
89
error.
76
90
allow_nan_stats: If false, any undefined statistics for any batch memeber
@@ -95,6 +109,12 @@ def __init__(self,
95
109
inflated_loc_probs , dtype = dtype , name = 'inflated_loc_probs' )
96
110
self ._inflated_loc = tensor_util .convert_nonref_to_tensor (
97
111
inflated_loc , dtype = dtype , name = 'inflated_loc' )
112
+ self ._inflated_loc_atol = tensor_util .convert_nonref_to_tensor (
113
+ 0 if inflated_loc_atol is None else inflated_loc_atol ,
114
+ dtype = dtype , name = 'inflated_loc_atol' )
115
+ self ._inflated_loc_rtol = tensor_util .convert_nonref_to_tensor (
116
+ 0 if inflated_loc_rtol is None else inflated_loc_rtol ,
117
+ dtype = dtype , name = 'inflated_loc_rtol' )
98
118
99
119
if inflated_loc_probs is None :
100
120
cat_logits = DeferredTensor (
@@ -122,17 +142,23 @@ def __init__(self,
122
142
allow_nan_stats = allow_nan_stats )
123
143
probs_or_logits = self ._inflated_loc_probs
124
144
145
+ self ._deterministic = deterministic .Deterministic (
146
+ DeferredTensor (
147
+ probs_or_logits ,
148
+ # pylint: disable=g-long-lambda
149
+ lambda _ : tf .broadcast_to (self ._inflated_loc ,
150
+ tf .shape (probs_or_logits ))
151
+ # pylint: enable=g-long-lambda
152
+ ),
153
+ atol = self ._inflated_loc_atol ,
154
+ rtol = self ._inflated_loc_rtol ,
155
+ validate_args = validate_args ,
156
+ allow_nan_stats = allow_nan_stats )
157
+
125
158
super (_Inflated , self ).__init__ (
126
159
cat = self ._categorical_dist ,
127
160
components = [
128
- deterministic .Deterministic (
129
- DeferredTensor (
130
- probs_or_logits ,
131
- lambda x : tf .constant ( # pylint: disable=g-long-lambda
132
- inflated_loc , dtype = distribution .dtype ,
133
- shape = probs_or_logits .shape )),
134
- validate_args = validate_args ,
135
- allow_nan_stats = allow_nan_stats ),
161
+ self ._deterministic ,
136
162
distribution
137
163
],
138
164
validate_args = validate_args ,
@@ -151,6 +177,12 @@ def _parameter_properties(cls, dtype, num_classes=None):
151
177
),
152
178
inflated_loc = parameter_properties .ParameterProperties ())
153
179
180
+ def _almost_inflated_loc (self , x ):
181
+ # pylint: disable=protected-access
182
+ return tf .abs (x - self ._inflated_loc ) <= self ._deterministic ._slack (
183
+ self ._inflated_loc )
184
+ # pylint: enable=protected-access
185
+
154
186
def _log_prob (self , x ):
155
187
# We override the log_prob implementation from Mixture in the case
156
188
# where we are inflating a continuous distribution, because we have
@@ -163,11 +195,19 @@ def _log_prob(self, x):
163
195
distribution_lib .DiscreteDistributionMixin ):
164
196
return super (_Inflated , self )._log_prob (x )
165
197
else :
198
+ # Enable non-NaN gradients of the log_prob, even if the gradient of
199
+ # the continuous distribution is NaN at _inflated_loc. See
200
+ # https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
201
+ # for details.
202
+ safe_x = tf .where (
203
+ self ._almost_inflated_loc (x ),
204
+ _safe_value_for_distribution (self ._distribution ),
205
+ x )
166
206
return tf .where (
167
- tf . equal ( x , self . _inflated_loc ),
207
+ self . _almost_inflated_loc ( x ),
168
208
self ._categorical_dist .log_prob (0 ),
169
209
self ._categorical_dist .log_prob (1 ) +
170
- self ._distribution .log_prob (x ))
210
+ self ._distribution .log_prob (safe_x ))
171
211
172
212
@property
173
213
def distribution (self ):
0 commit comments