@@ -1170,6 +1170,8 @@ def floor(x):
11701170
11711171def  gather (params , indices , axis = None ):
11721172    op  =  P .Gather ()
1173+     if  axis  is  None :
1174+         axis  =  0 
11731175    return  op (params , indices , axis )
11741176
11751177
@@ -1590,10 +1592,7 @@ def reduce_std(x, axis=None, keepdims=False):
15901592
15911593
15921594def  reduce_sum (x , axis = None , keepdims = False ):
1593-     op  =  P .ReduceSum (keep_dims = keepdims )
1594-     if  axis  is  None :
1595-         return  op (x )
1596-     return  op (x , axis = axis )
1595+     return  msnp .sum (x , axis = axis , keepdims = keepdims )
15971596
15981597
15991598def  reduce_variance (x , axis = None , keepdims = False ):
@@ -1729,11 +1728,15 @@ def tanh(x):
17291728
17301729def  any (x , axis = None , keepdims = False ):
17311730    op  =  P .ReduceAny (keep_dims = keepdims )
1731+     if  axis  is  None :
1732+         return  op (x )
17321733    return  op (x , axis )
17331734
17341735
17351736def  all (x , axis = None , keepdims = False ):
17361737    op  =  P .ReduceAll (keep_dims = keepdims )
1738+     if  axis  is  None :
1739+         return  op (x )
17371740    return  op (x , axis )
17381741
17391742
@@ -1779,8 +1782,7 @@ def zeros_like(x, dtype=None):
17791782
17801783
17811784def  squeeze (x , axis = None ):
1782-     op  =  P .Squeeze (axis )
1783-     return  op (x )
1785+     return  msnp .squeeze (x , axis )
17841786
17851787
17861788def  unsorted_segment_sum (x , segment_ids , num_segments ):
@@ -1792,7 +1794,7 @@ def unsorted_segment_sum(x, segment_ids, num_segments):
17921794def  unsorted_segment_mean (x , segment_ids , num_segments ):
17931795    segment_ids  =  ms .Tensor (segment_ids )
17941796    op  =  P .UnsortedSegmentSum ()
1795-     x_one  =    msnp .ones_like (x , dtype = x .dtype )
1797+     x_one  =  msnp .ones_like (x , dtype = x .dtype )
17961798    sum  =  op (x , segment_ids , num_segments )
17971799    one  =  op (x_one , segment_ids , num_segments )
17981800    return  sum / one 
0 commit comments