14
14
# ============================================================================
15
15
"""Functions for computing statistics of samples."""
16
16
17
+ JAX_MODE = False
18
+ NUMPY_MODE = False
19
+
17
20
# Dependency imports
18
21
import numpy as np
19
22
import tensorflow .compat .v2 as tf
20
23
24
+ if JAX_MODE or NUMPY_MODE :
25
+ tnp = np
26
+ else :
27
+ import tensorflow .experimental .numpy as tnp
28
+
21
29
from tensorflow_probability .python .internal import assert_util
22
30
from tensorflow_probability .python .internal import distribution_util
23
31
from tensorflow_probability .python .internal import dtype_util
@@ -712,7 +720,7 @@ def windowed_variance(
712
720
713
721
Computes variances among data in the Tensor `x` along the given windows:
714
722
715
- result[i] = variance(x[low_indices[i]:high_indices[i]+1 ])
723
+ result[i] = variance(x[low_indices[i]:high_indices[i]])
716
724
717
725
accurately and efficiently. To wit, if K is the size of
718
726
`low_indices` and `high_indices`, and `N` is the size of `x` along
@@ -727,10 +735,9 @@ def windowed_variance(
727
735
last half of an MCMC chain.
728
736
729
737
Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has
730
- rank `axis`, and `low_indices` and `high_indices` broadcast to shape
731
- `[M]`. Then each element of `low_indices` and `high_indices`
732
- must be between 0 and N+1, and the shape of the output will be
733
- `Bx + [M] + E`. Batch shape in the indices is not currently supported.
738
+ rank `axis`, and `low_indices` and `high_indices` broadcast to `x`.
739
+ Then each element of `low_indices` and `high_indices` must be
740
+ between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
734
741
735
742
The default windows are
736
743
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
@@ -769,7 +776,7 @@ def windowed_variance(
769
776
"""
770
777
with tf .name_scope (name or 'windowed_variance' ):
771
778
x = tf .convert_to_tensor (x )
772
- low_indices , high_indices , low_counts , high_counts = _prepare_window_args (
779
+ x , indices , axis = _prepare_window_args (
773
780
x , low_indices , high_indices , axis )
774
781
775
782
# We have a problem with indexing: the standard convention demands
@@ -786,15 +793,11 @@ def windowed_variance(
786
793
def index_for_cumulative (indices ):
787
794
return tf .maximum (indices - 1 , 0 )
788
795
cum_sums = tf .cumsum (x , axis = axis )
789
- low_sums = tf .gather (
790
- cum_sums , index_for_cumulative (low_indices ), axis = axis )
791
- high_sums = tf .gather (
792
- cum_sums , index_for_cumulative (high_indices ), axis = axis )
796
+ sums = tnp .take_along_axis (
797
+ cum_sums , index_for_cumulative (indices ), axis = axis )
793
798
cum_variances = cumulative_variance (x , sample_axis = axis )
794
- low_variances = tf .gather (
795
- cum_variances , index_for_cumulative (low_indices ), axis = axis )
796
- high_variances = tf .gather (
797
- cum_variances , index_for_cumulative (high_indices ), axis = axis )
799
+ variances = tnp .take_along_axis (
800
+ cum_variances , index_for_cumulative (indices ), axis = axis )
798
801
799
802
# This formula is the binary accurate variance merge from [1],
800
803
# adapted to subtract and batched across the indexed counts, sums,
@@ -812,15 +815,18 @@ def index_for_cumulative(indices):
812
815
# This formula can also be read as implementing the above variance
813
816
# computation by "unioning" A u B with a notional "negative B"
814
817
# multiset.
815
- counts = high_counts - low_counts # |A|
816
- discrepancies = (
817
- _safe_average (high_sums , high_counts ) -
818
- _safe_average (low_sums , low_counts ))** 2 # (mean(A u B) - mean(B))**2
819
- adjustments = high_counts * (- low_counts ) / counts # |A u B| * -|B| / |A|
820
- residuals = (high_variances * high_counts -
821
- low_variances * low_counts +
818
+ bounds = ps .cast (indices , sums .dtype )
819
+ counts = bounds [1 ] - bounds [0 ] # |A|
820
+ sum_averages = tf .math .divide_no_nan (sums , bounds )
821
+ # (mean(A u B) - mean(B))**2
822
+ discrepancies = tf .square (sum_averages [1 ] - sum_averages [0 ])
823
+ # |A u B| * -|B| / |A|
824
+ adjustments = tf .math .divide_no_nan (bounds [1 ] * (- bounds [0 ]), counts )
825
+ variances_scaled = variances * bounds
826
+ residuals = (variances_scaled [1 ] -
827
+ variances_scaled [0 ] +
822
828
adjustments * discrepancies )
823
- return _safe_average (residuals , counts )
829
+ return tf . math . divide_no_nan (residuals , counts )
824
830
825
831
826
832
def windowed_mean (
@@ -829,7 +835,7 @@ def windowed_mean(
829
835
830
836
Computes means among data in the Tensor `x` along the given windows:
831
837
832
- result[i] = mean(x[low_indices[i]:high_indices[i]+1 ])
838
+ result[i] = mean(x[low_indices[i]:high_indices[i]])
833
839
834
840
efficiently. To wit, if K is the size of `low_indices` and
835
841
`high_indices`, and `N` is the size of `x` along the given `axis`,
@@ -842,10 +848,9 @@ def windowed_mean(
842
848
last half of an MCMC chain.
843
849
844
850
Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has
845
- rank `axis`, and `low_indices` and `high_indices` broadcast to shape
846
- `[M]`. Then each element of `low_indices` and `high_indices`
847
- must be between 0 and N+1, and the shape of the output will be
848
- `Bx + [M] + E`. Batch shape in the indices is not currently supported.
851
+ rank `axis`, and `low_indices` and `high_indices` broadcast to `x`.
852
+ Then each element of `low_indices` and `high_indices` must be
853
+ between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
849
854
850
855
The default windows are
851
856
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
@@ -878,18 +883,17 @@ def windowed_mean(
878
883
"""
879
884
with tf .name_scope (name or 'windowed_mean' ):
880
885
x = tf .convert_to_tensor (x )
881
- low_indices , high_indices , low_counts , high_counts = _prepare_window_args (
882
- x , low_indices , high_indices , axis )
886
+ x , indices , axis = _prepare_window_args (x , low_indices , high_indices , axis )
883
887
884
888
raw_cumsum = tf .cumsum (x , axis = axis )
885
- cum_sums = tf . concat (
886
- [ tf . zeros_like ( tf . gather ( raw_cumsum , [ 0 ], axis = axis )), raw_cumsum ] ,
887
- axis = axis )
888
- low_sums = tf . gather ( cum_sums , low_indices , axis = axis )
889
- high_sums = tf . gather (cum_sums , high_indices , axis = axis )
890
-
891
- counts = high_counts - low_counts
892
- return _safe_average ( high_sums - low_sums , counts )
889
+ rank = ps . rank ( x )
890
+ paddings = ps . reshape ( ps . one_hot ( 2 * axis , depth = 2 * rank , dtype = tf . int32 ) ,
891
+ ( rank , 2 ) )
892
+ cum_sums = ps . pad ( raw_cumsum , paddings )
893
+ sums = tnp . take_along_axis (cum_sums , indices ,
894
+ axis = axis )
895
+ counts = ps . cast ( indices [ 1 ] - indices [ 0 ], dtype = sums . dtype )
896
+ return tf . math . divide_no_nan ( sums [ 1 ] - sums [ 0 ] , counts )
893
897
894
898
895
899
def _prepare_window_args (x , low_indices = None , high_indices = None , axis = 0 ):
@@ -905,24 +909,20 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
905
909
# Broadcast indices together.
906
910
high_indices = high_indices + tf .zeros_like (low_indices )
907
911
low_indices = low_indices + tf .zeros_like (high_indices )
908
-
909
- # TODO(axch): Support batch low and high indices. That would
910
- # complicate this shape munging (though tf.gather should work
911
- # fine).
912
-
913
- # We want to place `low_counts` and `high_counts` at the `axis`
914
- # position, so we reshape them to shape `[1, 1, ..., 1, N, 1, ...,
915
- # 1]`, where the `N` is at `axis`. The `counts_shp`, below,
916
- # is this shape.
917
- size = ps .size (high_indices )
918
- counts_shp = ps .one_hot (
919
- axis , depth = ps .rank (x ), on_value = size , off_value = 1 )
920
-
921
- low_counts = tf .reshape (tf .cast (low_indices , dtype = x .dtype ),
922
- shape = counts_shp )
923
- high_counts = tf .reshape (tf .cast (high_indices , dtype = x .dtype ),
924
- shape = counts_shp )
925
- return low_indices , high_indices , low_counts , high_counts
912
+ indices = ps .stack ([low_indices , high_indices ], axis = 0 )
913
+ x = tf .expand_dims (x , axis = 0 )
914
+ axis += 1
915
+
916
+ if ps .rank (indices ) != ps .rank (x ) and ps .rank (indices ) == 2 :
917
+ # legacy usage, kept for backward compatibility
918
+ size = ps .size (indices ) // 2
919
+ bc_shape = ps .one_hot (axis , depth = ps .rank (x ), on_value = size ,
920
+ off_value = 1 )
921
+ bc_shape = ps .concat ([[2 ], bc_shape [1 :]], axis = 0 )
922
+ indices = ps .reshape (indices , bc_shape )
923
+ # `take_along_axis` requires the type to be int32
924
+ indices = ps .cast (indices , dtype = tf .int32 )
925
+ return x , indices , axis
926
926
927
927
928
928
def _safe_average (totals , counts ):
0 commit comments