Skip to content

Commit 3269b56

Browse files
Merge pull request #79 from theislab/dev
Dev
2 parents ae5c187 + b3a1b37 commit 3269b56

File tree

15 files changed

+577
-426
lines changed

15 files changed

+577
-426
lines changed

batchglm/api/data.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
from batchglm.data import design_matrix
2-
from batchglm.data import design_matrix_from_xarray
3-
from batchglm.data import design_matrix_from_anndata
4-
from batchglm.data import sample_description_from_xarray
5-
from batchglm.data import sample_description_from_anndata
6-
from batchglm.data import load_mtx_to_adata
7-
from batchglm.data import load_mtx_to_xarray
8-
from batchglm.data import load_recursive_mtx
9-
from batchglm.data import xarray_from_data
1+
from batchglm.data import design_matrix, design_matrix_from_xarray, design_matrix_from_anndata
2+
from batchglm.data import sample_description_from_xarray, sample_description_from_anndata
3+
from batchglm.data import load_mtx_to_adata, load_mtx_to_xarray, load_recursive_mtx, xarray_from_data
4+
from batchglm.data import constraint_matrix_from_dict, constraint_matrix_from_string, string_constraints_from_dict
5+
from batchglm.data import view_coef_names, preview_coef_names

batchglm/data.py

Lines changed: 372 additions & 107 deletions
Large diffs are not rendered by default.

