@@ -143,14 +143,16 @@ def __init__(self, A, matrix_type=None, factor=True, verbose=False):
143143
144144 self .matrix_type = matrix_type
145145
146- indptr = np .asarray (A .indptr ) # double check it's a numpy array
146+ A = self ._validate_csr_matrix (A )
147+
148+ max_a_ind_itemsize = max (A .indptr .itemsize , A .indices .itemsize )
147149 mkl_int_size = get_mkl_int_size ()
148150 mkl_int64_size = get_mkl_int64_size ()
149151
150- target_int_size = mkl_int_size if indptr . itemsize <= mkl_int_size else mkl_int64_size
152+ target_int_size = mkl_int_size if max_a_ind_itemsize <= mkl_int_size else mkl_int64_size
151153 self ._ind_dtype = np .dtype (f"i{ target_int_size } " )
152154
153- data , indptr , indices = self ._validate_matrix (A )
155+ data , indptr , indices = self ._validate_matrix_dtypes (A )
154156 self ._data = data
155157 self ._indptr = indptr
156158 self ._indices = indices
@@ -185,7 +187,9 @@ def refactor(self, A):
185187 raise TypeError ("A is not a sparse matrix." )
186188 if A .shape != self .shape :
187189 raise ValueError ("A is not the same size as the previous matrix." )
188- data , indptr , indices = self ._validate_matrix (A )
190+
191+ A = self ._validate_csr_matrix (A )
192+ data , indptr , indices = self ._validate_matrix_dtypes (A )
189193 if len (data ) != len (self ._data ):
190194 raise ValueError ("new A matrix does not have the same number of non zeros." )
191195
@@ -284,21 +288,24 @@ def iparm(self):
284288 """
285289 return np .array (self ._handle .iparm )
286290
287- def _validate_matrix (self , mat ):
288-
291+ def _validate_csr_matrix (self , mat ):
289292 if self .matrix_type in [- 2 , 2 , - 4 , 4 , 6 ]:
290- # Symmetric matrices must have only the upper triangle
291- if sp .isspmatrix_csc (mat ):
292- mat = mat .T # Transpose to get a CSR matrix since it's symmetric
293+ # only grab the upper triangle.
293294 mat = sp .triu (mat , format = 'csr' )
294295
295- if not (sp .isspmatrix_csr (mat )):
296- warnings .warn ("Converting %s matrix to CSR format."
297- % mat .__class__ .__name__ , PardisoTypeConversionWarning )
296+ if mat .format != 'csr' :
297+ warnings .warn (
298+ "Converting %s matrix to CSR format." % A .__class__ .__name__ ,
299+ PardisoTypeConversionWarning ,
300+ stacklevel = 3
301+ )
298302 mat = mat .tocsr ()
303+
299304 mat .sort_indices ()
300305 mat .sum_duplicates ()
306+ return mat
301307
308+ def _validate_matrix_dtypes (self , mat ):
302309 data = np .require (mat .data , self ._data_dtype , requirements = "C" )
303310 indptr = np .require (mat .indptr , self ._ind_dtype , requirements = "C" )
304311 indices = np .require (mat .indices , self ._ind_dtype , requirements = "C" )
0 commit comments