Skip to content

Commit d4de7b1

Browse files
committed
added global flow
1 parent 94df7b1 commit d4de7b1

File tree

1 file changed

+80
-21
lines changed

1 file changed

+80
-21
lines changed

tensorflow_probability/python/experimental/vi/cascading_flows.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,20 @@
2525

2626
import tensorflow.compat.v2 as tf
2727

28-
from tensorflow_probability.python.experimental.bijectors import \
29-
build_highway_flow_layer
3028
from tensorflow_probability.python.bijectors import chain
3129
from tensorflow_probability.python.bijectors import reshape
3230
from tensorflow_probability.python.bijectors import scale as scale_lib
3331
from tensorflow_probability.python.bijectors import shift
3432
from tensorflow_probability.python.bijectors import split
35-
3633
from tensorflow_probability.python.distributions import batch_broadcast
3734
from tensorflow_probability.python.distributions import beta
3835
from tensorflow_probability.python.distributions import blockwise
3936
from tensorflow_probability.python.distributions import chi2
37+
from tensorflow_probability.python.distributions import deterministic
4038
from tensorflow_probability.python.distributions import exponential
4139
from tensorflow_probability.python.distributions import gamma
4240
from tensorflow_probability.python.distributions import half_normal
41+
from tensorflow_probability.python.distributions import independent
4342
from tensorflow_probability.python.distributions import \
4443
joint_distribution_auto_batched
4544
from tensorflow_probability.python.distributions import \
@@ -49,10 +48,12 @@
4948
from tensorflow_probability.python.distributions import transformed_distribution
5049
from tensorflow_probability.python.distributions import truncated_normal
5150
from tensorflow_probability.python.distributions import uniform
51+
from tensorflow_probability.python.experimental.bijectors import \
52+
build_highway_flow_layer
5253
from tensorflow_probability.python.internal import samplers
5354

5455
__all__ = [
55-
'register_asvi_substitution_rule',
56+
'register_cf_substitution_rule',
5657
'build_cf_surrogate_posterior'
5758
]
5859

@@ -83,7 +84,7 @@ def _as_substituted_distribution(distribution):
8384

8485

8586
# Todo: inherited from asvi code, do we need this?
86-
def register_asvi_substitution_rule(condition, substitution_fn):
87+
def register_cf_substitution_rule(condition, substitution_fn):
8788
"""Registers a rule for substituting distributions in ASVI surrogates.
8889
8990
Args:
@@ -132,20 +133,20 @@ def register_asvi_substitution_rule(condition, substitution_fn):
132133
# Default substitutions attempt to express distributions using the most
133134
# flexible available parameterization.
134135
# pylint: disable=g-long-lambda
135-
register_asvi_substitution_rule(
136+
register_cf_substitution_rule(
136137
half_normal.HalfNormal,
137138
lambda dist: truncated_normal.TruncatedNormal(
138139
loc=0., scale=dist.scale, low=0., high=dist.scale * 10.))
139-
register_asvi_substitution_rule(
140+
register_cf_substitution_rule(
140141
uniform.Uniform,
141142
lambda dist: shift.Shift(dist.low)(
142143
scale_lib.Scale(dist.high - dist.low)(
143144
beta.Beta(concentration0=tf.ones_like(dist.mean()),
144145
concentration1=1.))))
145-
register_asvi_substitution_rule(
146+
register_cf_substitution_rule(
146147
exponential.Exponential,
147148
lambda dist: gamma.Gamma(concentration=1., rate=dist.rate))
148-
register_asvi_substitution_rule(
149+
register_cf_substitution_rule(
149150
chi2.Chi2,
150151
lambda dist: gamma.Gamma(concentration=0.5 * dist.df, rate=0.5))
151152

@@ -255,6 +256,7 @@ def model_fn():
255256
_cf_convex_update_for_base_distribution,
256257
initial_prior_weight=initial_prior_weight,
257258
num_auxiliary_variables=num_auxiliary_variables),
259+
num_auxiliary_variables=num_auxiliary_variables,
258260
seed=seed)
259261
surrogate_posterior.also_track = variables
260262
return surrogate_posterior
@@ -264,6 +266,8 @@ def _cf_surrogate_for_distribution(dist,
264266
base_distribution_surrogate_fn,
265267
sample_shape=None,
266268
variables=None,
269+
num_auxiliary_variables=0,
270+
global_auxiliary_variables=None,
267271
seed=None):
268272
# todo: change docstrings
269273
"""Recursively creates ASVI surrogates, and creates new variables if needed.
@@ -303,15 +307,19 @@ def _cf_surrogate_for_distribution(dist,
303307
dist,
304308
base_distribution_surrogate_fn=base_distribution_surrogate_fn,
305309
variables=variables,
310+
num_auxiliary_variables=num_auxiliary_variables,
311+
global_auxiliary_variables=global_auxiliary_variables,
306312
seed=seed)
307313
else:
308314
surrogate_posterior, variables = base_distribution_surrogate_fn(
309-
dist=dist, sample_shape=sample_shape, variables=variables, seed=seed)
315+
dist=dist, sample_shape=sample_shape, variables=variables,
316+
global_auxiliary_variables=global_auxiliary_variables, seed=seed)
310317
return surrogate_posterior, variables
311318

312319

313320
def _cf_surrogate_for_joint_distribution(
314-
dist, base_distribution_surrogate_fn, variables=None, seed=None):
321+
dist, base_distribution_surrogate_fn, variables=None,
322+
num_auxiliary_variables=0, global_auxiliary_variables=None, seed=None):
315323
"""Builds a structured joint surrogate posterior for a joint model."""
316324