batchglm/models/base/estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def _plot_coef_vs_ref(
145145
:param ncols: Number of columns in plot grid if multiple genes are plotted.
146146
:param row_gap: Vertical gap between panel rows relative to panel height.
147147
:param col_gap: Horizontal gap between panel columns relative to panel width.
148+
:param title: Plot title.
148149
:param return_axs: Whether to return axis objects.
149150
:return: Matplotlib axis objects.
150151
"""
@@ -196,8 +197,8 @@ def _plot_coef_vs_ref(
196197
legend=False
197198
)
198199
sns.lineplot(
199-
x=np.array([np.min([np.min(x), np.min(y), np.max([np.max(x), np.max(y)])])]),
200-
y=np.array([np.min([np.min(x), np.min(y), np.max([np.max(x), np.max(y)])])]),
200+
x=np.array([np.min([np.min(x), np.min(y)]), np.max([np.max(x), np.max(y)])]),
201+
y=np.array([np.min([np.min(x), np.min(y)]), np.max([np.max(x), np.max(y)])]),
201202
ax=ax
202203
)
203204

batchglm/pkg_constants.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,20 @@
2222
TF_CONFIG_PROTO.gpu_options.allow_growth = True
2323
TF_CONFIG_PROTO.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
2424

25-
TF_CONFIG_PROTO.inter_op_parallelism_threads = 0 if TF_NUM_THREADS == 0 else 1
25+
TF_CONFIG_PROTO.inter_op_parallelism_threads = TF_NUM_THREADS
2626
TF_CONFIG_PROTO.intra_op_parallelism_threads = TF_NUM_THREADS
2727

2828
if TF_NUM_THREADS == 0:
2929
TF_NUM_THREADS = multiprocessing.cpu_count()
3030

3131
# Trust region hyper parameters:
32-
TRUST_REGION_RADIUS_INIT = 4.
32+
TRUST_REGION_RADIUS_INIT = 100.
3333
TRUST_REGION_ETA0 = 0.
3434
TRUST_REGION_ETA1 = 0.25
3535
TRUST_REGION_ETA2 = 0.25 # Allow expansion if not shrinking.
36-
TRUST_REGION_T1 = 0.1 # Fast collapse to avoid trailing.
37-
TRUST_REGION_T2 = 2. # Very conservative expansion to run updates once valid region is reached.
38-
TRUST_REGION_UPPER_BOUND = 1e4 # Low upper limit so that collapse to valid region does not cause feature to trail.
36+
TRUST_REGION_T1 = 0.01 # Fast collapse to avoid trailing.
37+
TRUST_REGION_T2 = 10.
38+
TRUST_REGION_UPPER_BOUND = 1e5
3939

4040
# Convergence hyper-parameters:
4141
LLTOL_BY_FEATURE = 1e-10

batchglm/train/tf/base_glm_all/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def train(
236236
"""
237237
if train_loc is None:
238238
# check if mu was initialized with MLE
239-
train_mu = self._train_loc
239+
train_loc = self._train_loc
240240
if train_scale is None:
241241
# check if r was initialized with MLE
242242
train_scale = self._train_scale

batchglm/train/tf/glm_nb/estimator.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from .model import ProcessModel
1212
from .training_strategies import TrainingStrategies
1313

14-
logger = logging.getLogger("batchglm")
15-
1614

1715
class Estimator(EstimatorAll, AbstractEstimator, ProcessModel):
1816
"""
@@ -185,7 +183,7 @@ def init_par(
185183
init_a_str = init_a.lower()
186184
# Chose option if auto was chosen
187185
if init_a.lower() == "auto":
188-
init_a = "closed_form"
186+
init_a = "standard"
189187

190188
if init_a.lower() == "closed_form":
191189
groupwise_means, init_a, rmsd_a = closedform_nb_glm_logmu(
@@ -203,8 +201,8 @@ def init_par(
203201
if np.any(input_data.size_factors != 1):
204202
self._train_loc = True
205203

206-
logger.debug("Using closed-form MLE initialization for mean")
207-
logger.debug("Should train mu: %s", self._train_loc)
204+
logging.getLogger("batchglm").debug("Using closed-form MLE initialization for mean")
205+
logging.getLogger("batchglm").debug("Should train mu: %s", self._train_loc)
208206
elif init_a.lower() == "standard":
209207
if isinstance(input_data.X, SparseXArrayDataArray):
210208
overall_means = input_data.X.mean(dim="observations")
@@ -216,14 +214,14 @@ def init_par(
216214
init_a[0, :] = np.log(overall_means)
217215
self._train_loc = True
218216

219-
logger.debug("Using standard initialization for mean")
220-
logger.debug("Should train mu: %s", self._train_loc)
217+
logging.getLogger("batchglm").debug("Using standard initialization for mean")
218+
logging.getLogger("batchglm").debug("Should train mu: %s", self._train_loc)
221219
elif init_a.lower() == "all_zero":
222220
init_a = np.zeros([input_data.num_loc_params, input_data.num_features])
223221
self._train_loc = True
224222

225-
logger.debug("Using all_zero initialization for mean")
226-
logger.debug("Should train mu: %s", self._train_loc)
223+
logging.getLogger("batchglm").debug("Using all_zero initialization for mean")
224+
logging.getLogger("batchglm").debug("Should train mu: %s", self._train_loc)
227225
else:
228226
raise ValueError("init_a string %s not recognized" % init_a)
229227

@@ -243,8 +241,8 @@ def init_par(
243241
init_b = np.zeros([input_data.num_scale_params, input_data.X.shape[1]])
244242
init_b[0, :] = init_b_intercept
245243

246-
logger.debug("Using standard-form MME initialization for dispersion")
247-
logger.debug("Should train r: %s", self._train_scale)
244+
logging.getLogger("batchglm").debug("Using standard-form MME initialization for dispersion")
245+
logging.getLogger("batchglm").debug("Should train r: %s", self._train_scale)
248246
elif init_b.lower() == "closed_form":
249247
dmats_unequal = False
250248
if input_data.design_loc.shape[1] == input_data.design_scale.shape[1]:
@@ -269,13 +267,13 @@ def init_par(
269267
link_fn=lambda r: np.log(self.np_clip_param(r, "r"))
270268
)
271269

272-
logger.debug("Using closed-form MME initialization for dispersion")
273-
logger.debug("Should train r: %s", self._train_scale)
270+
logging.getLogger("batchglm").debug("Using closed-form MME initialization for dispersion")
271+
logging.getLogger("batchglm").debug("Should train r: %s", self._train_scale)
274272
elif init_b.lower() == "all_zero":
275273
init_b = np.zeros([input_data.num_scale_params, input_data.X.shape[1]])
276274

277-
logger.debug("Using standard initialization for dispersion")
278-
logger.debug("Should train r: %s", self._train_scale)
275+
logging.getLogger("batchglm").debug("Using standard initialization for dispersion")
276+
logging.getLogger("batchglm").debug("Should train r: %s", self._train_scale)
279277
else:
280278
raise ValueError("init_b string %s not recognized" % init_b)
281279
else:
@@ -291,7 +289,7 @@ def init_par(
291289
init_loc[my_idx] = init_model.a_var[init_idx]
292290

293291
init_a = init_loc
294-
logger.debug("Using initialization based on input model for mean")
292+
logging.getLogger("batchglm").debug("Using initialization based on input model for mean")
295293

296294
# Scale model:
297295
if isinstance(init_b, str) and (init_b.lower() == "auto" or init_b.lower() == "init_model"):
@@ -305,7 +303,7 @@ def init_par(
305303
init_scale[my_idx] = init_model.b_var[init_idx]
306304

307305
init_b = init_scale
308-
logger.debug("Using initialization based on input model for dispersion")
306+
logging.getLogger("batchglm").debug("Using initialization based on input model for dispersion")
309307

310308
return init_a, init_b
311309

batchglm/train/tf/glm_nb/external.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import batchglm.data as data_utils
22

3-
from batchglm.models.base.input import SparseXArrayDataSet, SparseXArrayDataArray
43
from batchglm.models.glm_nb import AbstractEstimator, EstimatorStoreXArray, InputData, Model
54
from batchglm.models.base_glm.utils import closedform_glm_mean, closedform_glm_scale
65
from batchglm.models.glm_nb.utils import closedform_nb_glm_logmu, closedform_nb_glm_logphi
@@ -17,4 +16,5 @@
1716

1817
import batchglm.utils.random as rand_utils
1918
from batchglm.utils.linalg import groupwise_solve_lm
19+
from batchglm.xarray_sparse.base import SparseXArrayDataSet, SparseXArrayDataArray
2020
from batchglm import pkg_constants

batchglm/train/tf/glm_norm/estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,15 @@ def init_par(
246246
# Calculated variance via E(x)^2 or directly depending on whether `mu` was specified.
247247
if isinstance(input_data.X, SparseXArrayDataArray):
248248
variance = input_data.X.var(input_data.X.dims[0])
249-
variance = np.expand_dims(variance, axis=0)
250249
else:
251-
expect_xsq = input_data.X.mean(input_data.X.dims[0])
250+
expect_xsq = np.square(input_data.X).mean(input_data.X.dims[0])
252251
mean_model = np.matmul(
253252
np.matmul(input_data.design_loc.values, input_data.constraints_loc.values),
254253
init_a
255254
)
256-
expect_x_sq = np.square(mean_model).mean(input_data.X.dims[0])
257-
variance = expect_xsq - expect_x_sq
255+
expect_x_sq = np.mean(np.square(mean_model), axis=0) # for xr compatibility input_data.X.dims[0])
256+
variance = (expect_xsq - expect_x_sq).values
257+
variance = np.expand_dims(variance, axis=0)
258258
init_b = np.log(np.sqrt(variance))
259259

260260
self._train_scale = False

batchglm/train/tf/glm_norm/hessians.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def _weight_hessian_bb(
5252
else:
5353
X_minus_loc = X - loc
5454

55-
const = - tf.multiply(scalar_two,
55+
const = - tf.multiply(
56+
scalar_two,
5657
tf.square(
5758
tf.divide(
5859
X_minus_loc,

batchglm/unit_test/base_glm/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from .test_acc_glm import Test_Accuracy_GLM, _Test_Accuracy_GLM_Estim
2-
from .test_acc_analytic_glm import Test_AccuracyAnalytic_GLM, _Test_AccuracyAnalytic_GLM_Estim
3-
from .test_acc_constrained_vglm import Test_AccuracyConstrained_VGLM, _Test_AccuracyConstrained_VGLM_Estim
42
from .test_acc_sizefactors_glm import Test_AccuracySizeFactors_GLM, _Test_AccuracySizeFactors_GLM_Estim
53
from .test_graph_glm import Test_Graph_GLM, _Test_Graph_GLM_Estim
64
from .test_data_types_glm import Test_DataTypes_GLM, _Test_DataTypes_GLM_Estim

0 commit comments

Comments
 (0)