@@ -681,6 +681,19 @@ def test_windowed_mean_corner_cases(self):
681
681
682
682
@test_util .test_all_tf_execution_regimes
683
683
class WindowedStatsTest (test_util .TestCase ):
684
+
685
+ def _maybe_expand_dims_to_make_broadcastable (self , x , shape , axis ):
686
+ if len (shape ) > len (x .shape ):
687
+ if len (x .shape ) == 1 :
688
+ bc_shape = np .ones (len (shape ), dtype = np .int32 )
689
+ bc_shape [axis ] = x .shape [0 ]
690
+ return x .reshape (bc_shape )
691
+ else :
692
+ extra_dims = len (shape ) - len (x .shape )
693
+ bc_shape = x .shape + (1 ,) * extra_dims
694
+ return x .reshape (bc_shape )
695
+ return x
696
+
684
697
def apply_slice_along_axis (self , func , arr , low , high , axis ):
685
698
"""Applies `func` over slices of `arr` along `axis`. Slices intervals are
686
699
specified through `low` and `high`. Support broadcasting.
@@ -705,6 +718,7 @@ def apply_slice_along_axis(self, func, arr, low, high, axis):
705
718
for r in range (j ):
706
719
out_1d [r ] = func (a_1d [low_1d [r ]:high_1d [r ]])
707
720
return out
721
+
708
722
def check_gaussian_windowed (self , shape , indice_shape , axis ,
709
723
window_func , np_func ):
710
724
stat_shape = np .array (shape ).astype (np .int32 )
@@ -717,6 +731,10 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
717
731
indices = rng .randint (shape [axis ] + 1 , size = indice_shape )
718
732
indices = np .sort (indices , axis = 0 )
719
733
low_indices , high_indices = indices [0 ], indices [1 ]
734
+ low_indices = self ._maybe_expand_dims_to_make_broadcastable (
735
+ low_indices , x .shape , axis )
736
+ high_indices = self ._maybe_expand_dims_to_make_broadcastable (
737
+ high_indices , x .shape , axis )
720
738
a = window_func (x , low_indices = low_indices ,
721
739
high_indices = high_indices , axis = axis )
722
740
b = self .apply_slice_along_axis (np_func , x , low_indices , high_indices ,
@@ -732,20 +750,34 @@ def check_windowed(self, func, numpy_func):
732
750
check_fn ((64 , 4 , 8 ), (32 , 4 , 1 ), axis = 0 )
733
751
check_fn ((64 , 4 , 8 ), (32 , 4 , 8 ), axis = 0 )
734
752
check_fn ((64 , 4 , 8 ), (64 , 4 , 8 ), axis = 0 )
753
+ check_fn ((64 , 4 , 8 ), (128 , 1 ), axis = 0 )
754
+ check_fn ((64 , 4 , 8 ), (32 ,), axis = 0 )
755
+ check_fn ((64 , 4 , 8 ), (32 , 4 ), axis = 0 )
756
+
735
757
check_fn ((64 , 4 , 8 ), (64 , 64 , 1 ), axis = 1 )
736
758
check_fn ((64 , 4 , 8 ), (1 , 64 , 1 ), axis = 1 )
737
759
check_fn ((64 , 4 , 8 ), (64 , 2 , 8 ), axis = 1 )
738
760
check_fn ((64 , 4 , 8 ), (64 , 4 , 8 ), axis = 1 )
761
+ check_fn ((64 , 4 , 8 ), (16 ,), axis = 1 )
762
+ check_fn ((64 , 4 , 8 ), (1 , 64 ), axis = 1 )
763
+
739
764
check_fn ((64 , 4 , 8 ), (64 , 4 , 64 ), axis = 2 )
740
765
check_fn ((64 , 4 , 8 ), (1 , 1 , 64 ), axis = 2 )
741
766
check_fn ((64 , 4 , 8 ), (64 , 4 , 4 ), axis = 2 )
742
767
check_fn ((64 , 4 , 8 ), (1 , 1 , 4 ), axis = 2 )
743
768
check_fn ((64 , 4 , 8 ), (64 , 4 , 8 ), axis = 2 )
769
+ check_fn ((64 , 4 , 8 ), (16 ,), axis = 2 )
770
+ check_fn ((64 , 4 , 8 ), (1 , 4 ), axis = 2 )
771
+ check_fn ((64 , 4 , 8 ), (64 , 4 ), axis = 2 )
744
772
745
773
with self .assertRaises (Exception ):
746
774
# Non broadcastable shapes
747
775
check_fn ((64 , 4 , 8 ), (4 , 1 , 4 ), axis = 2 )
748
776
777
+ with self .assertRaises (Exception ):
778
+ # Non broadcastable shapes
779
+ check_fn ((64 , 4 , 8 ), (2 , 4 ), axis = 2 )
780
+
749
781
def test_windowed_mean (self ):
750
782
self .check_windowed (func = tfp .stats .windowed_mean , numpy_func = np .mean )
751
783
0 commit comments