317325
# Probabilistic program for CF surrogate posterior.
@@ -322,7 +330,46 @@ def _cf_surrogate_for_joint_distribution(
322330
def posterior_generator(seed=seed):
323331
prior_gen = prior_coroutine()
324332
dist = next(prior_gen)
325-
i = 0
333+
334+
if num_auxiliary_variables > 0:
335+
i = 1
336+
337+
if flat_variables:
338+
variables = flat_variables[0]
339+
340+
else:
341+
layers = 3
342+
bijectors = []
343+
344+
for _ in range(0, layers - 1):
345+
bijectors.append(
346+
build_highway_flow_layer(num_auxiliary_variables,
347+
residual_fraction_initial_value=0.5,
348+
activation_fn=True, gate_first_n=0,
349+
seed=seed))
350+
bijectors.append(
351+
build_highway_flow_layer(num_auxiliary_variables,
352+
residual_fraction_initial_value=0.5,
353+
activation_fn=False, gate_first_n=0,
354+
seed=seed))
355+
356+
variables = chain.Chain(bijectors=list(reversed(bijectors)))
357+
358+
eps = transformed_distribution.TransformedDistribution(
359+
distribution=sample.Sample(normal.Normal(0., 0.1),
360+
num_auxiliary_variables),
361+
bijector=variables)
362+
363+
eps = Root(eps)
364+
365+
value_out = yield (eps if flat_variables
366+
else (eps, variables))
367+
368+
global_auxiliary_variables = value_out
369+
370+
else:
371+
i = 0
372+
326373
try:
327374
while True:
328375
was_root = isinstance(dist, Root)
@@ -334,9 +381,10 @@ def posterior_generator(seed=seed):
334381
dist,
335382
base_distribution_surrogate_fn=base_distribution_surrogate_fn,
336383
variables=flat_variables[i] if flat_variables else None,
384+
global_auxiliary_variables=global_auxiliary_variables,
337385
seed=init_seed)
338386

339-
if was_root:
387+
if was_root and num_auxiliary_variables == 0:
340388
surrogate_posterior = Root(surrogate_posterior)
341389
# If variables were not given---i.e., we're creating new
342390
# variables---then yield the new variables along with the surrogate
@@ -367,6 +415,8 @@ def posterior_generator(seed=seed):
367415
return _cf_surrogate_for_joint_distribution(
368416
dist=dist,
369417
base_distribution_surrogate_fn=base_distribution_surrogate_fn,
418+
num_auxiliary_variables=num_auxiliary_variables,
419+
global_auxiliary_variables=global_auxiliary_variables,
370420
variables=dist._model_unflatten( # pylint: disable=protected-access
371421
_extract_variables_from_coroutine_model(
372422
posterior_generator, seed=seed)))
@@ -401,6 +451,7 @@ def posterior_generator(seed=seed):
401451
def _cf_convex_update_for_base_distribution(dist,
402452
initial_prior_weight,
403453
num_auxiliary_variables=0,
454+
global_auxiliary_variables=None,
404455
sample_shape=None,
405456
variables=None,
406457
seed=None):
@@ -412,31 +463,39 @@ def _cf_convex_update_for_base_distribution(dist,
412463
actual_event_shape.shape.as_list()[0] > 0 else 1
413464
layers = 3
414465
bijectors = [reshape.Reshape([-1],
415-
event_shape_in=actual_event_shape +
416-
num_auxiliary_variables)]
466+
event_shape_in=actual_event_shape +
467+
num_auxiliary_variables)]
417468

418469
for _ in range(0, layers - 1):
419470
bijectors.append(
420471
build_highway_flow_layer(
421472
tf.reduce_prod(actual_event_shape + num_auxiliary_variables),
422473
residual_fraction_initial_value=initial_prior_weight,
423-
activation_fn=True, gate_first_n=int_event_shape))
474+
activation_fn=True, gate_first_n=int_event_shape, seed=seed))
424475
bijectors.append(
425476
build_highway_flow_layer(
426477
tf.reduce_prod(actual_event_shape + num_auxiliary_variables),
427478
residual_fraction_initial_value=initial_prior_weight,
428-
activation_fn=False, gate_first_n=int_event_shape))
429-
bijectors.append(reshape.Reshape(actual_event_shape + num_auxiliary_variables))
479+
activation_fn=False, gate_first_n=int_event_shape, seed=seed))
480+
bijectors.append(
481+
reshape.Reshape(actual_event_shape + num_auxiliary_variables))
430482

431483
variables = chain.Chain(bijectors=list(reversed(bijectors)))
432484

433485
if num_auxiliary_variables > 0:
486+
batch_shape = global_auxiliary_variables.shape[0] if len(
487+
global_auxiliary_variables.shape) > 1 else []
488+
434489
cascading_flows = split.Split(
435490
[-1, num_auxiliary_variables])(
436491
transformed_distribution.TransformedDistribution(
437-
distribution=blockwise.Blockwise([dist, batch_broadcast.BatchBroadcast(
438-
sample.Sample(normal.Normal(0., .1), num_auxiliary_variables),
439-
to_shape=dist.batch_shape)]),
492+
distribution=blockwise.Blockwise([
493+
batch_broadcast.BatchBroadcast(dist,
494+
to_shape=batch_shape),
495+
independent.Independent(
496+
deterministic.Deterministic(
497+
global_auxiliary_variables),
498+
reinterpreted_batch_ndims=1)]),
440499
bijector=variables))
441500

442501
else:

0 commit comments

Comments
 (0)