Skip to content

Commit 9dd940e

Browse files
committed
Remove use of dask.array.ma in PCA in favour of array API compliant functions
1 parent 161b0c7 commit 9dd940e

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

sgkit/stats/preprocessing.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def fit(
6262
Alternate allele counts with missing values encoded as either nan
6363
or negative numbers.
6464
"""
65-
X = da.ma.masked_array(X, mask=da.isnan(X) | (X < 0))
66-
self.mean_ = da.ma.filled(da.mean(X, axis=0), fill_value=np.nan)
65+
X = _replace_missing_with_nan(X)
66+
self.mean_ = da.nanmean(X, axis=0)
6767
p = self.mean_ / self.ploidy
6868
self.scale_ = da.sqrt(p * (1 - p))
6969
self.n_features_in_ = X.shape[1]
@@ -90,10 +90,10 @@ def transform(
9090
Alternate allele counts with missing values encoded as either nan
9191
or negative numbers.
9292
"""
93-
X = da.ma.masked_array(X, mask=da.isnan(X) | (X < 0))
93+
X = _replace_missing_with_nan(X)
9494
X -= self.mean_
9595
X /= self.scale_
96-
return da.ma.filled(X, fill_value=np.nan)
96+
return X
9797

9898
def inverse_transform(self, X: ArrayLike, copy: Optional[bool] = None) -> ArrayLike:
9999
"""Invert transform
@@ -109,6 +109,14 @@ def inverse_transform(self, X: ArrayLike, copy: Optional[bool] = None) -> ArrayL
109109
return X
110110

111111

112+
def _replace_missing_with_nan(X):
113+
if np.issubdtype(X.dtype, np.floating):
114+
nanarray = da.asarray(np.nan, dtype=X.dtype)
115+
else:
116+
nanarray = da.asarray(np.nan)
117+
return da.where(X < 0, nanarray, X)
118+
119+
112120
def filter_partial_calls(
113121
ds: Dataset,
114122
*,

0 commit comments

Comments
 (0)