@@ -87,7 +87,7 @@ def wrapper(a: np.ndarray) -> DaskArray:
8787
8888DASK_CONVERTERS = {
8989 f : _chunked_1d (f )
90- for f in (_helpers .as_dense_dask_array , _helpers .as_sparse_dask_array )
90+ for f in (_helpers .as_dense_dask_array , _helpers .as_sparse_dask_matrix )
9191}
9292
9393
@@ -131,7 +131,10 @@ def gen_pca_params(
131131 svd_solver_type : Literal ["valid" , "invalid" ] | None ,
132132 zero_center : bool ,
133133) -> Generator [tuple [SVDSolver | None , str | None , str | None ], None , None ]:
134- if array_type is DASK_CONVERTERS [_helpers .as_sparse_dask_array ] and not zero_center :
134+ if (
135+ array_type is DASK_CONVERTERS [_helpers .as_sparse_dask_matrix ]
136+ and not zero_center
137+ ):
135138 xfail_reason = "Sparse-in-dask with zero_center=False not implemented yet"
136139 yield None , None , xfail_reason
137140 return
@@ -171,7 +174,7 @@ def possible_solvers(
171174 svd_solvers = {"auto" , "full" , "tsqr" , "randomized" , "covariance_eigh" }
172175 case (dc , False ) if dc is DASK_CONVERTERS [_helpers .as_dense_dask_array ]:
173176 svd_solvers = {"tsqr" , "randomized" }
174- case (dc , True ) if dc is DASK_CONVERTERS [_helpers .as_sparse_dask_array ]:
177+ case (dc , True ) if dc is DASK_CONVERTERS [_helpers .as_sparse_dask_matrix ]:
175178 svd_solvers = {"covariance_eigh" }
176179 case (type () as dc , True ) if issubclass (dc , CSBase ):
177180 svd_solvers = {"arpack" } | SKLEARN_ADDITIONAL
@@ -570,7 +573,7 @@ def test_pca_layer():
570573 "other_array_type" ,
571574 [
572575 lambda x : x .toarray (),
573- DASK_CONVERTERS [_helpers .as_sparse_dask_array ],
576+ DASK_CONVERTERS [_helpers .as_sparse_dask_matrix ],
574577 DASK_CONVERTERS [_helpers .as_dense_dask_array ],
575578 ],
576579 ids = ["dense-mem" , "sparse-dask" , "dense-dask" ],
@@ -616,7 +619,7 @@ def test_covariance_eigh_impls(other_array_type):
616619)
617620def test_sparse_dask_input_errors (msg_re : str , op : Callable [[DaskArray ], DaskArray ]):
618621 adata_sparse = pbmc3k_normalized ()
619- adata_sparse .X = op (DASK_CONVERTERS [_helpers .as_sparse_dask_array ](adata_sparse .X ))
622+ adata_sparse .X = op (DASK_CONVERTERS [_helpers .as_sparse_dask_matrix ](adata_sparse .X ))
620623
621624 with pytest .raises (ValueError , match = msg_re ):
622625 sc .pp .pca (adata_sparse , svd_solver = "covariance_eigh" )
@@ -635,7 +638,7 @@ def test_sparse_dask_input_errors(msg_re: str, op: Callable[[DaskArray], DaskArr
635638)
636639def test_cov_sparse_dask (dtype , dtype_arg , rtol ):
637640 x_arr = A_list .astype (dtype )
638- x = DASK_CONVERTERS [_helpers .as_sparse_dask_array ](x_arr )
641+ x = DASK_CONVERTERS [_helpers .as_sparse_dask_matrix ](x_arr )
639642 cov , gram , mean = _cov_sparse_dask (x , return_gram = True , dtype = dtype_arg )
640643 np .testing .assert_allclose (mean , np .mean (x_arr , axis = 0 ))
641644 np .testing .assert_allclose (gram , (x_arr .T @ x_arr ) / x .shape [0 ])
0 commit comments