Skip to content

Commit 57a4be0

Browse files
committed
merged conflicts
2 parents 8671bd1 + fbde45e commit 57a4be0

File tree

231 files changed

+7332
-2590
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

231 files changed

+7332
-2590
lines changed

.github/workflows/continuous-integration.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,9 @@ jobs:
6868
run: |
6969
source ${TEST_VENV_PATH}/bin/activate
7070
./testing/run_github_tests.sh
71+
- name: Upload test logs
72+
if: failure()
73+
uses: actions/upload-artifact@v1
74+
with:
75+
name: testlogs-${{ matrix.shard }}
76+
path: bazel-testlogs

discussion/fun_mcmc/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ py_library(
137137
],
138138
)
139139

140-
# pytype
141140
py_library(
142141
name = "backend_jax",
143142
srcs = ["backend_jax.py"],
@@ -182,6 +181,7 @@ py_library(
182181

183182
py_test(
184183
name = "fun_mcmc_test",
184+
size = "large",
185185
srcs = ["fun_mcmc_test.py"],
186186
python_version = "PY3",
187187
shard_count = 8,

discussion/fun_mcmc/backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
'get_backend',
4141
'JAX',
4242
'MANUAL_TRANSFORMS',
43+
'prefer_static',
4344
'set_backend',
4445
'TENSORFLOW',
4546
'tf',
@@ -171,6 +172,7 @@ def __getattr__(self, attr):
171172
return ret
172173

173174

175+
prefer_static = _Dispatcher('prefer_static')
174176
tf = _Dispatcher('tf')
175177
tfp = _Dispatcher('tfp')
176178
util = _Dispatcher('util')

discussion/fun_mcmc/backend_jax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
from discussion.fun_mcmc import tf_on_jax
1818
from discussion.fun_mcmc import util_jax as util
1919
from tensorflow_probability.substrates import jax as tfp
20+
from tensorflow_probability.substrates.jax.internal import prefer_static
2021

2122
tf = tf_on_jax.tf
2223

2324
__all__ = [
2425
'BACKEND_NAME',
2526
'multi_backend_test',
27+
'prefer_static',
2628
'tf',
2729
'tfp',
2830
'util',

discussion/fun_mcmc/backend_tf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
import tensorflow.compat.v2 as tf
1818
import tensorflow_probability as tfp
1919
from discussion.fun_mcmc import util_tf as util
20+
from tensorflow_probability.python.internal import prefer_static
2021

2122
__all__ = [
2223
'BACKEND_NAME',
2324
'multi_backend_test',
25+
'prefer_static',
2426
'tf',
2527
'tfp',
2628
'util',

discussion/fun_mcmc/fun_mcmc_lib.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
from discussion.fun_mcmc import backend
4747

48+
ps = backend.prefer_static
4849
tf = backend.tf
4950
tfp = backend.tfp
5051
util = backend.util
@@ -1349,7 +1350,7 @@ def hamiltonian_integrator(
13491350
state_grads = int_state.state_grads
13501351
state_extra = int_state.state_extra
13511352

1352-
num_steps = tf.convert_to_tensor(num_steps)
1353+
num_steps = ps.convert_to_shape_tensor(num_steps)
13531354
is_ragged = len(num_steps.shape) > 0 # pylint: disable=g-explicit-length-test
13541355

13551356
kinetic_energy, kinetic_energy_extra = call_potential_fn(
@@ -1358,7 +1359,7 @@ def hamiltonian_integrator(
13581359

13591360
if is_ragged:
13601361
step = 0
1361-
max_num_steps = tf.reduce_max(num_steps)
1362+
max_num_steps = ps.reduce_max(num_steps)
13621363
else:
13631364
step = []
13641365
max_num_steps = num_steps

discussion/fun_mcmc/util_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def wrapper(_, *rest):
7171

7272
def get_shallow_tree(is_leaf, tree):
7373
"""Returns a shallow tree, expanding only when is_leaf(subtree) is False."""
74-
return tree_util.tree_map(is_leaf, tree, is_leaf)
74+
return tree_util.tree_map(is_leaf, tree, is_leaf=is_leaf)
7575

7676

7777
def assert_same_shallow_tree(shallow, tree):

spinoffs/fun_mc/fun_mc/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ py_library(
108108

109109
py_test(
110110
name = "fun_mc_test",
111+
size = "large",
111112
srcs = ["fun_mc_test.py"],
112113
python_version = "PY3",
113114
shard_count = 8,

spinoffs/fun_mc/fun_mc/dynamic/backend_jax/backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
from fun_mc.dynamic.backend_jax import tf_on_jax
1818
from fun_mc.dynamic.backend_jax import util
1919
from tensorflow_probability.substrates import jax as tfp
20+
from tensorflow_probability.substrates.jax.internal import prefer_static
2021

2122
tf = tf_on_jax.tf
2223

2324
__all__ = [
25+
'prefer_static',
2426
'tf',
2527
'tfp',
2628
'util',

spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def wrapper(_, *rest):
7171

7272
def get_shallow_tree(is_leaf, tree):
7373
"""Returns a shallow tree, expanding only when is_leaf(subtree) is False."""
74-
return tree_util.tree_map(is_leaf, tree, is_leaf)
74+
return tree_util.tree_map(is_leaf, tree, is_leaf=is_leaf)
7575

7676

7777
def assert_same_shallow_tree(shallow, tree):

0 commit comments

Comments
 (0)