Skip to content

Commit 0df7250

Browse files
Googlertensorflower-gardener
authored andcommitted
Restore tfd.Mixture to jax substrate.
PiperOrigin-RevId: 455459945
1 parent d3c82ea commit 0df7250

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ multi_substrate_py_library(
3636
srcs = ["__init__.py"],
3737
substrates_omit_deps = [
3838
":pixel_cnn",
39-
":mixture",
4039
],
4140
deps = [
4241
":autoregressive",
@@ -3322,7 +3321,6 @@ multi_substrate_py_test(
33223321
name = "mixture_test",
33233322
size = "medium",
33243323
srcs = ["mixture_test.py"],
3325-
jax_tags = ["notap"],
33263324
numpy_tags = ["notap"],
33273325
deps = [
33283326
# hypothesis dep,

tensorflow_probability/python/distributions/mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def _sample_n(self, n, seed=None):
347347
mask = distribution_util.pad_mixture_dimensions(
348348
mask, self, self._cat,
349349
tensorshape_util.rank(self._static_event_shape)) # [n, B, k, [1]*e]
350-
if x.dtype.is_floating:
350+
if dtype_util.is_floating(x.dtype):
351351
masked = tf.math.multiply_no_nan(x, mask)
352352
else:
353353
masked = x * mask

0 commit comments

Comments
 (0)