Skip to content

Commit cc2c37e

Browse files
authored
Merge pull request #1191 from jburnim/r0.12
Prepare branch for the TFP 0.12.0rc4 release
2 parents ed47dda + 3de3fe0 commit cc2c37e

File tree

141 files changed

+8243
-3054
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

141 files changed

+8243
-3054
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2020 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
name: Tests
16+
on: [push, pull_request]
17+
env:
18+
TEST_VENV_PATH: ~/test_virtualenv
19+
jobs:
20+
lints:
21+
name: Lints
22+
runs-on: ubuntu-latest
23+
strategy:
24+
matrix:
25+
python-version: [3.7]
26+
steps:
27+
- name: Checkout
28+
uses: actions/checkout@v1
29+
with:
30+
fetch-depth: 20
31+
- name: Setup Python
32+
uses: actions/setup-python@v2
33+
with:
34+
python-version: ${{ matrix.python-version }}
35+
- name: Setup virtualenv
36+
run: |
37+
sudo apt install virtualenv
38+
virtualenv -p python${{ matrix.python-version }} ${TEST_VENV_PATH}
39+
- name: Lints
40+
run: |
41+
source ${TEST_VENV_PATH}/bin/activate
42+
./testing/run_github_lints.sh
43+
tests:
44+
name: Tests
45+
runs-on: ubuntu-latest
46+
strategy:
47+
matrix:
48+
python-version: [3.7]
49+
shard: [0, 1, 2, 3, 4]
50+
env:
51+
TEST_VENV_PATH: ~/test_virtualenv
52+
SHARD: ${{ matrix.shard }}
53+
NUM_SHARDS: 5
54+
steps:
55+
- name: Checkout
56+
uses: actions/checkout@v1
57+
with:
58+
fetch-depth: 1
59+
- name: Setup Python
60+
uses: actions/setup-python@v2
61+
with:
62+
python-version: ${{ matrix.python-version }}
63+
- name: Setup virtualenv
64+
run: |
65+
sudo apt install virtualenv
66+
virtualenv -p python${{ matrix.python-version }} ${TEST_VENV_PATH}
67+
- name: Tests
68+
run: |
69+
source ${TEST_VENV_PATH}/bin/activate
70+
./testing/run_github_tests.sh

.travis.yml

Lines changed: 0 additions & 56 deletions
This file was deleted.

CONTRIBUTING.md

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,19 @@ repository (with credit to the original author) and closes the pull request.
3232

3333
## Continuous Integration
3434

35-
We use [Travis CI](https://travis-ci.org/tensorflow/probability) to do automated
36-
style checking and run unit-tests (discussed in more detail below). A build
37-
will be triggered when you open a pull request, or update the pull request by
38-
adding a commit, rebasing etc.
35+
We use [GitHub Actions](https://github.com/tensorflow/probability/actions) to do
36+
automated style checking and run unit-tests (discussed in more detail below). A
37+
build will be triggered when you open a pull request, or update the pull request
38+
by adding a commit, rebasing etc.
3939

40-
We test against TensorFlow nightly on Python 2.7 and 3.6. We shard our tests
40+
We test against TensorFlow nightly on Python 3.7. We shard our tests
4141
across several build jobs (identified by the `SHARD` environment variable).
42-
Linting, in particular, is only done on the first shard, so look at that shard's
43-
logs for lint errors if any.
42+
Lints are also done in a separate job.
4443

4544
All pull-requests will need to pass the automated lint and unit-tests before
46-
being merged. As Travis-CI tests can take a bit of time, see the following
47-
sections on how to run the lint checks and unit-tests locally while you're
48-
developing your change.
45+
being merged. As the tests can take a bit of time, see the following sections
46+
on how to run the lint checks and unit-tests locally while you're developing
47+
your change.
4948

5049
## Style
5150

discussion/fun_mcmc/prefab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def kernel(adaptive_hmc_state):
347347
hmc_state.state,
348348
axis=tuple(range(chain_ndims)) if chain_ndims else None,
349349
window_size=int(np.prod(hmc_state.target_log_prob.shape)) *
350-
variance_window_steps)
350+
variance_window_steps) # pytype: disable=wrong-arg-types
351351

