Skip to content

Commit 1f6603a

Browse files
adapted code for new diffxpy constraint interface
1 parent 8f40a95 commit 1f6603a

File tree

2 files changed

+130
-61
lines changed

2 files changed

+130
-61
lines changed

batchglm/api/data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@
77
from batchglm.data import load_mtx_to_xarray
88
from batchglm.data import load_recursive_mtx
99
from batchglm.data import xarray_from_data
10+
from batchglm.data import setup_constrained, constraint_matrix_from_string
11+
from batchglm.data import view_coef_names

batchglm/data.py

Lines changed: 128 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
except ImportError:
2121
anndata = None
2222

23-
logger = logging.getLogger(__name__)
24-
2523

2624
def _sparse_to_xarray(data, dims):
2725
num_observations, num_features = data.shape
@@ -123,10 +121,10 @@ def design_matrix(
123121
sample_description: pd.DataFrame,
124122
formula: str,
125123
as_categorical: Union[bool, list] = True,
126-
return_type: str = "matrix",
124+
return_type: str = "xarray",
127125
) -> Union[patsy.design_info.DesignMatrix, xr.Dataset, pd.DataFrame]:
128126
"""
129-
Create a design matrix from some sample description
127+
Create a design matrix from some sample description.
130128
131129
:param sample_description: pandas.DataFrame of length "num_observations" containing explanatory variables as columns
132130
:param formula: model formula as string, describing the relations of the explanatory variables.
@@ -158,7 +156,7 @@ def design_matrix(
158156
if to_cat:
159157
sample_description[col] = sample_description[col].astype("category")
160158

161-
dmat = patsy.highlevel.dmatrix(formula, sample_description)
159+
dmat = patsy.dmatrix(formula, sample_description)
162160

163161
if return_type == "dataframe":
164162
df = pd.DataFrame(dmat, columns=dmat.design_info.column_names)
@@ -170,31 +168,47 @@ def design_matrix(
170168
ar = xr.DataArray(dmat, dims=("observations", "design_params"))
171169
ar.coords["design_params"] = dmat.design_info.column_names
172170

173-
ds = xr.Dataset({
174-
"design": ar,
175-
})
171+
return ar
172+
else:
173+
return dmat
174+
176175

177-
for col in sample_description:
178-
ds[col] = (("observations",), sample_description[col])
176+
def view_coef_names(
177+
dmat: Union[patsy.design_info.DesignMatrix, xr.Dataset, pd.DataFrame]
178+
) -> np.ndarray:
179+
"""
180+
Show names of coefficient in dmat.
181+
182+
This wrapper provides quick access to this object attribute across all supported frameworks.
179183
180-
return ds
184+
:param dmat: Design matrix.
185+
:return: Array of coefficient names.
186+
"""
187+
if isinstance(dmat, xr.DataArray):
188+
return dmat.coords["design_params"].values
189+
elif isinstance(dmat, xr.Dataset):
190+
return dmat.design.coords["design_params"].values
191+
elif isinstance(dmat, pd.DataFrame):
192+
return np.asarray(dmat.columns)
193+
elif isinstance(dmat, patsy.design_info.DesignMatrix):
194+
return np.asarray(dmat.design_info.column_names)
181195
else:
182-
return dmat
196+
raise ValueError("dmat type %s not recognized" % type(dmat))
183197

184198

185199
def sample_description_from_xarray(
186200
dataset: xr.Dataset,
187201
dim: str,
188202
):
189203
"""
190-
Create a design matrix from a given xarray.Dataset and model formula.
204+
Create a design matrix from a given xarray.Dataset and model formula.
191205
192-
:param dataset: xarray.Dataset containing explanatory variables.
193-
:param dim: name of the dimension for which the design matrix should be created.
206+
:param dataset: xarray.Dataset containing explanatory variables.
207+
:param dim: name of the dimension for which the design matrix should be created.
194208
195-
The design matrix will be of shape (dim, "design_params").
196-
:return: pd.DataFrame
197-
"""
209+
The design matrix will be of shape (dim, "design_params").
210+
:return: pd.DataFrame
211+
"""
198212

199213
explanatory_vars = [key for key, val in dataset.variables.items() if val.dims == (dim,)]
200214

@@ -397,7 +411,7 @@ def load_mtx_to_xarray(path):
397411
delim = "\t"
398412

399413
fpath = os.path.join(path, file)
400-
logger.info("Reading %s as gene annotation...", fpath)
414+
logging.getLogger("batchglm").info("Reading %s as gene annotation...", fpath)
401415
tbl = pd.read_csv(fpath, header=None, sep=delim)
402416
# retval["var"] = (["var_annotations", "features"], np.transpose(tbl))
403417
for col_id in tbl:
@@ -408,7 +422,7 @@ def load_mtx_to_xarray(path):
408422
delim = "\t"
409423

