@@ -102,7 +102,7 @@ def __init__(self,
102
102
lambda logit : tf .stack ([logit , - logit ], axis = - 1 ),
103
103
dtype = self ._inflated_loc_logits .dtype ,
104
104
shape = self ._inflated_loc_logits .shape + (2 ,))
105
- categorical_dist = categorical .Categorical (
105
+ self . _categorical_dist = categorical .Categorical (
106
106
logits = cat_logits ,
107
107
validate_args = validate_args ,
108
108
allow_nan_stats = allow_nan_stats )
@@ -116,14 +116,14 @@ def __init__(self,
116
116
dtype = self ._inflated_loc_probs .dtype ,
117
117
shape = self ._inflated_loc_probs .shape + (2 ,)
118
118
)
119
- categorical_dist = categorical .Categorical (
119
+ self . _categorical_dist = categorical .Categorical (
120
120
probs = cat_probs ,
121
121
validate_args = validate_args ,
122
122
allow_nan_stats = allow_nan_stats )
123
123
probs_or_logits = self ._inflated_loc_probs
124
124
125
125
super (_Inflated , self ).__init__ (
126
- cat = categorical_dist ,
126
+ cat = self . _categorical_dist ,
127
127
components = [
128
128
deterministic .Deterministic (
129
129
DeferredTensor (
@@ -151,6 +151,24 @@ def _parameter_properties(cls, dtype, num_classes=None):
151
151
),
152
152
inflated_loc = parameter_properties .ParameterProperties ())
153
153
154
+ def _log_prob (self , x ):
155
+ # We override the log_prob implementation from Mixture in the case
156
+ # where we are inflating a continuous distribution, because we have
157
+ # found that this "censored" version gives a good maximum likelihood
158
+ # estimate of the continuous distribution's parameters but the
159
+ # default implementation doesn't. This follows the proposal in
160
+ # https://arxiv.org/pdf/2010.09647.pdf for summing distributions of
161
+ # different Hausdorff dimension.
162
+ if isinstance (self ._distribution ,
163
+ distribution_lib .DiscreteDistributionMixin ):
164
+ return super (_Inflated , self )._log_prob (x )
165
+ else :
166
+ return tf .where (
167
+ tf .equal (x , self ._inflated_loc ),
168
+ self ._categorical_dist .log_prob (0 ),
169
+ self ._categorical_dist .log_prob (1 ) +
170
+ self ._distribution .log_prob (x ))
171
+
154
172
@property
155
173
def distribution (self ):
156
174
"""The distribution used for the non-inflated part."""
0 commit comments