Skip to content

Commit 967e917

Browse files
Merge pull request #61 from theislab/trustregion_plus_sparse
Trustregion NR and IRLS and full sparse data support
2 parents 00756f3 + 5d42742 commit 967e917

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+3200
-1429
lines changed

batchglm/data.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import dask
1414
import dask.array
1515

16+
from .external import SparseXArrayDataArray, SparseXArrayDataSet
17+
1618
try:
1719
import anndata
1820
except ImportError:
@@ -31,17 +33,18 @@ def fetch_X(idx):
3133
if idx.size == 1:
3234
retval = np.squeeze(retval, axis=0)
3335

34-
return retval.astype(np.float32)
36+
return retval.astype(np.float64)
3537

3638
delayed_fetch = dask.delayed(fetch_X, pure=True)
3739
X = [
3840
dask.array.from_delayed(
3941
delayed_fetch(idx),
4042
shape=(num_features,),
41-
dtype=np.float32
43+
dtype=np.float64
4244
) for idx in range(num_observations)
4345
]
44-
X = xr.DataArray(dask.array.stack(X), dims=dims)
46+
47+
X = data
4548

4649
# currently broken:
4750
# X = data.X
@@ -53,9 +56,9 @@ def fetch_X(idx):
5356

5457

5558
def xarray_from_data(
56-
data: Union[anndata.AnnData, xr.DataArray, xr.Dataset, np.ndarray],
59+
data: Union[anndata.AnnData, anndata.base.Raw, xr.DataArray, xr.Dataset, np.ndarray, scipy.sparse.csr_matrix],
5760
dims: Union[Tuple, List] = ("observations", "features")
58-
) -> xr.DataArray:
61+
):
5962
"""
6063
Parse any array-like object, xr.DataArray, xr.Dataset or anndata.Anndata and return a xarray containing
6164
the observations.
@@ -64,26 +67,52 @@ def xarray_from_data(
6467
:param dims: tuple or list with two strings. Specifies the names of the xarray dimensions.
6568
:return: xr.DataArray of shape `dims`
6669
"""
67-
if anndata is not None and isinstance(data, anndata.AnnData):
70+
if anndata is not None and (isinstance(data, anndata.AnnData) or isinstance(data, anndata.base.Raw)):
71+
# Anndata.raw does not have obs_names.
72+
if isinstance(data, anndata.AnnData):
73+
obs_names = np.asarray(data.obs_names)
74+
else:
75+
obs_names = ["obs_" + str(i) for i in range(data.X.shape[0])]
76+
6877
if scipy.sparse.issparse(data.X):
69-
X = _sparse_to_xarray(data.X, dims=dims)
70-
X.coords[dims[0]] = np.asarray(data.obs_names)
71-
X.coords[dims[1]] = np.asarray(data.var_names)
78+
# X = _sparse_to_xarray(data.X, dims=dims)
79+
# X.coords[dims[0]] = np.asarray(data.obs_names)
80+
# X.coords[dims[1]] = np.asarray(data.var_names)
81+
X = SparseXArrayDataSet(
82+
X=data.X,
83+
obs_names=np.asarray(obs_names),
84+
feature_names=np.asarray(data.var_names),
85+
dims=dims
86+
)
7287
else:
73-
X = data.X
74-
X = xr.DataArray(X, dims=dims, coords={
75-
dims[0]: np.asarray(data.obs_names),
76-
dims[1]: np.asarray(data.var_names),
77-
})
88+
X = xr.DataArray(
89+
data.X,
90+
dims=dims,
91+
coords={
92+
dims[0]: np.asarray(obs_names),
93+
dims[1]: np.asarray(data.var_names),
94+
}
95+
)
7896
elif isinstance(data, xr.Dataset):
7997
X: xr.DataArray = data["X"]
8098
elif isinstance(data, xr.DataArray):
8199
X = data
100+
elif isinstance(data, SparseXArrayDataSet):
101+
X = data
102+
elif scipy.sparse.issparse(data):
103+
# X = _sparse_to_xarray(data, dims=dims)
104+
# X.coords[dims[0]] = np.asarray(data.obs_names)
105+
# X.coords[dims[1]] = np.asarray(data.var_names)
106+
X = SparseXArrayDataSet(
107+
X=data,
108+
obs_names=None,
109+
feature_names=None,
110+
dims=dims
111+
)
112+
elif isinstance(data, np.ndarray):
113+
X = xr.DataArray(data, dims=dims)
82114
else:
83-
if scipy.sparse.issparse(data):
84-
X = _sparse_to_xarray(data, dims=dims)
85-
else:
86-
X = xr.DataArray(data, dims=dims)
115+
raise ValueError("batchglm data parsing: data format %s not recognized" % type(data))
87116

