Skip to content

Commit 8a47c15

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Remove unnecessary code from ldj_ratio.py
It was somehow left in as part of addressing review comments last time. This also revealed a bug in the previously skipped _default_ildj_ratio_fn. PiperOrigin-RevId: 387008770
1 parent d54cb4a commit 8a47c15

File tree

1 file changed

+5
-32
lines changed

1 file changed

+5
-32
lines changed

tensorflow_probability/python/bijectors/ldj_ratio.py

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
'RegisterILDJRatio',
3333
]
3434

35-
3635
_fldj_ratio_registry = {}
3736
_ildj_ratio_registry = {}
3837

@@ -45,9 +44,8 @@ def _reduce_ldj_ratio(unreduced_ldj_ratio, p, q, input_shape, min_event_ndims,
4544
p._parameter_batch_shape is not None and
4645
q._parameter_batch_shape is not None)
4746
if have_parameter_batch_shape:
48-
parameter_batch_shape = ps.broadcast_shape(
49-
p._parameter_batch_shape,
50-
q._parameter_batch_shape)
47+
parameter_batch_shape = ps.broadcast_shape(p._parameter_batch_shape,
48+
q._parameter_batch_shape)
5149
else:
5250
parameter_batch_shape = None
5351

@@ -78,8 +76,8 @@ def _default_fldj_ratio_fn(p, x, q, y, event_ndims, p_kwargs, q_kwargs):
7876
def _default_ildj_ratio_fn(p, x, q, y, event_ndims, p_kwargs, q_kwargs):
7977
min_event_ndims = p.inverse_min_event_ndims
8078
unreduced_fldj_ratio = (
81-
p.inverse_log_det_jacobian(x, event_ndims=event_ndims, **p_kwargs) -
82-
q.inverse_log_det_jacobian(y, event_ndims=event_ndims, **q_kwargs))
79+
p.inverse_log_det_jacobian(x, event_ndims=min_event_ndims, **p_kwargs) -
80+
q.inverse_log_det_jacobian(y, event_ndims=min_event_ndims, **q_kwargs))
8381
return _reduce_ldj_ratio(unreduced_fldj_ratio, p, q, ps.shape(x),
8482
min_event_ndims, event_ndims)
8583

@@ -200,32 +198,7 @@ def inverse_ildj_ratio_fn(p, x, q, y, event_ndims, p_kwargs, q_kwargs):
200198
else:
201199
ildj_ratio_fn = inverse_ildj_ratio_fn
202200

203-
if tf.nest.is_nested(p.inverse_min_event_ndims):
204-
# See the comment in forward_log_det_jacobian_ratio about why we do this.
205-
return ildj_ratio_fn(p, x, q, y, event_ndims, p_kwargs, q_kwargs)
206-
else:
207-
# Evaluate the ratio at minimum event ndims, and then reduce the unreduced
208-
# LDJ.
209-
min_event_ndims = p.inverse_min_event_ndims
210-
have_parameter_batch_shape = (
211-
p._parameter_batch_shape is not None and # pylint: disable=protected-access
212-
q._parameter_batch_shape is not None) # pylint: disable=protected-access
213-
reduce_shape, assertions = bijector_lib.ldj_reduction_shape(
214-
ps.shape(x),
215-
event_ndims=event_ndims,
216-
min_event_ndims=min_event_ndims,
217-
parameter_batch_shape=(ps.broadcast_shape(p._parameter_batch_shape, # pylint: disable=protected-access
218-
q._parameter_batch_shape) # pylint: disable=protected-access
219-
if have_parameter_batch_shape else None),
220-
allow_event_shape_broadcasting=True,
221-
validate_args=p.validate_args or q.validate_args)
222-
223-
sum_fn = getattr(p, '_sum_fn', getattr(q, '_sum_fn', tf.reduce_sum))
224-
with tf.control_dependencies(assertions):
225-
return bijector_lib.reduce_jacobian_det_over_shape(
226-
ildj_ratio_fn(p, x, q, y, min_event_ndims, p_kwargs, q_kwargs),
227-
reduce_shape=reduce_shape,
228-
sum_fn=sum_fn)
201+
return ildj_ratio_fn(p, x, q, y, event_ndims, p_kwargs, q_kwargs)
229202

230203

231204
class RegisterFLDJRatio(object):

0 commit comments

Comments
 (0)