Skip to content

Commit 54bd80c

Browse files
committed
Merge remote-tracking branch 'theislab/dev' into dev
2 parents 7cee8d6 + 0f70f48 commit 54bd80c

File tree

9 files changed

+591
-51
lines changed

9 files changed

+591
-51
lines changed

batchglm/pkg_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
ACCURACY_MARGIN_RELATIVE_TO_LIMIT = float(os.environ.get('BATCHGLM_ACCURACY_MARGIN', 2.5))
1010
HESSIAN_MODE = str(os.environ.get('HESSIAN_MODE', "obs_batched"))
11+
JACOBIAN_MODE = str(os.environ.get('JACOBIAN_MODE', "analytic"))
1112

1213
XARRAY_NETCDF_ENGINE = "h5netcdf"
1314

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
from .base import param_bounds, tf_clip_param, np_clip_param, apply_constraints
2020

2121
from .external import AbstractEstimator, XArrayEstimatorStore, InputData, Model, MonitoredTFEstimator, TFEstimatorGraph
22-
from .external import nb_utils, train_utils, op_utils, rand_utils
22+
from .external import nb_utils, train_utils, op_utils, rand_utils, data_utils
2323
from .external import pkg_constants
2424
from .hessians import Hessians
25+
from .jacobians import Jacobians
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -91,6 +92,20 @@ def map_model(idx, data) -> BasicModelGraph:
9192
constraints_scale=constraints_scale,
9293
model_vars=model_vars,
9394
mode=pkg_constants.HESSIAN_MODE,
95+
iterator=True,
96+
dtype=dtype
97+
)
98+
99+
with tf.name_scope("jacobians"):
100+
jacobians = Jacobians(
101+
batched_data=batched_data,
102+
sample_indices=sample_indices,
103+
batch_model=None,
104+
constraints_loc=constraints_loc,
105+
constraints_scale=constraints_scale,
106+
model_vars=model_vars,
107+
mode=pkg_constants.JACOBIAN_MODE,
108+
iterator=True,
94109
dtype=dtype
95110
)
96111

@@ -121,6 +136,8 @@ def map_model(idx, data) -> BasicModelGraph:
121136
self.norm_neg_log_likelihood = norm_neg_log_likelihood
122137
self.loss = loss
123138

139+
self.jac = jacobians.jac
140+
self.neg_jac = jacobians.neg_jac
124141
self.hessian = hessians.hessian
125142
self.neg_hessian = hessians.neg_hessian
126143

