Skip to content

Commit 0f70f48

Browse files
made Xdiff in b init of nb_glm lazy evaluation through xarray
1 parent f64aad3 commit 0f70f48

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .base import param_bounds, tf_clip_param, np_clip_param, apply_constraints
2020

2121
from .external import AbstractEstimator, XArrayEstimatorStore, InputData, Model, MonitoredTFEstimator, TFEstimatorGraph
22-
from .external import nb_utils, train_utils, op_utils, rand_utils
22+
from .external import nb_utils, train_utils, op_utils, rand_utils, data_utils
2323
from .external import pkg_constants
2424
from .hessians import Hessians
2525
from .jacobians import Jacobians
@@ -772,6 +772,9 @@ def __init__(
772772
shape=[input_data.num_observations, input_data.num_features]
773773
)
774774

775+
groupwise_means = None # [groups, features]
776+
overall_means = None # [1, features]
777+
logger.debug(" * Initialize mean model")
775778
if isinstance(init_a, str):
776779
# Chose option if auto was chosen
777780
if init_a.lower() == "auto":
@@ -797,12 +800,12 @@ def __init__(
797800
if size_factors_init is not None:
798801
X = np.divide(X, size_factors_init)
799802

800-
mean = X.groupby("group").mean(dim="observations").values
803+
groupwise_means = X.groupby("group").mean(dim="observations").values
801804
# clipping
802-
mean = np_clip_param(mean, "mu")
805+
groupwise_means = np_clip_param(groupwise_means, "mu")
803806
# mean = np.nextafter(0, 1, out=mean.values, where=mean == 0, dtype=mean.dtype)
804807

805-
a = np.log(mean)
808+
a = np.log(groupwise_means)
806809
if input_data.constraints_loc is not None:
807810
a_constraints = np.zeros([input_data.constraints_loc.shape[0], a.shape[1]])
808811
# Add constraints (sum to zero) to value vector to remove structural unidentifiability.
@@ -812,7 +815,9 @@ def __init__(
812815
# inv_design = np.linalg.inv(unique_design_loc) # NOTE: this is exact if full rank!
813816
# init_a = np.matmul(inv_design, a)
814817
#
815-
# Better option: use least-squares solver to calculate a'
818+
# Use least-squares solver to calculate a':
819+
# This is faster and more accurate than using matrix inversion.
820+
logger.debug(" ** Solve lstsq problem")
816821
a_prime = np.linalg.lstsq(unique_design_loc, a, rcond=None)
817822
init_a = a_prime[0]
818823
# stat_utils.rmsd(np.exp(unique_design_loc @ init_a), mean)
@@ -830,18 +835,19 @@ def __init__(
830835
except np.linalg.LinAlgError:
831836
logger.warning("Closed form initialization failed!")
832837
elif init_a.lower() == "standard":
833-
mean = input_data.X.mean(dim="observations").values # directly calculate the mean
838+
overall_means = input_data.X.mean(dim="observations").values # directly calculate the mean
834839
# clipping
835-
mean = np_clip_param(mean, "mu")
840+
overall_means = np_clip_param(overall_means, "mu")
836841
# mean = np.nextafter(0, 1, out=mean, where=mean == 0, dtype=mean.dtype)
837842

838843
init_a = np.zeros([input_data.num_design_loc_params, input_data.num_features])
839-
init_a[0, :] = np.log(mean)
844+
init_a[0, :] = np.log(overall_means)
840845
self._train_mu = True
841846

842847
logger.info("Using standard initialization for mean")
843848
logger.info("Should train mu: %s", self._train_mu)
844849

850+
logger.debug(" * Initialize dispersion model")
845851
if isinstance(init_b, str):
846852
if init_b.lower() == "auto":
847853
init_b = "closed_form"
@@ -868,14 +874,20 @@ def __init__(
868874
if input_data.size_factors is not None:
869875
X = np.divide(X, size_factors_init)
870876

871-
Xdiff = X - np.exp(input_data.design_loc @ init_a)
877+
#Xdiff = X - np.exp(input_data.design_loc @ init_a)
878+
# Define xarray version of init so that Xdiff can be evaluated lazy by dask.
879+
init_a_xr = data_utils.xarray_from_data(init_a, dims=("design_loc_params", "features"))
880+
init_a_xr.coords["design_loc_params"] = input_data.design_loc.coords["design_loc_params"]
881+
logger.debug(" ** Define Xdiff")
882+
Xdiff = X - np.exp(input_data.design_loc.dot(init_a_xr))
872883
variance = np.square(Xdiff).groupby("group").mean(dim="observations")
873884

874-
group_mean = X.groupby("group").mean(dim="observations")
875-
denominator = np.fmax(variance - group_mean, 0)
885+
if groupwise_means is None:
886+
groupwise_means = X.groupby("group").mean(dim="observations")
887+
denominator = np.fmax(variance - groupwise_means, 0)
876888
denominator = np.nextafter(0, 1, out=denominator.values, where=denominator == 0,
877889
dtype=denominator.dtype)
878-
r = np.asarray(np.square(group_mean) / denominator)
890+
r = np.asarray(np.square(groupwise_means) / denominator)
879891
# clipping
880892
r = np_clip_param(r, "r")
881893
# r = np.nextafter(0, 1, out=r.values, where=r == 0, dtype=r.dtype)
@@ -891,7 +903,9 @@ def __init__(
891903
# inv_design = np.linalg.inv(unique_design_scale) # NOTE: this is exact if full rank!
892904
# init_b = np.matmul(inv_design, b)
893905
#
894-
# Better option: use least-squares solver to calculate b''
906+
# Use least-squares solver to calculate a':
907+
# This is faster and more accurate than using matrix inversion.
908+
logger.debug(" ** Solve lstsq problem")
895909
b_prime = np.linalg.lstsq(unique_design_scale, b, rcond=None)
896910
init_b = b_prime[0]
897911

@@ -1012,6 +1026,7 @@ def fetch_fn(idx):
10121026
else:
10131027
init_b = init_b.astype(dtype)
10141028

1029+
logger.debug(" * Start creating model")
10151030
with graph.as_default():
10161031
# create model
10171032
model = EstimatorGraph(
@@ -1030,6 +1045,7 @@ def fetch_fn(idx):
10301045
extended_summary=extended_summary,
10311046
dtype=dtype
10321047
)
1048+
logger.debug(" * Finished creating model")
10331049

10341050
MonitoredTFEstimator.__init__(self, model)
10351051

batchglm/train/tf/nb_glm/external.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import batchglm.data as data_utils
2+
13
from batchglm.models.nb_glm.base import AbstractEstimator, XArrayEstimatorStore, InputData, Model
24

35
import batchglm.train.tf.ops as op_utils

0 commit comments

Comments
 (0)