410424
fpath = os.path.join(path, file)
411-
logger.info("Reading %s as barcode file...", fpath)
425+
logging.getLogger("batchglm").info("Reading %s as barcode file...", fpath)
412426
tbl = pd.read_csv(fpath, header=None, sep=delim)
413427
# retval["obs"] = (["obs_annotations", "observations"], np.transpose(tbl))
414428
for col_id in tbl:
@@ -452,13 +466,13 @@ def load_recursive_mtx(dir_or_zipfile, target_format="xarray", cache=True) -> Di
452466
return adatas
453467

454468

455-
def build_equality_constraints(
469+
def setup_constrained(
456470
sample_description: pd.DataFrame,
457471
formula: str,
458-
constraints: List[str],
459-
dims: list,
460-
as_categorical: Union[bool, list] = True
461-
):
472+
as_numeric: Union[List[str], Tuple[str], str] = (),
473+
constraints: Union[Tuple[str], List[str]] = (),
474+
dims: Union[Tuple[str], List[str]] = ()
475+
) -> Tuple:
462476
"""
463477
Create a design matrix from some sample description and a constraint matrix
464478
based on factor encoding of constrained parameter sets.
@@ -467,55 +481,106 @@ def build_equality_constraints(
467481
:param formula: model formula as string, describing the relations of the explanatory variables.
468482
469483
E.g. '~ 1 + batch + confounder'
470-
:param constraints: List of constraints as strings, e.g. "x1 + x5 = 0".
484+
:param as_numeric:
485+
Which columns of sample_description to treat as numeric and
486+
not as categorical. This yields columns in the design matrix
487+
which do not correspond to one-hot encoded discrete factors.
488+
:param constraints: Grouped factors to enfore equality constraints on. Every element of
489+
the iteratable corresponds to one set of equality constraints. Each set has to be
490+
a dictionary of the form {x: y} where x is the factor to be constrained and y is
491+
a factor by which levels of x are grouped and then constrained. Set y="1" to constrain
492+
all levels of x to sum to one, a single equality constraint.
493+
494+
E.g.: {"batch": "condition"} Batch levels within each condition are constrained to sum to
495+
zero. This is applicable if repeats of a an experiment within each condition
496+
are independent so that the set-up ~1+condition+batch is perfectly confounded.
497+
498+
Can only group by non-constrained effects right now, use constraint_matrix_from_string
499+
for other cases.
500+
:param dims: ["design_loc_params", "loc_params"] or ["design_scale_params", "scale_params"]
501+
Dimension names of xarray.
502+
:return: a model design matrix
503+
"""
504+
assert len(constraints) > 0, "supply constraints"
505+
assert len(dims) == 2, "supply 2 dimension names in dim"
506+
sample_description: pd.DataFrame = sample_description.copy()
471507

472-
E.g. 'batch'
473-
:param as_categorical: boolean or list of booleans corresponding to the columns in 'sample_description'
508+
if isinstance(as_numeric, str):
509+
as_numeric = [as_numeric]
510+
as_categorical = [False if x in as_numeric else True for x in sample_description.columns.values]
511+
if type(as_categorical) is not bool or as_categorical:
512+
if type(as_categorical) is bool and as_categorical:
513+
as_categorical = np.repeat(True, sample_description.columns.size)
474514

475-
If True, all values in 'sample_description' will be treated as categorical values.
515+
for to_cat, col in zip(as_categorical, sample_description):
516+
if to_cat:
517+
sample_description[col] = sample_description[col].astype("category")
476518

477-
If list of booleans, each column will be changed to categorical if the corresponding value in 'as_categorical'
478-
is True.
519+
# Build core design matrix on unconstrained factors. Then add design matrices without
520+
# absorption of the first level of each factor for each constrained factor onto the
521+
# core matrix.
522+
formula_unconstrained = formula.split("+")
523+
formula_unconstrained = [x for x in formula_unconstrained if x not in constraints.keys()]
524+
formula_unconstrained = "+".join(formula_unconstrained)
525+
dmat = patsy.dmatrix(formula_unconstrained, sample_description)
526+
coef_names = dmat.design_info.column_names
527+
528+
constraints_ls = []
529+
for i, x in enumerate(constraints.keys()):
530+
assert isinstance(x, str), "constrained should contain strings"
531+
dmat_constrained_temp = patsy.highlevel.dmatrix("0+" + x, sample_description)
532+
dmat = np.hstack([dmat, dmat_constrained_temp])
533+
coef_names.extend(dmat_constrained_temp.design_info.column_names)
534+
535+
# Build slices by group.
536+
dmat_grouping_temp = patsy.highlevel.dmatrix("0+" + list(constraints.values())[i], sample_description)
537+
for j in range(dmat_grouping_temp.shape[1]):
538+
grouping = dmat_grouping_temp[:, j]
539+
idx_constrained_group = np.where(np.sum(dmat_constrained_temp[grouping == 1, :], axis=0) > 0)[0]
540+
# Assert that required grouping is nested.
541+
assert np.all(np.logical_xor(
542+
np.sum(dmat_constrained_temp[grouping == 1, :], axis=0) > 0,
543+
np.sum(dmat_constrained_temp[grouping == 0, :], axis=0) > 0
544+
)), "proposed grouping of constraints is not nested, read docstrings"
545+
# Add new string-encoded equality constraint.
546+
constraints_ls.append(
547+
"+".join(list(np.asarray(dmat_constrained_temp.design_info.column_names)[idx_constrained_group]))+"=0"
548+
)
479549

