Skip to content

Commit a602a8c

Browse files
committed
Doc fix
Replace `**2` with `tf.square`
1 parent 43e874a commit a602a8c

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

tensorflow_probability/python/stats/sample_stats.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -702,8 +702,8 @@ def cumulative_variance(x, sample_axis=0, name=None):
702702
excl_counts = tf.reshape(tf.range(size, dtype=x.dtype), shape=counts_shp)
703703
incl_counts = excl_counts + 1
704704
excl_sums = tf.cumsum(x, axis=sample_axis, exclusive=True)
705-
discrepancies = (excl_sums / excl_counts - x)**2
706-
discrepancies = tf.where(excl_counts == 0, x**2, discrepancies)
705+
discrepancies = tf.math.square(excl_sums / excl_counts - x)
706+
discrepancies = tf.where(excl_counts == 0, tf.math.square(x), discrepancies)
707707
adjustments = excl_counts / incl_counts
708708
# The zeroth item's residual contribution is 0, because it has no
709709
# other items to vary from. The preceding expressions, however,
@@ -734,8 +734,10 @@ def windowed_variance(
734734
trailing-window estimators from some iterative process, such as the
735735
last half of an MCMC chain.
736736
737-
Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has
738-
rank `axis`, and `low_indices` and `high_indices` broadcast to `x`.
737+
Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
738+
have shape `Bi + [M] + F`, such that:
739+
- `rank(Bx) = rank(Bi) = axis`,
740+
- `Bi + [1] + F` broadcasts to `Bx + [N] + E`.
739741
Then each element of `low_indices` and `high_indices` must be
740742
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
741743
@@ -847,8 +849,10 @@ def windowed_mean(
847849
trailing-window estimators from some iterative process, such as the
848850
last half of an MCMC chain.
849851
850-
Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has
851-
rank `axis`, and `low_indices` and `high_indices` broadcast to `x`.
852+
Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
853+
have shape `Bi + [M] + F`, such that:
854+
- `rank(Bx) = rank(Bi) = axis`,
855+
- `Bi + [1] + F` broadcasts to `Bx + [N] + E`.
852856
Then each element of `low_indices` and `high_indices` must be
853857
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
854858

0 commit comments

Comments
 (0)