Skip to content

Commit 1d0abc0

Browse files
committed
backport cleaned-up closed form initialization; fix some unit test bugs
1 parent 1af08a0 commit 1d0abc0

File tree

7 files changed

+193
-104
lines changed

7 files changed

+193
-104
lines changed

batchglm/api/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from . import stats
22
from . import random
3+
from . import numeric
4+
from . import linalg

batchglm/api/utils/linalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from batchglm.utils.linalg import stacked_lstsq, groupwise_solve_lm

batchglm/api/utils/numeric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from batchglm.utils.numeric import combine_matrices, softmax, weighted_mean, weighted_variance

batchglm/models/nb_glm/utils.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from typing import Union
2+
3+
import numpy as np
4+
import xarray as xr
5+
6+
from batchglm.utils.linalg import groupwise_solve_lm
7+
from batchglm.utils.numeric import weighted_mean
8+
from batchglm.models.glm import closedform_glm_mean
9+
10+
11+
def closedform_nb_glm_logmu(
12+
X: xr.DataArray,
13+
design_loc,
14+
constraints=None,
15+
size_factors=None,
16+
weights=None,
17+
link_fn=np.log
18+
):
19+
r"""
20+
Calculates a closed-form solution for the `mu` parameters of negative-binomial GLMs.
21+
22+
:param X: The sample data
23+
:param design_loc: design matrix for location
24+
:param constraints: some design constraints
25+
:param size_factors: size factors for X
26+
:param weights: the weights of the arrays' elements; if `none` it will be ignored.
27+
:return: tuple: (groupwise_means, mu, rmsd)
28+
"""
29+
return closedform_glm_mean(
30+
X=X,
31+
dmat=design_loc,
32+
constraints=constraints,
33+
size_factors=size_factors,
34+
weights=weights,
35+
link_fn=link_fn
36+
)
37+
38+
39+
def closedform_nb_glm_logphi(
40+
X: xr.DataArray,
41+
design_scale: xr.DataArray,
42+
constraints=None,
43+
size_factors=None,
44+
weights: Union[np.ndarray, xr.DataArray] = None,
45+
mu=None,
46+
groupwise_means=None,
47+
link_fn=np.log
48+
):
49+
r"""
50+
Calculates a closed-form solution for the log-scale parameters of negative-binomial GLMs.
51+
Based on the Method-of-Moments estimator.
52+
53+
:param X: The sample data
54+
:param design_scale: design matrix for scale
55+
:param constraints: some design constraints
56+
:param size_factors: size factors for X
57+
:param weights: the weights of the arrays' elements; if `none` it will be ignored.
58+
:param mu: optional, if there are for example different mu's per observation.
59+
60+
Used to calculate `Xdiff = X - mu`.
61+
:param groupwise_means: optional, in case if already computed this can be specified to spare double-calculation
62+
:return: tuple (groupwise_scales, logphi, rmsd)
63+
"""
64+
if size_factors is not None:
65+
X = np.divide(X, size_factors)
66+
67+
# to circumvent nonlocal error
68+
provided_groupwise_means = groupwise_means
69+
provided_weights = weights
70+
provided_mu = mu
71+
72+
def apply_fun(grouping):
73+
grouped_X = X.assign_coords(group=((X.dims[0],), grouping))
74+
75+
# convert weights into a xr.DataArray
76+
if provided_weights is not None:
77+
weights = xr.DataArray(
78+
data=provided_weights,
79+
dims=(X.dims[0],),
80+
coords={
81+
"group": ((X.dims[0],), grouping),
82+
}
83+
)
84+
else:
85+
weights = None
86+
87+
# calculate group-wise means if necessary
88+
if provided_groupwise_means is None:
89+
if weights is None:
90+
groupwise_means = grouped_X.mean(X.dims[0]).values
91+
else:
92+
# for each group: calculate weighted mean
93+
groupwise_means: xr.DataArray = xr.concat([
94+
weighted_mean(d, w, axis=0) for (g, d), (g, w) in zip(
95+
grouped_X.groupby("group"),
96+
weights.groupby("group"))
97+
], dim="group")
98+
else:
99+
groupwise_means = provided_groupwise_means
100+
101+
# calculated (x - mean) depending on whether `mu` was specified
102+
if provided_mu is None:
103+
Xdiff = grouped_X - groupwise_means
104+
else:
105+
Xdiff = grouped_X - provided_mu
106+
107+
if weights is None:
108+
# for each group:
109+
# calculate mean of (X - mean)^2
110+
variance = np.square(Xdiff).groupby("group").mean(X.dims[0])
111+
else:
112+
# for each group:
113+
# calculate weighted mean of (X - mean)^2
114+
variance: xr.DataArray = xr.concat([
115+
weighted_mean(d, w, axis=0) for (g, d), (g, w) in zip(
116+
np.square(Xdiff).groupby("group"),
117+
weights.groupby("group")
118+
)
119+
], dim="group")
120+
121+
denominator = np.fmax(variance - groupwise_means, np.sqrt(np.nextafter(0, 1, dtype=variance.dtype)))
122+
groupwise_scales = np.square(groupwise_means) / denominator
123+
124+
# # clipping
125+
# # r = np_clip_param(r, "r")
126+
# groupwise_scales = np.nextafter(0, 1, out=groupwise_scales,
127+
# where=groupwise_scales == 0,
128+
# dtype=groupwise_scales.dtype)
129+
# groupwise_scales = np.fmin(groupwise_scales, np.finfo(groupwise_scales.dtype).max)
130+
131+
return link_fn(groupwise_scales)
132+
133+
groupwise_scales, logphi, rmsd, rank, _ = groupwise_solve_lm(
134+
dmat=design_scale,
135+
apply_fun=apply_fun,
136+
constraints=constraints
137+
)
138+
139+
return groupwise_scales, logphi, rmsd

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 32 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .base import param_bounds, tf_clip_param, np_clip_param, apply_constraints
1919

