Skip to content

Commit 4d8ce35

Browse files
Merge pull request #1270 from gisilvs:cascading_flow_vi
PiperOrigin-RevId: 374964013
2 parents 5838e0e + 430d151 commit 4d8ce35

File tree

4 files changed

+621
-1
lines changed

4 files changed

+621
-1
lines changed

tensorflow_probability/python/experimental/bijectors/BUILD

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ multi_substrate_py_library(
3636
srcs_version = "PY3",
3737
deps = [
3838
":distribution_bijectors",
39+
":highway_flow",
3940
":scalar_function_with_inferred_inverse",
4041
":sharded",
4142
"//tensorflow_probability/python/bijectors:ldj_ratio",
@@ -104,6 +105,37 @@ multi_substrate_py_test(
104105
],
105106
)
106107

108+
multi_substrate_py_library(
109+
name = "highway_flow",
110+
srcs = ["highway_flow.py"],
111+
srcs_version = "PY3",
112+
deps = [
113+
":scalar_function_with_inferred_inverse",
114+
# numpy dep,
115+
# tensorflow dep,
116+
"//tensorflow_probability/python/bijectors",
117+
"//tensorflow_probability/python/internal:samplers",
118+
"//tensorflow_probability/python/util",
119+
],
120+
)
121+
122+
multi_substrate_py_test(
123+
name = "highway_flow_test",
124+
size = "medium",
125+
srcs = ["highway_flow_test.py"],
126+
disabled_substrates = ["numpy"],
127+
jax_size = "medium",
128+
python_version = "PY3",
129+
srcs_version = "PY3",
130+
deps = [
131+
# numpy dep,
132+
# tensorflow dep,
133+
"//tensorflow_probability",
134+
"//tensorflow_probability/python/bijectors:bijector_test_util",
135+
"//tensorflow_probability/python/internal:test_util",
136+
],
137+
)
138+
107139
multi_substrate_py_library(
108140
name = "sharded",
109141
srcs = ["sharded.py"],

tensorflow_probability/python/experimental/bijectors/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717
from tensorflow_probability.python.bijectors.ldj_ratio import forward_log_det_jacobian_ratio
1818
from tensorflow_probability.python.bijectors.ldj_ratio import inverse_log_det_jacobian_ratio
1919
from tensorflow_probability.python.experimental.bijectors.distribution_bijectors import make_distribution_bijector
20+
from tensorflow_probability.python.experimental.bijectors.highway_flow import build_trainable_highway_flow
21+
from tensorflow_probability.python.experimental.bijectors.highway_flow import HighwayFlow
2022
from tensorflow_probability.python.experimental.bijectors.scalar_function_with_inferred_inverse import ScalarFunctionWithInferredInverse
2123
from tensorflow_probability.python.experimental.bijectors.sharded import Sharded
2224

23-
2425
__all__ = [
26+
'build_trainable_highway_flow',
2527
'forward_log_det_jacobian_ratio',
2628
'inverse_log_det_jacobian_ratio',
2729
'make_distribution_bijector',
30+
'HighwayFlow',
2831
'ScalarFunctionWithInferredInverse',
2932
'Sharded',
3033
]

0 commit comments

Comments
 (0)