77
88
99def kl_divergence_with_logit (q_logit , p_logit ):
10+ """Computes KL-divergence between to sets of logits."""
1011 q = tf .nn .softmax (q_logit )
1112 qlogq = - tf .nn .softmax_cross_entropy_with_logits_v2 (
1213 labels = q , logits = q_logit )
@@ -28,13 +29,26 @@ def get_normalizing_constant(d):
2829
2930
3031def get_loss_vat (inputs , predictions , is_train , model , predictions_var_scope ):
32+ """Computes the virtual adversarial loss for the provided inputs.
33+
34+ Args:
35+ inputs: A batch of input features, where the batch is the first
36+ dimension.
37+ predictions: The logits predicted by a model on the provided inputs.
38+ is_train: A boolean placeholder specifying if this is a training or
39+ testing setting.
40+ model: The model that generated the logits.
41+ predictions_var_scope: Variable scope for obtaining the predictions.
42+ Returns:
43+ A float value representing the virtual adversarial loss.
44+ """
3145 r_vadv = generate_virtual_adversarial_perturbation (
3246 inputs , predictions , model , predictions_var_scope , is_train = is_train )
3347 predictions = tf .stop_gradient (predictions )
3448 logit_p = predictions
3549 new_inputs = tf .add (inputs , r_vadv )
3650 with tf .variable_scope (
37- predictions_var_scope , auxiliary_name_scope = False , reuse = True ):
51+ predictions_var_scope , auxiliary_name_scope = False , reuse = True ):
3852 encoding_m , _ , _ = model .get_encoding_and_params (
3953 inputs = new_inputs ,
4054 is_train = is_train ,
@@ -98,5 +112,5 @@ def logsoftmax(x):
98112def entropy_y_x (logit ):
99113 """Entropy term to add to VATENT."""
100114 p = tf .nn .softmax (logit )
101- return tf .reduce_mean (tf . nn . softmax_cross_entropy_with_logits_v2 (
102- labels = p , logits = logit ))
115+ return tf .reduce_mean (
116+ tf . nn . softmax_cross_entropy_with_logits_v2 ( labels = p , logits = logit ))
0 commit comments