@@ -702,8 +702,8 @@ def cumulative_variance(x, sample_axis=0, name=None):
702
702
excl_counts = tf .reshape (tf .range (size , dtype = x .dtype ), shape = counts_shp )
703
703
incl_counts = excl_counts + 1
704
704
excl_sums = tf .cumsum (x , axis = sample_axis , exclusive = True )
705
- discrepancies = (excl_sums / excl_counts - x )** 2
706
- discrepancies = tf .where (excl_counts == 0 , x ** 2 , discrepancies )
705
+ discrepancies = tf . math . square (excl_sums / excl_counts - x )
706
+ discrepancies = tf .where (excl_counts == 0 , tf . math . square ( x ) , discrepancies )
707
707
adjustments = excl_counts / incl_counts
708
708
# The zeroth item's residual contribution is 0, because it has no
709
709
# other items to vary from. The preceding expressions, however,
@@ -734,8 +734,10 @@ def windowed_variance(
734
734
trailing-window estimators from some iterative process, such as the
735
735
last half of an MCMC chain.
736
736
737
- Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has
738
- rank `axis`, and `low_indices` and `high_indices` broadcast to `x`.
737
+ Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
738
+ have shape `Bi + [M] + F`, such that:
739
+ - `rank(Bx) = rank(Bi) = axis`,
740
+ - `Bi + [1] + F` broadcasts to `Bx + [N] + E`.
739
741
Then each element of `low_indices` and `high_indices` must be
740
742
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
741
743
@@ -847,8 +849,10 @@ def windowed_mean(
847
849
trailing-window estimators from some iterative process, such as the
848
850
last half of an MCMC chain.
849
851
850
- Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has
851
- rank `axis`, and `low_indices` and `high_indices` broadcast to `x`.
852
+ Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
853
+ have shape `Bi + [M] + F`, such that:
854
+ - `rank(Bx) = rank(Bi) = axis`,
855
+ - `Bi + [1] + F` broadcasts to `Bx + [N] + E`.
852
856
Then each element of `low_indices` and `high_indices` must be
853
857
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
854
858
0 commit comments