Skip to content
Open
4 changes: 4 additions & 0 deletions batchglm/models/base/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def fisher_inv(self):
def x(self) -> np.ndarray:
return self.input_data.x

@property
def w(self) -> np.ndarray:
return self.input_data.w

@property
def a_var(self):
if isinstance(self.model.a_var, dask.array.core.Array):
Expand Down
1 change: 1 addition & 0 deletions batchglm/models/base/external.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
import batchglm.pkg_constants as pkg_constants
import batchglm.data as data_utils
import batchglm.types as types
46 changes: 40 additions & 6 deletions batchglm/models/base/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
import scipy.sparse
import sparse
from typing import List
from typing import List, Union, Optional
from .external import types as T

try:
import anndata
Expand All @@ -29,13 +30,14 @@ class InputDataBase:

def __init__(
self,
data,
observation_names=None,
feature_names=None,
data: T.InputType,
weights: Optional[Union[T.ArrayLike, str]] = None,
observation_names: Optional[List[str]] = None,
feature_names: Optional[List[str]] = None,
chunk_size_cells: int = 100000,
chunk_size_genes: int = 100,
as_dask: bool = True,
cast_dtype=None
cast_dtype: Optional[np.dtype] = None
):
"""
Create a new InputData object.
Expand All @@ -53,22 +55,44 @@ def __init__(
"""
self.observations = observation_names
self.features = feature_names
self.w = weights

if isinstance(data, np.ndarray) or \
isinstance(data, scipy.sparse.csr_matrix) or \
isinstance(data, dask.array.core.Array):
self.x = data
elif isinstance(data, anndata.AnnData) or isinstance(data, Raw):
self.x = data.X
if isinstance(weights, str):
self.w = data.obs[weights].values
elif isinstance(data, InputDataBase):
self.x = data.x
self.w = data.w
else:
raise ValueError("type of data %s not recognized" % type(data))

if self.w is None:
self.w = np.ones(self.x.shape[0], dtype=self.x.dtype)

if scipy.sparse.issparse(self.w):
self.w = self.w.toarray()
if self.w.ndim == 2:
self.w = self.w.squeeze(1)

# sanity checks
assert self.w.shape == (self.x.shape[0],), "invalid weight shape %s" % self.w.shape
assert issubclass(self.w.dtype.type, np.floating)

if self.observations is not None:
assert len(self.observations) == self.x.shape[0]
if self.features is not None:
assert len(self.features) == self.x.shape[1]

