Skip to content

Commit c28faa5

Browse files
committed
Allow lower rank indices
1 parent e020543 commit c28faa5

File tree

2 files changed

+68
-15
lines changed

2 files changed

+68
-15
lines changed

tensorflow_probability/python/stats/sample_stats.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -735,12 +735,18 @@ def windowed_variance(
735735
last half of an MCMC chain.
736736
737737
Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
738-
have shape `Bi + [M] + F`, such that:
739-
- `rank(Bx) = rank(Bi) = axis`,
740-
- `Bi + [1] + F` broadcasts to `Bx + [N] + E`.
738+
have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`.
741739
Then each element of `low_indices` and `high_indices` must be
742740
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
743741
742+
The shape of indices must be broadcastable with `x` unless the rank is lower
743+
than the rank of `x`, then the shape is expanded with extra inner dimensions
744+
to match the rank of `x`.
745+
746+
In the special case where the rank of indices is one, i.e when
747+
`rank(Bi) = rank(F) = 0`, the indices are reshaped to
748+
`[1] * rank(Bx) + [M] + [1] * rank(E)`.
749+
744750
The default windows are
745751
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
746752
This corresponds to analyzing `x` as though it were streaming, for
@@ -850,12 +856,18 @@ def windowed_mean(
850856
last half of an MCMC chain.
851857
852858
Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
853-
have shape `Bi + [M] + F`, such that:
854-
- `rank(Bx) = rank(Bi) = axis`,
855-
- `Bi + [1] + F` broadcasts to `Bx + [N] + E`.
859+
have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`.
856860
Then each element of `low_indices` and `high_indices` must be
857861
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
858862
863+
The shape of indices must be broadcastable with `x` unless the rank is lower
864+
than the rank of `x`, then the shape is expanded with extra inner dimensions
865+
to match the rank of `x`.
866+
867+
In the special case where the rank of indices is one, i.e when
868+
`rank(Bi) = rank(F) = 0`, the indices are reshaped to
869+
`[1] * rank(Bx) + [M] + [1] * rank(E)`.
870+
859871
The default windows are
860872
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
861873
This corresponds to analyzing `x` as though it were streaming, for
@@ -913,17 +925,26 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
913925
# Broadcast indices together.
914926
high_indices = high_indices + tf.zeros_like(low_indices)
915927
low_indices = low_indices + tf.zeros_like(high_indices)
916-
indices = ps.stack([low_indices, high_indices], axis=0)
928+
929+
indices_shape = ps.shape(low_indices)
930+
if ps.rank(low_indices) < ps.rank(x):
931+
if ps.rank(low_indices) == 1:
932+
size = ps.size(low_indices)
933+
bc_shape = ps.one_hot(axis, depth=ps.rank(x), on_value=size,
934+
off_value=1)
935+
else:
936+
# we assume the first dimensions are broadcastable with `x`,
937+
# we add trailing dimensions
938+
extra_dims = ps.rank(x) - ps.rank(low_indices)
939+
bc_shape = ps.concat([indices_shape, [1]*extra_dims], axis=0)
940+
else:
941+
bc_shape = indices_shape
942+
943+
bc_shape = ps.concat([[2], bc_shape], axis=0)
944+
indices = tf.stack([low_indices, high_indices], axis=0)
945+
indices = ps.reshape(indices, bc_shape)
917946
x = tf.expand_dims(x, axis=0)
918947
axis += 1
919-
920-
if ps.rank(indices) != ps.rank(x) and ps.rank(indices) == 2:
921-
# legacy usage, kept for backward compatibility
922-
size = ps.size(indices) // 2
923-
bc_shape = ps.one_hot(axis, depth=ps.rank(x), on_value=size,
924-
off_value=1)
925-
bc_shape = ps.concat([[2], bc_shape[1:]], axis=0)
926-
indices = ps.reshape(indices, bc_shape)
927948
# `take_along_axis` requires the type to be int32
928949
indices = ps.cast(indices, dtype=tf.int32)
929950
return x, indices, axis

tensorflow_probability/python/stats/sample_stats_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,19 @@ def test_windowed_mean_corner_cases(self):
681681

682682
@test_util.test_all_tf_execution_regimes
683683
class WindowedStatsTest(test_util.TestCase):
684+
685+
def _maybe_expand_dims_to_make_broadcastable(self, x, shape, axis):
686+
if len(shape) > len(x.shape):
687+
if len(x.shape) == 1:
688+
bc_shape = np.ones(len(shape), dtype=np.int32)
689+
bc_shape[axis] = x.shape[0]
690+
return x.reshape(bc_shape)
691+
else:
692+
extra_dims = len(shape) - len(x.shape)
693+
bc_shape = x.shape + (1,) * extra_dims
694+
return x.reshape(bc_shape)
695+
return x
696+
684697
def apply_slice_along_axis(self, func, arr, low, high, axis):
685698
"""Applies `func` over slices of `arr` along `axis`. Slices intervals are
686699
specified through `low` and `high`. Support broadcasting.
@@ -705,6 +718,7 @@ def apply_slice_along_axis(self, func, arr, low, high, axis):
705718
for r in range(j):
706719
out_1d[r] = func(a_1d[low_1d[r]:high_1d[r]])
707720
return out
721+
708722
def check_gaussian_windowed(self, shape, indice_shape, axis,
709723
window_func, np_func):
710724
stat_shape = np.array(shape).astype(np.int32)
@@ -717,6 +731,10 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
717731
indices = rng.randint(shape[axis] + 1, size=indice_shape)
718732
indices = np.sort(indices, axis=0)
719733
low_indices, high_indices = indices[0], indices[1]
734+
low_indices = self._maybe_expand_dims_to_make_broadcastable(
735+
low_indices, x.shape, axis)
736+
high_indices = self._maybe_expand_dims_to_make_broadcastable(
737+
high_indices, x.shape, axis)
720738
a = window_func(x, low_indices=low_indices,
721739
high_indices=high_indices, axis=axis)
722740
b = self.apply_slice_along_axis(np_func, x, low_indices, high_indices,
@@ -732,20 +750,34 @@ def check_windowed(self, func, numpy_func):
732750
check_fn((64, 4, 8), (32, 4, 1), axis=0)
733751
check_fn((64, 4, 8), (32, 4, 8), axis=0)
734752
check_fn((64, 4, 8), (64, 4, 8), axis=0)
753+
check_fn((64, 4, 8), (128, 1), axis=0)
754+
check_fn((64, 4, 8), (32,), axis=0)
755+
check_fn((64, 4, 8), (32, 4), axis=0)
756+
735757
check_fn((64, 4, 8), (64, 64, 1), axis=1)
736758
check_fn((64, 4, 8), (1, 64, 1), axis=1)
737759
check_fn((64, 4, 8), (64, 2, 8), axis=1)
738760
check_fn((64, 4, 8), (64, 4, 8), axis=1)
761+
check_fn((64, 4, 8), (16,), axis=1)
762+
check_fn((64, 4, 8), (1, 64), axis=1)
763+
739764
check_fn((64, 4, 8), (64, 4, 64), axis=2)
740765
check_fn((64, 4, 8), (1, 1, 64), axis=2)
741766
check_fn((64, 4, 8), (64, 4, 4), axis=2)
742767
check_fn((64, 4, 8), (1, 1, 4), axis=2)
743768
check_fn((64, 4, 8), (64, 4, 8), axis=2)
769+
check_fn((64, 4, 8), (16,), axis=2)
770+
check_fn((64, 4, 8), (1, 4), axis=2)
771+
check_fn((64, 4, 8), (64, 4), axis=2)
744772

745773
with self.assertRaises(Exception):
746774
# Non broadcastable shapes
747775
check_fn((64, 4, 8), (4, 1, 4), axis=2)
748776

777+
with self.assertRaises(Exception):
778+
# Non broadcastable shapes
779+
check_fn((64, 4, 8), (2, 4), axis=2)
780+
749781
def test_windowed_mean(self):
750782
self.check_windowed(func=tfp.stats.windowed_mean, numpy_func=np.mean)
751783

0 commit comments

Comments
 (0)