Skip to content

Commit b71b6af

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Add tfp.experimental.math.manual_special_functions.
These trade speed for precision. They are sometimes useful when running large MCMC problems on a TPU. PiperOrigin-RevId: 390191832
1 parent c69428c commit b71b6af

File tree

7 files changed

+847
-3
lines changed

7 files changed

+847
-3
lines changed

tensorflow_probability/python/build_defs.bzl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ RUNFILES_ROOT = "tensorflow_probability/"
2828

2929
def _substrate_src(src, substrate):
3030
"""Rewrite a single src filename for the given substrate."""
31-
return "_{}/_generated_{}".format(substrate, src)
31+
32+
# When src's are sourced from a different package we cut away the package
33+
# name.
34+
parts = src.split(":")
35+
return "_{}/_generated_{}".format(substrate, parts[-1])
3236

3337
def _substrate_srcs(srcs, substrate):
3438
"""Rewrite src filenames for the given substrate."""
@@ -271,6 +275,7 @@ def multi_substrate_py_test(
271275
srcs = [],
272276
main = None,
273277
deps = [],
278+
jax_extra_deps = [],
274279
tags = [],
275280
numpy_tags = [],
276281
jax_tags = [],
@@ -279,7 +284,8 @@ def multi_substrate_py_test(
279284
srcs_version = "PY3",
280285
python_version = "PY3",
281286
timeout = None,
282-
shard_count = None):
287+
shard_count = None,
288+
args = []):
283289
"""A TFP `py_test` for each of TF, NumPy, and JAX.
284290
285291
Args:
@@ -296,6 +302,7 @@ def multi_substrate_py_test(
296302
use-case of the `main` argument is a secondary, i.e. GPU, test.
297303
deps: As with `py_test`. The list is rewritten to depend on
298304
substrate-specific libraries for substrate variants.
305+
jax_extra_deps: Extra dependencies for the JAX substrate.
299306
tags: Tags global to this test target. NumPy also gets a `'tfp_numpy'`
300307
tag, and JAX gets a `'tfp_jax'` tag. A `f'_{name}'` tag is used
301308
to produce the `test_suite`.
@@ -308,6 +315,7 @@ def multi_substrate_py_test(
308315
python_version: As with `py_test`.
309316
timeout: As with `py_test`.
310317
shard_count: As with `py_test`.
318+
args: As with `py_test`.
311319
"""
312320

313321
tags = [t for t in tags]
@@ -325,6 +333,7 @@ def multi_substrate_py_test(
325333
python_version = python_version,
326334
timeout = timeout,
327335
shard_count = shard_count,
336+
args = args,
328337
)
329338
test_targets.append(":{}.tf".format(name))
330339

@@ -349,6 +358,7 @@ def multi_substrate_py_test(
349358
python_version = "PY3",
350359
timeout = timeout,
351360
shard_count = shard_count,
361+
args = args,
352362
)
353363
test_targets.append(":{}.numpy".format(name))
354364

@@ -362,7 +372,7 @@ def multi_substrate_py_test(
362372
cmd = "$(location {}) $(SRCS) --numpy_to_jax > $@".format(REWRITER_TARGET),
363373
exec_tools = [REWRITER_TARGET],
364374
)
365-
jax_deps = _substrate_deps(deps, "jax")
375+
jax_deps = _substrate_deps(deps, "jax") + jax_extra_deps
366376
# [internal] Add JAX build dep
367377
native.py_test(
368378
name = "{}.jax".format(name),
@@ -375,6 +385,7 @@ def multi_substrate_py_test(
375385
python_version = "PY3",
376386
timeout = timeout,
377387
shard_count = shard_count,
388+
args = args,
378389
)
379390
test_targets.append(":{}.jax".format(name))
380391

tensorflow_probability/python/experimental/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ multi_substrate_py_library(
5757
"//tensorflow_probability/python/experimental/distributions",
5858
"//tensorflow_probability/python/experimental/linalg",
5959
"//tensorflow_probability/python/experimental/marginalize",
60+
"//tensorflow_probability/python/experimental/math",
6061
"//tensorflow_probability/python/experimental/mcmc",
6162
"//tensorflow_probability/python/experimental/nn",
6263
"//tensorflow_probability/python/experimental/parallel_filter",

tensorflow_probability/python/experimental/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from tensorflow_probability.python.experimental import distributions
3838
from tensorflow_probability.python.experimental import linalg
3939
from tensorflow_probability.python.experimental import marginalize
40+
from tensorflow_probability.python.experimental import math
4041
from tensorflow_probability.python.experimental import mcmc
4142
from tensorflow_probability.python.experimental import nn
4243
from tensorflow_probability.python.experimental import parallel_filter
@@ -62,6 +63,7 @@
6263
'distributions',
6364
'linalg',
6465
'marginalize',
66+
'math',
6567
'mcmc',
6668
'nn',
6769
'parallel_filter',
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
# Description:
16+
# Experimental math.
17+
18+
load(
19+
"//tensorflow_probability/python:build_defs.bzl",
20+
"multi_substrate_py_library",
21+
"multi_substrate_py_test",
22+
)
23+
24+
package(
25+
default_visibility = [
26+
"//tensorflow_probability:__subpackages__",
27+
],
28+
)
29+
30+
licenses(["notice"])
31+
32+
multi_substrate_py_library(
33+
name = "math",
34+
srcs = ["__init__.py"],
35+
srcs_version = "PY3",
36+
deps = [
37+
":manual_special_functions",
38+
],
39+
)
40+
41+
multi_substrate_py_library(
42+
name = "manual_special_functions",
43+
srcs = ["manual_special_functions.py"],
44+
srcs_version = "PY3",
45+
deps = [
46+
# numpy dep,
47+
# tensorflow dep,
48+
"//tensorflow_probability/python/internal:custom_gradient",
49+
"//tensorflow_probability/python/internal:dtype_util",
50+
],
51+
)
52+
53+
multi_substrate_py_test(
54+
name = "manual_special_functions_test",
55+
srcs = ["manual_special_functions_test.py"],
56+
srcs_version = "PY3",
57+
tags = [
58+
"tf1-broken",
59+
],
60+
deps = [
61+
# numpy dep,
62+
# tensorflow dep,
63+
"//tensorflow_probability",
64+
"//tensorflow_probability/python/internal:dtype_util",
65+
"//tensorflow_probability/python/internal:test_util",
66+
],
67+
)
68+
69+
exports_files(
70+
glob(["**/*.py"]),
71+
visibility = ["//tensorflow_probability:__subpackages__"],
72+
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2021 The TensorFlow Probability Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""Experimental math."""
16+
17+
from tensorflow_probability.python.experimental.math.manual_special_functions import exp_pade_4_4
18+
from tensorflow_probability.python.experimental.math.manual_special_functions import expm1_pade_4_4
19+
from tensorflow_probability.python.experimental.math.manual_special_functions import log1p_pade_4_4
20+
from tensorflow_probability.python.experimental.math.manual_special_functions import log_pade_4_4_newton
21+
from tensorflow_probability.python.experimental.math.manual_special_functions import patch_manual_special_functions
22+
from tensorflow_probability.python.experimental.math.manual_special_functions import reduce_logsumexp
23+
from tensorflow_probability.python.experimental.math.manual_special_functions import softplus
24+
25+
__all__ = [
26+
'exp_pade_4_4',
27+
'expm1_pade_4_4',
28+
'log1p_pade_4_4',
29+
'log_pade_4_4_newton',
30+
'patch_manual_special_functions',
31+
'reduce_logsumexp',
32+
'softplus',
33+
]

0 commit comments

Comments
 (0)