25
25
26
26
import tensorflow .compat .v2 as tf
27
27
28
- from tensorflow_probability .python .experimental .bijectors import \
29
- build_highway_flow_layer
30
28
from tensorflow_probability .python .bijectors import chain
31
29
from tensorflow_probability .python .bijectors import reshape
32
30
from tensorflow_probability .python .bijectors import scale as scale_lib
33
31
from tensorflow_probability .python .bijectors import shift
34
32
from tensorflow_probability .python .bijectors import split
35
-
36
33
from tensorflow_probability .python .distributions import batch_broadcast
37
34
from tensorflow_probability .python .distributions import beta
38
35
from tensorflow_probability .python .distributions import blockwise
39
36
from tensorflow_probability .python .distributions import chi2
37
+ from tensorflow_probability .python .distributions import deterministic
40
38
from tensorflow_probability .python .distributions import exponential
41
39
from tensorflow_probability .python .distributions import gamma
42
40
from tensorflow_probability .python .distributions import half_normal
41
+ from tensorflow_probability .python .distributions import independent
43
42
from tensorflow_probability .python .distributions import \
44
43
joint_distribution_auto_batched
45
44
from tensorflow_probability .python .distributions import \
49
48
from tensorflow_probability .python .distributions import transformed_distribution
50
49
from tensorflow_probability .python .distributions import truncated_normal
51
50
from tensorflow_probability .python .distributions import uniform
51
+ from tensorflow_probability .python .experimental .bijectors import \
52
+ build_highway_flow_layer
52
53
from tensorflow_probability .python .internal import samplers
53
54
54
55
__all__ = [
55
- 'register_asvi_substitution_rule ' ,
56
+ 'register_cf_substitution_rule ' ,
56
57
'build_cf_surrogate_posterior'
57
58
]
58
59
@@ -83,7 +84,7 @@ def _as_substituted_distribution(distribution):
83
84
84
85
85
86
# Todo: inherited from asvi code, do we need this?
86
- def register_asvi_substitution_rule (condition , substitution_fn ):
87
+ def register_cf_substitution_rule (condition , substitution_fn ):
87
88
"""Registers a rule for substituting distributions in ASVI surrogates.
88
89
89
90
Args:
@@ -132,20 +133,20 @@ def register_asvi_substitution_rule(condition, substitution_fn):
132
133
# Default substitutions attempt to express distributions using the most
133
134
# flexible available parameterization.
134
135
# pylint: disable=g-long-lambda
135
- register_asvi_substitution_rule (
136
+ register_cf_substitution_rule (
136
137
half_normal .HalfNormal ,
137
138
lambda dist : truncated_normal .TruncatedNormal (
138
139
loc = 0. , scale = dist .scale , low = 0. , high = dist .scale * 10. ))
139
- register_asvi_substitution_rule (
140
+ register_cf_substitution_rule (
140
141
uniform .Uniform ,
141
142
lambda dist : shift .Shift (dist .low )(
142
143
scale_lib .Scale (dist .high - dist .low )(
143
144
beta .Beta (concentration0 = tf .ones_like (dist .mean ()),
144
145
concentration1 = 1. ))))
145
- register_asvi_substitution_rule (
146
+ register_cf_substitution_rule (
146
147
exponential .Exponential ,
147
148
lambda dist : gamma .Gamma (concentration = 1. , rate = dist .rate ))
148
- register_asvi_substitution_rule (
149
+ register_cf_substitution_rule (
149
150
chi2 .Chi2 ,
150
151
lambda dist : gamma .Gamma (concentration = 0.5 * dist .df , rate = 0.5 ))
151
152
@@ -255,6 +256,7 @@ def model_fn():
255
256
_cf_convex_update_for_base_distribution ,
256
257
initial_prior_weight = initial_prior_weight ,
257
258
num_auxiliary_variables = num_auxiliary_variables ),
259
+ num_auxiliary_variables = num_auxiliary_variables ,
258
260
seed = seed )
259
261
surrogate_posterior .also_track = variables
260
262
return surrogate_posterior
@@ -264,6 +266,8 @@ def _cf_surrogate_for_distribution(dist,
264
266
base_distribution_surrogate_fn ,
265
267
sample_shape = None ,
266
268
variables = None ,
269
+ num_auxiliary_variables = 0 ,
270
+ global_auxiliary_variables = None ,
267
271
seed = None ):
268
272
# todo: change docstrings
269
273
"""Recursively creates ASVI surrogates, and creates new variables if needed.
@@ -303,15 +307,19 @@ def _cf_surrogate_for_distribution(dist,
303
307
dist ,
304
308
base_distribution_surrogate_fn = base_distribution_surrogate_fn ,
305
309
variables = variables ,
310
+ num_auxiliary_variables = num_auxiliary_variables ,
311
+ global_auxiliary_variables = global_auxiliary_variables ,
306
312
seed = seed )
307
313
else :
308
314
surrogate_posterior , variables = base_distribution_surrogate_fn (
309
- dist = dist , sample_shape = sample_shape , variables = variables , seed = seed )
315
+ dist = dist , sample_shape = sample_shape , variables = variables ,
316
+ global_auxiliary_variables = global_auxiliary_variables , seed = seed )
310
317
return surrogate_posterior , variables
311
318
312
319
313
320
def _cf_surrogate_for_joint_distribution (
314
- dist , base_distribution_surrogate_fn , variables = None , seed = None ):
321
+ dist , base_distribution_surrogate_fn , variables = None ,
322
+ num_auxiliary_variables = 0 , global_auxiliary_variables = None , seed = None ):
315
323
"""Builds a structured joint surrogate posterior for a joint model."""
316
324
317
325
# Probabilistic program for CF surrogate posterior.
@@ -322,7 +330,46 @@ def _cf_surrogate_for_joint_distribution(
322
330
def posterior_generator (seed = seed ):
323
331
prior_gen = prior_coroutine ()
324
332
dist = next (prior_gen )
325
- i = 0
333
+
334
+ if num_auxiliary_variables > 0 :
335
+ i = 1
336
+
337
+ if flat_variables :
338
+ variables = flat_variables [0 ]
339
+
340
+ else :
341
+ layers = 3
342
+ bijectors = []
343
+
344
+ for _ in range (0 , layers - 1 ):
345
+ bijectors .append (
346
+ build_highway_flow_layer (num_auxiliary_variables ,
347
+ residual_fraction_initial_value = 0.5 ,
348
+ activation_fn = True , gate_first_n = 0 ,
349
+ seed = seed ))
350
+ bijectors .append (
351
+ build_highway_flow_layer (num_auxiliary_variables ,
352
+ residual_fraction_initial_value = 0.5 ,
353
+ activation_fn = False , gate_first_n = 0 ,
354
+ seed = seed ))
355
+
356
+ variables = chain .Chain (bijectors = list (reversed (bijectors )))
357
+
358
+ eps = transformed_distribution .TransformedDistribution (
359
+ distribution = sample .Sample (normal .Normal (0. , 0.1 ),
360
+ num_auxiliary_variables ),
361
+ bijector = variables )
362
+
363
+ eps = Root (eps )
364
+
365
+ value_out = yield (eps if flat_variables
366
+ else (eps , variables ))
367
+
368
+ global_auxiliary_variables = value_out
369
+
370
+ else :
371
+ i = 0
372
+
326
373
try :
327
374
while True :
328
375
was_root = isinstance (dist , Root )
@@ -334,9 +381,10 @@ def posterior_generator(seed=seed):
334
381
dist ,
335
382
base_distribution_surrogate_fn = base_distribution_surrogate_fn ,
336
383
variables = flat_variables [i ] if flat_variables else None ,
384
+ global_auxiliary_variables = global_auxiliary_variables ,
337
385
seed = init_seed )
338
386
339
- if was_root :
387
+ if was_root and num_auxiliary_variables == 0 :
340
388
surrogate_posterior = Root (surrogate_posterior )
341
389
# If variables were not given---i.e., we're creating new
342
390
# variables---then yield the new variables along with the surrogate
@@ -367,6 +415,8 @@ def posterior_generator(seed=seed):
367
415
return _cf_surrogate_for_joint_distribution (
368
416
dist = dist ,
369
417
base_distribution_surrogate_fn = base_distribution_surrogate_fn ,
418
+ num_auxiliary_variables = num_auxiliary_variables ,
419
+ global_auxiliary_variables = global_auxiliary_variables ,
370
420
variables = dist ._model_unflatten ( # pylint: disable=protected-access
371
421
_extract_variables_from_coroutine_model (
372
422
posterior_generator , seed = seed )))
@@ -401,6 +451,7 @@ def posterior_generator(seed=seed):
401
451
def _cf_convex_update_for_base_distribution (dist ,
402
452
initial_prior_weight ,
403
453
num_auxiliary_variables = 0 ,
454
+ global_auxiliary_variables = None ,
404
455
sample_shape = None ,
405
456
variables = None ,
406
457
seed = None ):
@@ -412,31 +463,39 @@ def _cf_convex_update_for_base_distribution(dist,
412
463
actual_event_shape .shape .as_list ()[0 ] > 0 else 1
413
464
layers = 3
414
465
bijectors = [reshape .Reshape ([- 1 ],
415
- event_shape_in = actual_event_shape +
416
- num_auxiliary_variables )]
466
+ event_shape_in = actual_event_shape +
467
+ num_auxiliary_variables )]
417
468
418
469
for _ in range (0 , layers - 1 ):
419
470
bijectors .append (
420
471
build_highway_flow_layer (
421
472
tf .reduce_prod (actual_event_shape + num_auxiliary_variables ),
422
473
residual_fraction_initial_value = initial_prior_weight ,
423
- activation_fn = True , gate_first_n = int_event_shape ))
474
+ activation_fn = True , gate_first_n = int_event_shape , seed = seed ))
424
475
bijectors .append (
425
476
build_highway_flow_layer (
426
477
tf .reduce_prod (actual_event_shape + num_auxiliary_variables ),
427
478
residual_fraction_initial_value = initial_prior_weight ,
428
- activation_fn = False , gate_first_n = int_event_shape ))
429
- bijectors .append (reshape .Reshape (actual_event_shape + num_auxiliary_variables ))
479
+ activation_fn = False , gate_first_n = int_event_shape , seed = seed ))
480
+ bijectors .append (
481
+ reshape .Reshape (actual_event_shape + num_auxiliary_variables ))
430
482
431
483
variables = chain .Chain (bijectors = list (reversed (bijectors )))
432
484
433
485
if num_auxiliary_variables > 0 :
486
+ batch_shape = global_auxiliary_variables .shape [0 ] if len (
487
+ global_auxiliary_variables .shape ) > 1 else []
488
+
434
489
cascading_flows = split .Split (
435
490
[- 1 , num_auxiliary_variables ])(
436
491
transformed_distribution .TransformedDistribution (
437
- distribution = blockwise .Blockwise ([dist , batch_broadcast .BatchBroadcast (
438
- sample .Sample (normal .Normal (0. , .1 ), num_auxiliary_variables ),
439
- to_shape = dist .batch_shape )]),
492
+ distribution = blockwise .Blockwise ([
493
+ batch_broadcast .BatchBroadcast (dist ,
494
+ to_shape = batch_shape ),
495
+ independent .Independent (
496
+ deterministic .Deterministic (
497
+ global_auxiliary_variables ),
498
+ reinterpreted_batch_ndims = 1 )]),
440
499
bijector = variables ))
441
500
442
501
else :
0 commit comments