@@ -702,8 +702,8 @@ def cumulative_variance(x, sample_axis=0, name=None):
702702 excl_counts = tf .reshape (tf .range (size , dtype = x .dtype ), shape = counts_shp )
703703 incl_counts = excl_counts + 1
704704 excl_sums = tf .cumsum (x , axis = sample_axis , exclusive = True )
705- discrepancies = (excl_sums / excl_counts - x )** 2
706- discrepancies = tf .where (excl_counts == 0 , x ** 2 , discrepancies )
705+ discrepancies = tf . math . square (excl_sums / excl_counts - x )
706+ discrepancies = tf .where (excl_counts == 0 , tf . math . square ( x ) , discrepancies )
707707 adjustments = excl_counts / incl_counts
708708 # The zeroth item's residual contribution is 0, because it has no
709709 # other items to vary from. The preceding expressions, however,
@@ -734,8 +734,10 @@ def windowed_variance(
734734 trailing-window estimators from some iterative process, such as the
735735 last half of an MCMC chain.
736736
737- Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has
738- rank `axis`, and `low_indices` and `high_indices` broadcast to `x`.
737+ Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
738+ have shape `Bi + [M] + F`, such that:
739+ - `rank(Bx) = rank(Bi) = axis`,
740+ - `Bi + [1] + F` broadcasts to `Bx + [N] + E`.
739741 Then each element of `low_indices` and `high_indices` must be
740742 between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
741743
@@ -847,8 +849,10 @@ def windowed_mean(
847849 trailing-window estimators from some iterative process, such as the
848850 last half of an MCMC chain.
849851
850- Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has
851- rank `axis`, and `low_indices` and `high_indices` broadcast to `x`.
852+ Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
853+ have shape `Bi + [M] + F`, such that:
854+ - `rank(Bx) = rank(Bi) = axis`,
855+ - `Bi + [1] + F` broadcasts to `Bx + [N] + E`.
852856 Then each element of `low_indices` and `high_indices` must be
853857 between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
854858
0 commit comments