@@ -685,6 +685,19 @@ def test_windowed_mean_corner_cases(self):
685685
686686@test_util .test_all_tf_execution_regimes
687687class WindowedStatsTest (test_util .TestCase ):
688+
689+ def _maybe_expand_dims_to_make_broadcastable (self , x , shape , axis ):
690+ if len (shape ) > len (x .shape ):
691+ if len (x .shape ) == 1 :
692+ bc_shape = np .ones (len (shape ), dtype = np .int32 )
693+ bc_shape [axis ] = x .shape [0 ]
694+ return x .reshape (bc_shape )
695+ else :
696+ extra_dims = len (shape ) - len (x .shape )
697+ bc_shape = x .shape + (1 ,) * extra_dims
698+ return x .reshape (bc_shape )
699+ return x
700+
688701 def apply_slice_along_axis (self , func , arr , low , high , axis ):
689702 """Applies `func` over slices of `arr` along `axis`. Slices intervals are
690703 specified through `low` and `high`. Support broadcasting.
@@ -709,6 +722,7 @@ def apply_slice_along_axis(self, func, arr, low, high, axis):
709722 for r in range (j ):
710723 out_1d [r ] = func (a_1d [low_1d [r ]:high_1d [r ]])
711724 return out
725+
712726 def check_gaussian_windowed (self , shape , indice_shape , axis ,
713727 window_func , np_func ):
714728 stat_shape = np .array (shape ).astype (np .int32 )
@@ -721,6 +735,10 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
721735 indices = rng .randint (shape [axis ] + 1 , size = indice_shape )
722736 indices = np .sort (indices , axis = 0 )
723737 low_indices , high_indices = indices [0 ], indices [1 ]
738+ low_indices = self ._maybe_expand_dims_to_make_broadcastable (
739+ low_indices , x .shape , axis )
740+ high_indices = self ._maybe_expand_dims_to_make_broadcastable (
741+ high_indices , x .shape , axis )
724742 a = window_func (x , low_indices = low_indices ,
725743 high_indices = high_indices , axis = axis )
726744 b = self .apply_slice_along_axis (np_func , x , low_indices , high_indices ,
@@ -736,20 +754,34 @@ def check_windowed(self, func, numpy_func):
736754 check_fn ((64 , 4 , 8 ), (32 , 4 , 1 ), axis = 0 )
737755 check_fn ((64 , 4 , 8 ), (32 , 4 , 8 ), axis = 0 )
738756 check_fn ((64 , 4 , 8 ), (64 , 4 , 8 ), axis = 0 )
757+ check_fn ((64 , 4 , 8 ), (128 , 1 ), axis = 0 )
758+ check_fn ((64 , 4 , 8 ), (32 ,), axis = 0 )
759+ check_fn ((64 , 4 , 8 ), (32 , 4 ), axis = 0 )
760+
739761 check_fn ((64 , 4 , 8 ), (64 , 64 , 1 ), axis = 1 )
740762 check_fn ((64 , 4 , 8 ), (1 , 64 , 1 ), axis = 1 )
741763 check_fn ((64 , 4 , 8 ), (64 , 2 , 8 ), axis = 1 )
742764 check_fn ((64 , 4 , 8 ), (64 , 4 , 8 ), axis = 1 )
765+ check_fn ((64 , 4 , 8 ), (16 ,), axis = 1 )
766+ check_fn ((64 , 4 , 8 ), (1 , 64 ), axis = 1 )
767+
743768 check_fn ((64 , 4 , 8 ), (64 , 4 , 64 ), axis = 2 )
744769 check_fn ((64 , 4 , 8 ), (1 , 1 , 64 ), axis = 2 )
745770 check_fn ((64 , 4 , 8 ), (64 , 4 , 4 ), axis = 2 )
746771 check_fn ((64 , 4 , 8 ), (1 , 1 , 4 ), axis = 2 )
747772 check_fn ((64 , 4 , 8 ), (64 , 4 , 8 ), axis = 2 )
773+ check_fn ((64 , 4 , 8 ), (16 ,), axis = 2 )
774+ check_fn ((64 , 4 , 8 ), (1 , 4 ), axis = 2 )
775+ check_fn ((64 , 4 , 8 ), (64 , 4 ), axis = 2 )
748776
749777 with self .assertRaises (Exception ):
750778 # Non broadcastable shapes
751779 check_fn ((64 , 4 , 8 ), (4 , 1 , 4 ), axis = 2 )
752780
781+ with self .assertRaises (Exception ):
782+ # Non broadcastable shapes
783+ check_fn ((64 , 4 , 8 ), (2 , 4 ), axis = 2 )
784+
753785 def test_windowed_mean (self ):
754786 self .check_windowed (func = tfp .stats .windowed_mean , numpy_func = np .mean )
755787
0 commit comments