Skip to content

Commit cf9aeca

Browse files
committed
use batch gradient for extended summary
1 parent 1d0abc0 commit cf9aeca

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -564,11 +564,12 @@ def __init__(
564564
tf.reduce_sum(batch_model.log_probs, axis=1),
565565
50.)
566566
)
567-
tf.summary.histogram('gradient_a', tf.gradients(batch_loss, model_vars.a))
568-
tf.summary.histogram('gradient_b', tf.gradients(batch_loss, model_vars.b))
569-
tf.summary.histogram("full_gradient", full_gradient)
570-
tf.summary.scalar("full_gradient_median",
571-
tf.contrib.distributions.percentile(full_gradient, 50.))
567+
summary_full_grad = tf.where(tf.is_nan(full_gradient), tf.zeros_like(full_gradient), full_gradient,
568+
name="full_gradient")
569+
# TODO: adjust this if gradient is changed
570+
tf.summary.histogram('batch_gradient', batch_trainers.gradient_by_variable(model_vars.params))
571+
tf.summary.histogram("full_gradient", summary_full_grad)
572+
tf.summary.scalar("full_gradient_median", tf.contrib.distributions.percentile(full_gradient, 50.))
572573
tf.summary.scalar("full_gradient_mean", tf.reduce_mean(full_gradient))
573574

574575
self.saver = tf.train.Saver()

0 commit comments

Comments
 (0)