Skip to content

Commit 6544d76

Browse files
committed
Reduced amount of logging.
1 parent 3e73cb8 commit 6544d76

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

neural_structured_learning/research/gam/trainer/trainer_agreement.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,9 @@ def _eval_train(self, session, feed_dict):
442442
feed_dict: A train feed dictionary.
443443
444444
Returns:
445-
The computed train accuracy.
445+
train_acc: The computed train accuracy.
446+
acc_0: Accuracy for class 0.
447+
acc_1: Accuracy for class 1.
446448
"""
447449
train_acc, pred, targ = session.run(
448450
(self.accuracy, self.normalized_predictions, self.labels),
@@ -462,9 +464,7 @@ def _eval_train(self, session, feed_dict):
462464
acc_0 = sum(acc_0) / np.float32(len(acc_0))
463465
else:
464466
acc_0 = -1
465-
logging.info('Train acc: %.2f. Acc class 1: %.2f. Acc class 0: %.2f',
466-
train_acc, acc_1, acc_0)
467-
return train_acc
467+
return train_acc, acc_0, acc_1
468468

469469
def _eval_validation(self, data_iterator_val, num_samples_val, session):
470470
"""Evaluate the current model on validation data.
@@ -681,7 +681,8 @@ def train(self, data, session=None, **kwargs):
681681
# Evaluate the accuracy on the latest train batch. We track this to make
682682
# sure the agreement model is able to fit the training data, but can be
683683
# eliminated if efficiency is an issue.
684-
acc_train = self._eval_train(session, feed_dict)
684+
acc_train, acc_0_train, acc_1_train = self._eval_train(
685+
session, feed_dict)
685686

686687
if self.enable_summaries:
687688
summary = tf.Summary()
@@ -700,9 +701,10 @@ def train(self, data, session=None, **kwargs):
700701
summary_writer.flush()
701702
if step % self.logging_step == 0 or val_acc > best_val_acc:
702703
logging.info(
703-
'Agreement step %6d | Loss: %10.4f | val_acc: %10.4f |'
704-
'random_acc: %10.4f | acc_train: %10.4f', step, loss_val,
705-
val_acc, acc_random, acc_train)
704+
'Agreement step %6d | Loss: %10.4f | val_acc: %.4f |'
705+
'random_acc: %.4f | acc_train: %.4f | acc_train_cls_0: %.4f | '
706+
'acc_train_cls_1: %.4f', step, loss_val, val_acc, acc_random,
707+
acc_train,acc_0_train, acc_1_train)
706708
if val_acc > best_val_acc:
707709
best_val_acc = val_acc
708710
if self.checkpoint_path:

0 commit comments

Comments
 (0)