Skip to content

Commit ed47dda

Browse files
authored
Merge pull request #1174 from jburnim/r0.12
Prepare branch for the TFP 0.12.0rc2 release
2 parents 7784466 + 1fd985d commit ed47dda

File tree

75 files changed

+2646
-1114
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+2646
-1114
lines changed

STYLE_GUIDE.md

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ they supersede all previous conventions.
109109
* Definitely use named args for 2nd args onward in docstrings.
110110
111111
1. Use names which describe semantics, not computation or mathematics, e.g.,
112-
avoid `xp1 = x+1` or `tfd.Normal(loc=mu, scale=sigma)`.
112+
avoid `xp1 = x + 1` or `tfd.Normal(loc=mu, scale=sigma)`.
113113
114114
1. Prefer inlining intermediates which are used once.
115115
@@ -157,16 +157,16 @@ they supersede all previous conventions.
157157
158158
1. Prefer using the most specific TF operator. E.g,
159159
160-
* Use `tf.squared_difference(x,y)` over `(x-y)**2`.
161-
* Use `tf.rsqrt` over `1./tf.sqrt(x)`.
160+
* Use `tf.squared_difference(x, y)` over `(x - y)**2`.
161+
* Use `tf.rsqrt` over `1. / tf.sqrt(x)`.
162162
163163
1. Worry about gradients! (It's often not automatic for API builders!)
164164
165165
1. When forced to choose between FLOPS and numerical accuracy, prefer numerical
166166
accuracy.
167167
168-
1. Avoid tf.cast if possible. Eg, prefer `tf.where(cond, a, b)` over
169-
`tf.cast(cond,dtype=a.dtype)*a + (1-tf.cast(cond,dtype=b.dtype)*b`
168+
1. Avoid tf.cast if possible. Eg, prefer `tf.where(pred, a, b)` over
169+
`tf.cast(cond, dtype=a.dtype) * a + (1 - tf.cast(cond, dtype=b.dtype) * b`
170170
171171
1. Preserve static shape hints.
172172
@@ -217,3 +217,15 @@ they supersede all previous conventions.
217217
`Tensor`s, and Numpy objects. When converting a user-provided literal to a
218218
`Tensor` (see e.g. `Distribution._call_log_prob`), specify the dtype to
219219
`tf.convert_to_tensor` if it is available.
220+
221+
1. Prefer overloaded operators on `Tensor`s (`+`, `-`, etc.) to explicit
222+
method calls (`tf.add`, `tf.sub`, etc.). Exceptions:
223+
224+
* Prefer `tf.equal` to `==` when checking element-wise equality, because the
225+
semantics of the latter are inconsistent between eager and graph
226+
(`tf.function`) modes.
227+
* Use `&` and `|` only if you want bitwise logic. Note that these are
228+
equivalent to logical ops only if all inputs are `bool`s or are in
229+
`{0, 1}`.
230+
231+

spinoffs/oryx/oryx/bijectors/__init__.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,15 @@
1818
from oryx.bijectors import bijector_extensions
1919
from tensorflow_probability.substrates import jax as tfp
2020

21-
__all__ = [
22-
'bijector_extensions'
23-
]
24-
2521
tfb = tfp.bijectors
2622

27-
_bijectors = {}
23+
__all__ = tfb.__all__
2824

29-
for name in dir(tfb):
25+
for name in __all__:
3026
bij = getattr(tfb, name)
3127
if inspect.isclass(bij) and issubclass(bij, tfb.Bijector):
3228
if bij is not tfb.Bijector:
3329
bij = bijector_extensions.make_type(bij)
34-
_bijectors[name] = bij
35-
36-
37-
for key, val in _bijectors.items():
38-
locals()[key] = val
39-
30+
locals()[name] = bij
4031

41-
del _bijectors
32+
del tfb

spinoffs/oryx/oryx/core/interpreters/harvest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def process_higher_order_primitive(self, primitive, f, tracers, params,
335335
params = params.copy()
336336
new_params = dict(
337337
params,
338-
mapped_invars=(True,) * len(tree_util.tree_leaves(plants)) +
339-
params['mapped_invars'])
338+
in_axes=(0,) * len(tree_util.tree_leaves(plants)) +
339+
params['in_axes'])
340340
else:
341341
new_params = dict(params)
342342
all_args, all_tree = tree_util.tree_flatten((plants, vals))

spinoffs/oryx/oryx/core/interpreters/inverse/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,8 @@ def remove_slice(cell):
373373
flat_vals, in_tree = tree_util.tree_flatten((mapped_incells, mapped_outcells))
374374
f, aux = flat_propagate(f, in_tree)
375375
# Assume all invars as mapped
376-
new_mapped_invars = (True,) * len(flat_vals)
377-
new_params = dict(params, mapped_invars=new_mapped_invars)
376+
new_in_axes = (0,) * len(flat_vals)
377+
new_params = dict(params, in_axes=new_in_axes)
378378
if 'donated_invars' in params:
379379
new_params['donated_invars'] = (False,) * len(flat_vals)
380380
subenv_vals = prim.bind(f, *flat_vals, **new_params)

spinoffs/oryx/oryx/core/interpreters/unzip.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
from jax import core as jax_core
3535
from jax import custom_derivatives as cd
3636
from jax import linear_util as lu
37-
from jax import source_info_util
3837
from jax import tree_util
3938
from jax import util as jax_util
39+
from jax._src import source_info_util
4040
from jax.interpreters import partial_eval as pe
4141
import numpy as onp
4242

@@ -282,14 +282,13 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
282282
return current_custom_rules()[call_primitive](self, f, *tracers, **params)
283283
if call_primitive in pe.call_partial_eval_rules:
284284
raise NotImplementedError
285-
in_pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
285+
in_pvals = [t.pval for t in tracers]
286286
if is_map:
287-
pvs = [
288-
None if pv is None else mapped_aval(params['axis_size'], pv)
289-
for pv in in_pvs
290-
]
291-
else:
292-
pvs = in_pvs
287+
unknown = pe.PartialVal.unknown
288+
in_pvals = [pval if pval.is_known() or in_axis is None else
289+
unknown(mapped_aval(params['axis_size'], in_axis, pval[0]))
290+
for pval, in_axis in zip(in_pvals, params['in_axes'])]
291+
pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
293292
keys = tuple(t.is_key() for t in tracers)
294293
new_settings = UnzipSettings(settings.tag, call_primitive in block_registry)
295294
fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings)
@@ -360,12 +359,6 @@ def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
360359
for pv, const, key in safe_zip(out_pvs, out_consts, out_keys)
361360
]
362361
new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr)
363-
if is_map:
364-
new_params = dict(
365-
new_params,
366-
mapped_invars=tuple([True] * len(const_tracers) +
367-
[False] * len(env_tracers) +
368-
[True] * len(in_tracers)))
369362
if 'donated_invars' in params:
370363
new_donated_invars = (
371364
(False,) * len(const_tracers) + (False,) * len(env_tracers) +

spinoffs/oryx/oryx/distributions/__init__.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,12 @@
1616
from oryx.distributions import distribution_extensions
1717
from tensorflow_probability.substrates import jax as tfp
1818

19-
__all__ = [
20-
'distribution_extensions'
21-
]
22-
23-
2419
tfd = tfp.distributions
2520

26-
_distributions = {}
21+
__all__ = tfd.__all__
2722

28-
for name in dir(tfd):
23+
for name in __all__:
2924
dist = getattr(tfd, name)
30-
_distributions[name] = dist
31-
32-
33-
for key, val in _distributions.items():
34-
locals()[key] = val
35-
25+
locals()[name] = dist
3626

37-
del _distributions
38-
del distribution_extensions # Only needed for registration.
27+
del tfd

spinoffs/oryx/oryx/experimental/nn/normalization_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def test_check_grads(self):
171171
net = net_init.init(net_rng, state.Shape(in_shape))
172172

173173
x = random.normal(data_rng, in_shape)
174-
jtu.check_grads(net, (x,), 2)
174+
jtu.check_grads(net.call, (x,), 2)
175175

176176

177177
def mse(x, y):

tensorflow_probability/python/__init__.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,23 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import functools
22+
2123
from tensorflow_probability.python.internal import all_util
2224
from tensorflow_probability.python.internal import lazy_loader
2325

2426

25-
# Ensure TensorFlow is importable and its version is sufficiently recent. This
26-
# needs to happen before anything else, since the imports below will try to
27-
# import tensorflow, too.
2827
# pylint: disable=g-import-not-at-top
29-
def _ensure_tf_install():
30-
"""Attempt to import tensorflow, and ensure its version is sufficient.
28+
def _validate_tf_environment(package):
29+
"""Check TF version and (depending on package) warn about TensorFloat32.
30+
31+
Args:
32+
package: Python `str` indicating which package is being imported. Used for
33+
package-dependent warning about TensorFloat32.
3134
3235
Raises:
3336
ImportError: if either tensorflow is not importable or its version is
34-
inadequate.
37+
inadequate.
3538
"""
3639
try:
3740
import tensorflow.compat.v1 as tf
@@ -62,9 +65,10 @@ def _ensure_tf_install():
6265
required=required_tensorflow_version,
6366
present=tf.__version__))
6467

65-
if tf.config.experimental.tensor_float_32_execution_enabled():
68+
if (package == 'mcmc' and
69+
tf.config.experimental.tensor_float_32_execution_enabled()):
6670
# Must import here, because symbols get pruned to __all__.
67-
import warnings # pylint: disable=g-import-not-at-top
71+
import warnings
6872
warnings.warn(
6973
'TensorFloat-32 matmul/conv are enabled for NVIDIA Ampere+ GPUs. The '
7074
'resulting loss of precision may hinder MCMC convergence. To turn off, '
@@ -94,6 +98,8 @@ def _ensure_tf_install():
9498
for pkg in _allowed_symbols:
9599
globals()[pkg] = lazy_loader.LazyLoader(
96100
pkg, globals(), 'tensorflow_probability.python.{}'.format(pkg),
97-
on_first_access=_ensure_tf_install)
101+
# These checks need to happen before lazy-loading, since the modules
102+
# themselves will try to import tensorflow, too.
103+
on_first_access=functools.partial(_validate_tf_environment, pkg))
98104

99105
all_util.remove_undocumented(__name__, _allowed_symbols)

tensorflow_probability/python/bijectors/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from tensorflow_probability.python.bijectors.expm1 import Log1p
4141
from tensorflow_probability.python.bijectors.ffjord import FFJORD
4242
from tensorflow_probability.python.bijectors.fill_scale_tril import FillScaleTriL
43-
from tensorflow_probability.python.bijectors.fill_scale_tril import ScaleTriL
4443
from tensorflow_probability.python.bijectors.fill_triangular import FillTriangular
4544
from tensorflow_probability.python.bijectors.frechet_cdf import FrechetCDF
4645
from tensorflow_probability.python.bijectors.generalized_pareto import GeneralizedPareto
@@ -159,7 +158,6 @@
159158
"ScaleMatvecLinearOperatorBlock",
160159
"ScaleMatvecLU",
161160
"ScaleMatvecTriL",
162-
"ScaleTriL",
163161
"Shift",
164162
"ShiftedGompertzCDF",
165163
"Sigmoid",

tensorflow_probability/python/bijectors/bijector_properties_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@
7777
'ScaleMatvecTriL',
7878
'Shift',
7979
'ShiftedGompertzCDF',
80-
'ScaleTriL',
8180
'Sigmoid',
8281
'Sinh',
8382
'SinhArcsinh',

0 commit comments

Comments
 (0)