@@ -493,7 +493,6 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
493493 # edges at the end of training, so the shapes don't match needs fixing.
494494 left = tf .concat ((labels_ll_left , labels_lu_left , predictions_uu_left ),
495495 axis = 0 )
496- # left = tf.stop_gradient(left)
497496 right = tf .concat (
498497 (predictions_ll_right , predictions_lu_right , predictions_uu_right ),
499498 axis = 0 )
@@ -517,7 +516,6 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
517516 src_indices = indices_uu_left ,
518517 tgt_indices = indices_uu_right )
519518 agreement = tf .concat ((agreement_ll , agreement_lu , agreement_uu ), axis = 0 )
520- # agreement = tf.stop_gradient(agreement)
521519 if self .penalize_neg_agr :
522520 # Since the agreement is predicting scores between [0, 1], anything
523521 # under 0.5 should represent disagreement. Therefore, we want to encourage
@@ -712,17 +710,17 @@ def edge_iterator(self, data, batch_size, labeling):
712710 def _evaluate (self , indices , split , session , summary_writer ):
713711 """Evaluates the samples with the provided indices."""
714712 data_iterator_val = batch_iterator (
715- indices ,
716- batch_size = self .batch_size ,
717- shuffle = False ,
718- allow_smaller_batch = True ,
719- repeat = False )
713+ indices ,
714+ batch_size = self .batch_size ,
715+ shuffle = False ,
716+ allow_smaller_batch = True ,
717+ repeat = False )
720718 feed_dict_val = self ._construct_feed_dict (data_iterator_val , split )
721719 cummulative_acc = 0.0
722720 num_samples = 0
723721 while feed_dict_val is not None :
724722 val_acc , batch_size_actual = session .run (
725- (self .accuracy , self .batch_size_actual ), feed_dict = feed_dict_val )
723+ (self .accuracy , self .batch_size_actual ), feed_dict = feed_dict_val )
726724 cummulative_acc += val_acc * batch_size_actual
727725 num_samples += batch_size_actual
728726 feed_dict_val = self ._construct_feed_dict (data_iterator_val , split )
@@ -732,8 +730,8 @@ def _evaluate(self, indices, split, session, summary_writer):
732730 if self .enable_summaries :
733731 summary = tf .Summary ()
734732 summary .value .add (
735- tag = 'ClassificationModel/' + split + '_acc' ,
736- simple_value = cummulative_acc )
733+ tag = 'ClassificationModel/' + split + '_acc' ,
734+ simple_value = cummulative_acc )
737735 iter_cls_total = session .run (self .iter_cls_total )
738736 summary_writer .add_summary (summary , iter_cls_total )
739737 summary_writer .flush ()
0 commit comments