Skip to content

Commit a4da3a4

Browse files
Googlertensorflower-gardener
authored andcommitted
Forward kwargs into tangent space methods.
And fix how we check whether the tangent space has been specified already. PiperOrigin-RevId: 429622219
1 parent 1e2e658 commit a4da3a4

File tree

3 files changed

+33
-25
lines changed

3 files changed

+33
-25
lines changed

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,7 +1684,8 @@ def forward_log_det_jacobian(self,
16841684
def experimental_compute_density_correction(self,
16851685
x,
16861686
tangent_space,
1687-
backward_compat=False):
1687+
backward_compat=False,
1688+
**kwargs):
16881689
"""Density correction for this transformation wrt the tangent space, at x.
16891690
16901691
Subclasses of Bijector may call the most specific applicable
@@ -1699,6 +1700,7 @@ def experimental_compute_density_correction(self,
16991700
the support manifold at `x`.
17001701
backward_compat: `bool` specifying whether to assume that the Bijector
17011702
is dimension-preserving.
1703+
**kwargs: Optional keyword arguments forwarded to tangent space methods.
17021704
17031705
Returns:
17041706
density_correction: `Tensor` representing the density correction---in log
@@ -1710,7 +1712,7 @@ def experimental_compute_density_correction(self,
17101712
17111713
"""
17121714
if backward_compat:
1713-
return tangent_space.transform_dimension_preserving(x, self)
1715+
return tangent_space.transform_dimension_preserving(x, self, **kwargs)
17141716
else:
17151717
raise TypeError(
17161718
'Please call the `TangentSpace` method applicable to this Bijector.')

tensorflow_probability/python/distributions/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1926,7 +1926,7 @@ def experimental_local_measure(self, value, backward_compat=False, **kwargs):
19261926
"""
19271927
log_prob = self.log_prob(value, **kwargs)
19281928
tangent_space = None
1929-
if getattr(self, '_experimental_tangent_space'):
1929+
if hasattr(self, '_experimental_tangent_space'):
19301930
tangent_space = self._experimental_tangent_space
19311931
elif backward_compat:
19321932
# Import here rather than top-level to avoid circular import.

tensorflow_probability/python/experimental/tangent_spaces/spaces.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class TangentSpace(object):
4949
5050
"""
5151

52-
def transform_general(self, x, f):
52+
def transform_general(self, x, f, **kwargs):
5353
"""Returns the density correction, in log space, corresponding to f at x.
5454
5555
Also returns a new `TangentSpace` representing the tangent to fM at f(x).
@@ -58,6 +58,7 @@ def transform_general(self, x, f):
5858
x: `Tensor` (structure). The point at which to calculate the density.
5959
f: `Bijector` or one of its subclasses. The transformation that requires a
6060
density correction based on this tangent space.
61+
**kwargs: Optional keyword arguments as part of the Bijector.
6162
6263
Returns:
6364
log_density: A `Tensor` representing the log density correction of f at x
@@ -69,7 +70,7 @@ def transform_general(self, x, f):
6970
"""
7071
raise NotImplementedError
7172

72-
def transform_dimension_preserving(self, x, f):
73+
def transform_dimension_preserving(self, x, f, **kwargs):
7374
"""Same as `transform_general`, assuming f goes from R^n to R^n.
7475
7576
Default falls back to `transform_general`, which may be overridden
@@ -78,6 +79,7 @@ def transform_dimension_preserving(self, x, f):
7879
Args:
7980
x: same as in `transform_general`.
8081
f: same as in `transform_general`.
82+
**kwargs: same as in `transform_general`.
8183
8284
Returns:
8385
log_density: A `Tensor` representing the log density correction of f at x
@@ -88,9 +90,9 @@ def transform_dimension_preserving(self, x, f):
8890
`transform_general`.
8991
9092
"""
91-
return self.transform_general(x, f)
93+
return self.transform_general(x, f, **kwargs)
9294

93-
def transform_projection(self, x, f):
95+
def transform_projection(self, x, f, **kwargs):
9496
"""Same as `transform_general`, with f a projection (or its inverse).
9597
9698
Default falls back to `transform_general`, which may be overridden
@@ -99,6 +101,7 @@ def transform_projection(self, x, f):
99101
Args:
100102
x: same as in `transform_general`.
101103
f: same as in `transform_general`.
104+
**kwargs: same as in `transform_general`.
102105
103106
Returns:
104107
log_density: A `Tensor` representing the log density correction of f at x
@@ -108,9 +111,9 @@ def transform_projection(self, x, f):
108111
NotImplementedError: if the `TangentSpace` subclass does not implement
109112
`transform_general`.
110113
"""
111-
return self.transform_general(x, f)
114+
return self.transform_general(x, f, **kwargs)
112115

