Skip to content

Commit 0df33ac

Browse files
davmretensorflower-gardener
authored andcommitted
Use batch shape annotations to ensure that autodiff computes correct bijector LDJs.
PiperOrigin-RevId: 375770777
1 parent 193fc50 commit 0df33ac

File tree

4 files changed

+77
-20
lines changed

4 files changed

+77
-20
lines changed

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,7 +1415,23 @@ def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
14151415
x = self.inverse(y, **kwargs) # Fall back to computing `-fldj(x)`
14161416
ildj = attrs['ildj'] = -self._forward_log_det_jacobian(x, **kwargs)
14171417
elif self._is_scalar:
1418-
ildj = _autodiff_log_det_jacobian(self._inverse, y)
1418+
try:
1419+
scalar_batch_shape = self.experimental_batch_shape_tensor(
1420+
y_event_ndims=0)
1421+
except NotImplementedError:
1422+
raise NotImplementedError(
1423+
'Cannot derive `inverse_log_det_jacobian` using automatic '
1424+
'differentiation because its shape could not be determined. '
1425+
'Please implement at least one of:\n'
1426+
'`{bijector_type}._parameter_properties`\n'
1427+
'`{bijector_type}._batch_shape_tensor`\n'
1428+
'`{bijector_type}._forward_log_det_jacobian`\n '
1429+
'`{bijector_type}._inverse_log_det_jacobian`.'.format(
1430+
bijector_type=type(self).__name__))
1431+
ildj = _autodiff_log_det_jacobian(
1432+
self.inverse,
1433+
tf.broadcast_to(y, ps.broadcast_shape(ps.shape(y),
1434+
scalar_batch_shape)))
14191435
else:
14201436
raise NotImplementedError(
14211437
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian '
@@ -1524,7 +1540,23 @@ def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
15241540
y = self.forward(x, **kwargs) # Fall back to computing `ildj(y)`
15251541
ildj = attrs['ildj'] = self._inverse_log_det_jacobian(y, **kwargs)
15261542
elif self._is_scalar:
1527-
ildj = -_autodiff_log_det_jacobian(self._forward, x)
1543+
try:
1544+
scalar_batch_shape = self.experimental_batch_shape_tensor(
1545+
x_event_ndims=0)
1546+
except NotImplementedError:
1547+
raise NotImplementedError(
1548+
'Cannot derive `forward_log_det_jacobian` using automatic '
1549+
'differentiation because its shape could not be determined. '
1550+
'Please implement at least one of:\n'
1551+
'`{bijector_type}._parameter_properties`\n'
1552+
'`{bijector_type}._batch_shape_tensor`\n'
1553+
'`{bijector_type}._forward_log_det_jacobian`\n '
1554+
'`{bijector_type}._inverse_log_det_jacobian`.'.format(
1555+
bijector_type=type(self).__name__))
1556+
ildj = -_autodiff_log_det_jacobian(
1557+
self.forward,
1558+
tf.broadcast_to(x, ps.broadcast_shape(ps.shape(x),
1559+
scalar_batch_shape)))
15281560
else:
15291561
raise NotImplementedError(
15301562
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian '
@@ -2111,6 +2143,8 @@ def ldj_reduction_shape(shape_structure,
21112143

21122144
def _autodiff_log_det_jacobian(fn, x):
21132145
"""Automatically compute the log det jacobian of a scalar function."""
2146+
# Note: x must be fully broadcast (`shape(x) == shape(fn(x))`); otherwise
2147+
# the gradients will be (incorrectly) summed.
21142148
_, grads = gradient.value_and_gradient(fn, x)
21152149
if grads is None:
21162150
raise ValueError('Cannot compute log det jacobian; function {} has `None` '

tensorflow_probability/python/bijectors/bijector_test.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from tensorflow_probability.python import bijectors as tfb
3030
from tensorflow_probability.python.bijectors import bijector as bijector_lib
3131
from tensorflow_probability.python.internal import cache_util
32+
from tensorflow_probability.python.internal import parameter_properties
3233
from tensorflow_probability.python.internal import prefer_static as ps
3334
from tensorflow_probability.python.internal import tensor_util
3435
from tensorflow_probability.python.internal import test_util
@@ -84,12 +85,12 @@ def __init__(self):
8485

8586
with self.assertRaisesRegexp(
8687
NotImplementedError,
87-
'inverse not implemented'):
88+
'Cannot derive `inverse_log_det_jacobian`'):
8889
bij.inverse_log_det_jacobian(0, event_ndims=0)
8990

9091
with self.assertRaisesRegexp(
9192
NotImplementedError,
92-
'forward not implemented'):
93+
'Cannot derive `forward_log_det_jacobian`'):
9394
bij.forward_log_det_jacobian(0, event_ndims=0)
9495

9596
@test_util.disable_test_for_backend(
@@ -128,8 +129,11 @@ def _forward(self, x):
128129
error_clazz, 'Tensor conversion requested dtype'):
129130
b64.forward(x32)
130131

132+
@parameterized.named_parameters(
133+
('no_batch_shape', 1.4),
134+
('with_batch_shape', [[[2., 3.], [5., 7.]]]))
131135
@test_util.numpy_disable_gradient_test
132-
def testAutodiffLogDetJacobian(self):
136+
def testAutodiffLogDetJacobian(self, bijector_scale):
133137

134138
class NoJacobianBijector(tfb.Bijector):
135139
"""Bijector with no log det jacobian methods."""
@@ -148,7 +152,12 @@ def _forward(self, x):
148152
def _inverse(self, y):
149153
return tf.math.log(y) / self._scale
150154

151-
b = NoJacobianBijector(scale=1.4)
155+
@classmethod
156+
def _parameter_properties(cls, dtype, num_classes=None):
157+
return dict(
158+
scale=parameter_properties.ParameterProperties(event_ndims=0))
159+
160+
b = NoJacobianBijector(scale=bijector_scale)
152161
x = tf.convert_to_tensor([2., -3.])
153162
[
154163
fldj,

tensorflow_probability/python/experimental/bijectors/distribution_bijectors_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,6 @@ def test_all_distributions_either_work_or_raise_error(self, dist_name, data):
9797

9898
dist = data.draw(dhps.base_distributions(
9999
dist_name=dist_name,
100-
# TODO(b/175354524) fix autodiff for batch LDJs and enable batch tests.
101-
batch_shape=[],
102100
enable_vars=False,
103101
param_strategy_fn=_constrained_zeros_fn))
104102
try:

tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919

2020
from tensorflow_probability.python import math as tfp_math
2121
from tensorflow_probability.python.bijectors import bijector
22+
from tensorflow_probability.python.internal import callable_util
2223
from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient
2324
from tensorflow_probability.python.internal import prefer_static as ps
25+
from tensorflow_probability.python.internal import tensorshape_util
2426

2527
__all__ = ['ScalarFunctionWithInferredInverse']
2628

@@ -35,6 +37,7 @@ def __init__(self,
3537
max_iterations=50,
3638
require_convergence=True,
3739
additional_scalar_parameters_requiring_gradients=(),
40+
dtype=None,
3841
validate_args=False,
3942
name='scalar_function_with_inferred_inverse'):
4043
"""Initialize the ScalarFunctionWithInferredInverse bijector.
@@ -72,6 +75,9 @@ def __init__(self,
7275
anything in the closure of `fn`) will not, in general, receive
7376
gradients.
7477
Default value: `()`.
78+
dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not
79+
enforced.
80+
Default value: `None`.
7581
validate_args: Python `bool` indicating whether arguments should be
7682
checked for correctness.
7783
name: Python `str` name given to ops managed by this object.
@@ -91,14 +97,14 @@ def __init__(self,
9197
# VJPs and JVPs can be computed efficiently using actual matrix ops.
9298
self._additional_scalar_parameters_requiring_gradients = (
9399
additional_scalar_parameters_requiring_gradients)
94-
self._cached_fn_batch_shape = None
95100

96101
self._bound_fn = (
97102
lambda x: fn(x, *additional_scalar_parameters_requiring_gradients))
98103
self._inverse = self._wrap_inverse_with_implicit_gradient()
99104

100105
super(ScalarFunctionWithInferredInverse, self).__init__(
101106
parameters=parameters,
107+
dtype=dtype,
102108
forward_min_event_ndims=0,
103109
inverse_min_event_ndims=0,
104110
validate_args=validate_args,
@@ -129,15 +135,25 @@ def bound_fn(self):
129135
"""Forward `fn` with any extra args bound, so that `y = bound_fn(x)`."""
130136
return self._bound_fn
131137

132-
def _fn_batch_shape(self):
133-
if self._cached_fn_batch_shape is None:
134-
# Evaluating at a scalar value (0.) exposes the function's batch shape.
135-
# For example, evaluating
136-
# `fn = lambda x: x * constant([1., 2., 3.])`
137-
# returns a result of shape `[3]`.
138-
self._cached_fn_batch_shape = ps.shape(
139-
self.bound_fn(self.domain_constraint_fn(0.))) # pylint: disable=not-callable
140-
return self._cached_fn_batch_shape
138+
def _batch_shape(self, x_event_ndims):
139+
try:
140+
# Trace the function to extract its batch shape without executing it.
141+
fn_shape = callable_util.get_output_spec(
142+
lambda x: self.bound_fn(self.domain_constraint_fn(x)), # pylint: disable=not-callable
143+
tf.TensorSpec([], dtype=self.dtype if self.dtype else tf.float32)
144+
).shape
145+
except TypeError: # `dtype` wasn't specified.
146+
return tf.TensorShape(None)
147+
148+
fn_rank = tensorshape_util.rank(fn_shape)
149+
if fn_rank is not None:
150+
return fn_shape[:fn_rank - x_event_ndims]
151+
return fn_shape
152+
153+
def _batch_shape_tensor(self, x_event_ndims):
154+
fn_shape = ps.shape(
155+
self.bound_fn(self.domain_constraint_fn(0.))) # pylint: disable=not-callable
156+
return fn_shape[:ps.rank_from_shape(fn_shape) - x_event_ndims]
141157

142158
def _forward(self, x):
143159
return self.bound_fn(x)
@@ -220,8 +236,8 @@ def _arg_broadcasting_wrapped_inverse(y):
220236
# TODO(davmre): Do gradient reductions directly in the VJP using
221237
# `tf.raw_ops.BroadcastGradientArgs` so we can remove this wrapper
222238
# and avoid spurious broadcasting.
223-
full_batch_shape = ps.broadcast_shape(self._fn_batch_shape(),
224-
ps.shape(y))
239+
full_batch_shape = ps.broadcast_shape(
240+
self.experimental_batch_shape_tensor(), ps.shape(y))
225241
args = [tf.broadcast_to(arg, full_batch_shape) for arg in args]
226242
return _inverse_with_gradient(y, *args)
227243

0 commit comments

Comments
 (0)