@@ -235,7 +252,20 @@ def __init__(
235252
# use the mean loss to keep a constant learning rate independently of the batch size
236253
batch_loss = batch_model.loss
237254

238-
# Define the hessian on the batched model:
255+
# Define the jacobian on the batched model for newton-rhapson:
256+
batch_jac = Jacobians(
257+
batched_data=batch_data,
258+
sample_indices=batch_sample_index,
259+
batch_model=batch_model,
260+
constraints_loc=constraints_loc,
261+
constraints_scale=constraints_scale,
262+
model_vars=model_vars,
263+
mode="analytic",
264+
iterator=False,
265+
dtype=dtype
266+
)
267+
268+
# Define the hessian on the batched model for newton-rhapson:
239269
batch_hessians = Hessians(
240270
batched_data=batch_data,
241271
singleobs_data=None,
@@ -366,21 +396,23 @@ def __init__(
366396
name="full_data_trainers_b_only"
367397
)
368398
with tf.name_scope("full_gradient"):
369-
full_gradient = full_data_trainers.gradient[0][0]
370-
full_gradient = tf.reduce_sum(tf.abs(full_gradient), axis=0)
399+
#full_gradient = full_data_trainers.gradient[0][0]
400+
#full_gradient = tf.reduce_sum(tf.abs(full_gradient), axis=0)
401+
full_gradient = full_data_model.neg_jac
371402
# full_gradient = tf.add_n(
372403
# [tf.reduce_sum(tf.abs(grad), axis=0) for (grad, var) in full_data_trainers.gradient])
373404

374405
with tf.name_scope("newton-raphson"):
375406
# tf.gradients(- full_data_model.log_likelihood, [model_vars.a, model_vars.b])
376407
# Full data model:
377-
param_grad_vec = tf.gradients(- full_data_model.log_likelihood, model_vars.params)[0]
378-
param_grad_vec_t = tf.transpose(param_grad_vec)
408+
param_grad_vec = full_data_model.neg_jac
409+
#param_grad_vec = tf.gradients(- full_data_model.log_likelihood, model_vars.params)[0]
410+
#param_grad_vec_t = tf.transpose(param_grad_vec)
379411

380412
delta_t = tf.squeeze(tf.matrix_solve_ls(
381413
full_data_model.neg_hessian,
382414
# (full_data_model.hessians + tf.transpose(full_data_model.hessians, perm=[0, 2, 1])) / 2, # don't need this with closed forms
383-
tf.expand_dims(param_grad_vec_t, axis=-1),
415+
tf.expand_dims(param_grad_vec, axis=-1),
384416
fast=False
385417
), axis=-1)
386418
delta = tf.transpose(delta_t)
@@ -392,13 +424,14 @@ def __init__(
392424
)
393425

394426
# Batched data model:
395-
param_grad_vec_batched = tf.gradients(- batch_model.log_likelihood,
396-
model_vars.params)[0]
397-
param_grad_vec_batched_t = tf.transpose(param_grad_vec_batched)
427+
param_grad_vec_batched = batch_jac.neg_jac
428+
#param_grad_vec_batched = tf.gradients(- batch_model.log_likelihood,
429+
# model_vars.params)[0]
430+
#param_grad_vec_batched_t = tf.transpose(param_grad_vec_batched)
398431

399432
delta_batched_t = tf.squeeze(tf.matrix_solve_ls(
400433
batch_hessians.neg_hessian,
401-
tf.expand_dims(param_grad_vec_batched_t, axis=-1),
434+
tf.expand_dims(param_grad_vec_batched, axis=-1),
402435
fast=False
403436
), axis=-1)
404437
delta_batched = tf.transpose(delta_batched_t)
@@ -741,6 +774,9 @@ def __init__(
741774
shape=[input_data.num_observations, input_data.num_features]
742775
)
743776

777+
groupwise_means = None # [groups, features]
778+
overall_means = None # [1, features]
779+
logger.debug(" * Initialize mean model")
744780
if isinstance(init_a, str):
745781
# Chose option if auto was chosen
746782
if init_a.lower() == "auto":
@@ -766,12 +802,12 @@ def __init__(
766802
if size_factors_init is not None:
767803
X = np.divide(X, size_factors_init)
768804

769-
mean = X.groupby("group").mean(dim="observations").values
805+
groupwise_means = X.groupby("group").mean(dim="observations").values
770806
# clipping
771-
mean = np_clip_param(mean, "mu")
807+
groupwise_means = np_clip_param(groupwise_means, "mu")
772808
# mean = np.nextafter(0, 1, out=mean.values, where=mean == 0, dtype=mean.dtype)
773809

774-
a = np.log(mean)
810+
a = np.log(groupwise_means)
775811
if input_data.constraints_loc is not None:
776812
a_constraints = np.zeros([input_data.constraints_loc.shape[0], a.shape[1]])
777813
# Add constraints (sum to zero) to value vector to remove structural unidentifiability.
@@ -781,7 +817,9 @@ def __init__(
781817
# inv_design = np.linalg.inv(unique_design_loc) # NOTE: this is exact if full rank!
782818
# init_a = np.matmul(inv_design, a)
783819
#
784-
# Better option: use least-squares solver to calculate a'
820+
# Use least-squares solver to calculate a':
821+
# This is faster and more accurate than using matrix inversion.
822+
logger.debug(" ** Solve lstsq problem")
785823
a_prime = np.linalg.lstsq(unique_design_loc, a, rcond=None)
786824
init_a = a_prime[0]
787825
# stat_utils.rmsd(np.exp(unique_design_loc @ init_a), mean)
@@ -799,18 +837,19 @@ def __init__(
799837
except np.linalg.LinAlgError:
800838
logger.warning("Closed form initialization failed!")
801839
elif init_a.lower() == "standard":
802-
mean = input_data.X.mean(dim="observations").values # directly calculate the mean
840+
overall_means = input_data.X.mean(dim="observations").values # directly calculate the mean
803841
# clipping
804-
mean = np_clip_param(mean, "mu")
842+
overall_means = np_clip_param(overall_means, "mu")
805843
# mean = np.nextafter(0, 1, out=mean, where=mean == 0, dtype=mean.dtype)
806844

807845
init_a = np.zeros([input_data.num_design_loc_params, input_data.num_features])
808-
init_a[0, :] = np.log(mean)
846+
init_a[0, :] = np.log(overall_means)
809847
self._train_mu = True
810848

811849
logger.info("Using standard initialization for mean")
812850
logger.info("Should train mu: %s", self._train_mu)
813851

852+
logger.debug(" * Initialize dispersion model")
814853
if isinstance(init_b, str):
815854
if init_b.lower() == "auto":
816855
init_b = "closed_form"
@@ -837,14 +876,20 @@ def __init__(
837876
if input_data.size_factors is not None:
838877
X = np.divide(X, size_factors_init)
839878

840-
Xdiff = X - np.exp(input_data.design_loc @ init_a)
879+
#Xdiff = X - np.exp(input_data.design_loc @ init_a)
880+
# Define xarray version of init so that Xdiff can be evaluated lazy by dask.
881+
init_a_xr = data_utils.xarray_from_data(init_a, dims=("design_loc_params", "features"))
882+
init_a_xr.coords["design_loc_params"] = input_data.design_loc.coords["design_loc_params"]
883+
logger.debug(" ** Define Xdiff")
884+
Xdiff = X - np.exp(input_data.design_loc.dot(init_a_xr))
841885
variance = np.square(Xdiff).groupby("group").mean(dim="observations")
842886

843-
group_mean = X.groupby("group").mean(dim="observations")
844-
denominator = np.fmax(variance - group_mean, 0)
887+
if groupwise_means is None:
888+
groupwise_means = X.groupby("group").mean(dim="observations")
889+
denominator = np.fmax(variance - groupwise_means, 0)
845890
denominator = np.nextafter(0, 1, out=denominator.values, where=denominator == 0,
846891
dtype=denominator.dtype)
847-
r = np.asarray(np.square(group_mean) / denominator)
892+
r = np.asarray(np.square(groupwise_means) / denominator)
848893
# clipping
849894
r = np_clip_param(r, "r")
850895
# r = np.nextafter(0, 1, out=r.values, where=r == 0, dtype=r.dtype)
@@ -860,7 +905,9 @@ def __init__(
860905
# inv_design = np.linalg.inv(unique_design_scale) # NOTE: this is exact if full rank!
861906
# init_b = np.matmul(inv_design, b)
862907
#
863-
# Better option: use least-squares solver to calculate b''
908+
# Use least-squares solver to calculate a':
909+
# This is faster and more accurate than using matrix inversion.
910+
logger.debug(" ** Solve lstsq problem")
864911
b_prime = np.linalg.lstsq(unique_design_scale, b, rcond=None)
865912
init_b = b_prime[0]
866913

@@ -981,6 +1028,7 @@ def fetch_fn(idx):
9811028
else:
9821029
init_b = init_b.astype(dtype)
9831030

1031+
logger.debug(" * Start creating model")
9841032
with graph.as_default():
9851033
# create model
9861034
model = EstimatorGraph(
@@ -999,6 +1047,7 @@ def fetch_fn(idx):
9991047
extended_summary=extended_summary,
10001048
dtype=dtype
10011049
)
1050+
logger.debug(" * Finished creating model")
10021051

10031052
MonitoredTFEstimator.__init__(self, model)
10041053

batchglm/train/tf/nb_glm/external.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import batchglm.data as data_utils
2+
13
from batchglm.models.nb_glm.base import AbstractEstimator, XArrayEstimatorStore, InputData, Model
24

35
import batchglm.train.tf.ops as op_utils

batchglm/train/tf/nb_glm/hessians.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _coef_invariant_aa(
9090
given observations and features.
9191
"""
9292
const = tf.negative(tf.multiply(
93-
mu, # [observations x features]
93+
mu, # [observations, features]
9494
tf.divide(
9595
(X / r) + 1,
9696
tf.square((mu / r) + 1)
@@ -169,7 +169,8 @@ def _coef_invariant_bb(
169169
class Hessians:
170170
""" Compute the nb_glm model hessian.
171171
"""
172-
H: tf.Tensor
172+
hessian: tf.Tensor
173+
neg_hessian: tf.Tensor
173174

174175
def __init__(
175176
self,
@@ -290,10 +291,10 @@ def byobs(
290291
):
291292
"""
292293
Compute the closed-form of the nb_glm model hessian
293-
by evalutating its terms grouped by observations.
294+
by evaluating its terms grouped by observations.
294295
295-
Has three subfunctions which built the specific blocks of the hessian
296-
and one subfunction which concatenates the blocks into a full hessian.
296+
Has three sub-functions which built the specific blocks of the hessian
297+
and one sub-function which concatenates the blocks into a full hessian.
297298
298299
Note that two different groups of functions compute the hessian
299300
block either with standard matrix multiplication for a single observation
@@ -303,7 +304,7 @@ def byobs(
303304
use the einsum to compute the hessian block on a batch of observations
304305
in a single go. This requires the handling of a latent 4D tensor which
305306
potentially large memory usage, depending on the einsum behaviour. In
306-
principle the latter can be fast though as they relace iterations which
307+
principle the latter can be fast though as they replace iterations which
307308
larger tensor operations.
308309
"""
309310

@@ -355,7 +356,7 @@ def _bb_byobs(X, design_loc, design_scale, mu, r):
355356
def _ab_byobs(X, design_loc, design_scale, mu, r):
356357
"""
357358
Compute the mean-dispersion model off-diagonal block of the
358-
closed form hessian of nb_glm model by observastion across features.
359+
closed form hessian of nb_glm model by observation across features.
359360
360361
Note that there are two blocks of the same size which can
361362
be compute from each other with a transpose operation as
@@ -510,7 +511,7 @@ def _red(prev, cur):
510511
511512
Every evaluation of the hessian on an observation yields a full
512513
hessian matrix. This function sums over consecutive evaluations
513-
of this hessian so that not all seperate evluations have to be
514+
of this hessian so that not all separate evaluations have to be
514515
stored.
515516
"""
516517
return tf.add(prev, cur)
@@ -545,11 +546,11 @@ def byfeature(
545546
):
546547
"""
547548
Compute the closed-form of the nb_glm model hessian
548-
by evalutating its terms grouped by features.
549+
by evaluating its terms grouped by features.
549550
550551
551-
Has three subfunctions which built the specific blocks of the hessian
552-
and one subfunction which concatenates the blocks into a full hessian.
552+
Has three sub-functions which built the specific blocks of the hessian
553+
and one sub-function which concatenates the blocks into a full hessian.
553554
"""
554555

555556
def _aa_byfeature(X, design_loc, design_scale, mu, r):

0 commit comments

Comments
 (0)