Skip to content

Commit c63bb23

Browse files
Merge pull request #85 from theislab/dev
Dev
2 parents a92d55e + be9234f commit c63bb23

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

batchglm/data.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,16 @@ def constraint_system_from_star(
217217
sample_description=sample_description,
218218
formula=formula,
219219
as_categorical=as_categorical,
220-
constraints=constraints
220+
constraints=constraints,
221+
return_type="dataframe"
221222
)
222223
elif isinstance(constraints, tuple) or isinstance(constraints, list):
223224
cmat = constraint_matrix_from_string(
224225
dmat=dmat,
225226
constraints=constraints
226227
)
227228
elif isinstance(constraints, np.ndarray):
228-
cmat = parse_constraints
229+
cmat = constraints
229230
elif constraints is None:
230231
cmat = None
231232
else:
@@ -238,7 +239,8 @@ def constraint_matrix_from_dict(
238239
sample_description: pd.DataFrame,
239240
formula: str,
240241
as_categorical: Union[bool, list] = True,
241-
constraints: dict = {}
242+
constraints: dict = {},
243+
return_type: str = "dataframe"
242244
) -> Tuple:
243245
"""
244246
Create a design matrix from some sample description and a constraint matrix
@@ -303,9 +305,14 @@ def constraint_matrix_from_dict(
303305
# Build constraint matrix.
304306
constraints_ar = constraint_matrix_from_string(
305307
dmat=dmat,
308+
coef_names=coef_names,
306309
constraints=constraints_ls
307310
)
308311

312+
# Format return type
313+
if return_type == "dataframe":
314+
dmat = pd.DataFrame(dmat, columns=coef_names)
315+
309316
return dmat, constraints_ar
310317

311318

@@ -362,6 +369,7 @@ def string_constraints_from_dict(
362369

363370
def constraint_matrix_from_string(
364371
dmat: np.ndarray,
372+
coef_names: list,
365373
constraints: Union[Tuple[str, str], List[str]]
366374
):
367375
r"""
@@ -375,10 +383,10 @@ def constraint_matrix_from_string(
375383
"""
376384
assert len(constraints) > 0, "supply constraints"
377385

378-
n_par_all = dmat.values.shape[1]
386+
n_par_all = dmat.shape[1]
379387
n_par_free = n_par_all - len(constraints)
380388

381-
di = patsy.DesignInfo(dmat.coords["design_params"].values)
389+
di = patsy.DesignInfo(coef_names)
382390
constraint_ls = [di.linear_constraint(x).coefs[0] for x in constraints]
383391
idx_constr = np.asarray([np.where(x == 1)[0][0] for x in constraint_ls])
384392
idx_depending = [np.where(x == 1)[0][1:] for x in constraint_ls]

batchglm/models/base_glm/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def closedform_glm_scale(
137137
:return: tuple (groupwise_scales, logphi, rmsd)
138138
"""
139139
if size_factors is not None:
140-
x = np.divide(x, size_factors)
140+
x = x / size_factors
141141

142142
# to circumvent nonlocal error
143143
provided_groupwise_means = groupwise_means

batchglm/unit_test/test_acc_glm_all.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,6 @@ def basic_test(
306306
else:
307307
raise ValueError("noise model %s not recognized" % self.noise_model)
308308

309-
estimator = _TestAccuracyGlmAllEstim(
310-
simulator=self.simulator(train_loc=train_loc),
311-
quick_scale=False if train_scale else True,
312-
noise_model=self.noise_model,
313-
sparse=sparse,
314-
init_mode=init_mode
315-
)
316309
for algo in algos:
317310
logger.info("algorithm: %s" % algo)
318311
if algo in ["ADAM", "RMSPROP", "GD"]:
@@ -348,6 +341,13 @@ def basic_test(
348341
glm.pkg_constants.JACOBIAN_MODE = "analytic"
349342
else:
350343
return ValueError("algo %s not recognized" % algo)
344+
estimator = _TestAccuracyGlmAllEstim(
345+
simulator=self.simulator(train_loc=train_loc),
346+
quick_scale=False if train_scale else True,
347+
noise_model=self.noise_model,
348+
sparse=sparse,
349+
init_mode=init_mode
350+
)
351351
estimator.estimate(
352352
algo=algo,
353353
batched=batched,

0 commit comments

Comments
 (0)