Skip to content

Commit 750fadf

Browse files
further adaptions to new constraints interface
1 parent 9e556fa commit 750fadf

File tree

2 files changed

+84
-30
lines changed

2 files changed

+84
-30
lines changed

batchglm/api/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
from batchglm.data import load_recursive_mtx
99
from batchglm.data import xarray_from_data
1010
from batchglm.data import setup_constrained, constraint_matrix_from_string
11-
from batchglm.data import view_coef_names
11+
from batchglm.data import view_coef_names, preview_coef_names

batchglm/data.py

Lines changed: 83 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,17 @@ def xarray_from_data(
118118

119119

120120
def design_matrix(
121-
sample_description: pd.DataFrame,
122-
formula: str,
121+
sample_description: Union[pd.DataFrame, None] = None,
122+
formula: Union[str, None] = None,
123123
as_categorical: Union[bool, list] = True,
124+
dmat: Union[pd.DataFrame, None] = None,
124125
return_type: str = "xarray",
125126
) -> Union[patsy.design_info.DesignMatrix, xr.Dataset, pd.DataFrame]:
126127
"""
127128
Create a design matrix from some sample description.
128-
129+
130+
This function defaults to perform formatting if dmat is directly supplied as a pd.DataFrame.
131+
129132
:param sample_description: pandas.DataFrame of length "num_observations" containing explanatory variables as columns
130133
:param formula: model formula as string, describing the relations of the explanatory variables.
131134
@@ -138,39 +141,58 @@ def design_matrix(
138141
is True.
139142
140143
Set to false, if columns should not be changed.
144+
:param dmat: a model design matrix as a pd.DataFrame
141145
:param return_type: type of the returned value.
142146
143-
- "matrix": return plain patsy.design_info.DesignMatrix object
147+
- "patsy": return plain patsy.design_info.DesignMatrix object
144148
- "dataframe": return pd.DataFrame with observations as rows and params as columns
145149
- "xarray": return xr.Dataset with design matrix as ds["design"] and the sample description embedded as
146150
one variable per column
147151
:return: a model design matrix
148152
"""
149-
sample_description: pd.DataFrame = sample_description.copy()
153+
if (dmat is None and sample_description is None) or \
154+
(dmat is not None and sample_description is not None):
155+
raise ValueError("supply either dmat or sample_description")
150156

151-
if type(as_categorical) is not bool or as_categorical:
152-
if type(as_categorical) is bool and as_categorical:
153-
as_categorical = np.repeat(True, sample_description.columns.size)
157+
if dmat is None:
158+
sample_description: pd.DataFrame = sample_description.copy()
154159

155-
for to_cat, col in zip(as_categorical, sample_description):
156-
if to_cat:
157-
sample_description[col] = sample_description[col].astype("category")
160+
if type(as_categorical) is not bool or as_categorical:
161+
if type(as_categorical) is bool and as_categorical:
162+
as_categorical = np.repeat(True, sample_description.columns.size)
163+
164+
for to_cat, col in zip(as_categorical, sample_description):
165+
if to_cat:
166+
sample_description[col] = sample_description[col].astype("category")
158167

159-
dmat = patsy.dmatrix(formula, sample_description)
168+
dmat = patsy.dmatrix(formula, sample_description)
160169

161-
if return_type == "dataframe":
162-
df = pd.DataFrame(dmat, columns=dmat.design_info.column_names)
163-
df = pd.concat([df, sample_description], axis=1)
164-
df.set_index(list(sample_description.columns), inplace=True)
170+
if return_type == "dataframe":
171+
df = pd.DataFrame(dmat, columns=dmat.design_info.column_names)
172+
df = pd.concat([df, sample_description], axis=1)
173+
df.set_index(list(sample_description.columns), inplace=True)
165174

166-
return df
167-
elif return_type == "xarray":
168-
ar = xr.DataArray(dmat, dims=("observations", "design_params"))
169-
ar.coords["design_params"] = dmat.design_info.column_names
175+
return df
176+
elif return_type == "xarray":
177+
ar = xr.DataArray(dmat, dims=("observations", "design_params"))
178+
ar.coords["design_params"] = dmat.design_info.column_names
170179

171-
return ar
180+
return ar
181+
elif return_type == "patsy":
182+
return dmat
183+
else:
184+
raise ValueError("return type %s not recognized" % return_type)
172185
else:
173-
return dmat
186+
if return_type == "dataframe":
187+
return dmat
188+
elif return_type == "xarray":
189+
ar = xr.DataArray(dmat, dims=("observations", "design_params"))
190+
ar.coords["design_params"] = dmat.columns
191+
return ar
192+
elif return_type == "patsy":
193+
raise ValueError("return type 'patsy' not supported for input (dmat is not None)")
194+
else:
195+
raise ValueError("return type %s not recognized" % return_type)
174196

175197

176198
def view_coef_names(
@@ -196,10 +218,42 @@ def view_coef_names(
196218
raise ValueError("dmat type %s not recognized" % type(dmat))
197219

198220

221+
def preview_coef_names(
222+
sample_description: pd.DataFrame,
223+
formula: str,
224+
as_categorical: Union[bool, list] = True
225+
) -> np.ndarray:
226+
"""
227+
Return coefficient names of model.
228+
229+
Use this to preview what the model would look like.
230+
231+
:param sample_description: pandas.DataFrame of length "num_observations" containing explanatory variables as columns
232+
:param formula: model formula as string, describing the relations of the explanatory variables.
233+
234+
E.g. '~ 1 + batch + confounder'
235+
:param as_categorical: boolean or list of booleans corresponding to the columns in 'sample_description'
236+
237+
If True, all values in 'sample_description' will be treated as categorical values.
238+
239+
If list of booleans, each column will be changed to categorical if the corresponding value in 'as_categorical'
240+
is True.
241+
242+
Set to false, if columns should not be changed.
243+
:return: A list of coefficient names.
244+
"""
245+
return view_coef_names(dmat=design_matrix(
246+
sample_description=sample_description,
247+
formula=formula,
248+
as_categorical=as_categorical,
249+
return_type="patsy"
250+
))
251+
252+
199253
def sample_description_from_xarray(
200254
dataset: xr.Dataset,
201255
dim: str,
202-
):
256+
) -> pd.DataFrame:
203257
"""
204258
Create a design matrix from a given xarray.Dataset and model formula.
205259
@@ -226,7 +280,7 @@ def design_matrix_from_xarray(
226280
formula=None,
227281
formula_key="formula",
228282
as_categorical=True,
229-
return_type="matrix",
283+
return_type="patsy",
230284
):
231285
"""
232286
Create a design matrix from a given xarray.Dataset and model formula.
@@ -275,7 +329,7 @@ def design_matrix_from_xarray(
275329
return dmat
276330

277331

278-
def sample_description_from_anndata(dataset: anndata.AnnData):
332+
def sample_description_from_anndata(dataset: anndata.AnnData) -> pd.DataFrame:
279333
"""
280334
Create a design matrix from a given xarray.Dataset and model formula.
281335
@@ -292,7 +346,7 @@ def design_matrix_from_anndata(
292346
formula=None,
293347
formula_key="formula",
294348
as_categorical=True,
295-
return_type="matrix",
349+
return_type="patsy",
296350
):
297351
r"""
298352
Create a design matrix from a given xarray.Dataset and model formula.
@@ -470,7 +524,7 @@ def setup_constrained(
470524
sample_description: pd.DataFrame,
471525
formula: str,
472526
as_numeric: Union[List[str], Tuple[str], str] = (),
473-
constraints: Union[Tuple[str], List[str]] = (),
527+
constraints: dict = {},
474528
dims: Union[Tuple[str], List[str]] = ()
475529
) -> Tuple:
476530
"""
@@ -486,8 +540,8 @@ def setup_constrained(
486540
not as categorical. This yields columns in the design matrix
487541
which do not correspond to one-hot encoded discrete factors.
488542
:param constraints: Grouped factors to enfore equality constraints on. Every element of
489-
the iteratable corresponds to one set of equality constraints. Each set has to be
490-
a dictionary of the form {x: y} where x is the factor to be constrained and y is
543+
the dictionary corresponds to one set of equality constraints. Each set has to be
544+
be an entry of the form {..., x: y, ...} where x is the factor to be constrained and y is
491545
a factor by which levels of x are grouped and then constrained. Set y="1" to constrain
492546
all levels of x to sum to one, a single equality constraint.
493547

0 commit comments

Comments
 (0)