@@ -118,14 +118,17 @@ def xarray_from_data(
118118
119119
120120def design_matrix (
121- sample_description : pd .DataFrame ,
122- formula : str ,
121+ sample_description : Union [ pd .DataFrame , None ] = None ,
122+ formula : Union [ str , None ] = None ,
123123 as_categorical : Union [bool , list ] = True ,
124+ dmat : Union [pd .DataFrame , None ] = None ,
124125 return_type : str = "xarray" ,
125126) -> Union [patsy .design_info .DesignMatrix , xr .Dataset , pd .DataFrame ]:
126127 """
127128 Create a design matrix from some sample description.
128-
129+
130+ This function defaults to perform formatting if dmat is directly supplied as a pd.DataFrame.
131+
129132 :param sample_description: pandas.DataFrame of length "num_observations" containing explanatory variables as columns
130133 :param formula: model formula as string, describing the relations of the explanatory variables.
131134
@@ -138,39 +141,58 @@ def design_matrix(
138141 is True.
139142
140143 Set to false, if columns should not be changed.
144+ :param dmat: a model design matrix as a pd.DataFrame
141145 :param return_type: type of the returned value.
142146
143- - "matrix ": return plain patsy.design_info.DesignMatrix object
147+ - "patsy ": return plain patsy.design_info.DesignMatrix object
144148 - "dataframe": return pd.DataFrame with observations as rows and params as columns
145149 - "xarray": return xr.Dataset with design matrix as ds["design"] and the sample description embedded as
146150 one variable per column
147151 :return: a model design matrix
148152 """
149- sample_description : pd .DataFrame = sample_description .copy ()
153+ if (dmat is None and sample_description is None ) or \
154+ (dmat is not None and sample_description is not None ):
155+ raise ValueError ("supply either dmat or sample_description" )
150156
151- if type (as_categorical ) is not bool or as_categorical :
152- if type (as_categorical ) is bool and as_categorical :
153- as_categorical = np .repeat (True , sample_description .columns .size )
157+ if dmat is None :
158+ sample_description : pd .DataFrame = sample_description .copy ()
154159
155- for to_cat , col in zip (as_categorical , sample_description ):
156- if to_cat :
157- sample_description [col ] = sample_description [col ].astype ("category" )
160+ if type (as_categorical ) is not bool or as_categorical :
161+ if type (as_categorical ) is bool and as_categorical :
162+ as_categorical = np .repeat (True , sample_description .columns .size )
163+
164+ for to_cat , col in zip (as_categorical , sample_description ):
165+ if to_cat :
166+ sample_description [col ] = sample_description [col ].astype ("category" )
158167
159- dmat = patsy .dmatrix (formula , sample_description )
168+ dmat = patsy .dmatrix (formula , sample_description )
160169
161- if return_type == "dataframe" :
162- df = pd .DataFrame (dmat , columns = dmat .design_info .column_names )
163- df = pd .concat ([df , sample_description ], axis = 1 )
164- df .set_index (list (sample_description .columns ), inplace = True )
170+ if return_type == "dataframe" :
171+ df = pd .DataFrame (dmat , columns = dmat .design_info .column_names )
172+ df = pd .concat ([df , sample_description ], axis = 1 )
173+ df .set_index (list (sample_description .columns ), inplace = True )
165174
166- return df
167- elif return_type == "xarray" :
168- ar = xr .DataArray (dmat , dims = ("observations" , "design_params" ))
169- ar .coords ["design_params" ] = dmat .design_info .column_names
175+ return df
176+ elif return_type == "xarray" :
177+ ar = xr .DataArray (dmat , dims = ("observations" , "design_params" ))
178+ ar .coords ["design_params" ] = dmat .design_info .column_names
170179
171- return ar
180+ return ar
181+ elif return_type == "patsy" :
182+ return dmat
183+ else :
184+ raise ValueError ("return type %s not recognized" % return_type )
172185 else :
173- return dmat
186+ if return_type == "dataframe" :
187+ return dmat
188+ elif return_type == "xarray" :
189+ ar = xr .DataArray (dmat , dims = ("observations" , "design_params" ))
190+ ar .coords ["design_params" ] = dmat .columns
191+ return ar
192+ elif return_type == "patsy" :
193+ raise ValueError ("return type 'patsy' not supported for input (dmat is not None)" )
194+ else :
195+ raise ValueError ("return type %s not recognized" % return_type )
174196
175197
176198def view_coef_names (
@@ -196,10 +218,42 @@ def view_coef_names(
196218 raise ValueError ("dmat type %s not recognized" % type (dmat ))
197219
198220
221+ def preview_coef_names (
222+ sample_description : pd .DataFrame ,
223+ formula : str ,
224+ as_categorical : Union [bool , list ] = True
225+ ) -> np .ndarray :
226+ """
227+ Return coefficient names of model.
228+
229+ Use this to preview what the model would look like.
230+
231+ :param sample_description: pandas.DataFrame of length "num_observations" containing explanatory variables as columns
232+ :param formula: model formula as string, describing the relations of the explanatory variables.
233+
234+ E.g. '~ 1 + batch + confounder'
235+ :param as_categorical: boolean or list of booleans corresponding to the columns in 'sample_description'
236+
237+ If True, all values in 'sample_description' will be treated as categorical values.
238+
239+ If list of booleans, each column will be changed to categorical if the corresponding value in 'as_categorical'
240+ is True.
241+
242+ Set to false, if columns should not be changed.
243+ :return: A list of coefficient names.
244+ """
245+ return view_coef_names (dmat = design_matrix (
246+ sample_description = sample_description ,
247+ formula = formula ,
248+ as_categorical = as_categorical ,
249+ return_type = "patsy"
250+ ))
251+
252+
199253def sample_description_from_xarray (
200254 dataset : xr .Dataset ,
201255 dim : str ,
202- ):
256+ ) -> pd . DataFrame :
203257 """
204258 Create a design matrix from a given xarray.Dataset and model formula.
205259
@@ -226,7 +280,7 @@ def design_matrix_from_xarray(
226280 formula = None ,
227281 formula_key = "formula" ,
228282 as_categorical = True ,
229- return_type = "matrix " ,
283+ return_type = "patsy " ,
230284):
231285 """
232286 Create a design matrix from a given xarray.Dataset and model formula.
@@ -275,7 +329,7 @@ def design_matrix_from_xarray(
275329 return dmat
276330
277331
278- def sample_description_from_anndata (dataset : anndata .AnnData ):
332+ def sample_description_from_anndata (dataset : anndata .AnnData ) -> pd . DataFrame :
279333 """
280334 Create a design matrix from a given xarray.Dataset and model formula.
281335
@@ -292,7 +346,7 @@ def design_matrix_from_anndata(
292346 formula = None ,
293347 formula_key = "formula" ,
294348 as_categorical = True ,
295- return_type = "matrix " ,
349+ return_type = "patsy " ,
296350):
297351 r"""
298352 Create a design matrix from a given xarray.Dataset and model formula.
@@ -470,7 +524,7 @@ def setup_constrained(
470524 sample_description : pd .DataFrame ,
471525 formula : str ,
472526 as_numeric : Union [List [str ], Tuple [str ], str ] = (),
473- constraints : Union [ Tuple [ str ], List [ str ]] = () ,
527+ constraints : dict = {} ,
474528 dims : Union [Tuple [str ], List [str ]] = ()
475529) -> Tuple :
476530 """
@@ -486,8 +540,8 @@ def setup_constrained(
486540 not as categorical. This yields columns in the design matrix
487541 which do not correspond to one-hot encoded discrete factors.
488542 :param constraints: Grouped factors to enfore equality constraints on. Every element of
489- the iteratable corresponds to one set of equality constraints. Each set has to be
490- a dictionary of the form {x: y} where x is the factor to be constrained and y is
543+ the dictionary corresponds to one set of equality constraints. Each set has to be
544+ be an entry of the form {..., x: y, ... } where x is the factor to be constrained and y is
491545 a factor by which levels of x are grouped and then constrained. Set y="1" to constrain
492546 all levels of x to sum to one, a single equality constraint.
493547
0 commit comments