@@ -2541,7 +2541,7 @@ def lstsq(a, b, rcond=None):
25412541 return wrap (x ), wrap (resids ), rank , s
25422542
25432543
2544- def _multi_svd_norm (x , row_axis , col_axis , op ):
2544+ def _multi_svd_norm (x , row_axis , col_axis , op , initial = None ):
25452545 """Compute a function of the singular values of the 2-D matrices in `x`.
25462546
25472547 This is a private utility function used by `numpy.linalg.norm()`.
@@ -2565,7 +2565,7 @@ def _multi_svd_norm(x, row_axis, col_axis, op):
25652565
25662566 """
25672567 y = moveaxis (x , (row_axis , col_axis ), (- 2 , - 1 ))
2568- result = op (svd (y , compute_uv = False ), axis = - 1 )
2568+ result = op (svd (y , compute_uv = False ), axis = - 1 , initial = initial )
25692569 return result
25702570
25712571
@@ -2763,7 +2763,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
27632763
27642764 if len (axis ) == 1 :
27652765 if ord == inf :
2766- return abs (x ).max (axis = axis , keepdims = keepdims )
2766+ return abs (x ).max (axis = axis , keepdims = keepdims , initial = 0 )
27672767 elif ord == - inf :
27682768 return abs (x ).min (axis = axis , keepdims = keepdims )
27692769 elif ord == 0 :
@@ -2797,17 +2797,17 @@ def norm(x, ord=None, axis=None, keepdims=False):
27972797 if row_axis == col_axis :
27982798 raise ValueError ('Duplicate axes given.' )
27992799 if ord == 2 :
2800- ret = _multi_svd_norm (x , row_axis , col_axis , amax )
2800+ ret = _multi_svd_norm (x , row_axis , col_axis , amax , 0 )
28012801 elif ord == - 2 :
28022802 ret = _multi_svd_norm (x , row_axis , col_axis , amin )
28032803 elif ord == 1 :
28042804 if col_axis > row_axis :
28052805 col_axis -= 1
2806- ret = add .reduce (abs (x ), axis = row_axis ).max (axis = col_axis )
2806+ ret = add .reduce (abs (x ), axis = row_axis ).max (axis = col_axis , initial = 0 )
28072807 elif ord == inf :
28082808 if row_axis > col_axis :
28092809 row_axis -= 1
2810- ret = add .reduce (abs (x ), axis = col_axis ).max (axis = row_axis )
2810+ ret = add .reduce (abs (x ), axis = col_axis ).max (axis = row_axis , initial = 0 )
28112811 elif ord == - 1 :
28122812 if col_axis > row_axis :
28132813 col_axis -= 1
@@ -2819,7 +2819,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
28192819 elif ord in [None , 'fro' , 'f' ]:
28202820 ret = sqrt (add .reduce ((x .conj () * x ).real , axis = axis ))
28212821 elif ord == 'nuc' :
2822- ret = _multi_svd_norm (x , row_axis , col_axis , sum )
2822+ ret = _multi_svd_norm (x , row_axis , col_axis , sum , 0 )
28232823 else :
28242824 raise ValueError ("Invalid norm order for matrices." )
28252825 if keepdims :
0 commit comments