Skip to content

Commit ebb151f

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Speed up diagonal_mass_matrix_adaptation_test.
PiperOrigin-RevId: 387203021
1 parent 9ac2b6b commit ebb151f

File tree

1 file changed

+39
-21
lines changed

1 file changed

+39
-21
lines changed

tensorflow_probability/python/experimental/distribute/diagonal_mass_matrix_adaptation_test.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tensorflow_probability.python.internal import distribute_test_lib as test_lib
2525
from tensorflow_probability.python.internal import samplers
2626
from tensorflow_probability.python.internal import test_util
27+
from tensorflow_probability.python.mcmc.internal import util as mcmc_util
2728

2829
tfd = tfp.distributions
2930
tfp_dist = tfp.experimental.distribute
@@ -55,8 +56,16 @@ def test_diagonal_mass_matrix_no_distribute(self):
5556
state = tf.zeros(3)
5657
pkr = kernel.bootstrap_results(state)
5758
draws = np.random.randn(10, 3).astype(np.float32)
58-
for draw, seed in zip(draws, samplers.split_seed(self.key, draws.shape[0])):
59-
_, pkr = kernel.one_step(draw, pkr, seed=seed)
59+
60+
def body(pkr_seed, draw):
61+
pkr, seed = pkr_seed
62+
seed, kernel_seed = samplers.split_seed(seed)
63+
_, pkr = kernel.one_step(draw, pkr, seed=kernel_seed)
64+
return (pkr, seed)
65+
66+
(pkr, _), _ = mcmc_util.trace_scan(body,
67+
(pkr, samplers.sanitize_seed(self.key)),
68+
draws, lambda _: ())
6069

6170
running_variance = pkr.running_variance[0]
6271
emp_mean = draws.sum(axis=0) / 20.
@@ -80,26 +89,31 @@ def run(seed):
8089
tfp.experimental.stats.RunningVariance.from_stats(
8190
num_samples=10., mean=tf.zeros(3), variance=tf.ones(3)))
8291
pkr = kernel.bootstrap_results(state)
83-
draws = []
84-
for seed in seeds:
92+
93+
def body(draw_pkr, seed):
94+
_, pkr = draw_pkr
8595
draw_seed, step_seed = samplers.split_seed(seed)
8696
draw = dist.sample(seed=draw_seed)
8797
_, pkr = kernel.one_step(draw, pkr, seed=step_seed)
88-
draws.append(draw)
98+
return draw, pkr
99+
100+
(_, pkr), draws = mcmc_util.trace_scan(body,
101+
(tf.zeros(dist.event_shape), pkr),
102+
seeds, lambda v: v[0])
103+
89104
return draws, pkr
90105

91106
draws, pkr = self.strategy_run(run, (self.key,), in_axes=None)
92-
draws = tf.stack(self.evaluate(self.per_replica_to_tensor(draws)), axis=0)
93-
94107
running_variance = self.per_replica_to_composite_tensor(
95108
pkr.running_variance[0])
109+
draws = self.per_replica_to_tensor(draws, axis=1)
110+
mean, sum_squared_residuals, draws = self.evaluate(
111+
(running_variance.mean, running_variance.sum_squared_residuals, draws))
96112
emp_mean = tf.reduce_sum(draws, axis=0) / 20.
97-
emp_squared_residuals = (tf.reduce_sum((draws - emp_mean) ** 2, axis=0) +
98-
10 * emp_mean ** 2 +
99-
10)
100-
self.assertAllClose(emp_mean, running_variance.mean)
101-
self.assertAllClose(emp_squared_residuals,
102-
running_variance.sum_squared_residuals)
113+
emp_squared_residuals = (
114+
tf.reduce_sum((draws - emp_mean)**2, axis=0) + 10 * emp_mean**2 + 10)
115+
self.assertAllClose(emp_mean, mean)
116+
self.assertAllClose(emp_squared_residuals, sum_squared_residuals)
103117

104118
def test_diagonal_mass_matrix_sample(self):
105119
@tf.function(autograph=False)
@@ -114,25 +128,29 @@ def run(seed):
114128
tfp.experimental.stats.RunningVariance.from_stats(
115129
num_samples=10., mean=tf.zeros(3), variance=tf.ones(3)))
116130
pkr = kernel.bootstrap_results(state)
117-
draws = []
118-
for seed in seeds:
131+
def body(draw_pkr, seed):
132+
_, pkr = draw_pkr
119133
draw_seed, step_seed = samplers.split_seed(seed)
120134
draw = dist.sample(seed=draw_seed)
121135
_, pkr = kernel.one_step(draw, pkr, seed=step_seed)
122-
draws.append(draw)
136+
return draw, pkr
137+
138+
(_, pkr), draws = mcmc_util.trace_scan(body,
139+
(tf.zeros(dist.event_shape), pkr),
140+
seeds, lambda v: v[0])
123141
return draws, pkr
124142

125143
draws, pkr = self.strategy_run(run, (self.key,), in_axes=None)
126-
draws = tf.stack(self.evaluate(self.per_replica_to_tensor(draws)), axis=0)
127-
128144
running_variance = self.per_replica_to_composite_tensor(
129145
pkr.running_variance[0])
146+
draws = self.per_replica_to_tensor(draws, axis=1)
147+
mean, sum_squared_residuals, draws = self.evaluate(
148+
(running_variance.mean, running_variance.sum_squared_residuals, draws))
130149
emp_mean = tf.reduce_sum(draws, axis=0) / 20.
131150
emp_squared_residuals = tf.reduce_sum(
132151
(draws - emp_mean[None, ...])**2, axis=0) + 10 * emp_mean**2 + 10
133-
self.assertAllClose(emp_mean, running_variance.mean)
134-
self.assertAllClose(emp_squared_residuals,
135-
running_variance.sum_squared_residuals)
152+
self.assertAllClose(emp_mean, mean)
153+
self.assertAllClose(emp_squared_residuals, sum_squared_residuals)
136154

137155

138156
if __name__ == '__main__':

0 commit comments

Comments
 (0)