if as_dask:
if isinstance(self.x, dask.array.core.Array):
self.x = self.x.compute()
# Need to wrap dask around the COO matrix version of the sparse package if matrix is sparse.
if isinstance(self.x, scipy.sparse.spmatrix):
if scipy.sparse.issparse(self.x):
self.x = dask.array.from_array(
sparse.COO.from_scipy_sparse(
self.x.astype(cast_dtype if cast_dtype is not None else self.x.dtype)
Expand All @@ -81,11 +105,21 @@ def __init__(
self.x.astype(cast_dtype if cast_dtype is not None else self.x.dtype),
chunks=(chunk_size_cells, chunk_size_genes),
)

if isinstance(self.w, dask.array.core.Array):
self.w = self.w.compute()
self.w = dask.array.from_array(
self.w.astype(cast_dtype if cast_dtype is not None else self.w.dtype),
chunks=(chunk_size_cells,),
)
else:
if isinstance(self.x, dask.array.core.Array):
self.x = self.x.compute()
if isinstance(self.w, dask.array.core.Array):
self.w = self.w.compute()
if cast_dtype is not None:
self.x = self.x.astype(cast_dtype)
self.w = self.w.astype(cast_dtype)

self._feature_allzero = np.sum(self.x, axis=0) == 0
self.chunk_size_cells = chunk_size_cells
Expand Down
4 changes: 4 additions & 0 deletions batchglm/models/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def __init__(
def x(self):
return self.input_data.x

@property
def w(self):
return self.input_data.w

def get(self, key: Union[str, Iterable]) -> Union[Any, Dict[str, Any]]:
"""
Returns the values specified by key.
Expand Down
3 changes: 2 additions & 1 deletion batchglm/models/base_glm/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from batchglm.models.base import _SimulatorBase

import batchglm.data as data_utils
from batchglm.utils.linalg import groupwise_solve_lm
from batchglm.utils.linalg import groupwise_solve_lm
import batchglm.types as types
4 changes: 2 additions & 2 deletions batchglm/models/base_glm/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import numpy as np
import pandas as pd
import patsy
import scipy.sparse
from typing import Union

from .utils import parse_constraints, parse_design
from .external import InputDataBase
from .external import types as T


class InputDataGLM(InputDataBase):
Expand All @@ -25,7 +25,7 @@ class InputDataGLM(InputDataBase):

def __init__(
self,
data: Union[np.ndarray, anndata.AnnData, scipy.sparse.csr_matrix],
data: T.InputType,
design_loc: Union[np.ndarray, pd.DataFrame, patsy.design_info.DesignMatrix] = None,
design_loc_names: Union[list, np.ndarray] = None,
design_scale: Union[np.ndarray, pd.DataFrame, patsy.design_info.DesignMatrix] = None,
Expand Down
1 change: 0 additions & 1 deletion batchglm/train/numpy/base_glm/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import sparse
import sys
import time
from typing import Tuple

from .external import _EstimatorGLM, pkg_constants
from .training_strategies import TrainingStrategies
Expand Down
14 changes: 7 additions & 7 deletions batchglm/train/numpy/base_glm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ def jac_b_j(self, j) -> np.ndarray:
# Make sure that dimensionality of sliced array is kept:
if isinstance(j, int) or isinstance(j, np.int32) or isinstance(j, np.int64):
j = [j]
w = self.jac_weight_b_j(j=j) # (observations x features)
xh = np.matmul(self.design_scale, self.constraints_scale) # (observations x inferred param)
return np.einsum(
'fob,of->fb',
np.einsum('ob,of->fob', xh, w),
xh
)
w = self.jac_weight_b_j(j=j) # (observations x features)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seemed like a bug, so I've changed it (I think this is one of the unused function, based on what PyCharm told me).

xh = np.matmul(self.design_scale, self.constraints_scale) # (observations x inferred param)
return np.einsum(
'fob,of->fb',
np.einsum('ob,of->fob', xh, w),
xh
)
3 changes: 1 addition & 2 deletions batchglm/train/numpy/base_glm/vars.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import dask.array
import numpy as np
import scipy.sparse
import abc


class ModelVarsGlm:
"""
Build variables to be optimzed and their constraints.
Build variables to be optimized and their constraints.

"""

Expand Down
28 changes: 15 additions & 13 deletions batchglm/train/numpy/glm_nb/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import scipy.sparse
import scipy.special
import sparse

from .external import Model, ModelIwls, InputDataGLM
from .processModel import ProcessModel
Expand Down Expand Up @@ -41,7 +40,7 @@ def fim_weight_aa(self):

:return: observations x features
"""
return - self.location * self.scale / (self.scale + self.location)
return - self.w * self.location * self.scale / (self.scale + self.location)

@property
def ybar(self) -> np.ndarray:
Expand All @@ -56,7 +55,7 @@ def fim_weight_aa_j(self, j):

:return: observations x features
"""
return - self.location_j(j=j) * self.scale_j(j=j) / (self.scale_j(j=j) + self.location_j(j=j))
return - self.w * self.location_j(j=j) * self.scale_j(j=j) / (self.scale_j(j=j) + self.location_j(j=j))

def ybar_j(self, j) -> np.ndarray:
"""
Expand Down Expand Up @@ -89,7 +88,7 @@ def jac_weight_b(self):
const1 = scipy.special.digamma(scale_plus_x) - scipy.special.digamma(scale)
const2 = - scale_plus_x / r_plus_mu
const3 = np.log(scale) + np.ones_like(scale) - np.log(r_plus_mu)
return scale * (const1 + const2 + const3)
return self.w * scale * (const1 + const2 + const3)

def jac_weight_b_j(self, j):
"""
Expand All @@ -111,7 +110,7 @@ def jac_weight_b_j(self, j):
const1 = scipy.special.digamma(scale_plus_x) - scipy.special.digamma(scale)
const2 = - scale_plus_x / r_plus_mu
const3 = np.log(scale) + np.ones_like(scale) - np.log(r_plus_mu)
return scale * (const1 + const2 + const3)
return self.w * scale * (const1 + const2 + const3)

@property
def fim_ab(self) -> np.ndarray:
Expand Down Expand Up @@ -150,6 +149,7 @@ def hessian_weight_ab(self):
scale = self.scale
loc = self.location
return np.multiply(
self.w,
loc * scale,
np.asarray(self.x - loc) / np.square(loc + scale)
)
Expand All @@ -163,7 +163,7 @@ def hessian_weight_aa(self):
else:
x_by_scale_plus_one = np.asarray(self.x.divide(scale) + np.ones_like(scale))

return - loc * x_by_scale_plus_one / np.square((loc / scale) + np.ones_like(loc))
return - self.w * loc * x_by_scale_plus_one / np.square((loc / scale) + np.ones_like(loc))

@property
def hessian_weight_bb(self):
Expand All @@ -176,7 +176,7 @@ def hessian_weight_bb(self):
const2 = - scipy.special.digamma(scale) + scale * scipy.special.polygamma(n=1, x=scale)
const3 = - loc * scale_plus_x + np.ones_like(scale) * 2. * scale * scale_plus_loc / np.square(scale_plus_loc)
const4 = np.log(scale) + np.ones_like(scale) * 2. - np.log(scale_plus_loc)
return scale * (const1 + const2 + const3 + const4)
return self.w * scale * (const1 + const2 + const3 + const4)

@property
def ll(self):
Expand All @@ -198,7 +198,7 @@ def ll(self):
np.asarray(self.x.multiply(self.eta_loc - log_r_plus_mu) +
np.multiply(scale, self.eta_scale - log_r_plus_mu))
ll = np.asarray(ll)
return self.np_clip_param(ll, "ll")
return self.np_clip_param(self.w * ll, "ll")

def ll_j(self, j):
# Make sure that dimensionality of sliced array is kept:
Expand All @@ -222,10 +222,11 @@ def ll_j(self, j):
np.asarray(self.x[:, j].multiply(self.eta_loc_j(j=j) - log_r_plus_mu) +
np.multiply(scale, self.eta_scale_j(j=j) - log_r_plus_mu))
ll = np.asarray(ll)
return self.np_clip_param(ll, "ll")
return self.np_clip_param(self.w * ll, "ll")

# TODO: not used
def ll_handle(self):
def fun(x, eta_loc, b_var, xh_scale):
def fun(x, w, eta_loc, b_var, xh_scale):
eta_scale = np.matmul(xh_scale, b_var)
scale = np.exp(eta_scale)
loc = np.exp(eta_loc)
Expand All @@ -239,11 +240,12 @@ def fun(x, eta_loc, b_var, xh_scale):
np.multiply(scale, eta_scale - log_r_plus_mu)
else:
raise ValueError("type x %s not supported" % type(x))
return self.np_clip_param(ll, "ll")
return self.np_clip_param(w * ll, "ll")
return fun

# TODO: not used
def jac_b_handle(self):
def fun(x, eta_loc, b_var, xh_scale):
def fun(x, w, eta_loc, b_var, xh_scale):
scale = np.exp(b_var)
loc = np.exp(eta_loc)
scale_plus_x = scale + x
Expand All @@ -253,6 +255,6 @@ def fun(x, eta_loc, b_var, xh_scale):
const1 = scipy.special.digamma(scale_plus_x) - scipy.special.digamma(scale)
const2 = - scale_plus_x / r_plus_mu
const3 = np.log(scale) + np.ones_like(scale) - np.log(r_plus_mu)
return scale * (const1 + const2 + const3)
return w * scale * (const1 + const2 + const3)

return fun
13 changes: 13 additions & 0 deletions batchglm/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import TypeVar, Union

import scipy
import dask
import numpy as np

try:
from anndata import AnnData
except ImportError:
AnnData = TypeVar("AnnData")

ArrayLike = Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array]
InputType = Union[ArrayLike, AnnData, "InputDataBase"]