@@ -217,15 +217,16 @@ def constraint_system_from_star(
217217 sample_description = sample_description ,
218218 formula = formula ,
219219 as_categorical = as_categorical ,
220- constraints = constraints
220+ constraints = constraints ,
221+ return_type = "dataframe"
221222 )
222223 elif isinstance (constraints , tuple ) or isinstance (constraints , list ):
223224 cmat = constraint_matrix_from_string (
224225 dmat = dmat ,
225226 constraints = constraints
226227 )
227228 elif isinstance (constraints , np .ndarray ):
228- cmat = parse_constraints
229+ cmat = constraints
229230 elif constraints is None :
230231 cmat = None
231232 else :
@@ -238,7 +239,8 @@ def constraint_matrix_from_dict(
238239 sample_description : pd .DataFrame ,
239240 formula : str ,
240241 as_categorical : Union [bool , list ] = True ,
241- constraints : dict = {}
242+ constraints : dict = {},
243+ return_type : str = "dataframe"
242244) -> Tuple :
243245 """
244246 Create a design matrix from some sample description and a constraint matrix
@@ -303,9 +305,14 @@ def constraint_matrix_from_dict(
303305 # Build constraint matrix.
304306 constraints_ar = constraint_matrix_from_string (
305307 dmat = dmat ,
308+ coef_names = coef_names ,
306309 constraints = constraints_ls
307310 )
308311
312+ # Format return type
313+ if return_type == "dataframe" :
314+ dmat = pd .DataFrame (dmat , columns = coef_names )
315+
309316 return dmat , constraints_ar
310317
311318
@@ -362,6 +369,7 @@ def string_constraints_from_dict(
362369
363370def constraint_matrix_from_string (
364371 dmat : np .ndarray ,
372+ coef_names : list ,
365373 constraints : Union [Tuple [str , str ], List [str ]]
366374):
367375 r"""
@@ -375,10 +383,10 @@ def constraint_matrix_from_string(
375383 """
376384 assert len (constraints ) > 0 , "supply constraints"
377385
378- n_par_all = dmat .values . shape [1 ]
386+ n_par_all = dmat .shape [1 ]
379387 n_par_free = n_par_all - len (constraints )
380388
381- di = patsy .DesignInfo (dmat . coords [ "design_params" ]. values )
389+ di = patsy .DesignInfo (coef_names )
382390 constraint_ls = [di .linear_constraint (x ).coefs [0 ] for x in constraints ]
383391 idx_constr = np .asarray ([np .where (x == 1 )[0 ][0 ] for x in constraint_ls ])
384392 idx_depending = [np .where (x == 1 )[0 ][1 :] for x in constraint_ls ]
0 commit comments