2222import tensorflow .compat .v2 as tf
2323
2424if JAX_MODE or NUMPY_MODE :
25- tnp = np
25+ numpy_ops = np
2626else :
27- import tensorflow .experimental . numpy as tnp
27+ from tensorflow .python . ops import numpy_ops
2828
2929from tensorflow_probability .python .internal import assert_util
3030from tensorflow_probability .python .internal import distribution_util
@@ -739,13 +739,12 @@ def windowed_variance(
739739 Then each element of `low_indices` and `high_indices` must be
740740 between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
741741
742- The shape of indices must be broadcastable with `x` unless the rank is lower
743- than the rank of `x`, then the shape is expanded with extra inner dimensions
744- to match the rank of `x`.
742+ The shape `Bi + [1] + F` must be broadcastable with the shape of `x`.
745743
746- In the special case where the rank of indices is one, i.e when
747- `rank(Bi) = rank(F) = 0`, the indices are reshaped to
748- `[1] * rank(Bx) + [M] + [1] * rank(E)`.
744+ If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded
745+ with extra inner dimensions to match the rank of `x`. In the special
746+ case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`,
747+ the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`.
749748
750749 The default windows are
751750 `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
@@ -801,10 +800,10 @@ def windowed_variance(
801800 def index_for_cumulative (indices ):
802801 return tf .maximum (indices - 1 , 0 )
803802 cum_sums = tf .cumsum (x , axis = axis )
804- sums = tnp .take_along_axis (
803+ sums = numpy_ops .take_along_axis (
805804 cum_sums , index_for_cumulative (indices ), axis = axis )
806805 cum_variances = cumulative_variance (x , sample_axis = axis )
807- variances = tnp .take_along_axis (
806+ variances = numpy_ops .take_along_axis (
808807 cum_variances , index_for_cumulative (indices ), axis = axis )
809808
810809 # This formula is the binary accurate variance merge from [1],
@@ -860,13 +859,12 @@ def windowed_mean(
860859 Then each element of `low_indices` and `high_indices` must be
861860 between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
862861
863- The shape of indices must be broadcastable with `x` unless the rank is lower
864- than the rank of `x`, then the shape is expanded with extra inner dimensions
865- to match the rank of `x`.
862+ The shape `Bi + [1] + F` must be broadcastable with the shape of `x`.
866863
867- In the special case where the rank of indices is one, i.e when
868- `rank(Bi) = rank(F) = 0`, the indices are reshaped to
869- `[1] * rank(Bx) + [M] + [1] * rank(E)`.
864+ If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded
865+ with extra inner dimensions to match the rank of `x`. In the special
866+ case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`,
867+ the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`.
870868
871869 The default windows are
872870 `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
@@ -906,7 +904,7 @@ def windowed_mean(
906904 paddings = ps .reshape (ps .one_hot (2 * axis , depth = 2 * rank , dtype = tf .int32 ),
907905 (rank , 2 ))
908906 cum_sums = ps .pad (raw_cumsum , paddings )
909- sums = tnp .take_along_axis (cum_sums , indices ,
907+ sums = numpy_ops .take_along_axis (cum_sums , indices ,
910908 axis = axis )
911909 counts = ps .cast (indices [1 ] - indices [0 ], dtype = sums .dtype )
912910 return tf .math .divide_no_nan (sums [1 ] - sums [0 ], counts )
@@ -915,7 +913,7 @@ def windowed_mean(
915913def _prepare_window_args (x , low_indices = None , high_indices = None , axis = 0 ):
916914 """Common argument defaulting logic for windowed statistics."""
917915 if high_indices is None :
918- high_indices = tf .range (ps .shape (x )[axis ]) + 1
916+ high_indices = ps .range (ps .shape (x )[axis ]) + 1
919917 else :
920918 high_indices = tf .convert_to_tensor (high_indices )
921919 if low_indices is None :
@@ -941,7 +939,7 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
941939 bc_shape = indices_shape
942940
943941 bc_shape = ps .concat ([[2 ], bc_shape ], axis = 0 )
944- indices = tf .stack ([low_indices , high_indices ], axis = 0 )
942+ indices = ps .stack ([low_indices , high_indices ], axis = 0 )
945943 indices = ps .reshape (indices , bc_shape )
946944 x = tf .expand_dims (x , axis = 0 )
947945 axis += 1
0 commit comments