@@ -209,15 +209,18 @@ def __init__(self,
209209
210210 # Create variables and predictions.
211211 with tf .variable_scope ('predictions' ):
212- encoding , variables , reg_params = self .model .get_encoding_and_params (
213- inputs = input_features , is_train = is_train )
214- self .variables = variables
215- self .reg_params = reg_params
216- predictions , variables , reg_params = (
212+ encoding , variables_enc , reg_params_enc = (
213+ self .model .get_encoding_and_params (
214+ inputs = input_features ,
215+ is_train = is_train ))
216+ self .variables = variables_enc
217+ self .reg_params = reg_params_enc
218+ predictions , variables_pred , reg_params_pred = (
217219 self .model .get_predictions_and_params (
218- encoding = encoding , is_train = is_train ))
219- self .variables .update (variables )
220- self .reg_params .update (reg_params )
220+ encoding = encoding ,
221+ is_train = is_train ))
222+ self .variables .update (variables_pred )
223+ self .reg_params .update (reg_params_pred )
221224 normalized_predictions = self .model .normalize_predictions (predictions )
222225 predictions_var_scope = tf .get_variable_scope ()
223226
@@ -262,7 +265,7 @@ def __init__(self,
262265 # Weight decay loss.
263266 loss_reg = 0.0
264267 if weight_decay_var is not None :
265- for var in reg_params .values ():
268+ for var in self . reg_params .values ():
266269 loss_reg += weight_decay_var * tf .nn .l2_loss (var )
267270
268271 # Adversarial loss, in case we want to add VAT on top of GAM.
@@ -351,7 +354,7 @@ def __init__(self,
351354 if isinstance (weight_decay_var , tf .Variable ):
352355 self .vars_to_save .append (weight_decay_var )
353356 if self .warm_start :
354- self .vars_to_save .extend ([v for v in variables ])
357+ self .vars_to_save .extend ([v for v in self . variables ])
355358
356359 # More variables to be initialized after the session is created.
357360 self .is_initialized = False
@@ -366,7 +369,6 @@ def __init__(self,
366369 self .weight_decay_update = weight_decay_update
367370 self .iter_cls_total = iter_cls_total
368371 self .iter_cls_total_update = iter_cls_total_update
369- self .variables = variables
370372 self .accuracy = accuracy
371373 self .train_op = train_op
372374 self .loss_op = loss_op
0 commit comments