@@ -471,14 +471,14 @@ def test_supply_full_step_size(self):
471
471
loc = tf .zeros (3 ), scale_diag = tf .constant ([1. , 2. , 3. ]))
472
472
})
473
473
474
- init_step_size = {'a' : tf .reshape (tf .linspace (1. , 2. , 20 ), (20 , 1 )),
475
- 'b' : tf .reshape (tf .linspace (1. , 2. , 60 ), (20 , 3 ))}
474
+ init_step_size = {'a' : tf .reshape (tf .linspace (1. , 2. , 3 ), (3 , 1 )),
475
+ 'b' : tf .reshape (tf .linspace (1. , 2. , 9 ), (3 , 3 ))}
476
476
477
477
_ , actual_step_size = tfp .experimental .mcmc .windowed_adaptive_hmc (
478
478
1 ,
479
479
jd_model ,
480
- num_adaptation_steps = 100 ,
481
- n_chains = 20 ,
480
+ num_adaptation_steps = 25 ,
481
+ n_chains = 3 ,
482
482
init_step_size = init_step_size ,
483
483
num_leapfrog_steps = 5 ,
484
484
discard_tuning = False ,
@@ -504,8 +504,8 @@ def test_supply_partial_step_size(self):
504
504
_ , actual_step_size = tfp .experimental .mcmc .windowed_adaptive_hmc (
505
505
1 ,
506
506
jd_model ,
507
- num_adaptation_steps = 100 ,
508
- n_chains = 20 ,
507
+ num_adaptation_steps = 25 ,
508
+ n_chains = 3 ,
509
509
init_step_size = init_step_size ,
510
510
num_leapfrog_steps = 5 ,
511
511
discard_tuning = False ,
@@ -531,15 +531,15 @@ def test_supply_single_step_size(self):
531
531
tfp .experimental .mcmc .windowed_adaptive_hmc (
532
532
1 ,
533
533
jd_model ,
534
- num_adaptation_steps = 100 ,
534
+ num_adaptation_steps = 25 ,
535
535
n_chains = 20 ,
536
536
init_step_size = init_step_size ,
537
537
num_leapfrog_steps = 5 ,
538
538
discard_tuning = False ,
539
539
trace_fn = lambda * args : unnest .get_innermost (args [- 1 ], 'step_size' ),
540
540
seed = stream ()))
541
541
542
- self .assertEqual ((100 + 1 ,), traced_step_size .shape )
542
+ self .assertEqual ((25 + 1 ,), traced_step_size .shape )
543
543
self .assertAllClose (1. , traced_step_size [0 ])
544
544
545
545
def test_sequential_step_size (self ):
@@ -551,8 +551,8 @@ def test_sequential_step_size(self):
551
551
_ , actual_step_size = tfp .experimental .mcmc .windowed_adaptive_nuts (
552
552
1 ,
553
553
jd_model ,
554
- num_adaptation_steps = 100 ,
555
- n_chains = 20 ,
554
+ num_adaptation_steps = 25 ,
555
+ n_chains = 3 ,
556
556
init_step_size = init_step_size ,
557
557
discard_tuning = False ,
558
558
trace_fn = lambda * args : unnest .get_innermost (args [- 1 ], 'step_size' ),
0 commit comments