Skip to content

Commit 47407b8

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Fix *LDJ ratio computation.
The old code was broken in many ways: - Multipart bijectors didn't work - Scalar LDJ bijectors didn't reduce correctly in all situations - Subclass *LDJ ratio functions didn't correctly override parent class's There are four major issues with the solution and one small one. The resultant solution feels unsatisfactory. 1) There is no unified interpretation of an LDJ of amulti-part bijector. tfb.Bijector has one piece of code to do it, and tfb.Composition has a different one which are not compatible (tfb.Bijector's is more restrictive). A simple disagreement I ran into is that tfb.Bijector forbids len(set(map(lambda e, me: e - me, event_ndims, min_event_ndims))) > 1 from being True, while tfb.JointMap will happily broadcast component LDJs. This forces me to use two code paths when computing the *ldj ratios rather than just one (tfb.Bijector's), as well as plumb event_ndims args to the custom LDJ ratio functions. I don't know of this is WAI or a bug, but both behaviors are enshrined in tests. 2) JD's default bijector relies on kwargs to function, which forced me to plumb them throughout the entire machinery, which looks horrid and is likely fragile. 3) I had to put ldj_ratio.py into the `bijector` lib because of circular dependencies. This makes the already complicated code even harder to understand. 4) The kahan summation really should be decided by the bijector and not the LDJ ratio function. Since kahan summation has a small performance penalty and incomplete JAX support, I turned it off by default. The small issue is that the seemingly innocuous refactor of tfb.Composition broke the testChainDynamicToStatic test, which enshrines some incorrect code's behavior. While I managed to get the test passing, it's concerning that we're testing extremely internal aspects of the code in an end-to-end test. PiperOrigin-RevId: 380913022
1 parent 4deddda commit 47407b8

File tree

13 files changed

+477
-190
lines changed

13 files changed

+477
-190
lines changed