352352
if num_adaptation_steps is not None:
353353
# Take care of adaptation for variance and step size.

spinoffs/inference_gym/inference_gym/BUILD

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# A package for target densities and benchmarking of inference algorithms
1717
# against the same.
1818

19-
# [internal] load pytype.bzl (pytype_library, pytype_strict_library)
19+
# [internal] load pytype.bzl (pytype_strict_library)
2020
# [internal] load dummy dependency
2121

2222
package(
@@ -42,7 +42,6 @@ py_library(
4242
],
4343
)
4444

45-
# pytype
4645
py_library(
4746
name = "using_numpy",
4847
srcs = ["using_numpy.py"],
@@ -56,7 +55,6 @@ py_library(
5655
],
5756
)
5857

59-
# pytype
6058
py_library(
6159
name = "using_jax",
6260
srcs = ["using_jax.py"],
@@ -71,7 +69,6 @@ py_library(
7169
],
7270
)
7371

74-
# pytype
7572
py_library(
7673
name = "using_tensorflow",
7774
srcs = ["using_tensorflow.py"],

spinoffs/oryx/oryx/core/interpreters/harvest.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -333,22 +333,27 @@ def process_higher_order_primitive(self, primitive, f, tracers, params,
333333
if is_map:
334334
# TODO(sharadmv): figure out if invars are mapped or unmapped
335335
params = params.copy()
336+
out_axes_thunk = params['out_axes_thunk']
337+
@jax_util.as_hashable_function(closure=('harvest', out_axes_thunk))
338+
def new_out_axes_thunk():
339+
out_axes = out_axes_thunk()
340+
assert all(out_axis == 0 for out_axis in out_axes)
341+
return (0,) * out_tree().num_leaves
336342
new_params = dict(
337343
params,
338-
in_axes=(0,) * len(tree_util.tree_leaves(plants)) +
339-
params['in_axes'])
344+
in_axes=(0,) * len(tree_util.tree_leaves(plants)) + params['in_axes'],
345+
out_axes_thunk=new_out_axes_thunk)
340346
else:
341347
new_params = dict(params)
342348
all_args, all_tree = tree_util.tree_flatten((plants, vals))
343349
num_plants = len(all_args) - len(vals)
344350
if 'donated_invars' in params:
345351
new_params['donated_invars'] = ((False,) * num_plants
346352
+ params['donated_invars'])
347-
f, aux = harvest_eval(f, self, context.settings, all_tree)
353+
f, out_tree = harvest_eval(f, self, context.settings, all_tree)
348354
out_flat = primitive.bind(
349355
f, *all_args, **new_params, name=jax_util.wrap_name(name, 'harvest'))
350-
out_tree = aux()
351-
out, reaps = tree_util.tree_unflatten(out_tree, out_flat)
356+
out, reaps = tree_util.tree_unflatten(out_tree(), out_flat)
352357
out_tracers = safe_map(self.pure, out)
353358
reap_tracers = tree_util.tree_map(self.pure, reaps)
354359
if primitive is nest_p and reap_tracers:

spinoffs/oryx/oryx/core/interpreters/inverse/core.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def wrapped(*args, **kwargs):
178178
flat_incells = [InverseAndILDJ.unknown(aval) for aval in flat_forward_avals]
179179
flat_outcells = safe_map(InverseAndILDJ.new, flat_args)
180180
env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr,
181-
flat_constcells, flat_incells, flat_outcells)
181+
flat_constcells, flat_incells, flat_outcells) # pytype: disable=wrong-arg-types
182182
flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars]
183183
if any(not flat_incell.top() for flat_incell in flat_incells):
184184
raise ValueError('Cannot invert function.')
@@ -332,7 +332,7 @@ def hop_inverse_rule(prim):
332332
def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_):
333333
const_cells, incells = jax_util.split_list(incells, [num_consts])
334334
env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr, const_cells,
335-
incells, outcells)
335+
incells, outcells) # pytype: disable=wrong-arg-types
336336
new_incells = [env.read(invar) for invar in jaxpr.invars]
337337
new_outcells = [env.read(outvar) for outvar in jaxpr.outvars]
338338
return const_cells + new_incells, new_outcells, None
@@ -377,6 +377,12 @@ def remove_slice(cell):
377377
new_params = dict(params, in_axes=new_in_axes)
378378
if 'donated_invars' in params:
379379
new_params['donated_invars'] = (False,) * len(flat_vals)
380+
if 'out_axes' in params:
381+
assert all(out_axis == 0 for out_axis in params['out_axes'])
382+
new_params['out_axes_thunk'] = jax_util.HashableFunction(
383+
lambda: (0,) * aux().num_leaves,
384+
closure=('ildj', params['out_axes']))
385+
del new_params['out_axes']
380386
subenv_vals = prim.bind(f, *flat_vals, **new_params)
381387
subenv_tree = aux()
382388
subenv = tree_util.tree_unflatten(subenv_tree, subenv_vals)