88117
return X
89118

@@ -537,7 +566,7 @@ def parse_constraints(
537566
Parse constraint matrix into xarray.
538567
539568
:param dmat: Design matrix.
540-
:param a constraint matrix
569+
:param constraints: a constraint matrix
541570
:return: constraint matrix in xarray format
542571
"""
543572
constraints_ar = xr.DataArray(

batchglm/external.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from batchglm.xarray_sparse import SparseXArrayDataArray, SparseXArrayDataSet

batchglm/models/base/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .input import _InputData_Base, INPUT_DATA_PARAMS
1+
from .input import _InputData_Base, INPUT_DATA_PARAMS, SparseXArrayDataSet, SparseXArrayDataArray
22
from .estimator import _Estimator_Base, _EstimatorStore_XArray_Base
33
from .model import _Model_Base, _Model_XArray_Base
44
from .simulator import _Simulator_Base

batchglm/models/base/external.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
import batchglm.pkg_constants as pkg_constants
22
import batchglm.data as data_utils
3+
from batchglm.xarray_sparse import SparseXArrayDataSet, SparseXArrayDataArray

batchglm/models/base/input.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import abc
22
import os
33
import logging
4+
from typing import Union
45

6+
import numpy as np
7+
import scipy
8+
import scipy.sparse
59
import xarray as xr
610

711
try:
812
import anndata
913
except ImportError:
1014
anndata = None
1115

12-
from .external import pkg_constants, data_utils
16+
from .external import pkg_constants, data_utils, SparseXArrayDataSet, SparseXArrayDataArray
1317

1418
logger = logging.getLogger(__name__)
1519

@@ -21,7 +25,7 @@ class _InputData_Base:
2125
"""
2226
Base class for all input data types.
2327
"""
24-
data: xr.Dataset
28+
data: Union[xr.Dataset, SparseXArrayDataSet]
2529

2630
@classmethod
2731
@abc.abstractmethod
@@ -57,20 +61,40 @@ def new(cls, data, observation_names=None, feature_names=None, cast_dtype=None):
5761
X = X.astype(cast_dtype)
5862
# X = X.chunk({"observations": 1})
5963

60-
retval = cls(xr.Dataset({
61-
"X": X,
62-
}, coords={
63-
"feature_allzero": ~X.any(dim="observations")
64-
}))
65-
if observation_names is not None:
66-
retval.observations = observation_names
67-
elif "observations" not in retval.data.coords:
68-
retval.observations = retval.data.coords["observations"]
69-
70-
if feature_names is not None:
71-
retval.features = feature_names
72-
elif "features" not in retval.data.coords:
73-
retval.features = retval.data.coords["features"]
64+
if scipy.sparse.issparse(X):
65+
retval = cls(SparseXArrayDataSet(
66+
X=X,
67+
obs_names=observation_names,
68+
feature_names=feature_names
69+
))
70+
elif isinstance(X, SparseXArrayDataArray):
71+
retval = cls(SparseXArrayDataSet(
72+
X=X.X,
73+
obs_names=X.coords[X.dims[0]] if observation_names is None else observation_names,
74+
feature_names=X.coords[X.dims[1]] if feature_names is None else feature_names,
75+
dims=X.dims
76+
))
77+
elif isinstance(X, SparseXArrayDataSet):
78+
retval = cls(X)
79+
if observation_names is not None:
80+
retval.observations = observation_names
81+
if feature_names is not None:
82+
retval.features = feature_names
83+
else:
84+
retval = cls(xr.Dataset({
85+
"X": X,
86+
}, coords={
87+
"feature_allzero": ~X.any(dim="observations")
88+
}))
89+
if observation_names is not None:
90+
retval.observations = observation_names
91+
elif "observations" not in retval.data.coords:
92+
retval.observations = retval.data.coords["observations"]
93+
94+
if feature_names is not None:
95+
retval.features = feature_names
96+
elif "features" not in retval.data.coords:
97+
retval.features = retval.data.coords["features"]
7498

7599
return retval
76100

@@ -117,7 +141,7 @@ def save(self, path, group="", append=False):
117141
)
118142

119143
@property
120-
def X(self) -> xr.DataArray:
144+
def X(self):
121145
return self.data.X
122146

123147
@X.setter
@@ -156,9 +180,24 @@ def feature_isnonzero(self):
156180
def feature_isallzero(self):
157181
return self.data.coords["feature_allzero"]
158182

159-
def fetch_X(self, idx):
183+
def fetch_X_dense(self, idx):
160184
return self.X[idx].values
161185

186+
def fetch_X_sparse(self, idx):
187+
assert isinstance(self.X.X, scipy.sparse.csr_matrix), "tried to fetch sparse from non csr matrix"
188+
189+
data = self.X.X[idx]
190+
191+
data_idx = np.asarray(np.vstack(data.nonzero()).T, np.int64)
192+
data_val = np.asarray(data.data, np.float64)
193+
data_shape = np.asarray(data.shape, np.int64)
194+
195+
if idx.shape[0] == 1:
196+
data_val = np.squeeze(data_val, axis=0)
197+
data_idx = np.squeeze(data_idx, axis=0)
198+
199+
return data_idx, data_val, data_shape
200+
162201
def set_chunk_size(self, cs: int):
163202
self.X = self.X.chunk({"observations": cs})
164203

batchglm/models/base/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def to_xarray(self, parm: Union[str, list], coords=None):
6262
output = xr.Dataset(output)
6363
if coords is not None:
6464
for i in output.dims:
65-
if i in coords:
65+
if i in coords.coords:
6666
output.coords[i] = coords[i]
6767

6868
return output

batchglm/models/base_glm/external.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from batchglm.models.base import _Model_Base, _Model_XArray_Base
44
from batchglm.models.base import _Simulator_Base
55
from batchglm.models.base import INPUT_DATA_PARAMS
6+
from batchglm.models.base import SparseXArrayDataArray, SparseXArrayDataSet
67

78
import batchglm.data as data_utils
89
from batchglm.utils.linalg import groupwise_solve_lm

batchglm/models/base_glm/input.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
anndata = None
77
import xarray as xr
88
import numpy as np
9+
import scipy.sparse
910
import pandas as pd
1011

1112
from .utils import parse_constraints, parse_design
12-
from .external import _InputData_Base, INPUT_DATA_PARAMS
13+
from .external import _InputData_Base, INPUT_DATA_PARAMS, SparseXArrayDataSet, SparseXArrayDataArray
1314

1415
import patsy
1516

@@ -33,7 +34,7 @@ def param_shapes(cls) -> dict:
3334
@classmethod
3435
def new(
3536
cls,
36-
data: Union[np.ndarray, anndata.AnnData, xr.DataArray, xr.Dataset],
37+
data: Union[np.ndarray, anndata.AnnData, xr.DataArray, xr.Dataset, scipy.sparse.csr_matrix],
3738
design_loc: Union[np.ndarray, pd.DataFrame, patsy.design_info.DesignMatrix, xr.DataArray] = None,
3839
design_loc_names: Union[list, np.ndarray, xr.DataArray] = None,
3940
design_scale: Union[np.ndarray, pd.DataFrame, patsy.design_info.DesignMatrix, xr.DataArray] = None,
@@ -216,12 +217,16 @@ def size_factors(self):
216217
def size_factors(self, data):
217218
if data is None and "size_factors" in self.data.coords:
218219
del self.data.coords["size_factors"]
219-
else:
220-
dims = self.param_shapes()["size_factors"]
221-
self.data.coords["size_factors"] = xr.DataArray(
220+
221+
dims = self.param_shapes()["size_factors"]
222+
sf = xr.DataArray(
222223
dims=dims,
223224
data=np.broadcast_to(data, [self.data.dims[d] for d in dims])
224225
)
226+
if isinstance(self.data, SparseXArrayDataSet):
227+
self.data.size_factors = sf
228+
else:
229+
self.data.coords["size_factors"] = sf
225230

226231
@property
227232
def num_design_loc_params(self):

batchglm/models/base_glm/simulator.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,14 @@ def generate_params(
109109
default: rand_fn = lambda shape: np.random.uniform(0.5, 2, shape)
110110
:param rand_fn_loc: random function taking one argument `shape`.
111111
If not provided, will use `rand_fn` instead.
112+
This function generates location model parameters in inverse linker space,
113+
ie. these parameter will be log transformed if a log linker function is used!
114+
Values below 1e-08 will be set to 1e-08 to map them into the positive support.
112115
:param rand_fn_scale: random function taking one argument `shape`.
113116
If not provided, will use `rand_fn` instead.
117+
This function generates scale model parameters in inverse linker space,
118+
ie. these parameter will be log transformed if a log linker function is used!
119+
Values below 1e-08 will be set to 1e-08 to map them into the positive support.
114120
"""
115121
if rand_fn_loc is None:
116122
rand_fn_loc = rand_fn
@@ -157,8 +163,9 @@ def generate_params(
157163
dims=self.param_shapes()["a_var"],
158164
data=np.log(
159165
np.concatenate([
160-
np.expand_dims(rand_fn_ave(self.num_features), axis=0), # intercept
161-
rand_fn_loc((self.data.design_loc.shape[1] - 1, self.num_features))
166+
np.expand_dims(rand_fn_ave([self.num_features]), axis=0), # intercept
167+
np.maximum(rand_fn_loc((self.data.design_loc.shape[1] - 1, self.num_features)),
168+
np.zeros([self.data.design_loc.shape[1] - 1, self.num_features]) + 1e-08)
162169
], axis=0)
163170
),
164171
coords={"loc_params": self.data.loc_params}
@@ -167,7 +174,8 @@ def generate_params(
167174
dims=self.param_shapes()["b_var"],
168175
data=np.log(
169176
np.concatenate([
170-
rand_fn_scale((self.data.design_scale.shape[1], self.num_features))
177+
np.maximum(rand_fn_scale((self.data.design_scale.shape[1], self.num_features)),
178+
np.zeros([self.data.design_scale.shape[1], self.num_features]) + 1e-08)
171179
], axis=0)
172180
),
173181
coords={"scale_params": self.data.scale_params}
@@ -217,12 +225,12 @@ def size_factors(self):
217225
def size_factors(self, data):
218226
if data is None and "size_factors" in self.data.coords:
219227
del self.data.coords["size_factors"]
220-
else:
221-
dims = self.param_shapes()["size_factors"]
222-
self.data.coords["size_factors"] = xr.DataArray(
223-
dims=dims,
224-
data=np.broadcast_to(data, [self.data.dims[d] for d in dims])
225-
)
228+
229+
dims = self.param_shapes()["size_factors"]
230+
self.data.coords["size_factors"] = xr.DataArray(
231+
dims=dims,
232+
data=np.broadcast_to(data, [self.data.dims[d] for d in dims])
233+
)
226234

227235
@property
228236
def a_var(self):

0 commit comments

Comments
 (0)