Skip to content

Commit b74b166

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Enable JAX/NumPy backends for tfp.math.ode.
Along the way I deprecated the use_pfor_to_compute_jacobian arg for ease of implementation. The code now unconditionally tries pfor, and only falls back to loops if an exception is raised (see the behavior inside gradient.py). PiperOrigin-RevId: 388329718
1 parent a100002 commit b74b166

File tree

11 files changed

+548
-399
lines changed

11 files changed

+548
-399
lines changed

tensorflow_probability/python/math/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ multi_substrate_py_library(
3636
substrates_omit_deps = [
3737
":minimize",
3838
":sparse",
39-
"//tensorflow_probability/python/math/ode",
4039
],
4140
deps = [
4241
":bessel",

tensorflow_probability/python/math/ode/BUILD

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
# Description:
1616
# TensorFlow Probability ODE solvers.
1717

18+
load(
19+
"//tensorflow_probability/python:build_defs.bzl",
20+
"multi_substrate_py_library",
21+
"multi_substrate_py_test",
22+
)
23+
1824
package(
1925
default_visibility = [
2026
"//tensorflow_probability:__subpackages__",
@@ -23,18 +29,20 @@ package(
2329

2430
licenses(["notice"])
2531

26-
py_library(
32+
multi_substrate_py_library(
2733
name = "base",
2834
srcs = ["base.py"],
2935
srcs_version = "PY3",
3036
deps = [
3137
# six dep,
3238
# tensorflow dep,
39+
"//tensorflow_probability/python/internal:custom_gradient",
3340
"//tensorflow_probability/python/internal:dtype_util",
41+
"//tensorflow_probability/python/math:gradient",
3442
],
3543
)
3644

37-
py_library(
45+
multi_substrate_py_library(
3846
name = "bdf",
3947
srcs = ["bdf.py"],
4048
srcs_version = "PY3",
@@ -45,10 +53,12 @@ py_library(
4553
# numpy dep,
4654
# tensorflow dep,
4755
"//tensorflow_probability/python/internal:dtype_util",
56+
"//tensorflow_probability/python/internal:prefer_static",
57+
"//tensorflow_probability/python/internal:tensorshape_util",
4858
],
4959
)
5060

51-
py_library(
61+
multi_substrate_py_library(
5262
name = "dormand_prince",
5363
srcs = ["dormand_prince.py"],
5464
srcs_version = "PY3",
@@ -61,17 +71,20 @@ py_library(
6171
],
6272
)
6373

64-
py_library(
74+
multi_substrate_py_library(
6575
name = "bdf_util",
6676
srcs = ["bdf_util.py"],
6777
srcs_version = "PY3",
6878
deps = [
6979
# numpy dep,
7080
# tensorflow dep,
81+
"//tensorflow_probability/python/internal:dtype_util",
82+
"//tensorflow_probability/python/internal:prefer_static",
83+
"//tensorflow_probability/python/internal:tensorshape_util",
7184
],
7285
)
7386

74-
py_test(
87+
multi_substrate_py_test(
7588
name = "bdf_util_test",
7689
size = "small",
7790
srcs = ["bdf_util_test.py"],
@@ -87,7 +100,7 @@ py_test(
87100
],
88101
)
89102

90-
py_library(
103+
multi_substrate_py_library(
91104
name = "runge_kutta_util",
92105
srcs = ["runge_kutta_util.py"],
93106
srcs_version = "PY3",
@@ -99,7 +112,7 @@ py_library(
99112
],
100113
)
101114

102-
py_test(
115+
multi_substrate_py_test(
103116
name = "runge_kutta_util_test",
104117
size = "small",
105118
srcs = ["runge_kutta_util_test.py"],
@@ -113,7 +126,7 @@ py_test(
113126
],
114127
)
115128

116-
py_library(
129+
multi_substrate_py_library(
117130
name = "ode",
118131
srcs = ["__init__.py"],
119132
srcs_version = "PY3",
@@ -124,12 +137,12 @@ py_library(
124137
],
125138
)
126139

127-
py_test(
140+
multi_substrate_py_test(
128141
name = "ode_test",
129142
size = "large",
130143
srcs = ["ode_test.py"],
131144
python_version = "PY3",
132-
shard_count = 6,
145+
shard_count = 8,
133146
srcs_version = "PY3",
134147
deps = [
135148
# absl/testing:parameterized dep,
@@ -183,17 +196,18 @@ py_test(
183196
],
184197
)
185198

186-
py_library(
199+
multi_substrate_py_library(
187200
name = "util",
188201
srcs = ["util.py"],
189202
deps = [
190203
# tensorflow dep,
204+
"//tensorflow_probability/python/internal:dtype_util",
191205
"//tensorflow_probability/python/internal:prefer_static",
192206
"//tensorflow_probability/python/math:gradient",
193207
],
194208
)
195209

196-
py_test(
210+
multi_substrate_py_test(
197211
name = "util_test",
198212
size = "small",
199213
srcs = ["util_test.py"],

0 commit comments

Comments
 (0)