Skip to content

Commit 21f3744

Browse files
davmretensorflower-gardener
authored andcommitted
Reduce preconditioned_hmc_test weight by running some tests in graph mode only.
Also replaces the `use_default` parameter with a two-class structure (PreconditionedHMCTestDefaultMomentum and PreconditionedHMCTestExplicitMomentum), because I had trouble getting @parameterized to cooperate with the per-method execution regime decorators. PiperOrigin-RevId: 381318505
1 parent c2a61b5 commit 21f3744

File tree

2 files changed

+46
-36
lines changed

2 files changed

+46
-36
lines changed

tensorflow_probability/python/experimental/mcmc/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,6 @@ multi_substrate_py_test(
311311
srcs = ["preconditioned_hmc_test.py"],
312312
disabled_substrates = ["numpy"],
313313
python_version = "PY3",
314-
shard_count = 10,
315314
srcs_version = "PY3",
316315
deps = [
317316
# tensorflow dep,

tensorflow_probability/python/experimental/mcmc/preconditioned_hmc_test.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -454,14 +454,11 @@ def test_sets_kinetic_energy(self):
454454
grads_next_target_log_prob)
455455

456456

457-
@test_util.test_all_tf_execution_regimes
458-
@parameterized.named_parameters(
459-
dict(testcase_name='_default', use_default=True),
460-
dict(testcase_name='_explicit', use_default=False))
461-
class PreconditionedHMCTest(test_util.TestCase):
457+
class _PreconditionedHMCTest(test_util.TestCase):
462458

463-
def test_f64(self, use_default):
464-
if use_default:
459+
@test_util.test_graph_and_eager_modes()
460+
def test_f64(self):
461+
if self.use_default_momentum_distribution:
465462
momentum_distribution = None
466463
else:
467464
momentum_distribution = as_composite(
@@ -474,8 +471,9 @@ def test_f64(self, use_default):
474471
1, kernel=kernel, current_state=tf.ones([], tf.float64),
475472
num_burnin_steps=5, trace_fn=None, seed=test_util.test_seed()))
476473

477-
def test_f64_multichain(self, use_default):
478-
if use_default:
474+
@test_util.test_graph_and_eager_modes()
475+
def test_f64_multichain(self):
476+
if self.use_default_momentum_distribution:
479477
momentum_distribution = None
480478
else:
481479
momentum_distribution = as_composite(
@@ -489,8 +487,9 @@ def test_f64_multichain(self, use_default):
489487
1, kernel=kernel, current_state=tf.ones([nchains], tf.float64),
490488
num_burnin_steps=5, trace_fn=None, seed=test_util.test_seed()))
491489

492-
def test_f64_multichain_multipart(self, use_default):
493-
if use_default:
490+
@test_util.test_graph_and_eager_modes()
491+
def test_f64_multichain_multipart(self):
492+
if self.use_default_momentum_distribution:
494493
momentum_distribution = None
495494
else:
496495
momentum_distribution = _make_composite_tensor(
@@ -508,23 +507,20 @@ def test_f64_multichain_multipart(self, use_default):
508507
tf.ones([nchains], tf.float64)),
509508
num_burnin_steps=5, trace_fn=None, seed=test_util.test_seed()))
510509

511-
def test_diag(self, use_default):
510+
@test_util.test_graph_mode_only() # Long chains are very slow in eager mode.
511+
def test_diag(self):
512512
"""Test that a diagonal multivariate normal can be effectively sampled from.
513513
514514
Note that the effective sample size is expected to be exactly 100: this is
515515
because the step size is tuned well enough that a single HMC step takes
516516
a point to nearly the antipodal point, which causes a negative lag 1
517517
autocorrelation, and the effective sample size calculation cuts off when
518518
the autocorrelation drops below zero.
519-
520-
Args:
521-
use_default: bool, whether to use a custom momentum distribution, or
522-
the default.
523519
"""
524520
mvn = tfd.MultivariateNormalDiag(
525521
loc=[1., 2., 3.], scale_diag=[0.1, 1., 10.])
526522

527-
if use_default:
523+
if self.use_default_momentum_distribution:
528524
momentum_distribution = None
529525
step_size = 0.1
530526
else:
@@ -547,20 +543,20 @@ def test_diag(self, use_default):
547543
filter_threshold=0,
548544
filter_beyond_positive_pairs=False)
549545

550-
if not use_default:
546+
if not self.use_default_momentum_distribution:
551547
self.assertAllClose(ess, tf.fill([3], 100.))
552548
else:
553549
self.assertLess(self.evaluate(tf.reduce_min(ess)), 100.)
554550

555-
def test_tril(self, use_default):
556-
if tf.executing_eagerly():
557-
self.skipTest('b/169882656 Too many warnings are issued in eager logs')
551+
@test_util.test_graph_mode_only() # Long chains are very slow in eager mode.
552+
@test_util.jax_disable_test_missing_functionality('dynamic shapes')
553+
def test_tril(self):
558554
cov = 0.9 * tf.ones([3, 3]) + 0.1 * tf.eye(3)
559555
scale = tf.linalg.cholesky(cov)
560556
mv_tril = tfd.MultivariateNormalTriL(loc=[1., 2., 3.],
561557
scale_tril=scale)
562558

