Skip to content

Commit 8418e65

Browse files
fixed bugs in constraint parsing
1 parent 0f735f3 commit 8418e65

File tree

3 files changed

+27
-17
lines changed

3 files changed

+27
-17
lines changed

batchglm/data.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -483,30 +483,39 @@ def build_equality_constraints_string(
483483
E.g. ["batch1 + batch2 + batch3 = 0"]
484484
:return: a constraint matrix
485485
"""
486-
# TODO: automatically generate string constraints from factors
486+
n_par_all = dmat.data_vars['design'].values.shape[1]
487+
n_par_free = n_par_all - len(constraints)
488+
487489
di = patsy.DesignInfo(dmat.coords["design_params"].values)
488490
constraint_ls = [di.linear_constraint(x).coefs[0] for x in constraints]
489-
idx_constrained = [np.where(x == 1)[0][0] for x in constraint_ls]
490-
idx_unconstr = list(
491-
set(list(range(dmat.data_vars["design"].shape[1]))) -
492-
set(list(idx_constrained))
493-
)
491+
idx_constr = np.asarray([np.where(x == 1)[0][0] for x in constraint_ls])
492+
idx_depending = [np.where(x == 1)[0][1:] for x in constraint_ls]
493+
idx_unconstr = np.asarray(list(
494+
set(np.asarray(range(n_par_all))) - set(idx_constr)
495+
))
494496

495497
dmat_var = xr.DataArray(
496498
dims=[dmat.data_vars['design'].dims[0], "params"],
497499
data=dmat.data_vars["design"][:,idx_unconstr],
498500
coords={dmat.data_vars['design'].dims[0]: dmat.coords["observations"].values,
499501
"params": dmat.coords["design_params"].values[idx_unconstr]}
500502
)
501-
constraint_mat = np.vstack(constraint_ls)[:,idx_unconstr]
502503

503-
constraints = np.vstack([
504-
np.identity(n=len(idx_unconstr)),
505-
-constraint_mat
506-
])
504+
constraint_mat = np.zeros([n_par_all, n_par_free])
505+
for i in range(n_par_all):
506+
if i in idx_constr:
507+
idx_dep_i = idx_depending[np.where(idx_constr == i)[0][0]]
508+
idx_dep_i = np.asarray([np.where(idx_unconstr == x)[0] for x in idx_dep_i])
509+
constraint_mat[i, :] = 0
510+
constraint_mat[i, idx_dep_i] = -1
511+
else:
512+
idx_unconstr_i = np.where(idx_unconstr == i)
513+
constraint_mat[i, :] = 0
514+
constraint_mat[i, idx_unconstr_i] = 1
515+
507516
constraints_ar = parse_constraints(
508517
dmat=dmat,
509-
constraints=constraints,
518+
constraints=constraint_mat,
510519
dims=dims
511520
)
512521

batchglm/train/tf/glm_nb/estimator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,10 @@ def init_par(
158158

159159
if init_b.lower() == "closed_form":
160160
try:
161-
init_a_xr = data_utils.xarray_from_data(init_a, dims=("design_loc_params", "features"))
162-
init_a_xr.coords["design_loc_params"] = self.input_data.design_loc.coords["design_loc_params"]
163-
init_mu = np.exp(self.input_data.design_loc.dot(init_a_xr))
161+
init_a_xr = data_utils.xarray_from_data(init_a, dims=("loc_params", "features"))
162+
init_a_xr.coords["loc_params"] = self.input_data.constraints_loc.coords["loc_params"]
163+
# TODO: memory inefficient:
164+
init_mu = np.exp(self.input_data.design_loc.dot(self.input_data.constraints_loc.dot(init_a_xr)))
164165

165166
groupwise_scales, init_b, rmsd_b = closedform_nb_glm_logphi(
166167
X=self.input_data.X,

batchglm/utils/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def apply_fun(grouping):
7373
# Get unqiue rows of design matrix and vector with group assignments:
7474
unique_design, inverse_idx = np.unique(dmat, axis=0, return_inverse=True)
7575

76-
if unique_design.shape[1] > np.linalg.matrix_rank(unique_design):
77-
logger.warning("model is not full rank!")
76+
if unique_design.shape[1] > np.linalg.matrix_rank(np.matmul(unique_design, constraints)):
77+
logger.error("model is not full rank!")
7878

7979
# Get group-wise means in linker space based on group assignments
8080
# based on unique rows of design matrix:

0 commit comments

Comments
 (0)