113-
def transform_coordinatewise(self, x, f):
116+
def transform_coordinatewise(self, x, f, **kwargs):
114117
"""Same as `transform_dimension_preserving`, for a coordinatewise f.
115118
116119
Default falls back to `transform_dimension_preserving`, which may
@@ -119,6 +122,7 @@ def transform_coordinatewise(self, x, f):
119122
Args:
120123
x: same as in `transform_dimension_preserving`.
121124
f: same as in `transform_dimension_preserving`.
125+
**kwargs: same as in `transform_dimension_preserving`.
122126
123127
Returns:
124128
log_density: A `Tensor` representing the log density correction of f at x
@@ -129,7 +133,7 @@ def transform_coordinatewise(self, x, f):
129133
`transform_dimension_preserving`.
130134
131135
"""
132-
return self.transform_dimension_preserving(x, f)
136+
return self.transform_dimension_preserving(x, f, **kwargs)
133137

134138

135139
def unit_basis():
@@ -161,31 +165,32 @@ def __init__(self, axis_mask):
161165
"""
162166
self.axis_mask = axis_mask
163167

164-
def transform_general(self, x, f):
168+
def transform_general(self, x, f, **kwargs):
165169
as_general_space = GeneralSpace(unit_basis_on(self.axis_mask), 1)
166-
return as_general_space.transform_general(x, f)
170+
return as_general_space.transform_general(x, f, **kwargs)
167171

168-
def transform_projection(self, x, f):
172+
def transform_projection(self, x, f, **kwargs):
169173
if not hasattr(f, 'experimental_update_live_dimensions'):
170174
msg = ('When calling `transform_projection` the Bijector must implement '
171175
'the `experimental_update_live_dimensions` method.')
172176
raise NotImplementedError(msg)
173-
new_live_dimensions = f.experimental_update_live_dimensions(self.axis_mask)
177+
new_live_dimensions = f.experimental_update_live_dimensions(
178+
self.axis_mask, **kwargs)
174179
if all(tf.get_static_value(new_live_dimensions)):
175180
# Special-case a bijector (direction) that knows that the result
176181
# of the projection will be a full space
177182
return 0, FullSpace()
178183
else:
179184
return 0, AxisAlignedSpace(new_live_dimensions)
180185

181-
def transform_coordinatewise(self, x, f):
186+
def transform_coordinatewise(self, x, f, **kwargs):
182187
# TODO(pravnar): compute the derivative of f along x along the
183188
# live dimensions.
184189
raise NotImplementedError
185190

186191

187-
def jacobian_determinant(x, f):
188-
return f.forward_log_det_jacobian(x)
192+
def jacobian_determinant(x, f, **kwargs):
193+
return f.forward_log_det_jacobian(x, **kwargs)
189194

190195

191196
class FullSpace(TangentSpace):
@@ -197,16 +202,17 @@ class FullSpace(TangentSpace):
197202
at all.
198203
"""
199204

200-
def transform_general(self, x, f):
205+
def transform_general(self, x, f, **kwargs):
201206
"""If the bijector is weird, fall back to the general case."""
202207
as_general_space = GeneralSpace(unit_basis(), 1)
203-
return as_general_space.transform_general(x, f)
208+
return as_general_space.transform_general(x, f, **kwargs)
204209

205-
def transform_dimension_preserving(self, x, f):
206-
return jacobian_determinant(x, f), FullSpace()
210+
def transform_dimension_preserving(self, x, f, **kwargs):
211+
return jacobian_determinant(x, f, **kwargs), FullSpace()
207212

208-
def transform_projection(self, x, f):
209-
return AxisAlignedSpace(tf.ones_like(x)).transform_projection(x, f)
213+
def transform_projection(self, x, f, **kwargs):
214+
return AxisAlignedSpace(tf.ones_like(x)).transform_projection(
215+
x, f, **kwargs)
210216

211217

212218
def volume_coefficient(basis):
@@ -223,7 +229,7 @@ def __init__(self, basis, computed_volume=None):
223229
computed_volume = volume_coefficient(basis)
224230
self.volume = computed_volume
225231

226-
def transform_general(self, x, f):
232+
def transform_general(self, x, f, **kwargs):
227233
raise NotImplementedError
228234

229235

@@ -236,7 +242,7 @@ class ZeroSpace(TangentSpace):
236242
237243
"""
238244

239-
def transform_general(self, x, f):
245+
def transform_general(self, x, f, **kwargs):
240246
del x, f
241247
return 0, ZeroSpace()
242248

0 commit comments

Comments
 (0)