@@ -1170,6 +1170,8 @@ def floor(x):
11701170
11711171def gather (params , indices , axis = None ):
11721172 op = P .Gather ()
1173+ if axis is None :
1174+ axis = 0
11731175 return op (params , indices , axis )
11741176
11751177
@@ -1590,10 +1592,7 @@ def reduce_std(x, axis=None, keepdims=False):
15901592
15911593
15921594def reduce_sum (x , axis = None , keepdims = False ):
1593- op = P .ReduceSum (keep_dims = keepdims )
1594- if axis is None :
1595- return op (x )
1596- return op (x , axis = axis )
1595+ return msnp .sum (x , axis = axis , keepdims = keepdims )
15971596
15981597
15991598def reduce_variance (x , axis = None , keepdims = False ):
@@ -1729,11 +1728,15 @@ def tanh(x):
17291728
17301729def any (x , axis = None , keepdims = False ):
17311730 op = P .ReduceAny (keep_dims = keepdims )
1731+ if axis is None :
1732+ return op (x )
17321733 return op (x , axis )
17331734
17341735
17351736def all (x , axis = None , keepdims = False ):
17361737 op = P .ReduceAll (keep_dims = keepdims )
1738+ if axis is None :
1739+ return op (x )
17371740 return op (x , axis )
17381741
17391742
@@ -1779,8 +1782,7 @@ def zeros_like(x, dtype=None):
17791782
17801783
17811784def squeeze (x , axis = None ):
1782- op = P .Squeeze (axis )
1783- return op (x )
1785+ return msnp .squeeze (x , axis )
17841786
17851787
17861788def unsorted_segment_sum (x , segment_ids , num_segments ):
@@ -1792,7 +1794,7 @@ def unsorted_segment_sum(x, segment_ids, num_segments):
17921794def unsorted_segment_mean (x , segment_ids , num_segments ):
17931795 segment_ids = ms .Tensor (segment_ids )
17941796 op = P .UnsortedSegmentSum ()
1795- x_one = msnp .ones_like (x , dtype = x .dtype )
1797+ x_one = msnp .ones_like (x , dtype = x .dtype )
17961798 sum = op (x , segment_ids , num_segments )
17971799 one = op (x_one , segment_ids , num_segments )
17981800 return sum / one
0 commit comments