Skip to content

Commit 0fc5f62

Browse files
committed
fix full_gradient: add reduction along observations axis
1 parent 54bd80c commit 0fc5f62

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def __init__(
353353
name="batch_trainers_b_only"
354354
)
355355

356-
with tf.name_scope("full_gradient"):
356+
with tf.name_scope("batch_gradient"):
357357
batch_gradient = batch_trainers.gradient[0][0]
358358
batch_gradient = tf.reduce_sum(tf.abs(batch_gradient), axis=0)
359359

@@ -396,18 +396,18 @@ def __init__(
396396
name="full_data_trainers_b_only"
397397
)
398398
with tf.name_scope("full_gradient"):
399-
#full_gradient = full_data_trainers.gradient[0][0]
400-
#full_gradient = tf.reduce_sum(tf.abs(full_gradient), axis=0)
401-
full_gradient = full_data_model.neg_jac
399+
# full_gradient = full_data_trainers.gradient[0][0]
400+
# full_gradient = tf.reduce_sum(tf.abs(full_gradient), axis=0)
401+
full_gradient = tf.reduce_sum(full_data_model.neg_jac, axis=0)
402402
# full_gradient = tf.add_n(
403403
# [tf.reduce_sum(tf.abs(grad), axis=0) for (grad, var) in full_data_trainers.gradient])
404404

405405
with tf.name_scope("newton-raphson"):
406406
# tf.gradients(- full_data_model.log_likelihood, [model_vars.a, model_vars.b])
407407
# Full data model:
408408
param_grad_vec = full_data_model.neg_jac
409-
#param_grad_vec = tf.gradients(- full_data_model.log_likelihood, model_vars.params)[0]
410-
#param_grad_vec_t = tf.transpose(param_grad_vec)
409+
# param_grad_vec = tf.gradients(- full_data_model.log_likelihood, model_vars.params)[0]
410+
# param_grad_vec_t = tf.transpose(param_grad_vec)
411411

412412
delta_t = tf.squeeze(tf.matrix_solve_ls(
413413
full_data_model.neg_hessian,
@@ -425,9 +425,9 @@ def __init__(
425425

426426
# Batched data model:
427427
param_grad_vec_batched = batch_jac.neg_jac
428-
#param_grad_vec_batched = tf.gradients(- batch_model.log_likelihood,
428+
# param_grad_vec_batched = tf.gradients(- batch_model.log_likelihood,
429429
# model_vars.params)[0]
430-
#param_grad_vec_batched_t = tf.transpose(param_grad_vec_batched)
430+
# param_grad_vec_batched_t = tf.transpose(param_grad_vec_batched)
431431

432432
delta_batched_t = tf.squeeze(tf.matrix_solve_ls(
433433
batch_hessians.neg_hessian,
@@ -876,7 +876,7 @@ def __init__(
876876
if input_data.size_factors is not None:
877877
X = np.divide(X, size_factors_init)
878878

879-
#Xdiff = X - np.exp(input_data.design_loc @ init_a)
879+
# Xdiff = X - np.exp(input_data.design_loc @ init_a)
880880
# Define xarray version of init so that Xdiff can be evaluated lazy by dask.
881881
init_a_xr = data_utils.xarray_from_data(init_a, dims=("design_loc_params", "features"))
882882
init_a_xr.coords["design_loc_params"] = input_data.design_loc.coords["design_loc_params"]

0 commit comments

Comments
 (0)