spinoffs/oryx/oryx/core/interpreters/inverse/inverse_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,14 @@ def f(x, y):
249249
onp.testing.assert_allclose(y, np.ones(2))
250250
onp.testing.assert_allclose(ildj_, 0., atol=1e-6, rtol=1e-6)
251251

252+
def test_inverse_of_reshape(self):
253+
def f(x):
254+
return np.reshape(x, (4,))
255+
f_inv = core.inverse_and_ildj(f, np.ones((2, 2)))
256+
x, ildj_ = f_inv(np.ones(4))
257+
onp.testing.assert_allclose(x, np.ones((2, 2)))
258+
onp.testing.assert_allclose(ildj_, 0.)
259+
252260
def test_sigmoid_ildj(self):
253261
def naive_sigmoid(x):
254262
# This is the default JAX implementation of sigmoid.

spinoffs/oryx/oryx/core/interpreters/inverse/rules.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,8 @@ def reshape_ildj(incells, outcells, **params):
166166
))], None
167167
elif outcell.top() and not incell.top():
168168
val = outcell.val
169-
ndslice = NDSlice.new(np.reshape(val, incell.aval.shape))
170169
new_incells = [
171-
InverseAndILDJ(incell.aval, [ndslice])
170+
InverseAndILDJ.new(np.reshape(val, incell.aval.shape))
172171
]
173172
return new_incells, outcells, None
174173
return incells, outcells, None

spinoffs/oryx/oryx/core/interpreters/unzip.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -288,19 +288,29 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
288288
in_pvals = [pval if pval.is_known() or in_axis is None else
289289
unknown(mapped_aval(params['axis_size'], in_axis, pval[0]))
290290
for pval, in_axis in zip(in_pvals, params['in_axes'])]
291+
out_axes_thunk = params['out_axes_thunk']
292+
@jax_util.as_hashable_function(closure=('unzip', out_axes_thunk))
293+
def new_out_axes_thunk():
294+
out_axes = out_axes_thunk()
295+
assert all(out_axis == 0 for out_axis in out_axes)
296+
_, num_outputs, _ = aux()
297+
return (0,) * num_outputs
298+
new_params = dict(params, out_axes_thunk=new_out_axes_thunk)
299+
else:
300+
new_params = params
291301
pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
292302
keys = tuple(t.is_key() for t in tracers)
293303
new_settings = UnzipSettings(settings.tag, call_primitive in block_registry)
294304
fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings)
295-
out_flat = call_primitive.bind(fun, *in_consts, **params)
296-
success, results = aux()
305+
out_flat = call_primitive.bind(fun, *in_consts, **new_params)
306+
success, _, results = aux()
297307
if not success:
298308
out_pvs, out_keys, jaxpr, env = results
299309
out_pv_consts, consts = jax_util.split_list(out_flat, [len(out_pvs)])
300-
out_tracers = self._bound_output_tracers(call_primitive, params, jaxpr,
301-
consts, env, tracers, out_pvs,
302-
out_pv_consts, out_keys, name,
303-
is_map)
310+
out_tracers = self._bound_output_tracers(call_primitive, new_params,
311+
jaxpr, consts, env, tracers,
312+
out_pvs, out_pv_consts,
313+
out_keys, name, is_map)
304314
return out_tracers
305315
init_name = jax_util.wrap_name(name, 'init')
306316
apply_name = jax_util.wrap_name(name, 'apply')
@@ -319,15 +329,16 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
319329
[len(apply_pvs)])
320330

