Skip to content

Commit f2aff55

Browse files
fixed some bugs in data parsing
1 parent 054d729 commit f2aff55

File tree

4 files changed

+157
-141
lines changed

4 files changed

+157
-141
lines changed

batchglm/data.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def fetch_X(idx):
5656

5757

5858
def xarray_from_data(
59-
data: Union[anndata.AnnData, xr.DataArray, xr.Dataset, np.ndarray],
59+
data: Union[anndata.AnnData, anndata.base.Raw, xr.DataArray, xr.Dataset, np.ndarray, scipy.sparse.csr_matrix],
6060
dims: Union[Tuple, List] = ("observations", "features")
6161
):
6262
"""
@@ -85,27 +85,34 @@ def xarray_from_data(
8585
dims=dims
8686
)
8787
else:
88-
X = xr.DataArray(data.X, dims=dims, coords={
89-
dims[0]: np.asarray(obs_names),
90-
dims[1]: np.asarray(data.var_names),
91-
})
88+
X = xr.DataArray(
89+
data.X,
90+
dims=dims,
91+
coords={
92+
dims[0]: np.asarray(obs_names),
93+
dims[1]: np.asarray(data.var_names),
94+
}
95+
)
9296
elif isinstance(data, xr.Dataset):
9397
X: xr.DataArray = data["X"]
9498
elif isinstance(data, xr.DataArray):
9599
X = data
100+
elif isinstance(data, SparseXArrayDataSet):
101+
X = data
102+
elif scipy.sparse.issparse(data):
103+
# X = _sparse_to_xarray(data, dims=dims)
104+
# X.coords[dims[0]] = np.asarray(data.obs_names)
105+
# X.coords[dims[1]] = np.asarray(data.var_names)
106+
X = SparseXArrayDataSet(
107+
X=data,
108+
obs_names=None,
109+
feature_names=None,
110+
dims=dims
111+
)
112+
elif isinstance(data, np.ndarray):
113+
X = xr.DataArray(data, dims=dims)
96114
else:
97-
if scipy.sparse.issparse(data):
98-
# X = _sparse_to_xarray(data, dims=dims)
99-
# X.coords[dims[0]] = np.asarray(data.obs_names)
100-
# X.coords[dims[1]] = np.asarray(data.var_names)
101-
X = SparseXArrayDataSet(
102-
X=data,
103-
obs_names=None,
104-
feature_names=None,
105-
dims=dims
106-
)
107-
else:
108-
X = xr.DataArray(data, dims=dims)
115+
raise ValueError("batchglm data parsing: data format %s not recognized" % type(data))
109116

110117
return X
111118

batchglm/train/tf/glm_nb/estimator.py

Lines changed: 120 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -106,134 +106,134 @@ def init_par(
106106
shape=[input_data.num_observations, input_data.num_features]
107107
)
108108

109-
groupwise_means = None
110-
init_a_str = None
111-
if isinstance(init_a, str):
112-
init_a_str = init_a.lower()
113-
# Chose option if auto was chosen
114-
if init_a.lower() == "auto":
115-
init_a = "closed_form"
116-
117-
if init_a.lower() == "closed_form":
118-
#try:
119-
groupwise_means, init_a, rmsd_a = closedform_nb_glm_logmu(
120-
X=input_data.X,
121-
design_loc=input_data.design_loc,
122-
constraints_loc=input_data.constraints_loc.values,
123-
size_factors=size_factors_init,
124-
link_fn=lambda mu: np.log(self.np_clip_param(mu, "mu"))
125-
)
126-
127-
# train mu, if the closed-form solution is inaccurate
128-
self._train_loc = not np.all(rmsd_a == 0)
129-
130-
if input_data.size_factors is not None:
131-
if np.any(input_data.size_factors != 1):
132-
self._train_loc = True
133-
134-
logger.debug("Using closed-form MLE initialization for mean")
135-
logger.debug("Should train mu: %s", self._train_loc)
136-
#except np.linalg.LinAlgError:
137-
# logger.warning("Closed form initialization failed!")
138-
elif init_a.lower() == "standard":
139-
if isinstance(input_data.X, SparseXArrayDataArray):
140-
overall_means = input_data.X.mean(dim="observations")
141-
else:
142-
overall_means = input_data.X.mean(dim="observations").values # directly calculate the mean
143-
overall_means = self.np_clip_param(overall_means, "mu")
144-
145-
init_a = np.zeros([input_data.num_loc_params, input_data.num_features])
146-
init_a[0, :] = np.log(overall_means)
147-
self._train_loc = True
148-
149-
logger.debug("Using standard initialization for mean")
150-
logger.debug("Should train mu: %s", self._train_loc)
151-
elif init_a.lower() == "all_zero":
152-
init_a = np.zeros([input_data.num_loc_params, input_data.num_features])
153-
self._train_loc = True
154-
155-
logger.debug("Using all_zero initialization for mean")
156-
logger.debug("Should train mu: %s", self._train_loc)
157-
else:
158-
raise ValueError("init_a string %s not recognized" % init_a)
159-
160-
if isinstance(init_b, str):
161-
if init_b.lower() == "auto":
162-
init_b = "standard"
163-
164-
if init_b.lower() == "closed_form" or init_b.lower() == "standard":
165-
#try:
166-
# Check whether it is necessary to recompute group-wise means.
167-
dmats_unequal = False
168-
if input_data.design_loc.shape[1] == input_data.design_scale.shape[1]:
169-
if np.any(input_data.design_loc.values != input_data.design_scale.values):
170-
dmats_unequal = True
171-
172-
inits_unequal = False
173-
if init_a_str is not None:
174-
if init_a_str != init_b:
175-
inits_unequal = True
176-
177-
if inits_unequal or dmats_unequal:
178-
groupwise_means = None
179-
180-
# Watch out: init_mu is full obs x features matrix and is very large in many cases.
181-
if inits_unequal or dmats_unequal:
109+
if init_model is None:
110+
groupwise_means = None
111+
init_a_str = None
112+
if isinstance(init_a, str):
113+
init_a_str = init_a.lower()
114+
# Chose option if auto was chosen
115+
if init_a.lower() == "auto":
116+
init_a = "closed_form"
117+
118+
if init_a.lower() == "closed_form":
119+
#try:
120+
groupwise_means, init_a, rmsd_a = closedform_nb_glm_logmu(
121+
X=input_data.X,
122+
design_loc=input_data.design_loc,
123+
constraints_loc=input_data.constraints_loc.values,
124+
size_factors=size_factors_init,
125+
link_fn=lambda mu: np.log(self.np_clip_param(mu, "mu"))
126+
)
127+
128+
# train mu, if the closed-form solution is inaccurate
129+
self._train_loc = not np.all(rmsd_a == 0)
130+
131+
if input_data.size_factors is not None:
132+
if np.any(input_data.size_factors != 1):
133+
self._train_loc = True
134+
135+
logger.debug("Using closed-form MLE initialization for mean")
136+
logger.debug("Should train mu: %s", self._train_loc)
137+
#except np.linalg.LinAlgError:
138+
# logger.warning("Closed form initialization failed!")
139+
elif init_a.lower() == "standard":
182140
if isinstance(input_data.X, SparseXArrayDataArray):
183-
init_mu = np.matmul(
184-
input_data.design_loc.values,
185-
np.matmul(input_data.constraints_loc.values, init_a)
186-
)
141+
overall_means = input_data.X.mean(dim="observations")
187142
else:
188-
init_a_xr = data_utils.xarray_from_data(init_a, dims=("loc_params", "features"))
189-
init_a_xr.coords["loc_params"] = input_data.constraints_loc.coords["loc_params"]
190-
init_mu = input_data.design_loc.dot(input_data.constraints_loc.dot(init_a_xr))
143+
overall_means = input_data.X.mean(dim="observations").values # directly calculate the mean
144+
overall_means = self.np_clip_param(overall_means, "mu")
191145

192-
if size_factors_init is not None:
193-
init_mu = init_mu + np.log(size_factors_init)
194-
init_mu = np.exp(init_mu)
195-
else:
196-
init_mu = None
146+
init_a = np.zeros([input_data.num_loc_params, input_data.num_features])
147+
init_a[0, :] = np.log(overall_means)
148+
self._train_loc = True
197149

198-
if init_b.lower() == "closed_form":
199-
groupwise_scales, init_b, rmsd_b = closedform_nb_glm_logphi(
200-
X=input_data.X,
201-
mu=init_mu,
202-
design_scale=input_data.design_scale,
203-
constraints=input_data.constraints_scale.values,
204-
size_factors=size_factors_init,
205-
groupwise_means=groupwise_means,
206-
link_fn=lambda r: np.log(self.np_clip_param(r, "r"))
207-
)
150+
logger.debug("Using standard initialization for mean")
151+
logger.debug("Should train mu: %s", self._train_loc)
152+
elif init_a.lower() == "all_zero":
153+
init_a = np.zeros([input_data.num_loc_params, input_data.num_features])
154+
self._train_loc = True
208155

209-
logger.debug("Using closed-form MME initialization for dispersion")
210-
logger.debug("Should train r: %s", self._train_scale)
211-
elif init_b.lower() == "standard":
212-
groupwise_scales, init_b_intercept, rmsd_b = closedform_nb_glm_logphi(
213-
X=input_data.X,
214-
mu=init_mu,
215-
design_scale=input_data.design_scale[:,[0]],
216-
constraints=input_data.constraints_scale[[0], [0]].values,
217-
size_factors=size_factors_init,
218-
groupwise_means=None,
219-
link_fn=lambda r: np.log(self.np_clip_param(r, "r"))
220-
)
156+
logger.debug("Using all_zero initialization for mean")
157+
logger.debug("Should train mu: %s", self._train_loc)
158+
else:
159+
raise ValueError("init_a string %s not recognized" % init_a)
160+
161+
if isinstance(init_b, str):
162+
if init_b.lower() == "auto":
163+
init_b = "standard"
164+
165+
if init_b.lower() == "closed_form" or init_b.lower() == "standard":
166+
#try:
167+
# Check whether it is necessary to recompute group-wise means.
168+
dmats_unequal = False
169+
if input_data.design_loc.shape[1] == input_data.design_scale.shape[1]:
170+
if np.any(input_data.design_loc.values != input_data.design_scale.values):
171+
dmats_unequal = True
172+
173+
inits_unequal = False
174+
if init_a_str is not None:
175+
if init_a_str != init_b:
176+
inits_unequal = True
177+
178+
if inits_unequal or dmats_unequal:
179+
groupwise_means = None
180+
181+
# Watch out: init_mu is full obs x features matrix and is very large in many cases.
182+
if inits_unequal or dmats_unequal:
183+
if isinstance(input_data.X, SparseXArrayDataArray):
184+
init_mu = np.matmul(
185+
input_data.design_loc.values,
186+
np.matmul(input_data.constraints_loc.values, init_a)
187+
)
188+
else:
189+
init_a_xr = data_utils.xarray_from_data(init_a, dims=("loc_params", "features"))
190+
init_a_xr.coords["loc_params"] = input_data.constraints_loc.coords["loc_params"]
191+
init_mu = input_data.design_loc.dot(input_data.constraints_loc.dot(init_a_xr))
192+
193+
if size_factors_init is not None:
194+
init_mu = init_mu + np.log(size_factors_init)
195+
init_mu = np.exp(init_mu)
196+
else:
197+
init_mu = None
198+
199+
if init_b.lower() == "closed_form":
200+
groupwise_scales, init_b, rmsd_b = closedform_nb_glm_logphi(
201+
X=input_data.X,
202+
mu=init_mu,
203+
design_scale=input_data.design_scale,
204+
constraints=input_data.constraints_scale.values,
205+
size_factors=size_factors_init,
206+
groupwise_means=groupwise_means,
207+
link_fn=lambda r: np.log(self.np_clip_param(r, "r"))
208+
)
209+
210+
logger.debug("Using closed-form MME initialization for dispersion")
211+
logger.debug("Should train r: %s", self._train_scale)
212+
elif init_b.lower() == "standard":
213+
groupwise_scales, init_b_intercept, rmsd_b = closedform_nb_glm_logphi(
214+
X=input_data.X,
215+
mu=init_mu,
216+
design_scale=input_data.design_scale[:,[0]],
217+
constraints=input_data.constraints_scale[[0], [0]].values,
218+
size_factors=size_factors_init,
219+
groupwise_means=None,
220+
link_fn=lambda r: np.log(self.np_clip_param(r, "r"))
221+
)
222+
init_b = np.zeros([input_data.num_scale_params, input_data.X.shape[1]])
223+
init_b[0, :] = init_b_intercept
224+
225+
logger.debug("Using closed-form MME initialization for dispersion")
226+
logger.debug("Should train r: %s", self._train_scale)
227+
#except np.linalg.LinAlgError:
228+
# logger.warning("Closed form initialization failed!")
229+
elif init_b.lower() == "all_zero":
221230
init_b = np.zeros([input_data.num_scale_params, input_data.X.shape[1]])
222-
init_b[0, :] = init_b_intercept
223231

224-
logger.debug("Using closed-form MME initialization for dispersion")
232+
logger.debug("Using standard initialization for dispersion")
225233
logger.debug("Should train r: %s", self._train_scale)
226-
#except np.linalg.LinAlgError:
227-
# logger.warning("Closed form initialization failed!")
228-
elif init_b.lower() == "all_zero":
229-
init_b = np.zeros([input_data.num_scale_params, input_data.X.shape[1]])
230-
231-
logger.debug("Using standard initialization for dispersion")
232-
logger.debug("Should train r: %s", self._train_scale)
233-
else:
234-
raise ValueError("init_b string %s not recognized" % init_b)
235-
236-
if init_model is not None:
234+
else:
235+
raise ValueError("init_b string %s not recognized" % init_b)
236+
else:
237237
# Locations model:
238238
if isinstance(init_a, str) and (init_a.lower() == "auto" or init_a.lower() == "init_model"):
239239
my_loc_names = set(input_data.design_loc_names.values)

batchglm/unit_test/base_glm/test_data_types_glm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ def basic_test(
9292
return estimator.test_estimation()
9393

9494
def _test_numpy(self, sparse):
95+
X = self.sim.X
9596
if sparse:
96-
X = scipy.sparse.csr_matrix(self.sim.X)
97+
X = scipy.sparse.csr_matrix(X)
9798

9899
success = self.basic_test(
99100
data=X,

batchglm/xarray_sparse/base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,23 +116,27 @@ def square(self, copy=False):
116116
else:
117117
self.X = new_x
118118

119-
def mean(self, dim=None):
119+
def mean(self, dim: str = None, axis: int = None):
120+
assert not (dim is not None and axis is not None), "only supply dim or axis"
120121
if dim is not None:
121122
assert dim in self.dims, "dim not recognized"
122123
axis = self.dims.index(dim)
123124
return np.asarray(self.X.mean(axis=axis)).flatten()
125+
elif axis is not None:
126+
assert axis < len(self.X.shape), "axis index out of range"
127+
return np.asarray(self.X.mean(axis=axis)).flatten()
124128
else:
125129
return np.asarray(self.X.mean()).flatten()
126130

127-
def var(self, dim):
131+
def var(self, dim: str):
128132
assert dim in self.dims, "dim not recognized"
129133
axis = self.dims.index(dim)
130134
Xsq = self.square(copy=True)
131135
expect_x_sq = np.square(self.mean(dim=dim))
132136
expect_xsq = np.mean(Xsq, axis=axis)
133137
return np.asarray(expect_xsq - expect_x_sq).flatten()
134138

135-
def std(self, dim):
139+
def std(self, dim: str):
136140
return np.sqrt(self.var(dim=dim))
137141

138142
def groupby(self, key):
@@ -215,6 +219,10 @@ def __init__(
215219
def ndim(self):
216220
return len(self.dims)
217221

222+
@property
223+
def shape(self):
224+
return self.X.shape
225+
218226
@property
219227
def feature_allzero(self):
220228
return self.X.coords["feature_allzero"]

0 commit comments

Comments
 (0)