@@ -454,14 +454,11 @@ def test_sets_kinetic_energy(self):
454
454
grads_next_target_log_prob )
455
455
456
456
457
- @test_util .test_all_tf_execution_regimes
458
- @parameterized .named_parameters (
459
- dict (testcase_name = '_default' , use_default = True ),
460
- dict (testcase_name = '_explicit' , use_default = False ))
461
- class PreconditionedHMCTest (test_util .TestCase ):
457
+ class _PreconditionedHMCTest (test_util .TestCase ):
462
458
463
- def test_f64 (self , use_default ):
464
- if use_default :
459
+ @test_util .test_graph_and_eager_modes ()
460
+ def test_f64 (self ):
461
+ if self .use_default_momentum_distribution :
465
462
momentum_distribution = None
466
463
else :
467
464
momentum_distribution = as_composite (
@@ -474,8 +471,9 @@ def test_f64(self, use_default):
474
471
1 , kernel = kernel , current_state = tf .ones ([], tf .float64 ),
475
472
num_burnin_steps = 5 , trace_fn = None , seed = test_util .test_seed ()))
476
473
477
- def test_f64_multichain (self , use_default ):
478
- if use_default :
474
+ @test_util .test_graph_and_eager_modes ()
475
+ def test_f64_multichain (self ):
476
+ if self .use_default_momentum_distribution :
479
477
momentum_distribution = None
480
478
else :
481
479
momentum_distribution = as_composite (
@@ -489,8 +487,9 @@ def test_f64_multichain(self, use_default):
489
487
1 , kernel = kernel , current_state = tf .ones ([nchains ], tf .float64 ),
490
488
num_burnin_steps = 5 , trace_fn = None , seed = test_util .test_seed ()))
491
489
492
- def test_f64_multichain_multipart (self , use_default ):
493
- if use_default :
490
+ @test_util .test_graph_and_eager_modes ()
491
+ def test_f64_multichain_multipart (self ):
492
+ if self .use_default_momentum_distribution :
494
493
momentum_distribution = None
495
494
else :
496
495
momentum_distribution = _make_composite_tensor (
@@ -508,23 +507,20 @@ def test_f64_multichain_multipart(self, use_default):
508
507
tf .ones ([nchains ], tf .float64 )),
509
508
num_burnin_steps = 5 , trace_fn = None , seed = test_util .test_seed ()))
510
509
511
- def test_diag (self , use_default ):
510
+ @test_util .test_graph_mode_only () # Long chains are very slow in eager mode.
511
+ def test_diag (self ):
512
512
"""Test that a diagonal multivariate normal can be effectively sampled from.
513
513
514
514
Note that the effective sample size is expected to be exactly 100: this is
515
515
because the step size is tuned well enough that a single HMC step takes
516
516
a point to nearly the antipodal point, which causes a negative lag 1
517
517
autocorrelation, and the effective sample size calculation cuts off when
518
518
the autocorrelation drops below zero.
519
-
520
- Args:
521
- use_default: bool, whether to use a custom momentum distribution, or
522
- the default.
523
519
"""
524
520
mvn = tfd .MultivariateNormalDiag (
525
521
loc = [1. , 2. , 3. ], scale_diag = [0.1 , 1. , 10. ])
526
522
527
- if use_default :
523
+ if self . use_default_momentum_distribution :
528
524
momentum_distribution = None
529
525
step_size = 0.1
530
526
else :
@@ -547,20 +543,20 @@ def test_diag(self, use_default):
547
543
filter_threshold = 0 ,
548
544
filter_beyond_positive_pairs = False )
549
545
550
- if not use_default :
546
+ if not self . use_default_momentum_distribution :
551
547
self .assertAllClose (ess , tf .fill ([3 ], 100. ))
552
548
else :
553
549
self .assertLess (self .evaluate (tf .reduce_min (ess )), 100. )
554
550
555
- def test_tril ( self , use_default ):
556
- if tf . executing_eagerly ():
557
- self . skipTest ( 'b/169882656 Too many warnings are issued in eager logs' )
551
+ @ test_util . test_graph_mode_only () # Long chains are very slow in eager mode.
552
+ @ test_util . jax_disable_test_missing_functionality ( 'dynamic shapes' )
553
+ def test_tril ( self ):
558
554
cov = 0.9 * tf .ones ([3 , 3 ]) + 0.1 * tf .eye (3 )
559
555
scale = tf .linalg .cholesky (cov )
560
556
mv_tril = tfd .MultivariateNormalTriL (loc = [1. , 2. , 3. ],
561
557
scale_tril = scale )
562
558
563
- if use_default :
559
+ if self . use_default_momentum_distribution :
564
560
momentum_distribution = None
565
561
else :
566
562
momentum_distribution = tfde .MultivariateNormalPrecisionFactorLinearOperator (
@@ -588,16 +584,17 @@ def test_tril(self, use_default):
588
584
# was the wrong one. Why is that? A guess is that since there are *many*
589
585
# ways to have larger ess, these tests don't really test correctness.
590
586
# Perhaps remove all tests like these.
591
- if not use_default :
587
+ if not self . use_default_momentum_distribution :
592
588
self .assertAllClose (ess , tf .fill ([3 ], 100. ))
593
589
else :
594
590
self .assertLess (self .evaluate (tf .reduce_min (ess )), 100. )
595
591
596
- def test_transform (self , use_default ):
592
+ @test_util .test_graph_mode_only () # Long chains are very slow in eager mode.
593
+ def test_transform (self ):
597
594
mvn = tfd .MultivariateNormalDiag (loc = [1. , 2. , 3. ], scale_diag = [1. , 1. , 1. ])
598
595
diag_variance = tf .constant ([0.1 , 1. , 10. ])
599
596
600
- if use_default :
597
+ if self . use_default_momentum_distribution :
601
598
momentum_distribution = None
602
599
else :
603
600
momentum_distribution = tfde .MultivariateNormalPrecisionFactorLinearOperator (
@@ -622,19 +619,20 @@ def test_transform(self, use_default):
622
619
filter_threshold = 0 ,
623
620
filter_beyond_positive_pairs = False )
624
621
625
- if not use_default :
622
+ if not self . use_default_momentum_distribution :
626
623
self .assertAllClose (ess , tf .fill ([3 ], 100. ))
627
624
else :
628
625
self .assertLess (self .evaluate (tf .reduce_min (ess )), 100. )
629
626
630
- def test_multi_state_part (self , use_default ):
627
+ @test_util .test_graph_mode_only () # Long chains are very slow in eager mode.
628
+ def test_multi_state_part (self ):
631
629
mvn = tfd .JointDistributionSequential ([
632
630
tfd .Normal (1. , 0.1 ),
633
631
tfd .Normal (2. , 1. ),
634
632
tfd .Independent (tfd .Normal (3 * tf .ones ([2 , 3 , 4 ]), 10. ), 3 )
635
633
])
636
634
637
- if use_default :
635
+ if self . use_default_momentum_distribution :
638
636
momentum_distribution = None
639
637
step_size = 0.1
640
638
else :
@@ -667,7 +665,7 @@ def test_multi_state_part(self, use_default):
667
665
ess = tfp .mcmc .effective_sample_size (draws ,
668
666
filter_threshold = 0 ,
669
667
filter_beyond_positive_pairs = False )
670
- if not use_default :
668
+ if not self . use_default_momentum_distribution :
671
669
self .assertAllClose (
672
670
self .evaluate (ess ),
673
671
[tf .constant (100. ),
@@ -678,11 +676,12 @@ def test_multi_state_part(self, use_default):
678
676
tf .reduce_min (tf .nest .map_structure (tf .reduce_min , ess ))),
679
677
50. )
680
678
681
- def test_batched_state (self , use_default ):
679
+ @test_util .test_graph_mode_only () # Long chains are very slow in eager mode.
680
+ def test_batched_state (self ):
682
681
mvn = tfd .MultivariateNormalDiag (
683
682
loc = [1. , 2. , 3. ], scale_diag = [0.1 , 1. , 10. ])
684
683
batch_shape = [2 , 4 ]
685
- if use_default :
684
+ if self . use_default_momentum_distribution :
686
685
momentum_distribution = None
687
686
step_size = 0.1
688
687
else :
@@ -705,18 +704,19 @@ def test_batched_state(self, use_default):
705
704
ess = tfp .mcmc .effective_sample_size (draws [10 :], cross_chain_dims = [1 , 2 ],
706
705
filter_threshold = 0 ,
707
706
filter_beyond_positive_pairs = False )
708
- if not use_default :
707
+ if not self . use_default_momentum_distribution :
709
708
self .assertAllClose (self .evaluate (ess ), 100 * 2. * 4. * tf .ones (3 ))
710
709
else :
711
710
self .assertLess (self .evaluate (tf .reduce_min (ess )), 100. )
712
711
713
- def test_batches (self , use_default ):
712
+ @test_util .test_graph_mode_only () # Long chains are very slow in eager mode.
713
+ def test_batches (self ):
714
714
mvn = tfd .JointDistributionSequential (
715
715
[tfd .Normal (1. , 0.1 ),
716
716
tfd .Normal (2. , 1. ),
717
717
tfd .Normal (3. , 10. )])
718
718
n_chains = 10
719
- if use_default :
719
+ if self . use_default_momentum_distribution :
720
720
momentum_distribution = None
721
721
step_size = 0.1
722
722
else :
@@ -751,12 +751,23 @@ def test_batches(self, use_default):
751
751
ess = tfp .mcmc .effective_sample_size (
752
752
draws , cross_chain_dims = [1 for _ in draws ],
753
753
filter_threshold = 0 , filter_beyond_positive_pairs = False )
754
- if not use_default :
754
+ if not self . use_default_momentum_distribution :
755
755
self .assertAllClose (self .evaluate (ess ), 100 * n_chains * tf .ones (3 ))
756
756
else :
757
757
self .assertLess (self .evaluate (tf .reduce_min (ess )), 100. )
758
758
759
759
760
+ class PreconditionedHMCTestDefaultMomentum (_PreconditionedHMCTest ):
761
+ use_default_momentum_distribution = True
762
+
763
+
764
+ class PreconditionedHMCTestExplicitMomentum (_PreconditionedHMCTest ):
765
+ use_default_momentum_distribution = False
766
+
767
+
768
+ del _PreconditionedHMCTest # Don't try to run base class tests.
769
+
770
+
760
771
@test_util .test_all_tf_execution_regimes
761
772
class DistributedPHMCTest (distribute_test_lib .DistributedTest ):
762
773
0 commit comments