Skip to content

Commit bd53773

Browse files
authored
Merge branch 'main' into frighterafix#1505
2 parents 944af40 + 8e72c11 commit bd53773

File tree

392 files changed

+10778
-4681
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

392 files changed

+10778
-4681
lines changed

discussion/examples/windowed_sampling.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,9 +1837,9 @@
18371837
"WARNING:tensorflow:Note that RandomStandardNormal inside pfor op may not give same output as inside a sequential loop.\n",
18381838
"Fast window 75\n",
18391839
"Slow window 25\n",
1840-
"WARNING:tensorflow:5 out of the last 5 calls to \u003cfunction slow_window at 0x7f9456031ea0\u003e triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
1840+
"WARNING:tensorflow:5 out of the last 5 calls to \u003cfunction slow_window at 0x7f9456031ea0\u003e triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
18411841
"Slow window 50\n",
1842-
"WARNING:tensorflow:6 out of the last 6 calls to \u003cfunction slow_window at 0x7f9456031ea0\u003e triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
1842+
"WARNING:tensorflow:6 out of the last 6 calls to \u003cfunction slow_window at 0x7f9456031ea0\u003e triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
18431843
"Slow window 100\n",
18441844
"Slow window 200\n",
18451845
"Fast window 75\n",

discussion/meads/README.md

Lines changed: 6 additions & 0 deletions

discussion/meads/meads.ipynb

Lines changed: 1336 additions & 0 deletions
Large diffs are not rendered by default.

discussion/turnkey_inference_candidate/window_tune_nuts_sampling.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _sample_posterior(target_log_prob_unconstrained,
4646
parallel_iterations=10,
4747
jit_compile=True,
4848
use_input_signature=False,
49-
experimental_relax_shapes=False):
49+
reduce_retracing=False):
5050
"""MCMC sampling with HMC/NUTS using an expanding epoch tuning scheme."""
5151

5252
seed_stream = tfp.util.SeedStream(seed, 'window_tune_nuts_sampling')
@@ -117,7 +117,7 @@ def _sample_posterior(target_log_prob_unconstrained,
117117
input_signature=input_signature,
118118
autograph=False,
119119
jit_compile=jit_compile,
120-
experimental_relax_shapes=experimental_relax_shapes)
120+
reduce_retracing=reduce_retracing)
121121
def fast_adaptation_interval(num_steps, previous_state):
122122
"""Step size only adaptation interval.
123123
@@ -179,7 +179,7 @@ def body_fn_window2(
179179
input_signature=input_signature,
180180
autograph=False,
181181
jit_compile=jit_compile,
182-
experimental_relax_shapes=experimental_relax_shapes)
182+
reduce_retracing=reduce_retracing)
183183
def slow_adaptation_interval(num_steps, previous_n, previous_state,
184184
previous_mean, previous_cov):
185185
"""Interval that tunes the mass matrix and step size simultaneously.
@@ -328,7 +328,7 @@ def window_tune_nuts_sampling(target_log_prob,
328328
parallel_iterations=10,
329329
jit_compile=True,
330330
use_input_signature=True,
331-
experimental_relax_shapes=False):
331+
reduce_retracing=False):
332332
"""Sample from a density with NUTS and an expanding window tuning scheme.
333333
334334
This function implements a turnkey MCMC sampling routine using NUTS and an
@@ -347,7 +347,7 @@ def window_tune_nuts_sampling(target_log_prob,
347347
of the tuning epoch (window 1, 2, and 3 in Stan [1]) run with two @tf.function
348348
compiled functions. The user can control the compilation options using the
349349
kwargs `jit_compile`, `use_input_signature`, and
350-
`experimental_relax_shapes`. Setting all to True would compile to XLA and
350+
`reduce_retracing`. Setting all to True would compile to XLA and
351351
potentially avoid the small overhead of function recompilation (note that it
352352
is not yet the case in XLA right now). It is not yet clear whether doing it
353353
this way is better than just wrapping the full inference routine in
@@ -403,7 +403,7 @@ def window_tune_nuts_sampling(target_log_prob,
403403
function is always compiled by XLA.
404404
use_input_signature: If True, generate an input_signature kwarg to pass to
405405
tf.function decorator.
406-
experimental_relax_shapes: kwarg pass to tf.function decorator. When True,
406+
reduce_retracing: kwarg pass to tf.function decorator. When True,
407407
tf.function may generate fewer, graphs that are less specialized on input
408408
shapes.
409409
@@ -564,6 +564,6 @@ def target_log_prob_unconstrained_concated(x):
564564
parallel_iterations=parallel_iterations,
565565
jit_compile=jit_compile,
566566
use_input_signature=use_input_signature,
567-
experimental_relax_shapes=experimental_relax_shapes)
567+
reduce_retracing=reduce_retracing)
568568
return forward_transform(
569569
split_and_reshape(nuts_samples)), diagnostic, conditioning_bijector

spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
def map_tree(fn, tree, *args):
4949
"""Maps `fn` over the leaves of a nested structure."""
50-
return tree_util.tree_multimap(fn, tree, *args)
50+
return tree_util.tree_map(fn, tree, *args)
5151

5252

5353
def flatten_tree(tree):
@@ -66,7 +66,7 @@ def map_tree_up_to(shallow, fn, tree, *rest):
6666
def wrapper(_, *rest):
6767
return fn(*rest)
6868

69-
return tree_util.tree_multimap(wrapper, shallow, tree, *rest)
69+
return tree_util.tree_map(wrapper, shallow, tree, *rest)
7070

7171

7272
def get_shallow_tree(is_leaf, tree):
@@ -76,7 +76,7 @@ def get_shallow_tree(is_leaf, tree):
7676

7777
def assert_same_shallow_tree(shallow, tree):
7878
"""Asserts that `tree` has the same shallow structure as `shallow`."""
79-
# Do a dummy multimap for the side-effect of verifying that the structures are
79+
# Do a dummy map for the side-effect of verifying that the structures are
8080
# the same. This doesn't catch all the errors we actually care about, sadly.
8181
map_tree_up_to(shallow, lambda *args: (), tree)
8282

spinoffs/inference_gym/inference_gym/backends/jax_integration_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# Lint as: python3
21
# Copyright 2020 The TensorFlow Probability Authors.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");

spinoffs/inference_gym/inference_gym/backends/numpy_integration_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# Lint as: python3
21
# Copyright 2020 The TensorFlow Probability Authors.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");

spinoffs/inference_gym/inference_gym/backends/rewrite.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# Lint as: python3
21
# Copyright 2020 The TensorFlow Probability Authors.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");

spinoffs/inference_gym/inference_gym/backends/tensorflow_integration_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# Lint as: python3
21
# Copyright 2020 The TensorFlow Probability Authors.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");

spinoffs/inference_gym/inference_gym/backends/util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# Lint as: python3
21
# Copyright 2020 The TensorFlow Probability Authors.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");

0 commit comments

Comments
 (0)