Skip to content

Commit 1e3032d

Browse files
committed
clean up Jacobian
1 parent c155581 commit 1e3032d

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

batchglm/train/tf/nb_glm/base.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77

88
import numpy as np
99

10-
try:
11-
import anndata
12-
except ImportError:
13-
anndata = None
14-
1510
from .external import AbstractEstimator
1611
from .external import nb_utils
1712
from .external import pkg_constants

batchglm/train/tf/nb_glm/jacobians.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,27 +84,26 @@ def _coef_invariant_b(
8484
Value of mean model by observation and feature.
8585
:param r: tf.tensor observations x features
8686
Value of dispersion model by observation and feature.
87-
:param dtype: dtype
8887
:return const: tf.tensor observations x features
8988
Coefficient invariant terms of hessian of
9089
given observations and features.
9190
"""
92-
scalar_one = tf.constant(1, shape=[1,1], dtype=X.dtype)
91+
scalar_one = tf.constant(1, shape=(), dtype=X.dtype)
9392
# Pre-define sub-graphs that are used multiple times:
94-
r_plus_mu = tf.add(r, mu)
95-
r_plus_x = tf.add(r, X)
93+
r_plus_mu = r + mu
94+
r_plus_x = r + X
9695
# Define graphs for individual terms of constant term of hessian:
9796
const1 = tf.subtract(
9897
tf.math.digamma(x=r_plus_x),
9998
tf.math.digamma(x=r)
10099
)
101-
const2 = tf.negative(tf.divide(r_plus_x, r_plus_mu))
100+
const2 = tf.negative(r_plus_x / r_plus_mu)
102101
const3 = tf.add(
103102
tf.log(r),
104-
tf.subtract(scalar_one, tf.log(r_plus_mu))
103+
scalar_one - tf.log(r_plus_mu)
105104
)
106105
const = tf.add_n([const1, const2, const3]) # [observations, features]
107-
const = tf.multiply(r, const)
106+
const = r * const
108107
return const
109108

110109

@@ -159,9 +158,9 @@ def __init__(
159158
Whether an iterator or a tensor (single yield of an iterator) is given
160159
in
161160
"""
162-
if constraints_loc != None and mode != "tf":
161+
if constraints_loc is not None and mode != "tf":
163162
raise ValueError("closed form hessian does not work if constraints_loc is not None")
164-
if constraints_scale != None and mode != "tf":
163+
if constraints_scale is not None and mode != "tf":
165164
raise ValueError("closed form hessian does not work if constraints_scale is not None")
166165

167166
if mode == "analytic":
@@ -378,15 +377,15 @@ def _red(prev, cur):
378377
p_shape_a = model_vars.a.shape[0]
379378
p_shape_b = model_vars.b.shape[0]
380379

381-
if iterator==True and batch_model is None:
380+
if iterator == True and batch_model is None:
382381
J = op_utils.map_reduce(
383382
last_elem=tf.gather(sample_indices, tf.size(sample_indices) - 1),
384383
data=batched_data,
385384
map_fn=_assemble_bybatch,
386385
reduce_fn=_red,
387386
parallel_iterations=pkg_constants.TF_LOOP_PARALLEL_ITERATIONS
388387
)
389-
elif iterator==False and batch_model is None:
388+
elif iterator == False and batch_model is None:
390389
J = _assemble_bybatch(
391390
idx=sample_indices,
392391
data=batched_data

0 commit comments

Comments
 (0)