Skip to content

Commit 58433f7

Browse files
split modelvars, now have to split model graph by gene
1 parent dd17940 commit 58433f7

File tree

6 files changed

+134
-20
lines changed

6 files changed

+134
-20
lines changed

batchglm/pkg_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
ACCURACY_MARGIN_RELATIVE_TO_LIMIT = float(os.environ.get('BATCHGLM_ACCURACY_MARGIN', 2.5))
1010
HESSIAN_MODE = str(os.environ.get('HESSIAN_MODE', "obs_batched"))
1111
JACOBIAN_MODE = str(os.environ.get('JACOBIAN_MODE', "analytic"))
12+
DELTA_THETA_MIN_ABS = float(os.environ.get('BATCHGLM_ACCURACY_THETA', 1e-3))
1213

1314
XARRAY_NETCDF_ENGINE = "h5netcdf"
1415

batchglm/train/tf/base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,27 @@ def train(self, *args,
287287
)
288288

289289
tf.logging.info("Step: %d\tloss: %f", train_step, global_loss)
290+
elif convergence_criteria == "all_converged":
291+
train_step = self.session.run(self.model.global_step, feed_dict=feed_dict)
292+
theta_current = self.session.run(self.model.model_vars.params)
293+
while np.any(self.model.model_vars.converged == False):
294+
theta_prev = theta_current
295+
train_step, global_loss, _ = self.session.run(
296+
(self.model.global_step, loss, train_op),
297+
feed_dict=feed_dict
298+
)
299+
theta_current = self.session.run(self.model.model_vars.params)
300+
theta_delta = np.abs(theta_prev - theta_current)
301+
self.model.model_vars.converged = np.logical_or( # Only update non-converged.
302+
self.model.model_vars.converged,
303+
np.max(theta_delta, axis=0) < pkg_constants.DELTA_THETA_MIN_ABS
304+
)
305+
tf.logging.info(
306+
"Step: %d\tloss: %f\t models converged %i",
307+
train_step,
308+
global_loss,
309+
np.sum(self.model.model_vars.converged).astype("int32")
310+
)
290311
else:
291312
self._train_to_convergence(
292313
loss=loss,

batchglm/train/tf/nb_glm/base.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ def __init__(
212212
self.probs = probs
213213
self.log_probs = log_probs
214214
self.log_likelihood = tf.reduce_sum(self.log_probs, axis=0, name="log_likelihood")
215-
self.norm_log_likelihood = tf.reduce_mean(self.log_probs, axis=0, name="log_likelihood")
215+
#self.norm_log_likelihood = tf.reduce_mean(self.log_probs, axis=0, name="log_likelihood")
216+
self.norm_log_likelihood_bygene = tf.reduce_mean(self.log_probs, axis=0, name="log_likelihood")
217+
self.norm_log_likelihood = tf.reduce_mean(self.norm_log_likelihood_bygene, name="log_likelihood")
216218
self.norm_neg_log_likelihood = - self.norm_log_likelihood
217219

218220
with tf.name_scope("loss"):
@@ -225,6 +227,8 @@ class ModelVars:
225227
a_var: tf.Variable
226228
b_var: tf.Variable
227229
params: tf.Variable
230+
converged: np.ndarray
231+
228232
""" Build tf.Variables to be optimzed and their constraints.
229233
230234
a_var and b_var slices of the tf.Variable params which contains
@@ -309,8 +313,13 @@ def __init__(
309313
axis=0
310314
), name="params")
311315

312-
a_var = params[0:init_a.shape[0]]
313-
b_var = params[init_a.shape[0]:]
316+
params_by_gene = [tf.expand_dims(params[:, i], axis=-1) for i in range(params.shape[1])]
317+
a_by_gene = [x[0:init_a.shape[0],:] for x in params_by_gene]
318+
b_by_gene = [x[init_a.shape[0]:, :] for x in params_by_gene]
319+
a_var = tf.concat(a_by_gene, axis=1)
320+
b_var = tf.concat(b_by_gene, axis=1)
321+
#a_var = params[0:init_a.shape[0]]
322+
#b_var = params[init_a.shape[0]:]
314323

315324
# Define first layer of computation graph on identifiable variables
316325
# to yield dependent set of parameters of model for each location
@@ -334,3 +343,9 @@ def __init__(
334343
self.a_var = a_var
335344
self.b_var = b_var
336345
self.params = params
346+
# Properties to follow gene-wise convergence.
347+
self.params_by_gene = params_by_gene
348+
self.a_by_gene = a_by_gene
349+
self.b_by_gene = b_by_gene
350+
self.converged = np.repeat(a=False, repeats=self.params.shape[1]) # Initialise to non-converged.
351+

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def __init__(
149149
num_design_scale_params,
150150
graph: tf.Graph = None,
151151
batch_size=500,
152-
feature_batch_size=None,
153152
init_a=None,
154153
init_b=None,
155154
constraints_loc=None,
@@ -163,7 +162,6 @@ def __init__(
163162
self.num_design_loc_params = num_design_loc_params
164163
self.num_design_scale_params = num_design_scale_params
165164
self.batch_size = batch_size
166-
self.feature_batch_size = feature_batch_size
167165

168166
# initial graph elements
169167
with self.graph.as_default():
@@ -305,11 +303,25 @@ def __init__(
305303
with tf.name_scope("training"):
306304
global_step = tf.train.get_or_create_global_step()
307305

308-
# set up trainers for different selections of variables to train
309-
# set up multiple optimization algorithms for each trainer
306+
# Set up trainers for different selections of variables to train.
307+
# Set up multiple optimization algorithms for each trainer.
308+
# Note that params is tf.Variable and a, b are tensors as they are
309+
# slices of a variable! Accordingly, the updates are implemented differently.
310310
batch_trainers = train_utils.MultiTrainer(
311-
loss=batch_model.norm_neg_log_likelihood,
312-
variables=[model_vars.params],
311+
#loss=batch_model.norm_neg_log_likelihood, # add only selected features here TODO
312+
#variables=[model_vars.params], # tf.gather(model_vars.params, indices=np.where(model_vars.converged == False)[0], axis=1)],
313+
gradients=[
314+
(
315+
tf.concat([
316+
tf.gradients(batch_model.norm_neg_log_likelihood,
317+
model_vars.params_by_gene[i])[0]
318+
if i in np.where(model_vars.converged == False)[0]
319+
else tf.zeros([model_vars.params.shape[0], 1])
320+
for i in range(model_vars.params.shape[1])
321+
], axis=1),
322+
model_vars.params
323+
),
324+
],
313325
learning_rate=learning_rate,
314326
global_step=global_step,
315327
apply_gradients=lambda grad: tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad),
@@ -354,8 +366,20 @@ def __init__(
354366
# [tf.reduce_sum(tf.abs(grad), axis=0) for (grad, var) in batch_trainers.gradient])
355367

356368
full_data_trainers = train_utils.MultiTrainer(
357-
loss=full_data_model.norm_neg_log_likelihood,
358-
variables=[model_vars.params],
369+
#loss=full_data_model.norm_neg_log_likelihood,
370+
#variables=[tf.gather(model_vars.params, indices=np.where(model_vars.converged == False)[0], axis=1)],
371+
gradients=[
372+
(
373+
tf.concat([
374+
tf.gradients(full_data_model.norm_neg_log_likelihood,
375+
model_vars.params_by_gene[i])[0]
376+
if i in np.where(model_vars.converged == False)[0]
377+
else tf.zeros([model_vars.params.shape[0], 1])
378+
for i in range(model_vars.params.shape[1])
379+
], axis=1),
380+
model_vars.params
381+
),
382+
],
359383
learning_rate=learning_rate,
360384
global_step=global_step,
361385
apply_gradients=lambda grad: tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad),
@@ -863,7 +887,7 @@ def __init__(
863887
init_b = init_scale
864888

865889
# ### prepare fetch_fn:
866-
def fetch_fn(idx_obs, idx_genes=None):
890+
def fetch_fn(idx):
867891
r"""
868892
Documentation of tensorflow coding style in this function:
869893
tf.py_func defines a python function (the getters of the InputData object slots)
@@ -872,13 +896,8 @@ def fetch_fn(idx_obs, idx_genes=None):
872896
as explained below.
873897
"""
874898
# Catch dimension collapse error if idx is only one element long, ie. 0D:
875-
if len(idx_obs.shape) == 0:
876-
idx_obs = tf.expand_dims(idx_obs, axis=0)
877-
if idx_genes is None:
878-
idx_genes = ...
879-
else:
880-
if len(idx_genes.shape) == 0:
881-
idx_genes = tf.expand_dims(idx_genes, axis=0)
899+
if len(idx.shape) == 0:
900+
idx = tf.expand_dims(idx, axis=0)
882901

883902
X_tensor = tf.py_func(
884903
func=input_data.fetch_X,

batchglm/train/tf/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ class MultiTrainer:
235235
def __init__(self,
236236
learning_rate,
237237
loss=None,
238-
variables: List = None,
238+
variables: list = None,
239239
gradients: list = None,
240240
apply_gradients: Union[callable, Dict[tf.Variable, callable]] = None,
241241
global_step=None,
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import List
2+
3+
import os
4+
# import sys
5+
import unittest
6+
import tempfile
7+
import logging
8+
9+
import numpy as np
10+
import scipy.sparse
11+
12+
import batchglm.api as glm
13+
from batchglm.api.models.nb_glm import Simulator, Estimator, InputData
14+
15+
glm.setup_logging(verbosity="INFO", stream="STDOUT")
16+
logging.getLogger("tensorflow").setLevel(logging.INFO)
17+
18+
19+
def estimate(input_data: InputData):
20+
21+
estimator = Estimator(input_data, batch_size=500)
22+
estimator.initialize()
23+
24+
estimator.train(
25+
convergence_criteria="all_converged",
26+
use_batching=False
27+
)
28+
29+
return estimator
30+
31+
32+
class NB_GLM_Test(unittest.TestCase):
33+
sim: Simulator
34+
35+
_estims: List[Estimator]
36+
37+
def setUp(self):
38+
self.sim = Simulator(num_observations=1000, num_features=20)
39+
self.sim.generate()
40+
self._estims = []
41+
42+
def tearDown(self):
43+
for e in self._estims:
44+
e.close_session()
45+
46+
def test_default_fit(self):
47+
sim = self.sim.__copy__()
48+
49+
estimator = estimate(sim.input_data)
50+
self._estims.append(estimator)
51+
52+
# test finalizing
53+
estimator = estimator.finalize()
54+
return estimator, sim
55+
56+
57+
if __name__ == '__main__':
58+
unittest.main()

0 commit comments

Comments
 (0)