38
38
Root = tfd .JointDistributionCoroutine .Root
39
39
40
40
41
- def brownian_motion_prior_fn (num_timesteps , innovation_noise_scale ):
41
+ def brownian_motion_as_markov_chain (num_timesteps , innovation_noise_scale ):
42
+ return tfd .MarkovChain (
43
+ initial_state_prior = tfd .Normal (loc = 0. , scale = innovation_noise_scale ),
44
+ transition_fn = lambda _ , x_t : tfd .Normal ( # pylint: disable=g-long-lambda
45
+ loc = x_t , scale = innovation_noise_scale ),
46
+ num_steps = num_timesteps ,
47
+ name = 'locs' )
48
+
49
+
50
+ def brownian_motion_prior_fn (num_timesteps ,
51
+ innovation_noise_scale ):
42
52
"""Generative process for the Brownian Motion model."""
43
53
prior_loc = 0.
44
54
new = yield Root (tfd .Normal (loc = prior_loc ,
@@ -50,25 +60,34 @@ def brownian_motion_prior_fn(num_timesteps, innovation_noise_scale):
50
60
name = 'x_{}' .format (t ))
51
61
52
62
53
- def brownian_motion_unknown_scales_prior_fn (num_timesteps ):
63
+ def brownian_motion_unknown_scales_prior_fn (num_timesteps , use_markov_chain ):
54
64
"""Generative process for the Brownian Motion model with unknown scales."""
55
65
innovation_noise_scale = yield Root (tfd .LogNormal (
56
66
0. , 2. , name = 'innovation_noise_scale' ))
57
67
_ = yield Root (tfd .LogNormal (0. , 2. , name = 'observation_noise_scale' ))
58
- yield from brownian_motion_prior_fn (
59
- num_timesteps ,
60
- innovation_noise_scale = innovation_noise_scale )
68
+ if use_markov_chain :
69
+ yield brownian_motion_as_markov_chain (
70
+ num_timesteps = num_timesteps ,
71
+ innovation_noise_scale = innovation_noise_scale )
72
+ else :
73
+ yield from brownian_motion_prior_fn (
74
+ num_timesteps ,
75
+ innovation_noise_scale = innovation_noise_scale )
61
76
62
77
63
78
def brownian_motion_log_likelihood_fn (values ,
64
79
observed_locs ,
80
+ use_markov_chain ,
65
81
observation_noise_scale = None ):
66
82
"""Likelihood of observed data under the Brownian Motion model."""
67
83
if observation_noise_scale is None :
68
- (_ , observation_noise_scale ), values = values [:2 ], values [2 :]
84
+ (_ , observation_noise_scale ) = values [:2 ]
85
+ latents = values [2 ] if use_markov_chain else tf .stack (values [2 :], axis = - 1 )
86
+ else :
87
+ latents = values if use_markov_chain else tf .stack (values , axis = - 1 )
88
+
69
89
observation_noise_scale = tf .convert_to_tensor (
70
90
observation_noise_scale , name = 'observation_noise_scale' )
71
- latents = tf .stack (values , axis = - 1 )
72
91
is_observed = ~ tf .math .is_nan (observed_locs )
73
92
lps = tfd .Normal (
74
93
loc = latents , scale = observation_noise_scale [..., tf .newaxis ]).log_prob (
@@ -98,6 +117,7 @@ def __init__(self,
98
117
observed_locs ,
99
118
innovation_noise_scale ,
100
119
observation_noise_scale ,
120
+ use_markov_chain = False ,
101
121
name = 'brownian_motion' ,
102
122
pretty_name = 'Brownian Motion' ):
103
123
"""Construct the Brownian Motion model.
@@ -107,35 +127,52 @@ def __init__(self,
107
127
unobserved.
108
128
innovation_noise_scale: Python `float`.
109
129
observation_noise_scale: Python `float`.
130
+ use_markov_chain: Python `bool` indicating whether to use the
131
+ `MarkovChain` distribution in place of separate random variables for
132
+ each time step. The default of `False` is for backwards compatibility;
133
+ setting this to `True` should significantly improve performance.
110
134
name: Python `str` name prefixed to Ops created by this class.
111
135
pretty_name: A Python `str`. The pretty name of this model.
112
136
"""
113
137
with tf .name_scope (name ):
114
138
num_timesteps = observed_locs .shape [0 ]
115
- self ._prior_dist = tfd .JointDistributionCoroutine (
116
- functools .partial (
117
- brownian_motion_prior_fn ,
118
- num_timesteps = num_timesteps ,
119
- innovation_noise_scale = innovation_noise_scale ))
139
+ if use_markov_chain :
140
+ self ._prior_dist = brownian_motion_as_markov_chain (
141
+ num_timesteps = num_timesteps ,
142
+ innovation_noise_scale = innovation_noise_scale )
143
+ else :
144
+ self ._prior_dist = tfd .JointDistributionCoroutine (
145
+ functools .partial (
146
+ brownian_motion_prior_fn ,
147
+ num_timesteps = num_timesteps ,
148
+ innovation_noise_scale = innovation_noise_scale ))
120
149
121
150
self ._log_likelihood_fn = functools .partial (
122
151
brownian_motion_log_likelihood_fn ,
123
152
observation_noise_scale = observation_noise_scale ,
124
- observed_locs = observed_locs )
153
+ observed_locs = observed_locs ,
154
+ use_markov_chain = use_markov_chain )
125
155
126
156
def _ext_identity (params ):
127
157
return tf .stack (params , axis = - 1 )
128
158
159
+ def _ext_identity_markov_chain (params ):
160
+ return params
161
+
129
162
sample_transformations = {
130
163
'identity' :
131
164
model .Model .SampleTransformation (
132
- fn = _ext_identity ,
165
+ fn = (_ext_identity_markov_chain
166
+ if use_markov_chain else _ext_identity ),
133
167
pretty_name = 'Identity' ,
134
168
)
135
169
}
136
170
137
- event_space_bijector = type (
138
- self ._prior_dist .dtype )(* ([tfb .Identity ()] * num_timesteps ))
171
+ if use_markov_chain :
172
+ event_space_bijector = tfb .Identity ()
173
+ else :
174
+ event_space_bijector = type (
175
+ self ._prior_dist .dtype )(* ([tfb .Identity ()] * num_timesteps ))
139
176
super (BrownianMotion , self ).__init__ (
140
177
default_event_space_bijector = event_space_bijector ,
141
178
event_shape = self ._prior_dist .event_shape ,
@@ -157,11 +194,12 @@ class BrownianMotionMissingMiddleObservations(BrownianMotion):
157
194
158
195
GROUND_TRUTH_MODULE = brownian_motion_missing_middle_observations
159
196
160
- def __init__ (self ):
197
+ def __init__ (self , use_markov_chain = False ):
161
198
dataset = data .brownian_motion_missing_middle_observations ()
162
199
super (BrownianMotionMissingMiddleObservations , self ).__init__ (
163
200
name = 'brownian_motion_missing_middle_observations' ,
164
201
pretty_name = 'Brownian Motion Missing Middle Observations' ,
202
+ use_markov_chain = use_markov_chain ,
165
203
** dataset )
166
204
167
205
@@ -188,13 +226,19 @@ class BrownianMotionUnknownScales(bayesian_model.BayesianModel):
188
226
189
227
def __init__ (self ,
190
228
observed_locs ,
229
+ use_markov_chain = False ,
191
230
name = 'brownian_motion_unknown_scales' ,
192
231
pretty_name = 'Brownian Motion with Unknown Scales' ):
193
232
"""Construct the Brownian Motion model with unknown scales.
194
233
195
234
Args:
196
235
observed_locs: Array of loc parameters with nan value if loc is
197
236
unobserved.
237
+ use_markov_chain: Python `bool` indicating whether to use the
238
+ `MarkovChain` distribution in place of separate random variables for
239
+ each time step. The default of `False` is for backwards compatibility;
240
+ setting this to `True` should significantly improve performance.
241
+ Default value: `False`.
198
242
name: Python `str` name prefixed to Ops created by this class.
199
243
pretty_name: A Python `str`. The pretty name of this model.
200
244
"""
@@ -203,16 +247,20 @@ def __init__(self,
203
247
self ._prior_dist = tfd .JointDistributionCoroutine (
204
248
functools .partial (
205
249
brownian_motion_unknown_scales_prior_fn ,
250
+ use_markov_chain = use_markov_chain ,
206
251
num_timesteps = num_timesteps ))
207
252
208
253
self ._log_likelihood_fn = functools .partial (
209
254
brownian_motion_log_likelihood_fn ,
255
+ use_markov_chain = use_markov_chain ,
210
256
observed_locs = observed_locs )
211
257
212
258
def _ext_identity (params ):
213
259
return {'innovation_noise_scale' : params [0 ],
214
260
'observation_noise_scale' : params [1 ],
215
- 'locs' : tf .stack (params [2 :], axis = - 1 )}
261
+ 'locs' : (params [2 ]
262
+ if use_markov_chain
263
+ else tf .stack (params [2 :], axis = - 1 ))}
216
264
217
265
sample_transformations = {
218
266
'identity' :
@@ -228,7 +276,8 @@ def _ext_identity(params):
228
276
self ._prior_dist .dtype )(* (
229
277
[tfb .Softplus (),
230
278
tfb .Softplus ()
231
- ] + [tfb .Identity ()] * num_timesteps ))
279
+ ] + [tfb .Identity ()] * (
280
+ 1 if use_markov_chain else num_timesteps )))
232
281
super (BrownianMotionUnknownScales , self ).__init__ (
233
282
default_event_space_bijector = event_space_bijector ,
234
283
event_shape = self ._prior_dist .event_shape ,
@@ -252,11 +301,12 @@ class BrownianMotionUnknownScalesMissingMiddleObservations(
252
301
GROUND_TRUTH_MODULE = (
253
302
brownian_motion_unknown_scales_missing_middle_observations )
254
303
255
- def __init__ (self ):
304
+ def __init__ (self , use_markov_chain = False ):
256
305
dataset = data .brownian_motion_missing_middle_observations ()
257
306
del dataset ['innovation_noise_scale' ]
258
307
del dataset ['observation_noise_scale' ]
259
308
super (BrownianMotionUnknownScalesMissingMiddleObservations , self ).__init__ (
260
309
name = 'brownian_motion_unknown_scales_missing_middle_observations' ,
261
310
pretty_name = 'Brownian Motion with Unknown Scales' ,
311
+ use_markov_chain = use_markov_chain ,
262
312
** dataset )
0 commit comments