2121import logging
2222import os
2323
24+ from gam .trainer .adversarial import entropy_y_x
25+ from gam .trainer .adversarial import get_loss_vat
2426from gam .trainer .trainer_base import batch_iterator
2527from gam .trainer .trainer_base import Trainer
2628
@@ -65,6 +67,11 @@ class TrainerClassification(Trainer):
6567 model loss.
6668 iter_cotrain: A Tensorflow variable containing the current cotrain
6769 iteration.
70+ reg_weight_vat: A float representing the weight of the virtual adversarial
71+ training (VAT) regularization loss in the classification model loss
72+ function.
73+ use_ent_min: A boolean specifying whether to use entropy regularization with
74+ VAT.
6875 enable_summaries: Boolean specifying whether to enable variable summaries.
6976 summary_step: Integer representing the summary step size.
7077 summary_dir: String representing the path to a directory where to save the
@@ -122,6 +129,8 @@ def __init__(self,
122129 reg_weight_uu ,
123130 num_pairs_reg ,
124131 iter_cotrain ,
132+ reg_weight_vat = 0.0 ,
133+ use_ent_min = False ,
125134 enable_summaries = False ,
126135 summary_step = 1 ,
127136 summary_dir = None ,
@@ -170,6 +179,8 @@ def __init__(self,
170179 self .reg_weight_ll = reg_weight_ll
171180 self .reg_weight_lu = reg_weight_lu
172181 self .reg_weight_uu = reg_weight_uu
182+ self .reg_weight_vat = reg_weight_vat
183+ self .use_ent_min = use_ent_min
173184 self .penalize_neg_agr = penalize_neg_agr
174185 self .use_l2_classif = use_l2_classif
175186 self .first_iter_original = first_iter_original
@@ -188,6 +199,8 @@ def __init__(self,
188199 features_shape = [None ] + list (data .features_shape )
189200 input_features = tf .placeholder (
190201 tf .float32 , shape = features_shape , name = 'input_features' )
202+ input_features_unlabeled = tf .placeholder (
203+ tf .float32 , shape = features_shape , name = 'input_features_unlabeled' )
191204 input_labels = tf .placeholder (tf .int64 , shape = (None ,), name = 'input_labels' )
192205 one_hot_labels = tf .one_hot (
193206 input_labels , data .num_classes , name = 'input_labels_one_hot' )
@@ -206,6 +219,18 @@ def __init__(self,
206219 self .variables .update (variables )
207220 self .reg_params .update (reg_params )
208221 normalized_predictions = self .model .normalize_predictions (predictions )
222+ predictions_var_scope = tf .get_variable_scope ()
223+
224+ # Create predictions on unlabeled data, which is only used for VAT loss.
225+ with tf .variable_scope ("predictions" , reuse = True ):
226+ encoding_unlabeled , _ , _ = self .model .get_encoding_and_params (
227+ inputs = input_features_unlabeled ,
228+ is_train = is_train ,
229+ update_batch_stats = False )
230+ predictions_unlabeled , _ , _ = (
231+ self .model .get_predictions_and_params (
232+ encoding = encoding_unlabeled ,
233+ is_train = is_train ))
209234
210235 # Create a variable for weight decay that may be updated.
211236 weight_decay_var , weight_decay_update = self ._create_weight_decay_var (
@@ -240,8 +265,31 @@ def __init__(self,
240265 for var in reg_params .values ():
241266 loss_reg += weight_decay_var * tf .nn .l2_loss (var )
242267
268+ # Adversarial loss, in case we want to add VAT on top of GAM.
269+ loss_vat = get_loss_vat (
270+ input_features_unlabeled , predictions_unlabeled , is_train , model ,
271+ predictions_var_scope )
272+ num_unlabeled = tf .shape (input_features_unlabeled )[0 ]
273+ loss_vat = tf .cond (tf .greater (num_unlabeled , 0 ),
274+ lambda : loss_vat ,
275+ lambda : 0.0 )
276+ if self .use_ent_min :
277+ # Use entropy minimization with VAT (i.e. VATENT).
278+ loss_ent = entropy_y_x (predictions_unlabeled )
279+ loss_vat = loss_vat + tf .cond (tf .greater (num_unlabeled , 0 ),
280+ lambda : loss_ent ,
281+ lambda : 0.0 )
282+ loss_vat = loss_vat * self .reg_weight_vat
283+ if self .first_iter_original :
284+ # Do not add the adversarial loss in the first iteration if
285+ # the first iteration trains the plain baseline model.
286+ weight_loss_vat = tf .cond (tf .greater (iter_cotrain , 0 ),
287+ lambda : 1.0 ,
288+ lambda : 0.0 )
289+ loss_vat = loss_vat * weight_loss_vat
290+
243291 # Total loss.
244- loss_op = loss_supervised + loss_agr + loss_reg
292+ loss_op = loss_supervised + loss_agr + loss_reg + loss_vat
245293
246294 # Create accuracy.
247295 accuracy = tf .equal (tf .argmax (normalized_predictions , 1 ), input_labels )
@@ -310,6 +358,7 @@ def __init__(self,
310358
311359 self .rng = np .random .RandomState (seed )
312360 self .input_features = input_features
361+ self .input_features_unlabeled = input_features_unlabeled
313362 self .input_labels = input_labels
314363 self .predictions = predictions
315364 self .normalized_predictions = normalized_predictions
@@ -507,7 +556,8 @@ def _construct_feed_dict(self,
507556 split ,
508557 pair_ll_iterator = None ,
509558 pair_lu_iterator = None ,
510- pair_uu_iterator = None ):
559+ pair_uu_iterator = None ,
560+ data_iterator_unlabeled = None ):
511561 """Construct feed dictionary."""
512562 try :
513563 input_indices = next (data_iterator )
@@ -521,6 +571,14 @@ def _construct_feed_dict(self,
521571 self .input_labels : labels ,
522572 self .is_train : split == 'train'
523573 }
574+ if data_iterator_unlabeled is not None :
575+ # This is not None only when using VAT regularization.
576+ try :
577+ input_indices = next (data_iterator_unlabeled )
578+ input_features = self .data .get_features (input_indices )
579+ except StopIteration :
580+ input_features = np .zeros ([0 ] + list (self .data .features_shape ))
581+ feed_dict .update ({self .input_features_unlabeled : input_features })
524582 if pair_ll_iterator is not None :
525583 _ , _ , _ , features_tgt , labels_src , labels_tgt = next (pair_ll_iterator )
526584 feed_dict .update ({
@@ -720,6 +778,13 @@ def train(self, data, session=None, **kwargs):
720778 shuffle = True ,
721779 allow_smaller_batch = False ,
722780 repeat = True )
781+ # Create an iterator for unlabeled samples for the VAT loss term.
782+ data_iterator_unlabeled = batch_iterator (
783+ unlabeled_indices ,
784+ batch_size = self .batch_size ,
785+ shuffle = True ,
786+ allow_smaller_batch = False ,
787+ repeat = True )
723788 # Create iterators for ll, lu, uu pairs of samples for the agreement term.
724789 if self .use_graph :
725790 pair_ll_iterator = self .edge_iterator (
@@ -750,7 +815,8 @@ def train(self, data, session=None, **kwargs):
750815 split = 'train' ,
751816 pair_ll_iterator = pair_ll_iterator ,
752817 pair_lu_iterator = pair_lu_iterator ,
753- pair_uu_iterator = pair_uu_iterator )
818+ pair_uu_iterator = pair_uu_iterator ,
819+ data_iterator_unlabeled = data_iterator_unlabeled )
754820 if self .enable_summaries and step % self .summary_step == 0 :
755821 loss_val , summary , iter_cls_total , _ = session .run (
756822 [self .loss_op , self .summary_op , self .iter_cls_total , self .train_op ],
0 commit comments