2020
from .external import AbstractEstimator, XArrayEstimatorStore, InputData, Model, MonitoredTFEstimator, TFEstimatorGraph
21-
from .external import nb_utils, train_utils, op_utils, rand_utils, data_utils
21+
from .external import nb_utils, train_utils, op_utils, rand_utils, data_utils, nb_glm_utils
2222
from .external import pkg_constants
2323
from .hessians import Hessians
2424
from .jacobians import Jacobians
@@ -759,55 +759,23 @@ def __init__(
759759

760760
if init_a.lower() == "closed_form":
761761
try:
762-
unique_design_loc, inverse_idx = np.unique(input_data.design_loc, axis=0, return_inverse=True)
763-
if input_data.constraints_loc is not None:
764-
unique_design_loc_constraints = input_data.constraints_loc.copy()
765-
# -1 in the constraint matrix is used to indicate which variable
766-
# is made dependent so that the constrained is fullfilled.
767-
# This has to be rewritten here so that the design matrix is full rank
768-
# which is necessary so that it can be inverted for parameter
769-
# initialisation.
770-
unique_design_loc_constraints[unique_design_loc_constraints == -1] = 1
771-
# Add constraints into design matrix to remove structural unidentifiability.
772-
unique_design_loc = np.vstack([unique_design_loc, unique_design_loc_constraints])
773-
774-
if unique_design_loc.shape[1] > np.linalg.matrix_rank(unique_design_loc):
775-
logger.warning("Location model is not full rank!")
776-
X = input_data.X.assign_coords(group=(("observations",), inverse_idx))
777-
if size_factors_init is not None:
778-
X = np.divide(X, size_factors_init)
779-
780-
groupwise_means = X.groupby("group").mean(dim="observations").values
781-
# clipping
782-
groupwise_means = np_clip_param(groupwise_means, "mu")
783-
# mean = np.nextafter(0, 1, out=mean.values, where=mean == 0, dtype=mean.dtype)
784-
785-
a = np.log(groupwise_means)
786-
if input_data.constraints_loc is not None:
787-
a_constraints = np.zeros([input_data.constraints_loc.shape[0], a.shape[1]])
788-
# Add constraints (sum to zero) to value vector to remove structural unidentifiability.
789-
a = np.vstack([a, a_constraints])
790-
791-
# inv_design = np.linalg.pinv(unique_design_loc) # NOTE: this is numerically inaccurate!
792-
# inv_design = np.linalg.inv(unique_design_loc) # NOTE: this is exact if full rank!
793-
# init_a = np.matmul(inv_design, a)
794-
#
795-
# Use least-squares solver to calculate a':
796-
# This is faster and more accurate than using matrix inversion.
797-
logger.debug(" ** Solve lstsq problem")
798-
a_prime = np.linalg.lstsq(unique_design_loc, a, rcond=None)
799-
init_a = a_prime[0]
800-
# stat_utils.rmsd(np.exp(unique_design_loc @ init_a), mean)
762+
groupwise_means, init_a, rmsd_a = nb_glm_utils.closedform_nb_glm_logmu(
763+
X=input_data.X,
764+
design_loc=input_data.design_loc,
765+
constraints=input_data.constraints_loc,
766+
size_factors=size_factors_init,
767+
link_fn=lambda mu: np.log(np_clip_param(mu, "mu"))
768+
)
801769

802770
# train mu, if the closed-form solution is inaccurate
803-
self._train_mu = not np.all(a_prime[1] == 0)
771+
self._train_mu = not np.all(rmsd_a == 0)
804772

805773
# Temporal fix: train mu if size factors are given as closed form may be different:
806774
if input_data.size_factors is not None:
807775
self._train_mu = True
808776

809777
logger.info("Using closed-form MLE initialization for mean")
810-
logger.debug("RMSE of closed-form mean:\n%s", a_prime[1])
778+
logger.debug("RMSE of closed-form mean:\n%s", rmsd_a)
811779
logger.info("Should train mu: %s", self._train_mu)
812780
except np.linalg.LinAlgError:
813781
logger.warning("Closed form initialization failed!")
@@ -831,63 +799,22 @@ def __init__(
831799

832800
if init_b.lower() == "closed_form":
833801
try:
834-
unique_design_scale, inverse_idx = np.unique(input_data.design_scale, axis=0,
835-
return_inverse=True)
836-
if input_data.constraints_scale is not None:
837-
unique_design_scale_constraints = input_data.constraints_scale.copy()
838-
# -1 in the constraint matrix is used to indicate which variable
839-
# is made dependent so that the constrained is fullfilled.
840-
# This has to be rewritten here so that the design matrix is full rank
841-
# which is necessary so that it can be inverted for parameter
842-
# initialisation.
843-
unique_design_scale_constraints[unique_design_scale_constraints == -1] = 1
844-
# Add constraints into design matrix to remove structural unidentifiability.
845-
unique_design_scale = np.vstack([unique_design_scale, unique_design_scale_constraints])
846-
847-
if unique_design_scale.shape[1] > np.linalg.matrix_rank(unique_design_scale):
848-
logger.warning("Scale model is not full rank!")
849-
850-
X = input_data.X.assign_coords(group=(("observations",), inverse_idx))
851-
if input_data.size_factors is not None:
852-
X = np.divide(X, size_factors_init)
853-
854-
# Xdiff = X - np.exp(input_data.design_loc @ init_a)
855-
# Define xarray version of init so that Xdiff can be evaluated lazy by dask.
856802
init_a_xr = data_utils.xarray_from_data(init_a, dims=("design_loc_params", "features"))
857803
init_a_xr.coords["design_loc_params"] = input_data.design_loc.coords["design_loc_params"]
858-
logger.debug(" ** Define Xdiff")
859-
Xdiff = X - np.exp(input_data.design_loc.dot(init_a_xr))
860-
variance = np.square(Xdiff).groupby("group").mean(dim="observations")
861-
862-
if groupwise_means is None:
863-
groupwise_means = X.groupby("group").mean(dim="observations")
864-
denominator = np.fmax(variance - groupwise_means, 0)
865-
denominator = np.nextafter(0, 1, out=denominator.values, where=denominator == 0,
866-
dtype=denominator.dtype)
867-
r = np.asarray(np.square(groupwise_means) / denominator)
868-
# clipping
869-
r = np_clip_param(r, "r")
870-
# r = np.nextafter(0, 1, out=r.values, where=r == 0, dtype=r.dtype)
871-
# r = np.fmin(r, np.finfo(r.dtype).max)
872-
873-
b = np.log(r)
874-
if input_data.constraints_scale is not None:
875-
b_constraints = np.zeros([input_data.constraints_scale.shape[0], b.shape[1]])
876-
# Add constraints (sum to zero) to value vector to remove structural unidentifiability.
877-
b = np.vstack([b, b_constraints])
878-
879-
# inv_design = np.linalg.pinv(unique_design_scale) # NOTE: this is numerically inaccurate!
880-
# inv_design = np.linalg.inv(unique_design_scale) # NOTE: this is exact if full rank!
881-
# init_b = np.matmul(inv_design, b)
882-
#
883-
# Use least-squares solver to calculate a':
884-
# This is faster and more accurate than using matrix inversion.
885-
logger.debug(" ** Solve lstsq problem")
886-
b_prime = np.linalg.lstsq(unique_design_scale, b, rcond=None)
887-
init_b = b_prime[0]
804+
init_mu = np.exp(input_data.design_loc.dot(init_a_xr))
805+
806+
groupwise_scales, init_b, rmsd_b = nb_glm_utils.closedform_nb_glm_logphi(
807+
X=input_data.X,
808+
mu=init_mu,
809+
design_scale=input_data.design_scale,
810+
constraints=input_data.constraints_scale,
811+
size_factors=size_factors_init,
812+
groupwise_means=groupwise_means,
813+
link_fn=lambda r: np.log(np_clip_param(r, "r"))
814+
)
888815

889816
logger.info("Using closed-form MME initialization for dispersion")
890-
logger.debug("RMSE of closed-form dispersion:\n%s", b_prime[1])
817+
logger.debug("RMSE of closed-form dispersion:\n%s", rmsd_b)
891818
logger.info("Should train r: %s", self._train_r)
892819
except np.linalg.LinAlgError:
893820
logger.warning("Closed form initialization failed!")
@@ -903,8 +830,11 @@ def __init__(
903830
my_loc_names = set(input_data.design_loc_names.values)
904831
my_loc_names = my_loc_names.intersection(init_model.input_data.design_loc_names.values)
905832

906-
# Initialize new parameters to zero:
907-
init_loc = np.zeros(shape=(input_data.num_design_loc_params, input_data.num_features))
833+
init_loc = np.random.uniform(
834+
low=np.nextafter(0, 1, dtype=input_data.X.dtype),
835+
high=np.sqrt(np.nextafter(0, 1, dtype=input_data.X.dtype)),
836+
size=(input_data.num_design_loc_params, input_data.num_features)
837+
)
908838
for parm in my_loc_names:
909839
init_idx = np.where(init_model.input_data.design_loc_names == parm)
910840
my_idx = np.where(input_data.design_loc_names == parm)
@@ -917,8 +847,11 @@ def __init__(
917847
my_scale_names = set(input_data.design_scale_names.values)
918848
my_scale_names = my_scale_names.intersection(init_model.input_data.design_scale_names.values)
919849

920-
# Initialize new parameters to zero:
921-
init_scale = np.zeros(shape=(input_data.num_design_scale_params, input_data.num_features))
850+
init_scale = np.random.uniform(
851+
low=np.nextafter(0, 1, dtype=input_data.X.dtype),
852+
high=np.sqrt(np.nextafter(0, 1, dtype=input_data.X.dtype)),
853+
size=(input_data.num_design_scale_params, input_data.num_features)
854+
)
922855
for parm in my_scale_names:
923856
init_idx = np.where(init_model.input_data.design_scale_names == parm)
924857
my_idx = np.where(input_data.design_scale_names == parm)

batchglm/train/tf/nb_glm/external.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010

1111
# from train.tf.nb import EstimatorGraph as NegativeBinomialEstimatorGraph
1212

13+
import batchglm.models.nb_glm.utils as nb_glm_utils
1314
import batchglm.utils.random as rand_utils
1415
from batchglm import pkg_constants

batchglm/unit_test/test_nb_glm.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def test_nonconfounded_fit(self):
100100
sim.generate_sample_description(num_conditions=0, num_batches=4)
101101
sim.generate()
102102

103-
sample_description = data_utils.sample_description_from_xarray(sim.data, dim="observations")
104-
design_loc = data_utils.design_matrix(sample_description, formula="~ 1 - 1 + batch")
105-
design_scale = data_utils.design_matrix(sample_description, formula="~ 1 - 1 + batch")
103+
sample_description = glm.data.sample_description_from_xarray(sim.data, dim="observations")
104+
design_loc = glm.data.design_matrix(sample_description, formula="~ 1 - 1 + batch")
105+
design_scale = glm.data.design_matrix(sample_description, formula="~ 1 - 1 + batch")
106106

107107
input_data = InputData.new(sim.X, design_loc=design_loc, design_scale=design_scale)
108108

@@ -127,7 +127,13 @@ def test_nonconfounded_fit(self):
127127

128128
def test_anndata(self):
129129
adata = self.sim.data_to_anndata()
130-
idata = InputData.new(adata)
130+
design_loc = self.sim.design_loc
131+
design_scale = self.sim.design_scale
132+
idata = InputData.new(
133+
data=adata,
134+
design_loc=design_loc,
135+
design_scale=design_scale,
136+
)
131137

132138
wd = os.path.join(self.working_dir.name, "anndata")
133139
os.makedirs(wd, exist_ok=True)
@@ -141,7 +147,13 @@ def test_anndata(self):
141147
def test_anndata_sparse(self):
142148
adata = self.sim.data_to_anndata()
143149
adata.X = scipy.sparse.csr_matrix(adata.X)
144-
idata = InputData.new(adata)
150+
design_loc = self.sim.design_loc
151+
design_scale = self.sim.design_scale
152+
idata = InputData.new(
153+
data=adata,
154+
design_loc=design_loc,
155+
design_scale=design_scale,
156+
)
145157

146158
wd = os.path.join(self.working_dir.name, "anndata")
147159
os.makedirs(wd, exist_ok=True)

0 commit comments

Comments
 (0)