Skip to content

Commit 36074d7

Browse files
committed
Added more documentation.
1 parent 340dcd7 commit 36074d7

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

neural_structured_learning/research/gam/trainer/adversarial.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88

99
def kl_divergence_with_logit(q_logit, p_logit):
10+
"""Computes KL-divergence between to sets of logits."""
1011
q = tf.nn.softmax(q_logit)
1112
qlogq = -tf.nn.softmax_cross_entropy_with_logits_v2(
1213
labels=q, logits=q_logit)
@@ -28,13 +29,26 @@ def get_normalizing_constant(d):
2829

2930

3031
def get_loss_vat(inputs, predictions, is_train, model, predictions_var_scope):
32+
"""Computes the virtual adversarial loss for the provided inputs.
33+
34+
Args:
35+
inputs: A batch of input features, where the batch is the first
36+
dimension.
37+
predictions: The logits predicted by a model on the provided inputs.
38+
is_train: A boolean placeholder specifying if this is a training or
39+
testing setting.
40+
model: The model that generated the logits.
41+
predictions_var_scope: Variable scope for obtaining the predictions.
42+
Returns:
43+
A float value representing the virtual adversarial loss.
44+
"""
3145
r_vadv = generate_virtual_adversarial_perturbation(
3246
inputs, predictions, model, predictions_var_scope, is_train=is_train)
3347
predictions = tf.stop_gradient(predictions)
3448
logit_p = predictions
3549
new_inputs = tf.add(inputs, r_vadv)
3650
with tf.variable_scope(
37-
predictions_var_scope, auxiliary_name_scope=False, reuse=True):
51+
predictions_var_scope, auxiliary_name_scope=False, reuse=True):
3852
encoding_m, _, _ = model.get_encoding_and_params(
3953
inputs=new_inputs,
4054
is_train=is_train,
@@ -98,5 +112,5 @@ def logsoftmax(x):
98112
def entropy_y_x(logit):
99113
"""Entropy term to add to VATENT."""
100114
p = tf.nn.softmax(logit)
101-
return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
102-
labels=p, logits=logit))
115+
return tf.reduce_mean(
116+
tf.nn.softmax_cross_entropy_with_logits_v2(labels=p, logits=logit))

0 commit comments

Comments
 (0)