Skip to content

Commit 5ff0c6a

Browse files
improved parameter name handling if constraints are given
1 parent be9234f commit 5ff0c6a

File tree

5 files changed

+50
-23
lines changed

5 files changed

+50
-23
lines changed

batchglm/data.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def design_matrix(
2121
as_categorical: Union[bool, list] = True,
2222
dmat: Union[pd.DataFrame, None] = None,
2323
return_type: str = "patsy",
24-
) -> Union[patsy.design_info.DesignMatrix, pd.DataFrame]:
24+
) -> Tuple[Union[patsy.design_info.DesignMatrix, pd.DataFrame], List[str]]:
2525
"""
2626
Create a design matrix from some sample description.
2727
@@ -62,6 +62,7 @@ def design_matrix(
6262
sample_description[col] = sample_description[col].astype("category")
6363

6464
dmat = patsy.dmatrix(formula, sample_description)
65+
coef_names = dmat.design_info.column_names
6566

6667
if return_type == "dataframe":
6768
df = pd.DataFrame(dmat, columns=dmat.design_info.column_names)
@@ -70,12 +71,12 @@ def design_matrix(
7071

7172
return df
7273
elif return_type == "patsy":
73-
return dmat
74+
return dmat, coef_names
7475
else:
7576
raise ValueError("return type %s not recognized" % return_type)
7677
else:
7778
if return_type == "dataframe":
78-
return dmat
79+
return dmat, dmat.columns
7980
elif return_type == "patsy":
8081
raise ValueError("return type 'patsy' not supported for input (dmat is not None)")
8182
else:
@@ -134,7 +135,7 @@ def preview_coef_names(
134135

135136

136137
def constraint_system_from_star(
137-
dmat: Union[None, np.ndarray] = None,
138+
dmat: Union[None, patsy.design_info.DesignMatrix] = None,
138139
sample_description: Union[None, pd.DataFrame] = None,
139140
formula: Union[None, str] = None,
140141
as_categorical: Union[bool, list] = True,
@@ -202,7 +203,7 @@ def constraint_system_from_star(
202203
raise ValueError("supply either sample_description or dmat")
203204

204205
if dmat is None and not isinstance(constraints, dict):
205-
dmat = design_matrix(
206+
dmat, coef_names = design_matrix(
206207
sample_description=sample_description,
207208
formula=formula,
208209
as_categorical=as_categorical,
@@ -213,39 +214,49 @@ def constraint_system_from_star(
213214
raise ValueError("dmat was supplied even though constraints were given as dict")
214215

215216
if isinstance(constraints, dict):
216-
dmat, cmat = constraint_matrix_from_dict(
217+
dmat, coef_names, cmat, term_names = constraint_matrix_from_dict(
217218
sample_description=sample_description,
218219
formula=formula,
219220
as_categorical=as_categorical,
220221
constraints=constraints,
221-
return_type="dataframe"
222+
return_type="patsy"
222223
)
223224
elif isinstance(constraints, tuple) or isinstance(constraints, list):
224-
cmat = constraint_matrix_from_string(
225+
cmat, coef_names = constraint_matrix_from_string(
225226
dmat=dmat,
227+
coef_names=dmat.design_info.column_names,
226228
constraints=constraints
227229
)
230+
term_names = None # not supported yet.
228231
elif isinstance(constraints, np.ndarray):
229232
cmat = constraints
233+
term_names = None
230234
elif constraints is None:
231235
cmat = None
236+
term_names = None
232237
else:
233238
raise ValueError("constraint format %s not recognized" % type(constraints))
234239

235-
return dmat, cmat
240+
return dmat, coef_names, cmat, term_names
236241

237242

238243
def constraint_matrix_from_dict(
239244
sample_description: pd.DataFrame,
240245
formula: str,
241246
as_categorical: Union[bool, list] = True,
242247
constraints: dict = {},
243-
return_type: str = "dataframe"
248+
return_type: str = "patsy"
244249
) -> Tuple:
245250
"""
246251
Create a design matrix from some sample description and a constraint matrix
247252
based on factor encoding of constrained parameter sets.
248253
254+
Note that we build a dataframe instead of a pasty.DesignMatrix here if constraints are used.
255+
This is done because we were not able to build a patsy.DesignMatrix of the constrained form
256+
required in this context. In those cases in which the return type cannot be patsy, we encourage the
257+
use of the returned term_names to perform term-wise slicing which is not supported by other
258+
design matrix return types.
259+
249260
:param sample_description: pandas.DataFrame of length "num_observations" containing explanatory variables as columns
250261
:param formula: model formula as string, describing the relations of the explanatory variables.
251262
@@ -270,7 +281,9 @@ def constraint_matrix_from_dict(
270281
271282
Can only group by non-constrained effects right now, use constraint_matrix_from_string
272283
for other cases.
273-
:return: a model design matrix
284+
:return:
285+
- model design matrix
286+
- term_names to allow slicing by factor if return type cannot be patsy.DesignMatrix
274287
"""
275288
assert len(constraints) > 0, "supply constraints"
276289
sample_description: pd.DataFrame = sample_description.copy()
@@ -287,10 +300,11 @@ def constraint_matrix_from_dict(
287300
# absorption of the first level of each factor for each constrained factor onto the
288301
# core matrix.
289302
formula_unconstrained = formula.split("+")
290-
formula_unconstrained = [x for x in formula_unconstrained if x not in constraints.keys()]
303+
formula_unconstrained = [x for x in formula_unconstrained if x.strip(" ") not in constraints.keys()]
291304
formula_unconstrained = "+".join(formula_unconstrained)
292305
dmat = patsy.dmatrix(formula_unconstrained, sample_description)
293306
coef_names = dmat.design_info.column_names
307+
term_names = dmat.design_info.term_names
294308

295309
constraints_ls = string_constraints_from_dict(
296310
sample_description=sample_description,
@@ -301,6 +315,7 @@ def constraint_matrix_from_dict(
301315
dmat_constrained_temp = patsy.highlevel.dmatrix("0+" + x, sample_description)
302316
dmat = np.hstack([dmat, dmat_constrained_temp])
303317
coef_names.extend(dmat_constrained_temp.design_info.column_names)
318+
term_names.extend(dmat_constrained_temp.design_info.term_names)
304319

305320
# Build constraint matrix.
306321
constraints_ar = constraint_matrix_from_string(
@@ -312,8 +327,7 @@ def constraint_matrix_from_dict(
312327
# Format return type
313328
if return_type == "dataframe":
314329
dmat = pd.DataFrame(dmat, columns=coef_names)
315-
316-
return dmat, constraints_ar
330+
return dmat, coef_names, constraints_ar, term_names
317331

318332

319333
def string_constraints_from_dict(
@@ -407,7 +421,7 @@ def constraint_matrix_from_string(
407421
constraint_mat[i, idx_unconstr_i] = 1
408422

409423
# Test unconstrained subset design matrix for being full rank before returning constraints:
410-
dmat_var =dmat[:, idx_unconstr]
424+
dmat_var = dmat[:, idx_unconstr]
411425
if np.linalg.matrix_rank(dmat_var) != np.linalg.matrix_rank(dmat_var.T):
412426
logging.getLogger("batchglm").error("constrained design matrix is not full rank")
413427

batchglm/models/base_glm/input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
design_matrix=design_loc,
8989
param_names=design_loc_names
9090
)
91-
design_scale, design_scale_names = parse_design(
91+
design_scale, design_scale_names = parse_design(
9292
design_matrix=design_scale,
9393
param_names=design_scale_names
9494
)

batchglm/models/base_glm/simulator.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pandas
55
import patsy
6-
from typing import Union
6+
from typing import Union, Tuple
77

88
from .model import _ModelGLM
99
from .external import _SimulatorBase
@@ -14,7 +14,7 @@ def generate_sample_description(
1414
num_conditions: int = 2,
1515
num_batches: int = 4,
1616
shuffle_assignments=False
17-
) -> np.ndarray:
17+
) -> Tuple[patsy.DesignMatrix, pandas.DataFrame]:
1818
""" Build a sample description.
1919
2020
:param num_observations: Number of observations to simulate.
@@ -45,7 +45,7 @@ def generate_sample_description(
4545
observations=np.random.permutation(sample_description.observations.values)
4646
)
4747

48-
return np.asarray(patsy.dmatrix("~1+condition+batch", sample_description)), sample_description
48+
return patsy.dmatrix("~1+condition+batch", sample_description), sample_description
4949

5050

5151
class _SimulatorGLM(_SimulatorBase, metaclass=abc.ABCMeta):
@@ -73,6 +73,7 @@ def __init__(
7373
self.sample_description = None
7474
self.sim_a_var = None
7575
self.sim_b_var = None
76+
self._size_factors = None
7677

7778
def generate_sample_description(
7879
self,
@@ -139,6 +140,10 @@ def _generate_params(
139140
rand_fn_scale((self.sim_design_scale.shape[1], self.nfeatures))
140141
], axis=0)
141142

143+
@property
144+
def size_factors(self):
145+
return self._size_factors
146+
142147
@property
143148
def a_var(self):
144149
return self.sim_a_var

batchglm/models/base_glm/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def parse_design(
3737
assert False
3838

3939
if param_names is not None:
40-
if params is not None:
41-
assert len(param_names) == len(params)
40+
if params is None:
41+
assert len(param_names) == dmat.shape[1]
4242
params = param_names
4343

4444
return dmat, params
@@ -62,8 +62,11 @@ def parse_constraints(
6262
constraints = np.identity(n=dmat.shape[1])
6363
constraint_params = dmat_par_names
6464
else:
65-
# Cannot use given parameter names if constraint matrix is not identity: Make up new ones.
66-
constraint_params = ["var_"+str(x) for x in range(constraints.shape[1])]
65+
# Cannot use all parameter names if constraint matrix is not identity: Make up new ones.
66+
# Use variable names that can be mapped (unconstrained).
67+
constraint_params = ["var_"+str(i) if np.sum(constraints[:, i] != 0) > 1
68+
else dmat_par_names[np.where(constraints[:, i] != 0)[0][0]]
69+
for i in range(constraints.shape[1])]
6770
assert constraints.shape[0] == dmat.shape[1], "constraint dimension mismatch"
6871

6972
if constraint_par_names is not None:

batchglm/models/glm_nb/simulator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ def __init__(
1515
num_observations=1000,
1616
num_features=100
1717
):
18+
Model.__init__(
19+
self=self,
20+
input_data=None
21+
)
1822
_SimulatorGLM.__init__(
1923
self=self,
2024
model=None,
@@ -53,3 +57,4 @@ def generate_data(self):
5357
design_loc_names=None,
5458
design_scale_names=None
5559
)
60+

0 commit comments

Comments
 (0)