Skip to content

Commit 2675210

Browse files
committed
Enable batch support for windowed_mean|variance
1 parent 9a34093 commit 2675210

File tree

2 files changed

+135
-56
lines changed

2 files changed

+135
-56
lines changed

tensorflow_probability/python/stats/sample_stats.py

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,18 @@
1414
# ============================================================================
1515
"""Functions for computing statistics of samples."""
1616

17+
JAX_MODE = False
18+
NUMPY_MODE = False
19+
1720
# Dependency imports
1821
import numpy as np
1922
import tensorflow.compat.v2 as tf
2023

24+
if JAX_MODE or NUMPY_MODE:
25+
tnp = np
26+
else:
27+
import tensorflow.experimental.numpy as tnp
28+
2129
from tensorflow_probability.python.internal import assert_util
2230
from tensorflow_probability.python.internal import distribution_util
2331
from tensorflow_probability.python.internal import dtype_util
@@ -712,7 +720,7 @@ def windowed_variance(
712720
713721
Computes variances among data in the Tensor `x` along the given windows:
714722
715-
result[i] = variance(x[low_indices[i]:high_indices[i]+1])
723+
result[i] = variance(x[low_indices[i]:high_indices[i]])
716724
717725
accurately and efficiently. To wit, if K is the size of
718726
`low_indices` and `high_indices`, and `N` is the size of `x` along
@@ -727,10 +735,9 @@ def windowed_variance(
727735
last half of an MCMC chain.
728736
729737
Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has
730-
rank `axis`, and `low_indices` and `high_indices` broadcast to shape
731-
`[M]`. Then each element of `low_indices` and `high_indices`
732-
must be between 0 and N+1, and the shape of the output will be
733-
`Bx + [M] + E`. Batch shape in the indices is not currently supported.
738+
rank `axis`, and `low_indices` and `high_indices` broadcast to `x`.
739+
Then each element of `low_indices` and `high_indices` must be
740+
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
734741
735742
The default windows are
736743
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
@@ -769,7 +776,7 @@ def windowed_variance(
769776
"""
770777
with tf.name_scope(name or 'windowed_variance'):
771778
x = tf.convert_to_tensor(x)
772-
low_indices, high_indices, low_counts, high_counts = _prepare_window_args(
779+
x, indices, axis = _prepare_window_args(
773780
x, low_indices, high_indices, axis)
774781

775782
# We have a problem with indexing: the standard convention demands
@@ -786,15 +793,11 @@ def windowed_variance(
786793
def index_for_cumulative(indices):
787794
return tf.maximum(indices - 1, 0)
788795
cum_sums = tf.cumsum(x, axis=axis)
789-
low_sums = tf.gather(
790-
cum_sums, index_for_cumulative(low_indices), axis=axis)
791-
high_sums = tf.gather(
792-
cum_sums, index_for_cumulative(high_indices), axis=axis)
796+
sums = tnp.take_along_axis(
797+
cum_sums, index_for_cumulative(indices), axis=axis)
793798
cum_variances = cumulative_variance(x, sample_axis=axis)
794-
low_variances = tf.gather(
795-
cum_variances, index_for_cumulative(low_indices), axis=axis)
796-
high_variances = tf.gather(
797-
cum_variances, index_for_cumulative(high_indices), axis=axis)
799+
variances = tnp.take_along_axis(
800+
cum_variances, index_for_cumulative(indices), axis=axis)
798801

799802
# This formula is the binary accurate variance merge from [1],
800803
# adapted to subtract and batched across the indexed counts, sums,
@@ -812,15 +815,18 @@ def index_for_cumulative(indices):
812815
# This formula can also be read as implementing the above variance
813816
# computation by "unioning" A u B with a notional "negative B"
814817
# multiset.
815-
counts = high_counts - low_counts # |A|
816-
discrepancies = (
817-
_safe_average(high_sums, high_counts) -
818-
_safe_average(low_sums, low_counts))**2 # (mean(A u B) - mean(B))**2
819-
adjustments = high_counts * (-low_counts) / counts # |A u B| * -|B| / |A|
820-
residuals = (high_variances * high_counts -
821-
low_variances * low_counts +
818+
bounds = ps.cast(indices, sums.dtype)
819+
counts = bounds[1] - bounds[0] # |A|
820+
sum_averages = tf.math.divide_no_nan(sums, bounds)
821+
# (mean(A u B) - mean(B))**2
822+
discrepancies = tf.square(sum_averages[1] - sum_averages[0])
823+
# |A u B| * -|B| / |A|
824+
adjustments = tf.math.divide_no_nan(bounds[1] * (-bounds[0]), counts)
825+
variances_scaled = variances * bounds
826+
residuals = (variances_scaled[1] -
827+
variances_scaled[0] +
822828
adjustments * discrepancies)
823-
return _safe_average(residuals, counts)
829+
return tf.math.divide_no_nan(residuals, counts)
824830

825831

826832
def windowed_mean(
@@ -829,7 +835,7 @@ def windowed_mean(
829835
830836
Computes means among data in the Tensor `x` along the given windows:
831837
832-
result[i] = mean(x[low_indices[i]:high_indices[i]+1])
838+
result[i] = mean(x[low_indices[i]:high_indices[i]])
833839
834840
efficiently. To wit, if K is the size of `low_indices` and
835841
`high_indices`, and `N` is the size of `x` along the given `axis`,
@@ -842,10 +848,9 @@ def windowed_mean(
842848
last half of an MCMC chain.
843849
844850
Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has
845-
rank `axis`, and `low_indices` and `high_indices` broadcast to shape
846-
`[M]`. Then each element of `low_indices` and `high_indices`
847-
must be between 0 and N+1, and the shape of the output will be
848-
`Bx + [M] + E`. Batch shape in the indices is not currently supported.
851+
rank `axis`, and `low_indices` and `high_indices` broadcast to `x`.
852+
Then each element of `low_indices` and `high_indices` must be
853+
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
849854
850855
The default windows are
851856
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
@@ -878,18 +883,17 @@ def windowed_mean(
878883
"""
879884
with tf.name_scope(name or 'windowed_mean'):
880885
x = tf.convert_to_tensor(x)
881-
low_indices, high_indices, low_counts, high_counts = _prepare_window_args(
882-
x, low_indices, high_indices, axis)
886+
x, indices, axis = _prepare_window_args(x, low_indices, high_indices, axis)
883887

884888
raw_cumsum = tf.cumsum(x, axis=axis)
885-
cum_sums = tf.concat(
886-
[tf.zeros_like(tf.gather(raw_cumsum, [0], axis=axis)), raw_cumsum],
887-
axis=axis)
888-
low_sums = tf.gather(cum_sums, low_indices, axis=axis)
889-
high_sums = tf.gather(cum_sums, high_indices, axis=axis)
890-
891-
counts = high_counts - low_counts
892-
return _safe_average(high_sums - low_sums, counts)
889+
rank = ps.rank(x)
890+
paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32),
891+
(rank, 2))
892+
cum_sums = ps.pad(raw_cumsum, paddings)
893+
sums = tnp.take_along_axis(cum_sums, indices,
894+
axis=axis)
895+
counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype)
896+
return tf.math.divide_no_nan(sums[1] - sums[0], counts)
893897