tensorflow_probability/python/bijectors/BUILD

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ multi_substrate_py_library(
7575
":joint_map",
7676
":kumaraswamy_cdf",
7777
":lambertw_transform",
78-
":ldj_ratio",
7978
":masked_autoregressive",
8079
":matrix_inverse_tril",
8180
":moyal_cdf",
@@ -121,6 +120,7 @@ multi_substrate_py_library(
121120
"bijector.py",
122121
"chain.py",
123122
"composition.py",
123+
"ldj_ratio.py",
124124
],
125125
deps = [
126126
# numpy dep,
@@ -265,11 +265,11 @@ multi_substrate_py_library(
265265
name = "scale_matvec_diag",
266266
srcs = ["scale_matvec_diag.py"],
267267
deps = [
268-
":ldj_ratio",
269268
":scale_matvec_linear_operator",
270269
# tensorflow dep,
271270
"//tensorflow_probability/python/internal:dtype_util",
272271
"//tensorflow_probability/python/internal:parameter_properties",
272+
"//tensorflow_probability/python/internal:prefer_static",
273273
"//tensorflow_probability/python/internal:tensor_util",
274274
],
275275
)
@@ -604,14 +604,6 @@ multi_substrate_py_library(
604604
],
605605
)
606606

607-
multi_substrate_py_library(
608-
name = "ldj_ratio",
609-
srcs = ["ldj_ratio.py"],
610-
deps = [
611-
# tensorflow dep,
612-
],
613-
)
614-
615607
multi_substrate_py_library(
616608
name = "masked_autoregressive",
617609
srcs = ["masked_autoregressive.py"],

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,7 +1422,7 @@ def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
14221422
# Non-injective bijectors don't use caching, and the resulting
14231423
# LDJ is a tuple of LDJ over possible partitions on `x`.
14241424
return tuple(
1425-
self._reduce_jacobian_det_over_shape(ildj, reduce_shape)
1425+
reduce_jacobian_det_over_shape(ildj, reduce_shape)
14261426
for ildj in self._inverse_log_det_jacobian(y, **kwargs))
14271427

14281428
# Make sure the unreduced ILDJ is in the cache.
@@ -1457,7 +1457,7 @@ def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
14571457
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian '
14581458
'is implemented. One or the other is required.')
14591459

1460-
return self._reduce_jacobian_det_over_shape(ildj, reduce_shape)
1460+
return reduce_jacobian_det_over_shape(ildj, reduce_shape)
14611461

14621462
def inverse_log_det_jacobian(self,
14631463
y,
@@ -1582,7 +1582,7 @@ def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
15821582
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian '
15831583
'is implemented. One or the other is required.')
15841584

1585-
return self._reduce_jacobian_det_over_shape(-ildj, reduce_shape)
1585+
return reduce_jacobian_det_over_shape(-ildj, reduce_shape)
15861586

15871587
def forward_log_det_jacobian(self,
15881588
x,
@@ -1724,14 +1724,6 @@ def _name_and_control_scope(self, name=None):
17241724
with tf.control_dependencies(deps) as deps_scope:
17251725
yield deps_scope
17261726

1727-
def _reduce_jacobian_det_over_shape(self, unreduced, reduce_shape):
1728-
"""Reduce LDJ over the rightmost `reduce_shape.ndims` dimensions."""
1729-
# Broadcast LDJ to the reduce shape (in case of is_constant_jacobian)
1730-
# and reduce over the trailing dimensions.
1731-
ones = tf.ones(reduce_shape, unreduced.dtype)
1732-
reduce_dims = ps.range(-ps.size(reduce_shape), 0)
1733-
return tf.reduce_sum(ones * unreduced, axis=reduce_dims)
1734-
17351727
def _parameter_control_dependencies(self, is_init):
17361728
"""Returns a list of ops to be executed in members with graph deps.
17371729
@@ -2167,6 +2159,17 @@ def ldj_reduction_shape(shape_structure,
21672159
return ldj_reduce_shape, assertions
21682160

21692161

2162+
def reduce_jacobian_det_over_shape(unreduced,
2163+
reduce_shape,
2164+
sum_fn=tf.reduce_sum):
2165+
"""Reduce LDJ over the rightmost `reduce_shape.ndims` dimensions."""
2166+
# Broadcast LDJ to the reduce shape (in case of is_constant_jacobian)
2167+
# and reduce over the trailing dimensions.
2168+
ones = tf.ones(reduce_shape, unreduced.dtype)
2169+
reduce_dims = ps.range(-ps.size(reduce_shape), 0)
2170+
return sum_fn(ones * unreduced, axis=reduce_dims)
2171+
2172+
21702173
def _autodiff_log_det_jacobian(fn, x):
21712174
"""Automatically compute the log det jacobian of a scalar function."""
21722175
# Note: x must be fully broadcast (`shape(x) == shape(fn(x))`); otherwise

tensorflow_probability/python/bijectors/chain.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import tensorflow.compat.v2 as tf
2323
from tensorflow_probability.python.bijectors import bijector as bijector_lib
2424
from tensorflow_probability.python.bijectors import composition
25-
from tensorflow_probability.python.bijectors import ldj_ratio
2625
from tensorflow_probability.python.internal import parameter_properties
2726
from tensorflow_probability.python.internal import prefer_static as ps
2827

@@ -184,37 +183,3 @@ def __new__(cls, *args, **kwargs):
184183
'of `bijectors` is not a `CompositeTensor`, then a non-`CompositeTensor` '
185184
'`_Chain` instance is created instead. Bijector subclasses that inherit '
186185
'from `Chain` will also inherit from `CompositeTensor`.')
187-
188-
189-
@ldj_ratio.RegisterFLDJRatio(_Chain)
190-
def _fldj_ratio_chain(p, x, q, y):
191-
"""Sum-of-diffs FLDJRatio for Chains."""
192-
if len(p.bijectors) != len(q.bijectors):
193-
raise ValueError('Mismatched lengths of bijectors: `p` has '
194-
f'{len(p.bijectors)} but `q` has {len(q.bijectors)}.')
195-
ratios = []
196-
max_shp = []
197-
for p, q in zip(reversed(p.bijectors), reversed(q.bijectors)):
198-
ratios.append(ldj_ratio.forward_log_det_jacobian_ratio(
199-
p, x, q, y, p.forward_min_event_ndims))
200-
max_shp = ps.broadcast_shape(max_shp, ps.shape(ratios[-1]))
201-
x, y = p.forward(x), q.forward(y)
202-
ratios = [tf.broadcast_to(r, max_shp) for r in ratios]
203-
return tf.add_n(ratios)
204-
205-
206-
@ldj_ratio.RegisterILDJRatio(_Chain)
207-
def _ildj_ratio_chain(p, x, q, y):
208-
"""Sum-of-diffs ILDJRatio for Chains."""
209-
if len(p.bijectors) != len(q.bijectors):
210-
raise ValueError('Mismatched lengths of bijectors: `p` has '
211-
f'{len(p.bijectors)} but `q` has {len(q.bijectors)}.')
212-
ratios = []
213-
max_shp = []
214-
for p, q in zip(p.bijectors, q.bijectors):
215-
ratios.append(ldj_ratio.inverse_log_det_jacobian_ratio(
216-
p, x, q, y, p.inverse_min_event_ndims))
217-
max_shp = ps.broadcast_shape(max_shp, ps.shape(ratios[-1]))
218-
x, y = p.inverse(x), q.inverse(y)
219-
ratios = [tf.broadcast_to(r, max_shp) for r in ratios]
220-
return tf.add_n(ratios)

tensorflow_probability/python/bijectors/chain_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,9 @@ def xform_dynamic(x):
330330
return tf1.placeholder_with_default(x, shape=None)
331331

332332
def xform_static(x):
333+
# Copy the Tensor, because otherwise the set_shape can pass information
334+
# into the past.
335+
x = tf.identity(x)
333336
tensorshape_util.set_shape(x, [1])
334337
return x
335338

0 commit comments

Comments
 (0)