24
24
from tensorflow_probability .python .internal import distribute_test_lib as test_lib
25
25
from tensorflow_probability .python .internal import samplers
26
26
from tensorflow_probability .python .internal import test_util
27
+ from tensorflow_probability .python .mcmc .internal import util as mcmc_util
27
28
28
29
tfd = tfp .distributions
29
30
tfp_dist = tfp .experimental .distribute
@@ -55,8 +56,16 @@ def test_diagonal_mass_matrix_no_distribute(self):
55
56
state = tf .zeros (3 )
56
57
pkr = kernel .bootstrap_results (state )
57
58
draws = np .random .randn (10 , 3 ).astype (np .float32 )
58
- for draw , seed in zip (draws , samplers .split_seed (self .key , draws .shape [0 ])):
59
- _ , pkr = kernel .one_step (draw , pkr , seed = seed )
59
+
60
+ def body (pkr_seed , draw ):
61
+ pkr , seed = pkr_seed
62
+ seed , kernel_seed = samplers .split_seed (seed )
63
+ _ , pkr = kernel .one_step (draw , pkr , seed = kernel_seed )
64
+ return (pkr , seed )
65
+
66
+ (pkr , _ ), _ = mcmc_util .trace_scan (body ,
67
+ (pkr , samplers .sanitize_seed (self .key )),
68
+ draws , lambda _ : ())
60
69
61
70
running_variance = pkr .running_variance [0 ]
62
71
emp_mean = draws .sum (axis = 0 ) / 20.
@@ -80,26 +89,31 @@ def run(seed):
80
89
tfp .experimental .stats .RunningVariance .from_stats (
81
90
num_samples = 10. , mean = tf .zeros (3 ), variance = tf .ones (3 )))
82
91
pkr = kernel .bootstrap_results (state )
83
- draws = []
84
- for seed in seeds :
92
+
93
+ def body (draw_pkr , seed ):
94
+ _ , pkr = draw_pkr
85
95
draw_seed , step_seed = samplers .split_seed (seed )
86
96
draw = dist .sample (seed = draw_seed )
87
97
_ , pkr = kernel .one_step (draw , pkr , seed = step_seed )
88
- draws .append (draw )
98
+ return draw , pkr
99
+
100
+ (_ , pkr ), draws = mcmc_util .trace_scan (body ,
101
+ (tf .zeros (dist .event_shape ), pkr ),
102
+ seeds , lambda v : v [0 ])
103
+
89
104
return draws , pkr
90
105
91
106
draws , pkr = self .strategy_run (run , (self .key ,), in_axes = None )
92
- draws = tf .stack (self .evaluate (self .per_replica_to_tensor (draws )), axis = 0 )
93
-
94
107
running_variance = self .per_replica_to_composite_tensor (
95
108
pkr .running_variance [0 ])
109
+ draws = self .per_replica_to_tensor (draws , axis = 1 )
110
+ mean , sum_squared_residuals , draws = self .evaluate (
111
+ (running_variance .mean , running_variance .sum_squared_residuals , draws ))
96
112
emp_mean = tf .reduce_sum (draws , axis = 0 ) / 20.
97
- emp_squared_residuals = (tf .reduce_sum ((draws - emp_mean ) ** 2 , axis = 0 ) +
98
- 10 * emp_mean ** 2 +
99
- 10 )
100
- self .assertAllClose (emp_mean , running_variance .mean )
101
- self .assertAllClose (emp_squared_residuals ,
102
- running_variance .sum_squared_residuals )
113
+ emp_squared_residuals = (
114
+ tf .reduce_sum ((draws - emp_mean )** 2 , axis = 0 ) + 10 * emp_mean ** 2 + 10 )
115
+ self .assertAllClose (emp_mean , mean )
116
+ self .assertAllClose (emp_squared_residuals , sum_squared_residuals )
103
117
104
118
def test_diagonal_mass_matrix_sample (self ):
105
119
@tf .function (autograph = False )
@@ -114,25 +128,29 @@ def run(seed):
114
128
tfp .experimental .stats .RunningVariance .from_stats (
115
129
num_samples = 10. , mean = tf .zeros (3 ), variance = tf .ones (3 )))
116
130
pkr = kernel .bootstrap_results (state )
117
- draws = []
118
- for seed in seeds :
131
+ def body ( draw_pkr , seed ):
132
+ _ , pkr = draw_pkr
119
133
draw_seed , step_seed = samplers .split_seed (seed )
120
134
draw = dist .sample (seed = draw_seed )
121
135
_ , pkr = kernel .one_step (draw , pkr , seed = step_seed )
122
- draws .append (draw )
136
+ return draw , pkr
137
+
138
+ (_ , pkr ), draws = mcmc_util .trace_scan (body ,
139
+ (tf .zeros (dist .event_shape ), pkr ),
140
+ seeds , lambda v : v [0 ])
123
141
return draws , pkr
124
142
125
143
draws , pkr = self .strategy_run (run , (self .key ,), in_axes = None )
126
- draws = tf .stack (self .evaluate (self .per_replica_to_tensor (draws )), axis = 0 )
127
-
128
144
running_variance = self .per_replica_to_composite_tensor (
129
145
pkr .running_variance [0 ])
146
+ draws = self .per_replica_to_tensor (draws , axis = 1 )
147
+ mean , sum_squared_residuals , draws = self .evaluate (
148
+ (running_variance .mean , running_variance .sum_squared_residuals , draws ))
130
149
emp_mean = tf .reduce_sum (draws , axis = 0 ) / 20.
131
150
emp_squared_residuals = tf .reduce_sum (
132
151
(draws - emp_mean [None , ...])** 2 , axis = 0 ) + 10 * emp_mean ** 2 + 10
133
- self .assertAllClose (emp_mean , running_variance .mean )
134
- self .assertAllClose (emp_squared_residuals ,
135
- running_variance .sum_squared_residuals )
152
+ self .assertAllClose (emp_mean , mean )
153
+ self .assertAllClose (emp_squared_residuals , sum_squared_residuals )
136
154
137
155
138
156
if __name__ == '__main__' :
0 commit comments