@@ -21,7 +21,7 @@ def design_matrix(
2121 as_categorical : Union [bool , list ] = True ,
2222 dmat : Union [pd .DataFrame , None ] = None ,
2323 return_type : str = "patsy" ,
24- ) -> Union [patsy .design_info .DesignMatrix , pd .DataFrame ]:
24+ ) -> Tuple [ Union [patsy .design_info .DesignMatrix , pd .DataFrame ], List [ str ] ]:
2525 """
2626 Create a design matrix from some sample description.
2727
@@ -62,6 +62,7 @@ def design_matrix(
6262 sample_description [col ] = sample_description [col ].astype ("category" )
6363
6464 dmat = patsy .dmatrix (formula , sample_description )
65+ coef_names = dmat .design_info .column_names
6566
6667 if return_type == "dataframe" :
6768 df = pd .DataFrame (dmat , columns = dmat .design_info .column_names )
@@ -70,12 +71,12 @@ def design_matrix(
7071
7172 return df
7273 elif return_type == "patsy" :
73- return dmat
74+ return dmat , coef_names
7475 else :
7576 raise ValueError ("return type %s not recognized" % return_type )
7677 else :
7778 if return_type == "dataframe" :
78- return dmat
79+ return dmat , dmat . columns
7980 elif return_type == "patsy" :
8081 raise ValueError ("return type 'patsy' not supported for input (dmat is not None)" )
8182 else :
@@ -134,7 +135,7 @@ def preview_coef_names(
134135
135136
136137def constraint_system_from_star (
137- dmat : Union [None , np . ndarray ] = None ,
138+ dmat : Union [None , patsy . design_info . DesignMatrix ] = None ,
138139 sample_description : Union [None , pd .DataFrame ] = None ,
139140 formula : Union [None , str ] = None ,
140141 as_categorical : Union [bool , list ] = True ,
@@ -202,7 +203,7 @@ def constraint_system_from_star(
202203 raise ValueError ("supply either sample_description or dmat" )
203204
204205 if dmat is None and not isinstance (constraints , dict ):
205- dmat = design_matrix (
206+ dmat , coef_names = design_matrix (
206207 sample_description = sample_description ,
207208 formula = formula ,
208209 as_categorical = as_categorical ,
@@ -213,39 +214,49 @@ def constraint_system_from_star(
213214 raise ValueError ("dmat was supplied even though constraints were given as dict" )
214215
215216 if isinstance (constraints , dict ):
216- dmat , cmat = constraint_matrix_from_dict (
217+ dmat , coef_names , cmat , term_names = constraint_matrix_from_dict (
217218 sample_description = sample_description ,
218219 formula = formula ,
219220 as_categorical = as_categorical ,
220221 constraints = constraints ,
221- return_type = "dataframe "
222+ return_type = "patsy "
222223 )
223224 elif isinstance (constraints , tuple ) or isinstance (constraints , list ):
224- cmat = constraint_matrix_from_string (
225+ cmat , coef_names = constraint_matrix_from_string (
225226 dmat = dmat ,
227+ coef_names = dmat .design_info .column_names ,
226228 constraints = constraints
227229 )
230+ term_names = None # not supported yet.
228231 elif isinstance (constraints , np .ndarray ):
229232 cmat = constraints
233+ term_names = None
230234 elif constraints is None :
231235 cmat = None
236+ term_names = None
232237 else :
233238 raise ValueError ("constraint format %s not recognized" % type (constraints ))
234239
235- return dmat , cmat
240+ return dmat , coef_names , cmat , term_names
236241
237242
238243def constraint_matrix_from_dict (
239244 sample_description : pd .DataFrame ,
240245 formula : str ,
241246 as_categorical : Union [bool , list ] = True ,
242247 constraints : dict = {},
243- return_type : str = "dataframe "
248+ return_type : str = "patsy "
244249) -> Tuple :
245250 """
246251 Create a design matrix from some sample description and a constraint matrix
247252 based on factor encoding of constrained parameter sets.
248253
254+ Note that we build a dataframe instead of a pasty.DesignMatrix here if constraints are used.
255+ This is done because we were not able to build a patsy.DesignMatrix of the constrained form
256+ required in this context. In those cases in which the return type cannot be patsy, we encourage the
257+ use of the returned term_names to perform term-wise slicing which is not supported by other
258+ design matrix return types.
259+
249260 :param sample_description: pandas.DataFrame of length "num_observations" containing explanatory variables as columns
250261 :param formula: model formula as string, describing the relations of the explanatory variables.
251262
@@ -270,7 +281,9 @@ def constraint_matrix_from_dict(
270281
271282 Can only group by non-constrained effects right now, use constraint_matrix_from_string
272283 for other cases.
273- :return: a model design matrix
284+ :return:
285+ - model design matrix
286+ - term_names to allow slicing by factor if return type cannot be patsy.DesignMatrix
274287 """
275288 assert len (constraints ) > 0 , "supply constraints"
276289 sample_description : pd .DataFrame = sample_description .copy ()
@@ -287,10 +300,11 @@ def constraint_matrix_from_dict(
287300 # absorption of the first level of each factor for each constrained factor onto the
288301 # core matrix.
289302 formula_unconstrained = formula .split ("+" )
290- formula_unconstrained = [x for x in formula_unconstrained if x not in constraints .keys ()]
303+ formula_unconstrained = [x for x in formula_unconstrained if x . strip ( " " ) not in constraints .keys ()]
291304 formula_unconstrained = "+" .join (formula_unconstrained )
292305 dmat = patsy .dmatrix (formula_unconstrained , sample_description )
293306 coef_names = dmat .design_info .column_names
307+ term_names = dmat .design_info .term_names
294308
295309 constraints_ls = string_constraints_from_dict (
296310 sample_description = sample_description ,
@@ -301,6 +315,7 @@ def constraint_matrix_from_dict(
301315 dmat_constrained_temp = patsy .highlevel .dmatrix ("0+" + x , sample_description )
302316 dmat = np .hstack ([dmat , dmat_constrained_temp ])
303317 coef_names .extend (dmat_constrained_temp .design_info .column_names )
318+ term_names .extend (dmat_constrained_temp .design_info .term_names )
304319
305320 # Build constraint matrix.
306321 constraints_ar = constraint_matrix_from_string (
@@ -312,8 +327,7 @@ def constraint_matrix_from_dict(
312327 # Format return type
313328 if return_type == "dataframe" :
314329 dmat = pd .DataFrame (dmat , columns = coef_names )
315-
316- return dmat , constraints_ar
330+ return dmat , coef_names , constraints_ar , term_names
317331
318332
319333def string_constraints_from_dict (
@@ -407,7 +421,7 @@ def constraint_matrix_from_string(
407421 constraint_mat [i , idx_unconstr_i ] = 1
408422
409423 # Test unconstrained subset design matrix for being full rank before returning constraints:
410- dmat_var = dmat [:, idx_unconstr ]
424+ dmat_var = dmat [:, idx_unconstr ]
411425 if np .linalg .matrix_rank (dmat_var ) != np .linalg .matrix_rank (dmat_var .T ):
412426 logging .getLogger ("batchglm" ).error ("constrained design matrix is not full rank" )
413427
0 commit comments