32
32
'RegisterILDJRatio' ,
33
33
]
34
34
35
-
36
35
_fldj_ratio_registry = {}
37
36
_ildj_ratio_registry = {}
38
37
@@ -45,9 +44,8 @@ def _reduce_ldj_ratio(unreduced_ldj_ratio, p, q, input_shape, min_event_ndims,
45
44
p ._parameter_batch_shape is not None and
46
45
q ._parameter_batch_shape is not None )
47
46
if have_parameter_batch_shape :
48
- parameter_batch_shape = ps .broadcast_shape (
49
- p ._parameter_batch_shape ,
50
- q ._parameter_batch_shape )
47
+ parameter_batch_shape = ps .broadcast_shape (p ._parameter_batch_shape ,
48
+ q ._parameter_batch_shape )
51
49
else :
52
50
parameter_batch_shape = None
53
51
@@ -78,8 +76,8 @@ def _default_fldj_ratio_fn(p, x, q, y, event_ndims, p_kwargs, q_kwargs):
78
76
def _default_ildj_ratio_fn (p , x , q , y , event_ndims , p_kwargs , q_kwargs ):
79
77
min_event_ndims = p .inverse_min_event_ndims
80
78
unreduced_fldj_ratio = (
81
- p .inverse_log_det_jacobian (x , event_ndims = event_ndims , ** p_kwargs ) -
82
- q .inverse_log_det_jacobian (y , event_ndims = event_ndims , ** q_kwargs ))
79
+ p .inverse_log_det_jacobian (x , event_ndims = min_event_ndims , ** p_kwargs ) -
80
+ q .inverse_log_det_jacobian (y , event_ndims = min_event_ndims , ** q_kwargs ))
83
81
return _reduce_ldj_ratio (unreduced_fldj_ratio , p , q , ps .shape (x ),
84
82
min_event_ndims , event_ndims )
85
83
@@ -200,32 +198,7 @@ def inverse_ildj_ratio_fn(p, x, q, y, event_ndims, p_kwargs, q_kwargs):
200
198
else :
201
199
ildj_ratio_fn = inverse_ildj_ratio_fn
202
200
203
- if tf .nest .is_nested (p .inverse_min_event_ndims ):
204
- # See the comment in forward_log_det_jacobian_ratio about why we do this.
205
- return ildj_ratio_fn (p , x , q , y , event_ndims , p_kwargs , q_kwargs )
206
- else :
207
- # Evaluate the ratio at minimum event ndims, and then reduce the unreduced
208
- # LDJ.
209
- min_event_ndims = p .inverse_min_event_ndims
210
- have_parameter_batch_shape = (
211
- p ._parameter_batch_shape is not None and # pylint: disable=protected-access
212
- q ._parameter_batch_shape is not None ) # pylint: disable=protected-access
213
- reduce_shape , assertions = bijector_lib .ldj_reduction_shape (
214
- ps .shape (x ),
215
- event_ndims = event_ndims ,
216
- min_event_ndims = min_event_ndims ,
217
- parameter_batch_shape = (ps .broadcast_shape (p ._parameter_batch_shape , # pylint: disable=protected-access
218
- q ._parameter_batch_shape ) # pylint: disable=protected-access
219
- if have_parameter_batch_shape else None ),
220
- allow_event_shape_broadcasting = True ,
221
- validate_args = p .validate_args or q .validate_args )
222
-
223
- sum_fn = getattr (p , '_sum_fn' , getattr (q , '_sum_fn' , tf .reduce_sum ))
224
- with tf .control_dependencies (assertions ):
225
- return bijector_lib .reduce_jacobian_det_over_shape (
226
- ildj_ratio_fn (p , x , q , y , min_event_ndims , p_kwargs , q_kwargs ),
227
- reduce_shape = reduce_shape ,
228
- sum_fn = sum_fn )
201
+ return ildj_ratio_fn (p , x , q , y , event_ndims , p_kwargs , q_kwargs )
229
202
230
203
231
204
class RegisterFLDJRatio (object ):
0 commit comments