@@ -442,7 +442,9 @@ def _eval_train(self, session, feed_dict):
442442 feed_dict: A train feed dictionary.
443443
444444 Returns:
445- The computed train accuracy.
445+ train_acc: The computed train accuracy.
446+ acc_0: Accuracy for class 0.
447+ acc_1: Accuracy for class 1.
446448 """
447449 train_acc , pred , targ = session .run (
448450 (self .accuracy , self .normalized_predictions , self .labels ),
@@ -462,9 +464,7 @@ def _eval_train(self, session, feed_dict):
462464 acc_0 = sum (acc_0 ) / np .float32 (len (acc_0 ))
463465 else :
464466 acc_0 = - 1
465- logging .info ('Train acc: %.2f. Acc class 1: %.2f. Acc class 0: %.2f' ,
466- train_acc , acc_1 , acc_0 )
467- return train_acc
467+ return train_acc , acc_0 , acc_1
468468
469469 def _eval_validation (self , data_iterator_val , num_samples_val , session ):
470470 """Evaluate the current model on validation data.
@@ -681,7 +681,8 @@ def train(self, data, session=None, **kwargs):
681681 # Evaluate the accuracy on the latest train batch. We track this to make
682682 # sure the agreement model is able to fit the training data, but can be
683683 # eliminated if efficiency is an issue.
684- acc_train = self ._eval_train (session , feed_dict )
684+ acc_train , acc_0_train , acc_1_train = self ._eval_train (
685+ session , feed_dict )
685686
686687 if self .enable_summaries :
687688 summary = tf .Summary ()
@@ -700,9 +701,10 @@ def train(self, data, session=None, **kwargs):
700701 summary_writer .flush ()
701702 if step % self .logging_step == 0 or val_acc > best_val_acc :
702703 logging .info (
703- 'Agreement step %6d | Loss: %10.4f | val_acc: %10.4f |'
704- 'random_acc: %10.4f | acc_train: %10.4f' , step , loss_val ,
705- val_acc , acc_random , acc_train )
704+ 'Agreement step %6d | Loss: %10.4f | val_acc: %.4f |'
705+ 'random_acc: %.4f | acc_train: %.4f | acc_train_cls_0: %.4f | '
706+ 'acc_train_cls_1: %.4f' , step , loss_val , val_acc , acc_random ,
707+ acc_train ,acc_0_train , acc_1_train )
706708 if val_acc > best_val_acc :
707709 best_val_acc = val_acc
708710 if self .checkpoint_path :
0 commit comments