Skip to content

Commit 054d729

Browse files
added unit tests for anndata raw
1 parent c0a5fd5 commit 054d729

File tree

3 files changed

+43
-36
lines changed

3 files changed

+43
-36
lines changed

batchglm/data.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,26 @@ def xarray_from_data(
6767
:param dims: tuple or list with two strings. Specifies the names of the xarray dimensions.
6868
:return: xr.DataArray of shape `dims`
6969
"""
70-
if anndata is not None and isinstance(data, anndata.AnnData):
70+
if anndata is not None and (isinstance(data, anndata.AnnData) or isinstance(data, anndata.base.Raw)):
71+
# Anndata.raw does not have obs_names.
72+
if isinstance(data, anndata.AnnData):
73+
obs_names = np.asarray(data.obs_names)
74+
else:
75+
obs_names = ["obs_" + str(i) for i in range(data.X.shape[0])]
76+
7177
if scipy.sparse.issparse(data.X):
7278
# X = _sparse_to_xarray(data.X, dims=dims)
7379
# X.coords[dims[0]] = np.asarray(data.obs_names)
7480
# X.coords[dims[1]] = np.asarray(data.var_names)
7581
X = SparseXArrayDataSet(
7682
X=data.X,
77-
obs_names=np.asarray(data.obs_names),
83+
obs_names=np.asarray(obs_names),
7884
feature_names=np.asarray(data.var_names),
7985
dims=dims
8086
)
8187
else:
82-
X = data.X
83-
X = xr.DataArray(X, dims=dims, coords={
84-
dims[0]: np.asarray(data.obs_names),
88+
X = xr.DataArray(data.X, dims=dims, coords={
89+
dims[0]: np.asarray(obs_names),
8590
dims[1]: np.asarray(data.var_names),
8691
})
8792
elif isinstance(data, xr.Dataset):

batchglm/unit_test/base_glm/test_data_types_glm.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -91,36 +91,41 @@ def basic_test(
9191
estimator = self.get_estimator(input_data=input_data)
9292
return estimator.test_estimation()
9393

94-
def _test_numpy_dense(self):
95-
return self.basic_test(
96-
data=self.sim.X,
97-
design_loc=self.sim.design_loc,
98-
design_scale=self.sim.design_scale
99-
)
94+
def _test_numpy(self, sparse):
95+
if sparse:
96+
X = scipy.sparse.csr_matrix(self.sim.X)
10097

101-
def _test_scipy_sparse(self):
102-
return self.basic_test(
103-
data=scipy.sparse.csr_matrix(self.sim.X),
98+
success = self.basic_test(
99+
data=X,
104100
design_loc=self.sim.design_loc,
105101
design_scale=self.sim.design_scale
106102
)
103+
assert success, "_test_anndata with sparse=%s did not work" % sparse
107104

108-
def _test_anndata_dense(self):
105+
def _test_anndata(self, sparse):
109106
adata = self.sim.data_to_anndata()
110-
return self.basic_test(
107+
if sparse:
108+
adata.X = scipy.sparse.csr_matrix(adata.X)
109+
110+
success = self.basic_test(
111111
data=adata,
112112
design_loc=self.sim.design_loc,
113113
design_scale=self.sim.design_scale
114114
)
115+
assert success, "_test_anndata with sparse=%s did not work" % sparse
115116

116-
def _test_anndata_sparse(self):
117+
def _test_anndata_raw(self, sparse):
117118
adata = self.sim.data_to_anndata()
118-
adata.X = scipy.sparse.csr_matrix(adata.X)
119-
return self.basic_test(
120-
data=adata,
119+
if sparse:
120+
adata.X = scipy.sparse.csr_matrix(adata.X)
121+
122+
adata.raw = adata
123+
success = self.basic_test(
124+
data=adata.raw,
121125
design_loc=self.sim.design_loc,
122126
design_scale=self.sim.design_scale
123127
)
128+
assert success, "_test_anndata with sparse=%s did not work" % sparse
124129

125130

126131
if __name__ == '__main__':

batchglm/unit_test/glm_all/test_data_types_glm_all.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,6 @@ def get_estimator(
9999
noise_model=self.noise_model
100100
)
101101

102-
def _test_standard(self):
103-
self.simulate()
104-
logger.debug("* Running tests on numpy/scipy")
105-
self._test_numpy_dense()
106-
self._test_scipy_sparse()
107-
108-
def _test_anndata(self):
109-
self.simulate()
110-
logger.debug("* Running tests on anndata")
111-
self._test_anndata_dense()
112-
self._test_anndata_sparse()
113-
114102

115103
class Test_DataTypes_GLM_NB(
116104
Test_DataTypes_GLM_ALL,
@@ -120,22 +108,31 @@ class Test_DataTypes_GLM_NB(
120108
Test whether training graphs work for negative binomial noise.
121109
"""
122110

123-
def test_standard_nb(self):
111+
def test_standard(self):
124112
logging.getLogger("tensorflow").setLevel(logging.ERROR)
125113
logging.getLogger("batchglm").setLevel(logging.WARNING)
126114
logger.error("Test_DataTypes_GLM_NB.test_standard_nb()")
127115

128116
self.noise_model = "nb"
129-
self._test_standard()
117+
self.simulate()
118+
self._test_numpy(sparse=False)
119+
self._test_numpy(sparse=True)
120+
121+
return True
130122

131-
def test_anndata_nb(self):
123+
def test_anndata(self):
132124
logging.getLogger("tensorflow").setLevel(logging.ERROR)
133125
logging.getLogger("batchglm").setLevel(logging.WARNING)
134126
logger.error("Test_DataTypes_GLM_NB.test_anndata_nb()")
135127

136128
self.noise_model = "nb"
137-
self._test_anndata()
129+
self.simulate()
130+
self._test_anndata(sparse=False)
131+
self._test_anndata(sparse=True)
132+
self._test_anndata_raw(sparse=False)
133+
self._test_anndata_raw(sparse=True)
138134

135+
return True
139136

140137
if __name__ == '__main__':
141138
unittest.main()

0 commit comments

Comments
 (0)