480-
Set to false, if columns should not be changed.
481-
:return: a model design matrix
482-
"""
483-
dmat = design_matrix(
484-
sample_description=sample_description,
485-
formula=formula,
486-
as_categorical=as_categorical,
487-
return_type="xarray"
488-
)
489-
# Parse list of factors to be constrained to list of
490-
# string encoded explicit constraint equations.
491-
constraint_ls = ["+".join(patsy.highlevel.dmatrix("~1-1+"+x, sample_description))+"=0"
492-
for x in constraints]
493-
logger.debug("constraints enforced are: "+",".join(constraint_ls))
494-
constraint_mat = build_equality_constraints_string(
495-
dmat=dmat,
550+
logging.getLogger("batchglm").warning("Built constraints: "+", ".join(constraints_ls))
551+
552+
# Parse design matrix to xarray.
553+
ar = xr.DataArray(dmat, dims=("observations", "design_params"))
554+
ar.coords["design_params"] = coef_names
555+
556+
# Build constraint matrix.
557+
constraints_ar = constraint_matrix_from_string(
558+
dmat=ds,
496559
constraints=constraints_ls,
497560
dims=dims
498561
)
499562

500-
return dmat, constraint_mat
563+
return ds, constraints_ar
501564

502565

503-
def build_equality_constraints_string(
566+
def constraint_matrix_from_string(
504567
dmat: Union[xr.DataArray, xr.Dataset],
505-
constraints: List[str],
568+
constraints: Union[Tuple[str], List[str]],
506569
dims: list
507570
):
508571
r"""
509-
Parser for string encoded equality constraints.
572+
Create constraint matrix form string encoded equality constraints.
510573
511574
:param dmat: Design matrix.
512575
:param constraints: List of constraints as strings.
513576
514577
E.g. ["batch1 + batch2 + batch3 = 0"]
515578
:param dims: ["design_loc_params", "loc_params"] or ["design_scale_params", "scale_params"]
516-
Define dimension names of xarray.
579+
Dimension names of xarray.
517580
:return: a constraint matrix
518581
"""
582+
assert len(constraints) > 0, "supply constraints"
583+
519584
if isinstance(dmat, xr.Dataset):
520585
dmat = dmat.data_vars['design']
521586
n_par_all = dmat.values.shape[1]
@@ -529,13 +594,6 @@ def build_equality_constraints_string(
529594
set(np.asarray(range(n_par_all))) - set(idx_constr)
530595
))
531596

532-
dmat_var = xr.DataArray(
533-
dims=[dmat.dims[0], "params"],
534-
data=dmat[:,idx_unconstr],
535-
coords={dmat.dims[0]: dmat.coords["observations"].values,
536-
"params": dmat.coords["design_params"].values[idx_unconstr]}
537-
)
538-
539597
constraint_mat = np.zeros([n_par_all, n_par_free])
540598
for i in range(n_par_all):
541599
if i in idx_constr:
@@ -555,8 +613,15 @@ def build_equality_constraints_string(
555613
)
556614

557615
# Test reduced design matrix for full rank before returning constraints:
616+
dmat_var = xr.DataArray(
617+
dims=[dmat.dims[0], "params"],
618+
data=dmat[:, idx_unconstr],
619+
coords={dmat.dims[0]: dmat.coords["observations"].values,
620+
"params": dmat.coords["design_params"].values[idx_unconstr]}
621+
)
622+
558623
if np.linalg.matrix_rank(dmat_var) != np.linalg.matrix_rank(dmat_var.T):
559-
logger.warning("constrained design matrix is not full rank")
624+
logging.getLogger("batchglm").error("constrained design matrix is not full rank")
560625

561626
return constraints_ar
562627

@@ -570,7 +635,9 @@ def parse_constraints(
570635
Parse constraint matrix into xarray.
571636
572637
:param dmat: Design matrix.
573-
:param constraints: a constraint matrix
638+
:param constraints: a constraint matrix.
639+
:param dims: ["design_loc_params", "loc_params"] or ["design_scale_params", "scale_params"]
640+
Dimension names of xarray
574641
:return: constraint matrix in xarray format
575642
"""
576643
if isinstance(dmat, xr.Dataset):

0 commit comments

Comments
 (0)