Skip to content

Commit dbe6138

Browse files
srvasudejburnim
authored andcommitted
Enable experimental/marginalize for the JAX and Numpy backends.
PiperOrigin-RevId: 579016812
1 parent c3586af commit dbe6138

File tree

5 files changed

+39
-29
lines changed

5 files changed

+39
-29
lines changed

tensorflow_probability/python/experimental/marginalize/BUILD

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
# Description:
1616
# Automatic marginalization of latent variables.
1717

18-
# Placeholder: py_library
19-
# Placeholder: py_test
18+
load(
19+
"//tensorflow_probability/python:build_defs.bzl",
20+
"multi_substrate_py_library",
21+
"multi_substrate_py_test",
22+
)
2023

2124
package(
2225
# default_applicable_licenses
@@ -27,17 +30,18 @@ package(
2730

2831
licenses(["notice"])
2932

30-
py_library(
33+
multi_substrate_py_library(
3134
name = "logeinsumexp",
3235
srcs = ["logeinsumexp.py"],
3336
deps = [
3437
# numpy dep,
3538
# opt_einsum dep,
3639
# tensorflow dep,
40+
"//tensorflow_probability/python/internal:prefer_static",
3741
],
3842
)
3943

40-
py_test(
44+
multi_substrate_py_test(
4145
name = "logeinsumexp_test",
4246
size = "medium",
4347
srcs = [
@@ -53,7 +57,7 @@ py_test(
5357
],
5458
)
5559

56-
py_library(
60+
multi_substrate_py_library(
5761
name = "marginalize",
5862
srcs = ["__init__.py"],
5963
deps = [
@@ -62,7 +66,7 @@ py_library(
6266
],
6367
)
6468

65-
py_library(
69+
multi_substrate_py_library(
6670
name = "marginalizable",
6771
srcs = ["marginalizable.py"],
6872
deps = [
@@ -72,13 +76,17 @@ py_library(
7276
"//tensorflow_probability/python/distributions:categorical",
7377
"//tensorflow_probability/python/distributions:joint_distribution_coroutine",
7478
"//tensorflow_probability/python/distributions:sample",
79+
"//tensorflow_probability/python/internal:prefer_static",
80+
"//tensorflow_probability/python/internal:samplers",
7581
],
7682
)
7783

78-
py_test(
84+
multi_substrate_py_test(
7985
name = "marginalizable_test",
8086
size = "medium",
8187
srcs = ["marginalizable_test.py"],
88+
jax_tags = ["notap"],
89+
numpy_tags = ["notap"],
8290
deps = [
8391
":marginalize",
8492
# absl/testing:parameterized dep,
@@ -92,6 +100,7 @@ py_test(
92100
"//tensorflow_probability/python/distributions:normal",
93101
"//tensorflow_probability/python/distributions:poisson",
94102
"//tensorflow_probability/python/distributions:sample",
103+
"//tensorflow_probability/python/internal:prefer_static",
95104
"//tensorflow_probability/python/internal:test_util",
96105
],
97106
)

tensorflow_probability/python/experimental/marginalize/logeinsumexp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
"""Compute einsums in log space."""
1616

1717
import opt_einsum as oe
18-
import tensorflow.compat.v1 as tf
18+
import tensorflow.compat.v2 as tf
19+
from tensorflow_probability.python.internal import prefer_static as ps
1920

2021

2122
# pylint: disable=no-member
@@ -72,8 +73,8 @@ def rearrange(src, dst, t):
7273
if i not in src:
7374
new_indices += i
7475
new_src = src + new_indices
75-
new_t = tf.reshape(t, tf.concat(
76-
[tf.shape(t), tf.ones(len(new_indices), dtype=tf.int32)], axis=0))
76+
new_t = tf.reshape(t, ps.concat(
77+
[ps.shape(t), ps.ones(len(new_indices), dtype=tf.int32)], axis=0))
7778
formula = '{}->{}'.format(new_src, dst)
7879
# It is safe to use ordinary `einsum` here as no summations
7980
# are performed.

tensorflow_probability/python/experimental/marginalize/logeinsumexp_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from hypothesis.extra import numpy as hpnp
1919
import hypothesis.strategies as hps
2020
import numpy as np
21-
import tensorflow.compat.v1 as tf
21+
import tensorflow.compat.v2 as tf
2222
from tensorflow_probability.python.experimental.marginalize.logeinsumexp import _binary_einslogsumexp
2323
from tensorflow_probability.python.experimental.marginalize.logeinsumexp import logeinsumexp
2424
from tensorflow_probability.python.internal import test_util
@@ -179,7 +179,6 @@ def test_compare_einsum(self):
179179
formula = 'abcdcfg,edfcbaa->bd'
180180
u = tf.math.log(tf.einsum(formula, a, b))
181181
v = logeinsumexp(formula, tf.math.log(a), tf.math.log(b))
182-
183182
self.assertAllClose(u, v)
184183

185184
def test_zero_zero_multiplication(self):

tensorflow_probability/python/experimental/marginalize/marginalizable.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from tensorflow_probability.python.distributions import joint_distribution_coroutine as jdc_lib
2525
from tensorflow_probability.python.distributions import sample as sample_lib
2626
from tensorflow_probability.python.experimental.marginalize.logeinsumexp import logeinsumexp
27+
from tensorflow_probability.python.internal import prefer_static as ps
28+
from tensorflow_probability.python.internal import samplers
2729

2830

2931
__all__ = [
@@ -117,10 +119,9 @@ def _support(dist):
117119
dist.sample_shape, 'expand_sample_shape')
118120
p, rank = _support(dist.distribution)
119121
product = _power(p, n)
120-
new_shape = tf.concat([tf.shape(product)[:-1], sample_shape], axis=-1)
122+
new_shape = ps.concat([ps.shape(product)[:-1], sample_shape], axis=-1)
121123

122-
new_rank = rank + tf.compat.v2.compat.dimension_value(
123-
sample_shape.shape[0])
124+
new_rank = rank + tf.compat.dimension_value(sample_shape.shape[0])
124125
return tf.reshape(product, new_shape), new_rank
125126
else:
126127
raise ValueError('Unable to find support for distribution ' +
@@ -141,11 +142,11 @@ def _expand_right(a, n, pos):
141142
Tensor with inserted dimensions.
142143
"""
143144

144-
axis = tf.rank(a) + pos + 1
145-
return tf.reshape(a, tf.concat([
146-
tf.shape(a)[:axis],
147-
tf.ones([n], dtype=tf.int32),
148-
tf.shape(a)[axis:]], axis=0))
145+
axis = ps.rank(a) + pos + 1
146+
return tf.reshape(a, ps.concat([
147+
ps.shape(a)[:axis],
148+
ps.ones([n], dtype=tf.int32),
149+
ps.shape(a)[axis:]], axis=0))
149150

150151

151152
def _letter(i):
@@ -216,7 +217,9 @@ def marginalized_log_prob(self, values, name='marginalized_log_prob',
216217

217218
with tf.name_scope(name):
218219
ds = self._call_execute_model(
219-
sample_and_trace_fn=jd_lib.trace_distributions_only)
220+
sample_and_trace_fn=jd_lib.trace_distributions_only,
221+
# Only used for tracing so can be fixed.
222+
seed=samplers.zeros_seed())
220223

221224
# Both 'marginalize' and 'tabulate' indicate that
222225
# instead of using samples provided by the user, this method
@@ -229,7 +232,7 @@ def marginalized_log_prob(self, values, name='marginalized_log_prob',
229232
for value, dist in zip(values, ds):
230233
if value == 'marginalize':
231234
supp, rank = _support(dist)
232-
r = supp.shape.rank
235+
r = ps.rank(supp)
233236
num_new_variables = r - rank
234237
# We can think of supp as being a tensor containing tensors,
235238
# each of which is a draw from the distribution.
@@ -251,7 +254,7 @@ def marginalized_log_prob(self, values, name='marginalized_log_prob',
251254
formula.append(indices)
252255
elif value == 'tabulate':
253256
supp, rank = _support(dist)
254-
r = supp.shape.rank
257+
r = ps.rank(supp)
255258
if r is None:
256259
raise ValueError('Need to be able to statically find rank of'
257260
'support of random variable: {}'.format(str(dist)))

tensorflow_probability/python/experimental/marginalize/marginalizable_test.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from tensorflow_probability.python.distributions import poisson
3535
from tensorflow_probability.python.distributions import sample as sample_dist_lib
3636
import tensorflow_probability.python.experimental.marginalize as marginalize
37+
from tensorflow_probability.python.internal import prefer_static as ps
3738
from tensorflow_probability.python.internal import test_util
3839

3940

@@ -48,10 +49,6 @@ def _conform(ts):
4849
return [tf.broadcast_to(a, shape) for a in ts]
4950

5051

51-
def _cat(*ts):
52-
return tf.concat(ts, axis=0)
53-
54-
5552
def _stack(*ts):
5653
return tf.stack(_conform(ts), axis=-1)
5754

@@ -209,7 +206,7 @@ def test_hmm(self):
209206
n_steps = 4
210207
infer_step = 2
211208

212-
observations = [-1.0, 0.0, 1.0, 2.0]
209+
observations = np.array([-1.0, 0.0, 1.0, 2.0], np.float32)
213210

214211
initial_prob = tf.constant([0.6, 0.4], dtype=tf.float32)
215212
transition_matrix = tf.constant([[0.6, 0.4],
@@ -309,7 +306,7 @@ def model():
309306
0.4 * tf.roll(o, shift=[1, 0], axis=[-2, -1]))
310307

311308
# Reshape just last two dimensions.
312-
p = tf.reshape(p, _cat(p.shape[:-2], [-1]))
309+
p = tf.reshape(p, ps.concat([ps.shape(p)[:-2], [-1]], axis=0))
313310
xy = yield categorical.Categorical(probs=p, dtype=tf.int32)
314311
x[i] = xy // n
315312
y[i] = xy % n
@@ -342,6 +339,7 @@ def model():
342339
# order chosen by `tf.einsum` closer matches `_tree_example` above.
343340
self.assertAllClose(p, q)
344341

342+
@test_util.numpy_disable_gradient_test
345343
def test_marginalized_gradient(self):
346344
n = 10
347345

0 commit comments

Comments
 (0)