Skip to content

Commit 75dae2d

Browse files
authored
Merge branch 'tensorflow:main' into frighterafix#1384
2 parents 34a11a2 + b78a8fa commit 75dae2d

37 files changed

+1762
-770
lines changed

CONTRIBUTING.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,20 @@ We strongly recommend running unit tests in an active
7979
extra bazel flags, so we created a wrapper script, which we suggest using. An
8080
example invocation (presumed to run from the root of the TFP repo:
8181

82+
#### Dependencies
83+
84+
To run the unit tests, you'll need several packages installed (again, we
85+
strongly recommend you work in a virtualenv). We include a script to do this for
86+
you, which also does some sanity checks on the environtment:
87+
88+
```shell
89+
./testing/install_test_dependencies.sh
90+
```
91+
92+
See the
93+
[header comments in that script](https://github.com/tensorflow/probability/blob/main/testing/install_test_dependencies.sh)
94+
for more details.
95+
8296
#### Helper scripts
8397

8498
```shell
@@ -112,20 +126,6 @@ tfp_test //tensorflow_probability/python/distributions:joint_distribution_corout
112126
tfp_lints tensorflow_probability/python/distributions/joint_distribution_coroutine.py
113127
```
114128

115-
#### Dependencies
116-
117-
To run the unit tests, you'll need several packages installed (again, we
118-
strongly recommend you work in a virtualenv). We include a script to do this for
119-
you, which also does some sanity checks on the environtment:
120-
121-
```shell
122-
./testing/install_test_dependencies.sh
123-
```
124-
125-
See the
126-
[header comments in that script](https://github.com/tensorflow/probability/blob/main/testing/install_test_dependencies.sh)
127-
for more details.
128-
129129
### Additional considerations
130130

131131
As of early 2020, tensorflow and tf-nightly include GPU support by default,

tensorflow_probability/examples/jupyter_notebooks/Factorial_Mixture.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
},
7070
"source": [
7171
"In this notebook we show how to use [TensorFlow Probability](https://github.com/tensorflow/probability) (TFP) to sample from a factorial Mixture of Gaussians distribution defined as:\n",
72-
"$$p(x_1, ..., x_n) = \\prod_i p_i(x_i)$$ where: $$\\begin{align*} p_i &\\equiv \\frac{1}{K}\\sum_{i=1}^K \\pi_{ik}\\,\\text{Normal}\\left(\\text{loc}=\\mu_{ik},\\, \\text{scale}=\\sigma_{ik}\\right)\\\\1&=\\sum_{k=1}^K\\pi_{ik}, \\forall i.\\hphantom{MMMMMMMMMMM}\\end{align*}$$\n",
72+
"$$p(x_1, ..., x_n) = \\prod_i p_i(x_i)$$ where: $$\\begin{align*} p_i &\\equiv \\frac{1}{K}\\sum_{k=1}^K \\pi_{ik}\\,\\text{Normal}\\left(\\text{loc}=\\mu_{ik},\\, \\text{scale}=\\sigma_{ik}\\right)\\\\1&=\\sum_{k=1}^K\\pi_{ik}, \\forall i.\\hphantom{MMMMMMMMMMM}\\end{align*}$$\n",
7373
"\n",
7474
"Each variable $x_i$ is modeled as a mixture of Gaussians, and the joint distribution over all $n$ variables is a product of these densities.\n",
7575
"\n",

tensorflow_probability/examples/jupyter_notebooks/Probabilistic_Layers_VAE.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,12 @@
203203
"train_dataset = (datasets['train']\n",
204204
" .map(_preprocess)\n",
205205
" .batch(256)\n",
206-
" .prefetch(tf.data.experimental.AUTOTUNE)\n",
206+
" .prefetch(tf.data.AUTOTUNE)\n",
207207
" .shuffle(int(10e3)))\n",
208208
"eval_dataset = (datasets['test']\n",
209209
" .map(_preprocess)\n",
210210
" .batch(256)\n",
211-
" .prefetch(tf.data.experimental.AUTOTUNE))"
211+
" .prefetch(tf.data.AUTOTUNE))"
212212
]
213213
},
214214
{

tensorflow_probability/python/distributions/__init__.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,6 @@
166166
# pylint: enable=line-too-long
167167

168168
__all__ = [
169-
'FULLY_REPARAMETERIZED',
170-
'NOT_REPARAMETERIZED',
171-
'ReparameterizationType',
172-
'Distribution',
173169
'Autoregressive',
174170
'BatchBroadcast',
175171
'BatchReshape',
@@ -185,34 +181,38 @@
185181
'Chi',
186182
'Chi2',
187183
'CholeskyLKJ',
184+
'DeterminantalPointProcess',
188185
'Deterministic',
186+
'Dirichlet',
187+
'DirichletMultinomial',
188+
'Distribution',
189189
'DoublesidedMaxwell',
190-
'VectorDeterministic',
191-
'DeterminantalPointProcess',
192190
'Empirical',
193-
'ExponentiallyModifiedGaussian',
194191
'ExpGamma',
195192
'ExpInverseGamma',
196193
'Exponential',
197-
'VectorExponentialDiag',
194+
'ExponentiallyModifiedGaussian',
195+
'ExpRelaxedOneHotCategorical',
196+
'FiniteDiscrete',
197+
'FULLY_REPARAMETERIZED',
198198
'Gamma',
199199
'GammaGamma',
200-
'InverseGaussian',
200+
'GaussianProcess',
201+
'GaussianProcessRegressionModel',
202+
'GeneralizedExtremeValue',
201203
'GeneralizedNormal',
202204
'GeneralizedPareto',
203205
'Geometric',
204-
'GaussianProcess',
205-
'GaussianProcessRegressionModel',
206-
'VariationalGaussianProcess',
207206
'Gumbel',
208-
'GeneralizedExtremeValue',
209207
'HalfCauchy',
210208
'HalfNormal',
211209
'HalfStudentT',
212210
'HiddenMarkovModel',
213211
'Horseshoe',
214212
'Independent',
213+
'independent_joint_distribution_from_structure',
215214
'InverseGamma',
215+
'InverseGaussian',
216216
'JohnsonSU',
217217
'JointDistribution',
218218
'JointDistributionCoroutine',
@@ -221,25 +221,55 @@
221221
'JointDistributionNamedAutoBatched',
222222
'JointDistributionSequential',
223223
'JointDistributionSequentialAutoBatched',
224+
'kl_divergence',
224225
'Kumaraswamy',
225226
'LambertWDistribution',
226227
'LambertWNormal',
227228
'Laplace',
228229
'LinearGaussianStateSpaceModel',
229230
'LKJ',
230231
'Logistic',
232+
'LogitNormal',
231233
'LogLogistic',
232234
'LogNormal',
233-
'LogitNormal',
234235
'MarkovChain',
236+
'Masked',
237+
'MatrixNormalLinearOperator',
238+
'MatrixTLinearOperator',
239+
'Mixture',
240+
'MixtureSameFamily',
235241
'Moyal',
242+
'Multinomial',
243+
'MultivariateNormalDiag',
244+
'MultivariateNormalDiagPlusLowRank',
245+
'MultivariateNormalFullCovariance',
246+
'MultivariateNormalLinearOperator',
247+
'MultivariateNormalTriL',
248+
'MultivariateStudentTLinearOperator',
249+
'mvn_conjugate_linear_update',
236250
'NegativeBinomial',
237251
'Normal',
252+
'normal_conjugates_known_scale_posterior',
253+
'normal_conjugates_known_scale_predictive',
238254
'NormalInverseGaussian',
255+
'NOT_REPARAMETERIZED',
256+
'OneHotCategorical',
257+
'OrderedLogistic',
258+
'Pareto',
259+
'PERT',
239260
'PixelCNN',
261+
'PlackettLuce',
240262
'Poisson',
241263
'PoissonLogNormalQuadratureCompound',
264+
'PowerSpherical',
242265
'ProbitBernoulli',
266+
'quadrature_scheme_lognormal_gauss_hermite',
267+
'quadrature_scheme_lognormal_quantiles',
268+
'QuantizedDistribution',
269+
'RegisterKL',
270+
'RelaxedBernoulli',
271+
'RelaxedOneHotCategorical',
272+
'ReparameterizationType',
243273
'Sample',
244274
'SigmoidBeta',
245275
'SinhArcsinh',
@@ -248,47 +278,18 @@
248278
'StoppingRatioLogistic',
249279
'StudentT',
250280
'StudentTProcess',
281+
'TransformedDistribution',
251282
'Triangular',
252283
'TruncatedCauchy',
253284
'TruncatedNormal',
254285
'Uniform',
255-
'Masked',
256-
'MatrixNormalLinearOperator',
257-
'MatrixTLinearOperator',
258-
'MultivariateNormalDiag',
259-
'MultivariateNormalFullCovariance',
260-
'MultivariateNormalLinearOperator',
261-
'MultivariateNormalTriL',
262-
'MultivariateNormalDiagPlusLowRank',
263-
'MultivariateStudentTLinearOperator',
264-
'Dirichlet',
265-
'DirichletMultinomial',
266-
'Multinomial',
286+
'VariationalGaussianProcess',
287+
'VectorDeterministic',
288+
'VectorExponentialDiag',
267289
'VonMises',
268290
'VonMisesFisher',
269291
'Weibull',
270292
'WishartLinearOperator',
271293
'WishartTriL',
272-
'TransformedDistribution',
273-
'QuantizedDistribution',
274-
'Mixture',
275-
'MixtureSameFamily',
276-
'ExpRelaxedOneHotCategorical',
277-
'OneHotCategorical',
278-
'OrderedLogistic',
279-
'Pareto',
280-
'PERT',
281-
'PlackettLuce',
282-
'PowerSpherical',
283-
'RelaxedBernoulli',
284-
'RelaxedOneHotCategorical',
285294
'Zipf',
286-
'kl_divergence',
287-
'RegisterKL',
288-
'independent_joint_distribution_from_structure',
289-
'mvn_conjugate_linear_update',
290-
'normal_conjugates_known_scale_posterior',
291-
'normal_conjugates_known_scale_predictive',
292-
'quadrature_scheme_lognormal_gauss_hermite',
293-
'quadrature_scheme_lognormal_quantiles',
294295
]

tensorflow_probability/python/distributions/distribution_properties_test.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -485,20 +485,23 @@ def _test_slicing(self, data, dist_name, dist):
485485
# slicing the samples from the original.
486486
self.assertAllEqual(sliced_samples.shape, sliced_dist_samples.shape)
487487

488-
# Check that a sliced distribution can compute the log_prob of its own
489-
# samples (up to numerical validation errors).
488+
# Check that the sliced dist's log_prob agrees with slicing the original's
489+
# log_prob.
490+
# First, we make sure that the original sample we have passes the
491+
# original distribution's validations. We break the bijector cache here
492+
# because slicing will break it later too.
490493
with tfp_hps.no_tf_rank_errors():
491494
try:
492-
lp = self.evaluate(dist.log_prob(samples))
495+
lp = self.evaluate(dist.log_prob(
496+
samples + tf.constant(0, dtype=samples.dtype)))
493497
except tf.errors.InvalidArgumentError:
494498
# TODO(b/129271256): d.log_prob(d.sample()) should not fail
495499
# validate_args checks.
496-
# We only tolerate this case for the non-sliced dist.
500+
# `return` here passes the example. If we `hp.assume(False)`
501+
# instead, that would demand from Hypothesis that it find many
502+
# examples where this check (and the next one) passes;
503+
# empirically, it seems to complain that that's too hard.
497504
return
498-
sliced_lp = self.evaluate(sliced_dist.log_prob(sliced_samples))
499-
500-
# Check that the sliced dist's log_prob agrees with slicing the original's
501-
# log_prob.
502505

503506
# This `hp.assume` is suppressing array sizes that cause the sliced and
504507
# non-sliced distribution to follow different Eigen code paths. Those
@@ -518,6 +521,10 @@ def _test_slicing(self, data, dist_name, dist):
518521
hp.note('Non-packetization check {}'.format(all_non_packetized))
519522
hp.assume(all_packetized or all_non_packetized)
520523

524+
# Actually evaluate and test the sliced log_prob
525+
with tfp_hps.no_tf_rank_errors():
526+
sliced_lp = self.evaluate(sliced_dist.log_prob(sliced_samples))
527+
521528
self.assertAllClose(lp[slices], sliced_lp,
522529
atol=SLICING_LOGPROB_ATOL[dist_name],
523530
rtol=SLICING_LOGPROB_RTOL[dist_name])

tensorflow_probability/python/experimental/mcmc/windowed_sampling_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ def hmc_kwargs(self):
621621
('nuts_jit_sig', 'nuts'))
622622
def test_base_kernel(self, kind):
623623
self.skip_if_no_xla()
624+
self.skipTest('b/195070752')
624625