894898

895899
def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
@@ -905,24 +909,20 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
905909
# Broadcast indices together.
906910
high_indices = high_indices + tf.zeros_like(low_indices)
907911
low_indices = low_indices + tf.zeros_like(high_indices)
908-
909-
# TODO(axch): Support batch low and high indices. That would
910-
# complicate this shape munging (though tf.gather should work
911-
# fine).
912-
913-
# We want to place `low_counts` and `high_counts` at the `axis`
914-
# position, so we reshape them to shape `[1, 1, ..., 1, N, 1, ...,
915-
# 1]`, where the `N` is at `axis`. The `counts_shp`, below,
916-
# is this shape.
917-
size = ps.size(high_indices)
918-
counts_shp = ps.one_hot(
919-
axis, depth=ps.rank(x), on_value=size, off_value=1)
920-
921-
low_counts = tf.reshape(tf.cast(low_indices, dtype=x.dtype),
922-
shape=counts_shp)
923-
high_counts = tf.reshape(tf.cast(high_indices, dtype=x.dtype),
924-
shape=counts_shp)
925-
return low_indices, high_indices, low_counts, high_counts
912+
indices = ps.stack([low_indices, high_indices], axis=0)
913+
x = tf.expand_dims(x, axis=0)
914+
axis += 1
915+
916+
if ps.rank(indices) != ps.rank(x) and ps.rank(indices) == 2:
917+
# legacy usage, kept for backward compatibility
918+
size = ps.size(indices) // 2
919+
bc_shape = ps.one_hot(axis, depth=ps.rank(x), on_value=size,
920+
off_value=1)
921+
bc_shape = ps.concat([[2], bc_shape[1:]], axis=0)
922+
indices = ps.reshape(indices, bc_shape)
923+
# `take_along_axis` requires the type to be int32
924+
indices = ps.cast(indices, dtype=tf.int32)
925+
return x, indices, axis
926926

927927

928928
def _safe_average(totals, counts):

