@@ -681,6 +681,19 @@ def test_windowed_mean_corner_cases(self):
681681
682682@test_util .test_all_tf_execution_regimes
683683class WindowedStatsTest (test_util .TestCase ):
684+
685+ def _maybe_expand_dims_to_make_broadcastable (self , x , shape , axis ):
686+ if len (shape ) > len (x .shape ):
687+ if len (x .shape ) == 1 :
688+ bc_shape = np .ones (len (shape ), dtype = np .int32 )
689+ bc_shape [axis ] = x .shape [0 ]
690+ return x .reshape (bc_shape )
691+ else :
692+ extra_dims = len (shape ) - len (x .shape )
693+ bc_shape = x .shape + (1 ,) * extra_dims
694+ return x .reshape (bc_shape )
695+ return x
696+
684697 def apply_slice_along_axis (self , func , arr , low , high , axis ):
685698 """Applies `func` over slices of `arr` along `axis`. Slices intervals are
686699 specified through `low` and `high`. Support broadcasting.
@@ -705,6 +718,7 @@ def apply_slice_along_axis(self, func, arr, low, high, axis):
705718 for r in range (j ):
706719 out_1d [r ] = func (a_1d [low_1d [r ]:high_1d [r ]])
707720 return out
721+
708722 def check_gaussian_windowed (self , shape , indice_shape , axis ,
709723 window_func , np_func ):
710724 stat_shape = np .array (shape ).astype (np .int32 )
@@ -717,6 +731,10 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
717731 indices = rng .randint (shape [axis ] + 1 , size = indice_shape )
718732 indices = np .sort (indices , axis = 0 )
719733 low_indices , high_indices = indices [0 ], indices [1 ]
734+ low_indices = self ._maybe_expand_dims_to_make_broadcastable (
735+ low_indices , x .shape , axis )
736+ high_indices = self ._maybe_expand_dims_to_make_broadcastable (
737+ high_indices , x .shape , axis )
720738 a = window_func (x , low_indices = low_indices ,
721739 high_indices = high_indices , axis = axis )
722740 b = self .apply_slice_along_axis (np_func , x , low_indices , high_indices ,
@@ -732,20 +750,34 @@ def check_windowed(self, func, numpy_func):
732750 check_fn ((64 , 4 , 8 ), (32 , 4 , 1 ), axis = 0 )
733751 check_fn ((64 , 4 , 8 ), (32 , 4 , 8 ), axis = 0 )
734752 check_fn ((64 , 4 , 8 ), (64 , 4 , 8 ), axis = 0 )
753+ check_fn ((64 , 4 , 8 ), (128 , 1 ), axis = 0 )
754+ check_fn ((64 , 4 , 8 ), (32 ,), axis = 0 )
755+ check_fn ((64 , 4 , 8 ), (32 , 4 ), axis = 0 )
756+
735757 check_fn ((64 , 4 , 8 ), (64 , 64 , 1 ), axis = 1 )
736758 check_fn ((64 , 4 , 8 ), (1 , 64 , 1 ), axis = 1 )
737759 check_fn ((64 , 4 , 8 ), (64 , 2 , 8 ), axis = 1 )
738760 check_fn ((64 , 4 , 8 ), (64 , 4 , 8 ), axis = 1 )
761+ check_fn ((64 , 4 , 8 ), (16 ,), axis = 1 )
762+ check_fn ((64 , 4 , 8 ), (1 , 64 ), axis = 1 )
763+
739764 check_fn ((64 , 4 , 8 ), (64 , 4 , 64 ), axis = 2 )
740765 check_fn ((64 , 4 , 8 ), (1 , 1 , 64 ), axis = 2 )
741766 check_fn ((64 , 4 , 8 ), (64 , 4 , 4 ), axis = 2 )
742767 check_fn ((64 , 4 , 8 ), (1 , 1 , 4 ), axis = 2 )
743768 check_fn ((64 , 4 , 8 ), (64 , 4 , 8 ), axis = 2 )
769+ check_fn ((64 , 4 , 8 ), (16 ,), axis = 2 )
770+ check_fn ((64 , 4 , 8 ), (1 , 4 ), axis = 2 )
771+ check_fn ((64 , 4 , 8 ), (64 , 4 ), axis = 2 )
744772
745773 with self .assertRaises (Exception ):
746774 # Non broadcastable shapes
747775 check_fn ((64 , 4 , 8 ), (4 , 1 , 4 ), axis = 2 )
748776
777+ with self .assertRaises (Exception ):
778+ # Non broadcastable shapes
779+ check_fn ((64 , 4 , 8 ), (2 , 4 ), axis = 2 )
780+
749781 def test_windowed_mean (self ):
750782 self .check_windowed (func = tfp .stats .windowed_mean , numpy_func = np .mean )
751783
0 commit comments