563-
if use_default:
559+
if self.use_default_momentum_distribution:
564560
momentum_distribution = None
565561
else:
566562
momentum_distribution = tfde.MultivariateNormalPrecisionFactorLinearOperator(
@@ -588,16 +584,17 @@ def test_tril(self, use_default):
588584
# was the wrong one. Why is that? A guess is that since there are *many*
589585
# ways to have larger ess, these tests don't really test correctness.
590586
# Perhaps remove all tests like these.
591-
if not use_default:
587+
if not self.use_default_momentum_distribution:
592588
self.assertAllClose(ess, tf.fill([3], 100.))
593589
else:
594590
self.assertLess(self.evaluate(tf.reduce_min(ess)), 100.)
595591

596-
def test_transform(self, use_default):
592+
@test_util.test_graph_mode_only() # Long chains are very slow in eager mode.
593+
def test_transform(self):
597594
mvn = tfd.MultivariateNormalDiag(loc=[1., 2., 3.], scale_diag=[1., 1., 1.])
598595
diag_variance = tf.constant([0.1, 1., 10.])
599596

600-
if use_default:
597+
if self.use_default_momentum_distribution:
601598
momentum_distribution = None
602599
else:
603600
momentum_distribution = tfde.MultivariateNormalPrecisionFactorLinearOperator(
@@ -622,19 +619,20 @@ def test_transform(self, use_default):
622619
filter_threshold=0,
623620
filter_beyond_positive_pairs=False)
624621

625-
if not use_default:
622+
if not self.use_default_momentum_distribution:
626623
self.assertAllClose(ess, tf.fill([3], 100.))
627624
else:
628625
self.assertLess(self.evaluate(tf.reduce_min(ess)), 100.)
629626

630-
def test_multi_state_part(self, use_default):
627+
@test_util.test_graph_mode_only() # Long chains are very slow in eager mode.
628+
def test_multi_state_part(self):
631629
mvn = tfd.JointDistributionSequential([
632630
tfd.Normal(1., 0.1),
633631
tfd.Normal(2., 1.),
634632
tfd.Independent(tfd.Normal(3 * tf.ones([2, 3, 4]), 10.), 3)
635633
])
636634

637-
if use_default:
635+
if self.use_default_momentum_distribution:
638636
momentum_distribution = None
639637
step_size = 0.1
640638
else:
@@ -667,7 +665,7 @@ def test_multi_state_part(self, use_default):
667665
ess = tfp.mcmc.effective_sample_size(draws,
668666
filter_threshold=0,
669667
filter_beyond_positive_pairs=False)
670-
if not use_default:
668+
if not self.use_default_momentum_distribution:
671669
self.assertAllClose(
672670
self.evaluate(ess),
673671
[tf.constant(100.),
@@ -678,11 +676,12 @@ def test_multi_state_part(self, use_default):
678676
tf.reduce_min(tf.nest.map_structure(tf.reduce_min, ess))),
679677
50.)
680678

681-
def test_batched_state(self, use_default):
679+
@test_util.test_graph_mode_only() # Long chains are very slow in eager mode.
680+
def test_batched_state(self):
682681
mvn = tfd.MultivariateNormalDiag(
683682
loc=[1., 2., 3.], scale_diag=[0.1, 1., 10.])
684683
batch_shape = [2, 4]
685-
if use_default:
684+
if self.use_default_momentum_distribution:
686685
momentum_distribution = None
687686
step_size = 0.1
688687
else:
@@ -705,18 +704,19 @@ def test_batched_state(self, use_default):
705704
ess = tfp.mcmc.effective_sample_size(draws[10:], cross_chain_dims=[1, 2],
706705
filter_threshold=0,
707706
filter_beyond_positive_pairs=False)
708-
if not use_default:
707+
if not self.use_default_momentum_distribution:
709708
self.assertAllClose(self.evaluate(ess), 100 * 2. * 4. * tf.ones(3))
710709
else:
711710
self.assertLess(self.evaluate(tf.reduce_min(ess)), 100.)
712711

713-
def test_batches(self, use_default):
712+
@test_util.test_graph_mode_only() # Long chains are very slow in eager mode.
713+
def test_batches(self):
714714
mvn = tfd.JointDistributionSequential(
715715
[tfd.Normal(1., 0.1),
716716
tfd.Normal(2., 1.),
717717
tfd.Normal(3., 10.)])
718718
n_chains = 10
719-
if use_default:
719+
if self.use_default_momentum_distribution:
720720
momentum_distribution = None
721721
step_size = 0.1
722722
else:
@@ -751,12 +751,23 @@ def test_batches(self, use_default):
751751
ess = tfp.mcmc.effective_sample_size(
752752
draws, cross_chain_dims=[1 for _ in draws],
753753
filter_threshold=0, filter_beyond_positive_pairs=False)
754-
if not use_default:
754+
if not self.use_default_momentum_distribution:
755755
self.assertAllClose(self.evaluate(ess), 100 * n_chains * tf.ones(3))
756756
else:
757757
self.assertLess(self.evaluate(tf.reduce_min(ess)), 100.)
758758

759759

760+
class PreconditionedHMCTestDefaultMomentum(_PreconditionedHMCTest):
761+
use_default_momentum_distribution = True
762+
763+
764+
class PreconditionedHMCTestExplicitMomentum(_PreconditionedHMCTest):
765+
use_default_momentum_distribution = False
766+
767+
768+
del _PreconditionedHMCTest # Don't try to run base class tests.
769+
770+
760771
@test_util.test_all_tf_execution_regimes
761772
class DistributedPHMCTest(distribute_test_lib.DistributedTest):
762773

0 commit comments

Comments
 (0)