625626
if JAX_MODE:
626627
input_signature = None

tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,9 @@ def one_step_predictive(model,
340340
distribution of each timestep given previous timesteps.
341341
"""
342342
dtype = dtype_util.common_dtype([
343-
posterior_samples.level_scale.dtype,
344-
posterior_samples.observation_noise_scale.dtype,
345-
posterior_samples.level.dtype,
343+
posterior_samples.level_scale,
344+
posterior_samples.observation_noise_scale,
345+
posterior_samples.level,
346346
original_mean,
347347
original_scale], dtype_hint=tf.float32)
348348
num_observed_steps = prefer_static.shape(posterior_samples.level)[-1]

tensorflow_probability/python/experimental/sts_gibbs/gibbs_sampler_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def reshape_chain_and_sample(x):
202202

203203
@parameterized.named_parameters(
204204
{'testcase_name': 'float32_xla', 'dtype': tf.float32, 'use_xla': True},
205-
{'testcase_name': 'float16', 'dtype': tf.float16, 'use_xla': False})
205+
{'testcase_name': 'float64', 'dtype': tf.float64, 'use_xla': False})
206206
def test_end_to_end_prediction_works_and_is_deterministic(
207207
self, dtype, use_xla):
208208
if not tf.executing_eagerly():
@@ -211,7 +211,8 @@ def test_end_to_end_prediction_works_and_is_deterministic(
211211
model, observed_time_series, is_missing = self._build_test_model(
212212
num_timesteps=5,
213213
batch_shape=[3],
214-
prior_class=gibbs_sampler.XLACompilableInverseGamma)
214+
prior_class=gibbs_sampler.XLACompilableInverseGamma,
215+
dtype=dtype)
215216

216217
@tf.function(jit_compile=use_xla)
217218
def do_sampling(observed_time_series, is_missing):

tensorflow_probability/python/internal/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,17 @@ multi_substrate_py_library(
233233
],
234234
)
235235

236+
multi_substrate_py_test(
237+
name = "custom_gradient_test",
238+
srcs = ["custom_gradient_test.py"],
239+
deps = [
240+
":custom_gradient",
241+
# tensorflow dep,
242+
"//tensorflow_probability/python/internal:test_util",
243+
"//tensorflow_probability/python/math:gradient",
244+
],
245+
)
246+
236247
py_test(
237248
name = "cache_util_test",
238249
size = "small",

tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,17 @@ def dimension_at_index(shape, index):
213213

214214
@tf_export(v1=["Dimension"])
215215
class Dimension(object):
216-
"""Represents the value of one dimension in a TensorShape."""
216+
"""Represents the value of one dimension in a TensorShape.
217+
218+
@compatibility(TF2)
219+
In TF2, members of a `TensorShape` object are integers. The `Dimension` class
220+
is not part of TF2's data model.
221+
222+
Please refer to the [TensorShape section of the migration guide]
223+
(https://www.tensorflow.org/guide/migrate/index#tensorshape) on common code
224+
patterns adapting Dimension objects to a TF2 syntax.
225+
@end_compatibility
226+
"""
217227

218228
__slots__ = ["_value"]
219229

0 commit comments

Comments
 (0)