Skip to content

Commit 26f4f12

Browse files
committed
Expose tensorflow.experimental.numpy API to numpy and jax backends
1 parent 56c5c16 commit 26f4f12

File tree

4 files changed

+11
-14
lines changed

4 files changed

+11
-14
lines changed

tensorflow_probability/python/internal/backend/jax/rewrite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def main(argv):
4545
if FLAGS.rewrite_numpy_import:
4646
contents = contents.replace('\nimport numpy as np',
4747
'\nimport numpy as onp; import jax.numpy as np')
48+
contents = contents.replace('\nimport numpy as tnp',
49+
'\nimport jax.numpy as tnp')
4850
else:
4951
contents = contents.replace('\nimport numpy as np',
5052
'\nimport numpy as np; onp = np')

tensorflow_probability/python/stats/sample_stats.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,7 @@
2020
# Dependency imports
2121
import numpy as np
2222
import tensorflow.compat.v2 as tf
23-
24-
if NUMPY_MODE:
25-
take_along_axis = np.take_along_axis
26-
elif JAX_MODE:
27-
from jax.numpy import take_along_axis
28-
else:
29-
from tensorflow.python.ops.numpy_ops import take_along_axis
23+
import tensorflow.experimental.numpy as tnp
3024

3125
from tensorflow_probability.python.internal import assert_util
3226
from tensorflow_probability.python.internal import distribution_util
@@ -802,10 +796,10 @@ def windowed_variance(
802796
def index_for_cumulative(indices):
803797
return tf.maximum(indices - 1, 0)
804798
cum_sums = tf.cumsum(x, axis=axis)
805-
sums = take_along_axis(
799+
sums = tnp.take_along_axis(
806800
cum_sums, index_for_cumulative(indices), axis=axis)
807801
cum_variances = cumulative_variance(x, sample_axis=axis)
808-
variances = take_along_axis(
802+
variances = tnp.take_along_axis(
809803
cum_variances, index_for_cumulative(indices), axis=axis)
810804

811805
# This formula is the binary accurate variance merge from [1],
@@ -906,7 +900,7 @@ def windowed_mean(
906900
paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32),
907901
(rank, 2))
908902
cum_sums = ps.pad(raw_cumsum, paddings)
909-
sums = take_along_axis(cum_sums, indices, axis=axis)
903+
sums = tnp.take_along_axis(cum_sums, indices, axis=axis)
910904
counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype)
911905
return tf.math.divide_no_nan(sums[1] - sums[0], counts)
912906

tensorflow_probability/python/stats/sample_stats_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -788,17 +788,17 @@ def check_windowed(self, func, numpy_func):
788788
check_fn((64, 4, 8), (2, 4), axis=2)
789789

790790
def test_windowed_mean(self):
791-
self.check_windowed(func=tfp.stats.windowed_mean, numpy_func=np.mean)
791+
self.check_windowed(func=sample_stats.windowed_mean, numpy_func=np.mean)
792792

793793
def test_windowed_mean_graph(self):
794-
func = tf.function(tfp.stats.windowed_mean)
794+
func = tf.function(sample_stats.windowed_mean)
795795
self.check_windowed(func=func, numpy_func=np.mean)
796796

797797
def test_windowed_variance(self):
798-
self.check_windowed(func=tfp.stats.windowed_variance, numpy_func=np.var)
798+
self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var)
799799

800800
def test_windowed_variance_graph(self):
801-
func = tf.function(tfp.stats.windowed_variance)
801+
func = tf.function(sample_stats.windowed_variance)
802802
self.check_windowed(func=func, numpy_func=np.var)
803803

804804

tensorflow_probability/substrates/meta/rewrite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
TF_REPLACEMENTS = {
3030
'import tensorflow ':
3131
'from tensorflow_probability.python.internal.backend import numpy ',
32+
'import tensorflow.experimental.numpy as tnp': 'import numpy as tnp',
3233
'import tensorflow.compat.v1':
3334
'from tensorflow_probability.python.internal.backend.numpy.compat '
3435
'import v1',

0 commit comments

Comments
 (0)