15
15
# Description:
16
16
# TensorFlow Probability ODE solvers.
17
17
18
+ load (
19
+ "//tensorflow_probability/python:build_defs.bzl" ,
20
+ "multi_substrate_py_library" ,
21
+ "multi_substrate_py_test" ,
22
+ )
23
+
18
24
package (
19
25
default_visibility = [
20
26
"//tensorflow_probability:__subpackages__" ,
@@ -23,18 +29,20 @@ package(
23
29
24
30
licenses (["notice" ])
25
31
26
- py_library (
32
+ multi_substrate_py_library (
27
33
name = "base" ,
28
34
srcs = ["base.py" ],
29
35
srcs_version = "PY3" ,
30
36
deps = [
31
37
# six dep,
32
38
# tensorflow dep,
39
+ "//tensorflow_probability/python/internal:custom_gradient" ,
33
40
"//tensorflow_probability/python/internal:dtype_util" ,
41
+ "//tensorflow_probability/python/math:gradient" ,
34
42
],
35
43
)
36
44
37
- py_library (
45
+ multi_substrate_py_library (
38
46
name = "bdf" ,
39
47
srcs = ["bdf.py" ],
40
48
srcs_version = "PY3" ,
@@ -45,10 +53,12 @@ py_library(
45
53
# numpy dep,
46
54
# tensorflow dep,
47
55
"//tensorflow_probability/python/internal:dtype_util" ,
56
+ "//tensorflow_probability/python/internal:prefer_static" ,
57
+ "//tensorflow_probability/python/internal:tensorshape_util" ,
48
58
],
49
59
)
50
60
51
- py_library (
61
+ multi_substrate_py_library (
52
62
name = "dormand_prince" ,
53
63
srcs = ["dormand_prince.py" ],
54
64
srcs_version = "PY3" ,
@@ -61,17 +71,20 @@ py_library(
61
71
],
62
72
)
63
73
64
- py_library (
74
+ multi_substrate_py_library (
65
75
name = "bdf_util" ,
66
76
srcs = ["bdf_util.py" ],
67
77
srcs_version = "PY3" ,
68
78
deps = [
69
79
# numpy dep,
70
80
# tensorflow dep,
81
+ "//tensorflow_probability/python/internal:dtype_util" ,
82
+ "//tensorflow_probability/python/internal:prefer_static" ,
83
+ "//tensorflow_probability/python/internal:tensorshape_util" ,
71
84
],
72
85
)
73
86
74
- py_test (
87
+ multi_substrate_py_test (
75
88
name = "bdf_util_test" ,
76
89
size = "small" ,
77
90
srcs = ["bdf_util_test.py" ],
@@ -87,7 +100,7 @@ py_test(
87
100
],
88
101
)
89
102
90
- py_library (
103
+ multi_substrate_py_library (
91
104
name = "runge_kutta_util" ,
92
105
srcs = ["runge_kutta_util.py" ],
93
106
srcs_version = "PY3" ,
@@ -99,7 +112,7 @@ py_library(
99
112
],
100
113
)
101
114
102
- py_test (
115
+ multi_substrate_py_test (
103
116
name = "runge_kutta_util_test" ,
104
117
size = "small" ,
105
118
srcs = ["runge_kutta_util_test.py" ],
@@ -113,7 +126,7 @@ py_test(
113
126
],
114
127
)
115
128
116
- py_library (
129
+ multi_substrate_py_library (
117
130
name = "ode" ,
118
131
srcs = ["__init__.py" ],
119
132
srcs_version = "PY3" ,
@@ -124,12 +137,12 @@ py_library(
124
137
],
125
138
)
126
139
127
- py_test (
140
+ multi_substrate_py_test (
128
141
name = "ode_test" ,
129
142
size = "large" ,
130
143
srcs = ["ode_test.py" ],
131
144
python_version = "PY3" ,
132
- shard_count = 6 ,
145
+ shard_count = 8 ,
133
146
srcs_version = "PY3" ,
134
147
deps = [
135
148
# absl/testing:parameterized dep,
@@ -183,17 +196,18 @@ py_test(
183
196
],
184
197
)
185
198
186
- py_library (
199
+ multi_substrate_py_library (
187
200
name = "util" ,
188
201
srcs = ["util.py" ],
189
202
deps = [
190
203
# tensorflow dep,
204
+ "//tensorflow_probability/python/internal:dtype_util" ,
191
205
"//tensorflow_probability/python/internal:prefer_static" ,
192
206
"//tensorflow_probability/python/math:gradient" ,
193
207
],
194
208
)
195
209
196
- py_test (
210
+ multi_substrate_py_test (
197
211
name = "util_test" ,
198
212
size = "small" ,
199
213
srcs = ["util_test.py" ],
0 commit comments