Skip to content

Commit 7062386

Browse files
davmretensorflower-gardener
authored andcommitted
Add use_markov_chain flag to the inference gym's Brownian motion model.
PiperOrigin-RevId: 385001418
1 parent f572d16 commit 7062386

File tree

3 files changed

+104
-20
lines changed

3 files changed

+104
-20
lines changed

spinoffs/inference_gym/inference_gym/targets/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,10 @@ py_test(
128128
srcs_version = "PY3",
129129
deps = [
130130
":brownian_motion",
131+
# absl/testing:parameterized dep,
131132
# numpy dep,
132133
# tensorflow dep,
134+
# tensorflow_probability/python/internal:test_util dep,
133135
"//inference_gym/internal:test_util",
134136
],
135137
)

spinoffs/inference_gym/inference_gym/targets/brownian_motion.py

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,17 @@
3838
Root = tfd.JointDistributionCoroutine.Root
3939

4040

41-
def brownian_motion_prior_fn(num_timesteps, innovation_noise_scale):
41+
def brownian_motion_as_markov_chain(num_timesteps, innovation_noise_scale):
42+
return tfd.MarkovChain(
43+
initial_state_prior=tfd.Normal(loc=0., scale=innovation_noise_scale),
44+
transition_fn=lambda _, x_t: tfd.Normal( # pylint: disable=g-long-lambda
45+
loc=x_t, scale=innovation_noise_scale),
46+
num_steps=num_timesteps,
47+
name='locs')
48+
49+
50+
def brownian_motion_prior_fn(num_timesteps,
51+
innovation_noise_scale):
4252
"""Generative process for the Brownian Motion model."""
4353
prior_loc = 0.
4454
new = yield Root(tfd.Normal(loc=prior_loc,
@@ -50,25 +60,34 @@ def brownian_motion_prior_fn(num_timesteps, innovation_noise_scale):
5060
name='x_{}'.format(t))
5161

5262

53-
def brownian_motion_unknown_scales_prior_fn(num_timesteps):
63+
def brownian_motion_unknown_scales_prior_fn(num_timesteps, use_markov_chain):
5464
"""Generative process for the Brownian Motion model with unknown scales."""
5565
innovation_noise_scale = yield Root(tfd.LogNormal(
5666
0., 2., name='innovation_noise_scale'))
5767
_ = yield Root(tfd.LogNormal(0., 2., name='observation_noise_scale'))
58-
yield from brownian_motion_prior_fn(
59-
num_timesteps,
60-
innovation_noise_scale=innovation_noise_scale)
68+
if use_markov_chain:
69+
yield brownian_motion_as_markov_chain(
70+
num_timesteps=num_timesteps,
71+
innovation_noise_scale=innovation_noise_scale)
72+
else:
73+
yield from brownian_motion_prior_fn(
74+
num_timesteps,
75+
innovation_noise_scale=innovation_noise_scale)
6176

6277

6378
def brownian_motion_log_likelihood_fn(values,
6479
observed_locs,
80+
use_markov_chain,
6581
observation_noise_scale=None):
6682
"""Likelihood of observed data under the Brownian Motion model."""
6783
if observation_noise_scale is None:
68-
(_, observation_noise_scale), values = values[:2], values[2:]
84+
(_, observation_noise_scale) = values[:2]
85+
latents = values[2] if use_markov_chain else tf.stack(values[2:], axis=-1)
86+
else:
87+
latents = values if use_markov_chain else tf.stack(values, axis=-1)
88+
6989
observation_noise_scale = tf.convert_to_tensor(
7090
observation_noise_scale, name='observation_noise_scale')
71-
latents = tf.stack(values, axis=-1)
7291
is_observed = ~tf.math.is_nan(observed_locs)
7392
lps = tfd.Normal(
7493
loc=latents, scale=observation_noise_scale[..., tf.newaxis]).log_prob(
@@ -98,6 +117,7 @@ def __init__(self,
98117
observed_locs,
99118
innovation_noise_scale,
100119
observation_noise_scale,
120+
use_markov_chain=False,
101121
name='brownian_motion',
102122
pretty_name='Brownian Motion'):
103123
"""Construct the Brownian Motion model.
@@ -107,35 +127,52 @@ def __init__(self,
107127
unobserved.
108128
innovation_noise_scale: Python `float`.
109129
observation_noise_scale: Python `float`.
130+
use_markov_chain: Python `bool` indicating whether to use the
131+
`MarkovChain` distribution in place of separate random variables for
132+
each time step. The default of `False` is for backwards compatibility;
133+
setting this to `True` should significantly improve performance.
110134
name: Python `str` name prefixed to Ops created by this class.
111135
pretty_name: A Python `str`. The pretty name of this model.
112136
"""
113137
with tf.name_scope(name):
114138
num_timesteps = observed_locs.shape[0]
115-
self._prior_dist = tfd.JointDistributionCoroutine(
116-
functools.partial(
117-
brownian_motion_prior_fn,
118-
num_timesteps=num_timesteps,
119-
innovation_noise_scale=innovation_noise_scale))
139+
if use_markov_chain:
140+
self._prior_dist = brownian_motion_as_markov_chain(
141+
num_timesteps=num_timesteps,
142+
innovation_noise_scale=innovation_noise_scale)
143+
else:
144+
self._prior_dist = tfd.JointDistributionCoroutine(
145+
functools.partial(
146+
brownian_motion_prior_fn,
147+
num_timesteps=num_timesteps,
148+
innovation_noise_scale=innovation_noise_scale))
120149

121150
self._log_likelihood_fn = functools.partial(
122151
brownian_motion_log_likelihood_fn,
123152
observation_noise_scale=observation_noise_scale,
124-
observed_locs=observed_locs)
153+
observed_locs=observed_locs,
154+
use_markov_chain=use_markov_chain)
125155

126156
def _ext_identity(params):
127157
return tf.stack(params, axis=-1)
128158

159+
def _ext_identity_markov_chain(params):
160+
return params
161+
129162
sample_transformations = {
130163
'identity':
131164
model.Model.SampleTransformation(
132-
fn=_ext_identity,
165+
fn=(_ext_identity_markov_chain
166+
if use_markov_chain else _ext_identity),
133167
pretty_name='Identity',
134168
)
135169
}
136170

137-
event_space_bijector = type(
138-
self._prior_dist.dtype)(*([tfb.Identity()] * num_timesteps))
171+
if use_markov_chain:
172+
event_space_bijector = tfb.Identity()
173+
else:
174+
event_space_bijector = type(
175+
self._prior_dist.dtype)(*([tfb.Identity()] * num_timesteps))
139176
super(BrownianMotion, self).__init__(
140177
default_event_space_bijector=event_space_bijector,
141178
event_shape=self._prior_dist.event_shape,
@@ -157,11 +194,12 @@ class BrownianMotionMissingMiddleObservations(BrownianMotion):
157194

158195
GROUND_TRUTH_MODULE = brownian_motion_missing_middle_observations
159196

160-
def __init__(self):
197+
def __init__(self, use_markov_chain=False):
161198
dataset = data.brownian_motion_missing_middle_observations()
162199
super(BrownianMotionMissingMiddleObservations, self).__init__(
163200
name='brownian_motion_missing_middle_observations',
164201
pretty_name='Brownian Motion Missing Middle Observations',
202+
use_markov_chain=use_markov_chain,
165203
**dataset)
166204

167205

@@ -188,13 +226,19 @@ class BrownianMotionUnknownScales(bayesian_model.BayesianModel):
188226

189227
def __init__(self,
190228
observed_locs,
229+
use_markov_chain=False,
191230
name='brownian_motion_unknown_scales',
192231
pretty_name='Brownian Motion with Unknown Scales'):
193232
"""Construct the Brownian Motion model with unknown scales.
194233
195234
Args:
196235
observed_locs: Array of loc parameters with nan value if loc is
197236
unobserved.
237+
use_markov_chain: Python `bool` indicating whether to use the
238+
`MarkovChain` distribution in place of separate random variables for
239+
each time step. The default of `False` is for backwards compatibility;
240+
setting this to `True` should significantly improve performance.
241+
Default value: `False`.
198242
name: Python `str` name prefixed to Ops created by this class.
199243
pretty_name: A Python `str`. The pretty name of this model.
200244
"""
@@ -203,16 +247,20 @@ def __init__(self,
203247
self._prior_dist = tfd.JointDistributionCoroutine(
204248
functools.partial(
205249
brownian_motion_unknown_scales_prior_fn,
250+
use_markov_chain=use_markov_chain,
206251
num_timesteps=num_timesteps))
207252

208253
self._log_likelihood_fn = functools.partial(
209254
brownian_motion_log_likelihood_fn,
255+
use_markov_chain=use_markov_chain,
210256
observed_locs=observed_locs)
211257

212258
def _ext_identity(params):
213259
return {'innovation_noise_scale': params[0],
214260
'observation_noise_scale': params[1],
215-
'locs': tf.stack(params[2:], axis=-1)}
261+
'locs': (params[2]
262+
if use_markov_chain
263+
else tf.stack(params[2:], axis=-1))}
216264

217265
sample_transformations = {
218266
'identity':
@@ -228,7 +276,8 @@ def _ext_identity(params):
228276
self._prior_dist.dtype)(*(
229277
[tfb.Softplus(),
230278
tfb.Softplus()
231-
] + [tfb.Identity()] * num_timesteps))
279+
] + [tfb.Identity()] * (
280+
1 if use_markov_chain else num_timesteps)))
232281
super(BrownianMotionUnknownScales, self).__init__(
233282
default_event_space_bijector=event_space_bijector,
234283
event_shape=self._prior_dist.event_shape,
@@ -252,11 +301,12 @@ class BrownianMotionUnknownScalesMissingMiddleObservations(
252301
GROUND_TRUTH_MODULE = (
253302
brownian_motion_unknown_scales_missing_middle_observations)
254303

255-
def __init__(self):
304+
def __init__(self, use_markov_chain=False):
256305
dataset = data.brownian_motion_missing_middle_observations()
257306
del dataset['innovation_noise_scale']
258307
del dataset['observation_noise_scale']
259308
super(BrownianMotionUnknownScalesMissingMiddleObservations, self).__init__(
260309
name='brownian_motion_unknown_scales_missing_middle_observations',
261310
pretty_name='Brownian Motion with Unknown Scales',
311+
use_markov_chain=use_markov_chain,
262312
**dataset)

spinoffs/inference_gym/inference_gym/targets/brownian_motion_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import tensorflow.compat.v2 as tf
2121

22+
from tensorflow_probability.python.internal import test_util as tfp_test_util
2223
from inference_gym.internal import test_util
2324
from inference_gym.targets import brownian_motion
2425

@@ -112,5 +113,36 @@ def testBrownianMotionUnknownScalesMissingMiddleObservationsHMC(self):
112113
step_size=0.03,
113114
)
114115

116+
def testBrownianMotionMarkovChainLogprobMatchesOriginal(self):
117+
model = (
118+
brownian_motion.BrownianMotionMissingMiddleObservations(
119+
use_markov_chain=False))
120+
markov_chain_model = (
121+
brownian_motion.BrownianMotionMissingMiddleObservations(
122+
use_markov_chain=True))
123+
124+
x = self.evaluate(model.prior_distribution().sample(
125+
400, seed=tfp_test_util.test_seed()))
126+
self.assertAllClose(model.unnormalized_log_prob(x),
127+
markov_chain_model.unnormalized_log_prob(
128+
tf.stack(x, axis=-1)),
129+
atol=1e-2)
130+
131+
def testBrownianMotionUnknownScalesMarkovChainLogprobMatchesOriginal(self):
132+
model = (
133+
brownian_motion.BrownianMotionUnknownScalesMissingMiddleObservations(
134+
use_markov_chain=False))
135+
markov_chain_model = (
136+
brownian_motion.BrownianMotionUnknownScalesMissingMiddleObservations(
137+
use_markov_chain=True))
138+
139+
x = self.evaluate(model.prior_distribution().sample(
140+
400, seed=tfp_test_util.test_seed()))
141+
self.assertAllClose(model.unnormalized_log_prob(x),
142+
markov_chain_model.unnormalized_log_prob(
143+
type(markov_chain_model.dtype)(
144+
x[0], x[1], tf.stack(x[2:], axis=-1))),
145+
atol=1e-2)
146+
115147
if __name__ == '__main__':
116148
tf.test.main()

0 commit comments

Comments
 (0)