tensorflow_probability/python/stats/sample_stats_test.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Tests for Sample Stats Ops."""
1616

1717
# Dependency imports
18-
18+
import functools
1919
import numpy as np
2020
import tensorflow.compat.v1 as tf1
2121
import tensorflow.compat.v2 as tf
@@ -679,6 +679,85 @@ def test_windowed_mean_corner_cases(self):
679679
self.evaluate(sample_stats.windowed_mean(y)))
680680

681681

682+
@test_util.test_all_tf_execution_regimes
683+
class WindowedStatsTest(test_util.TestCase):
684+
def apply_slice_along_axis(self, func, arr, low, high, axis):
685+
"""Applies `func` over slices of `arr` along `axis`. Slices intervals are
686+
specified through `low` and `high`. Support broadcasting.
687+
"""
688+
np.testing.assert_equal(low.shape, high.shape)
689+
ni, _, nk = arr.shape[:axis], arr.shape[axis], arr.shape[axis + 1:]
690+
si, j, sk = low.shape[:axis], low.shape[axis], low.shape[axis + 1:]
691+
mk = max(nk, sk)
692+
mi = max(ni, si)
693+
out = np.empty(mi + (j,) + mk)
694+
for ki in np.ndindex(ni):
695+
for kk in np.ndindex(mk):
696+
ak = tuple(np.mod(kk, nk))
697+
ik = tuple(np.mod(kk, sk))
698+
ai = tuple(np.mod(ki, ni))
699+
ii = tuple(np.mod(ki, si))
700+
a_1d = arr[ai + np.s_[:, ] + ak]
701+
out_1d = out[ki + np.s_[:, ] + kk]
702+
low_1d = low[ii + np.s_[:, ] + ik]
703+
high_1d = high[ii + np.s_[:, ] + ik]
704+
705+
for r in range(j):
706+
out_1d[r] = func(a_1d[low_1d[r]:high_1d[r]])
707+
return out
708+
def check_gaussian_windowed(self, shape, indice_shape, axis,
709+
window_func, np_func):
710+
stat_shape = np.array(shape).astype(np.int32)
711+
stat_shape[axis] = 1
712+
loc = np.arange(np.prod(stat_shape)).reshape(stat_shape)
713+
scale = 0.1 * np.arange(np.prod(stat_shape)).reshape(stat_shape)
714+
rng = test_util.test_np_rng()
715+
x = rng.normal(loc=loc, scale=scale, size=shape)
716+
indice_shape = [2] + list(indice_shape)
717+
indices = rng.randint(shape[axis] + 1, size=indice_shape)
718+
indices = np.sort(indices, axis=0)
719+
low_indices, high_indices = indices[0], indices[1]
720+
a = window_func(x, low_indices=low_indices,
721+
high_indices=high_indices, axis=axis)
722+
b = self.apply_slice_along_axis(np_func, x, low_indices, high_indices,
723+
axis=axis)
724+
b[np.isnan(b)] = 0 # We treat stats computed on empty sets as zeros
725+
self.assertAllClose(a, b)
726+
727+
def check_windowed(self, func, numpy_func):
728+
check_fn = functools.partial(self.check_gaussian_windowed,
729+
window_func=func, np_func=numpy_func)
730+
check_fn((64, 4, 8), (128, 1, 1), axis=0)
731+
check_fn((64, 4, 8), (32, 1, 1), axis=0)
732+
check_fn((64, 4, 8), (32, 4, 1), axis=0)
733+
check_fn((64, 4, 8), (32, 4, 8), axis=0)
734+
check_fn((64, 4, 8), (64, 64, 1), axis=1)
735+
check_fn((64, 4, 8), (1, 64, 1), axis=1)
736+
check_fn((64, 4, 8), (64, 2, 8), axis=1)
737+
check_fn((64, 4, 8), (64, 4, 64), axis=2)
738+
check_fn((64, 4, 8), (1, 1, 64), axis=2)
739+
check_fn((64, 4, 8), (64, 4, 4), axis=2)
740+
check_fn((64, 4, 8), (1, 1, 4), axis=2)
741+
742+
with self.assertRaises(Exception):
743+
# Non broadcastable shapes
744+
check_fn((64, 4, 8), (4, 1, 4), axis=2)
745+
746+
def test_windowed_mean(self):
747+
self.check_windowed(func=tfp.stats.windowed_mean, numpy_func=np.mean)
748+
749+
def test_windowed_mean_graph(self):
750+
func = tf.function(tfp.stats.windowed_mean)
751+
self.check_windowed(func=func, numpy_func=np.mean)
752+
753+
def test_windowed_variance(self):
754+
self.check_windowed(func=tfp.stats.windowed_variance, numpy_func=np.var)
755+
756+
def test_windowed_variance_graph(self):
757+
func = tf.function(tfp.stats.windowed_variance)
758+
self.check_windowed(func=func, numpy_func=np.var)
759+
760+
682761
@test_util.test_all_tf_execution_regimes
683762
class LogAverageProbsTest(test_util.TestCase):
684763

0 commit comments

Comments
 (0)