Skip to content

Commit 5d1fd40

Browse files
ColCarrolltensorflower-gardener
authored andcommitted
Make step size tests a little smaller.
PiperOrigin-RevId: 377373340
1 parent 827ea3d commit 5d1fd40

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

tensorflow_probability/python/experimental/mcmc/windowed_sampling_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -471,14 +471,14 @@ def test_supply_full_step_size(self):
471471
loc=tf.zeros(3), scale_diag=tf.constant([1., 2., 3.]))
472472
})
473473

474-
init_step_size = {'a': tf.reshape(tf.linspace(1., 2., 20), (20, 1)),
475-
'b': tf.reshape(tf.linspace(1., 2., 60), (20, 3))}
474+
init_step_size = {'a': tf.reshape(tf.linspace(1., 2., 3), (3, 1)),
475+
'b': tf.reshape(tf.linspace(1., 2., 9), (3, 3))}
476476

477477
_, actual_step_size = tfp.experimental.mcmc.windowed_adaptive_hmc(
478478
1,
479479
jd_model,
480-
num_adaptation_steps=100,
481-
n_chains=20,
480+
num_adaptation_steps=25,
481+
n_chains=3,
482482
init_step_size=init_step_size,
483483
num_leapfrog_steps=5,
484484
discard_tuning=False,
@@ -504,8 +504,8 @@ def test_supply_partial_step_size(self):
504504
_, actual_step_size = tfp.experimental.mcmc.windowed_adaptive_hmc(
505505
1,
506506
jd_model,
507-
num_adaptation_steps=100,
508-
n_chains=20,
507+
num_adaptation_steps=25,
508+
n_chains=3,
509509
init_step_size=init_step_size,
510510
num_leapfrog_steps=5,
511511
discard_tuning=False,
@@ -531,15 +531,15 @@ def test_supply_single_step_size(self):
531531
tfp.experimental.mcmc.windowed_adaptive_hmc(
532532
1,
533533
jd_model,
534-
num_adaptation_steps=100,
534+
num_adaptation_steps=25,
535535
n_chains=20,
536536
init_step_size=init_step_size,
537537
num_leapfrog_steps=5,
538538
discard_tuning=False,
539539
trace_fn=lambda *args: unnest.get_innermost(args[-1], 'step_size'),
540540
seed=stream()))
541541

542-
self.assertEqual((100 + 1,), traced_step_size.shape)
542+
self.assertEqual((25 + 1,), traced_step_size.shape)
543543
self.assertAllClose(1., traced_step_size[0])
544544

545545
def test_sequential_step_size(self):
@@ -551,8 +551,8 @@ def test_sequential_step_size(self):
551551
_, actual_step_size = tfp.experimental.mcmc.windowed_adaptive_nuts(
552552
1,
553553
jd_model,
554-
num_adaptation_steps=100,
555-
n_chains=20,
554+
num_adaptation_steps=25,
555+
n_chains=3,
556556
init_step_size=init_step_size,
557557
discard_tuning=False,
558558
trace_fn=lambda *args: unnest.get_innermost(args[-1], 'step_size'),

0 commit comments

Comments
 (0)