321331
variable_tracers = self._bound_output_tracers(
322-
call_primitive, params, init_jaxpr, init_consts, init_env, key_tracers,
323-
init_pvs, init_pv_consts, [True] * len(init_pvs), init_name, is_map)
332+
call_primitive, new_params, init_jaxpr, init_consts, init_env,
333+
key_tracers, init_pvs, init_pv_consts, [True] * len(init_pvs),
334+
init_name, is_map)
324335

325336
unflat_variables = tree_util.tree_unflatten(variable_tree, variable_tracers)
326337
if call_primitive is harvest.nest_p:
327338
variable_dict = harvest.sow(
328339
dict(safe_zip(variable_names, unflat_variables)),
329340
tag=settings.tag,
330-
name=params['scope'],
341+
name=new_params['scope'],
331342
mode='strict')
332343
unflat_variables = tuple(variable_dict[name] for name in variable_names)
333344
else:
@@ -342,7 +353,7 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
342353
variable_tracers = tree_util.tree_leaves(unflat_variables)
343354

344355
out_tracers = self._bound_output_tracers(
345-
call_primitive, params, apply_jaxpr, apply_consts, apply_env,
356+
call_primitive, new_params, apply_jaxpr, apply_consts, apply_env,
346357
variable_tracers + abstract_tracers, apply_pvs, apply_pv_consts,
347358
apply_keys, apply_name, is_map)
348359
return out_tracers
@@ -365,6 +376,11 @@ def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
365376
tuple(v for v, t in zip(params['donated_invars'], in_tracers)
366377
if not t.pval.is_known()))
367378
new_params['donated_invars'] = new_donated_invars
379+
if is_map:
380+
out_axes = params['out_axes_thunk']()
381+
assert all(out_axis == 0 for out_axis in out_axes)
382+
new_params['out_axes'] = (0,) * len(out_tracers)
383+
del new_params['out_axes_thunk']
368384
eqn = pe.new_eqn_recipe(
369385
tuple(const_tracers + env_tracers + in_tracers), out_tracers, primitive,
370386
new_params, source_info_util.current()) # pytype: disable=wrong-arg-types
@@ -442,14 +458,16 @@ def unzip_eval_wrapper(pvs, *consts):
442458
out = (
443459
tuple(init_pv_consts) + tuple(init_consts) + tuple(apply_pv_consts) +
444460
tuple(apply_consts))
445-
yield out, (success, ((init_pvs, len(init_consts), apply_pvs),
446-
(init_jaxpr, apply_jaxpr), (init_env,
447-
apply_env), metadata))
461+
yield out, (success, len(out),
462+
((init_pvs, len(init_consts), apply_pvs),
463+
(init_jaxpr, apply_jaxpr),
464+
(init_env, apply_env),
465+
metadata))
448466
else:
449467
jaxpr, (out_pvals, out_keys, consts, env) = result
450468
out_pvs, out_consts = jax_util.unzip2(out_pvals)
451469
out = tuple(out_consts) + tuple(consts)
452-
yield out, (success, (out_pvs, out_keys, jaxpr, env))
470+
yield out, (success, len(out), (out_pvs, out_keys, jaxpr, env))
453471

454472

455473
@lu.transformation

0 commit comments

Comments
 (0)