Skip to content

Commit 340dcd7

Browse files
committed
Added VAT loss.
1 parent 2e8eff8 commit 340dcd7

File tree

4 files changed

+190
-3
lines changed

4 files changed

+190
-3
lines changed

neural_structured_learning/research/gam/experiments/run_train_gam_graph.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,12 @@
220220
'num_pairs_reg', 128,
221221
'Number of pairs of nodes to use in the agreement loss term of the '
222222
'classification model.')
223+
flags.DEFINE_float(
224+
'reg_weight_vat', 0.0,
225+
'Regularization weight for the virtual adversarial training (VAT) loss.')
226+
flags.DEFINE_bool(
227+
'use_ent_min', False,
228+
'A boolean specifying whether to add entropy minimization to VAT.')
223229
flags.DEFINE_string(
224230
'aggregation_agr_inputs', 'dist',
225231
'Operation to apply on the pair of nodes in the agreement model. '
@@ -407,6 +413,8 @@ def main(argv):
407413
reg_weight_lu=FLAGS.reg_weight_lu,
408414
reg_weight_uu=FLAGS.reg_weight_uu,
409415
num_pairs_reg=FLAGS.num_pairs_reg,
416+
reg_weight_vat=FLAGS.reg_weight_vat,
417+
use_ent_min=FLAGS.use_ent_min,
410418
penalize_neg_agr=FLAGS.penalize_neg_agr,
411419
use_l2_cls=FLAGS.use_l2_cls,
412420
first_iter_original=FLAGS.first_iter_original,
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import tensorflow as tf
2+
3+
epsilon = 5
4+
num_power_iterations = 1
5+
xi = 1e-6
6+
scale_r = False
7+
8+
9+
def kl_divergence_with_logit(q_logit, p_logit):
10+
q = tf.nn.softmax(q_logit)
11+
qlogq = -tf.nn.softmax_cross_entropy_with_logits_v2(
12+
labels=q, logits=q_logit)
13+
qlogp = -tf.nn.softmax_cross_entropy_with_logits_v2(
14+
labels=q, logits=p_logit)
15+
return qlogq - qlogp
16+
17+
18+
def get_normalized_vector(d):
19+
d /= (1e-12 + tf.reduce_max(tf.abs(d), keep_dims=True))
20+
d /= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), keep_dims=True))
21+
return d
22+
23+
24+
def get_normalizing_constant(d):
25+
c = 1e-12 + tf.reduce_max(tf.abs(d), keep_dims=True)
26+
c *= tf.sqrt(1e-6 + tf.reduce_sum(tf.pow(d, 2.0), keep_dims=True))
27+
return c
28+
29+
30+
def get_loss_vat(inputs, predictions, is_train, model, predictions_var_scope):
31+
r_vadv = generate_virtual_adversarial_perturbation(
32+
inputs, predictions, model, predictions_var_scope, is_train=is_train)
33+
predictions = tf.stop_gradient(predictions)
34+
logit_p = predictions
35+
new_inputs = tf.add(inputs, r_vadv)
36+
with tf.variable_scope(
37+
predictions_var_scope, auxiliary_name_scope=False, reuse=True):
38+
encoding_m, _, _ = model.get_encoding_and_params(
39+
inputs=new_inputs,
40+
is_train=is_train,
41+
update_batch_stats=False)
42+
logit_m, _, _ = model.get_predictions_and_params(
43+
encoding=encoding_m,
44+
is_train=is_train)
45+
loss = kl_divergence_with_logit(logit_p, logit_m)
46+
return tf.reduce_mean(loss)
47+
48+
49+
def generate_virtual_adversarial_perturbation(
50+
inputs, logits, model, predictions_var_scope, is_train=True):
51+
"""Generates an adversarial perturbation for virtual adversarial training.
52+
53+
Args:
54+
inputs: A batch of input features, where the batch is the first
55+
dimension.
56+
logits: The logits predicted by a model on the provided inputs.
57+
model: The model that generated the logits.
58+
predictions_var_scope: Variable scope for obtaining the predictions.
59+
is_train: A boolean placeholder specifying if this is a training or
60+
testing setting.
61+
62+
Returns:
63+
A Tensor of the same shape as the inputs containing the adversarial
64+
perturbation for these inputs.
65+
"""
66+
d = tf.random_normal(shape=tf.shape(inputs))
67+
68+
for _ in range(num_power_iterations):
69+
d = xi * get_normalized_vector(d)
70+
logit_p = logits
71+
with tf.variable_scope(
72+
predictions_var_scope, auxiliary_name_scope=False, reuse=True):
73+
encoding_m, _, _ = model.get_encoding_and_params(
74+
inputs=d + inputs,
75+
is_train=is_train,
76+
update_batch_stats=False)
77+
logit_m, _, _ = model.get_predictions_and_params(
78+
encoding=encoding_m,
79+
is_train=is_train)
80+
dist = kl_divergence_with_logit(logit_p, logit_m)
81+
grad = tf.gradients(dist, [d], aggregation_method=2)[0]
82+
d = tf.stop_gradient(grad)
83+
84+
r_vadv = get_normalized_vector(d)
85+
if scale_r:
86+
r_vadv *= get_normalizing_constant(inputs)
87+
r_vadv *= epsilon
88+
return r_vadv
89+
90+
91+
def logsoftmax(x):
92+
"""Implementation of softmax when the inputs are logits."""
93+
xdev = x - tf.reduce_max(x, 1, keep_dims=True)
94+
lsm = xdev - tf.log(tf.reduce_sum(tf.exp(xdev), 1, keep_dims=True))
95+
return lsm
96+
97+
98+
def entropy_y_x(logit):
99+
"""Entropy term to add to VATENT."""
100+
p = tf.nn.softmax(logit)
101+
return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
102+
labels=p, logits=logit))

neural_structured_learning/research/gam/trainer/trainer_classification.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import logging
2222
import os
2323

24+
from gam.trainer.adversarial import entropy_y_x
25+
from gam.trainer.adversarial import get_loss_vat
2426
from gam.trainer.trainer_base import batch_iterator
2527
from gam.trainer.trainer_base import Trainer
2628

@@ -65,6 +67,11 @@ class TrainerClassification(Trainer):
6567
model loss.
6668
iter_cotrain: A Tensorflow variable containing the current cotrain
6769
iteration.
70+
reg_weight_vat: A float representing the weight of the virtual adversarial
71+
training (VAT) regularization loss in the classification model loss
72+
function.
73+
use_ent_min: A boolean specifying whether to use entropy regularization with
74+
VAT.
6875
enable_summaries: Boolean specifying whether to enable variable summaries.
6976
summary_step: Integer representing the summary step size.
7077
summary_dir: String representing the path to a directory where to save the
@@ -122,6 +129,8 @@ def __init__(self,
122129
reg_weight_uu,
123130
num_pairs_reg,
124131
iter_cotrain,
132+
reg_weight_vat=0.0,
133+
use_ent_min=False,
125134
enable_summaries=False,
126135
summary_step=1,
127136
summary_dir=None,
@@ -170,6 +179,8 @@ def __init__(self,
170179
self.reg_weight_ll = reg_weight_ll
171180
self.reg_weight_lu = reg_weight_lu
172181
self.reg_weight_uu = reg_weight_uu
182+
self.reg_weight_vat = reg_weight_vat
183+
self.use_ent_min = use_ent_min
173184
self.penalize_neg_agr = penalize_neg_agr
174185
self.use_l2_classif = use_l2_classif
175186
self.first_iter_original = first_iter_original
@@ -188,6 +199,8 @@ def __init__(self,
188199
features_shape = [None] + list(data.features_shape)
189200
input_features = tf.placeholder(
190201
tf.float32, shape=features_shape, name='input_features')
202+
input_features_unlabeled = tf.placeholder(
203+
tf.float32, shape=features_shape, name='input_features_unlabeled')
191204
input_labels = tf.placeholder(tf.int64, shape=(None,), name='input_labels')
192205
one_hot_labels = tf.one_hot(
193206
input_labels, data.num_classes, name='input_labels_one_hot')
@@ -206,6 +219,18 @@ def __init__(self,
206219
self.variables.update(variables)
207220
self.reg_params.update(reg_params)
208221
normalized_predictions = self.model.normalize_predictions(predictions)
222+
predictions_var_scope = tf.get_variable_scope()
223+
224+
# Create predictions on unlabeled data, which is only used for VAT loss.
225+
with tf.variable_scope("predictions", reuse=True):
226+
encoding_unlabeled, _, _ = self.model.get_encoding_and_params(
227+
inputs=input_features_unlabeled,
228+
is_train=is_train,
229+
update_batch_stats=False)
230+
predictions_unlabeled, _, _ = (
231+
self.model.get_predictions_and_params(
232+
encoding=encoding_unlabeled,
233+
is_train=is_train))
209234

210235
# Create a variable for weight decay that may be updated.
211236
weight_decay_var, weight_decay_update = self._create_weight_decay_var(
@@ -240,8 +265,31 @@ def __init__(self,
240265
for var in reg_params.values():
241266
loss_reg += weight_decay_var * tf.nn.l2_loss(var)
242267

268+
# Adversarial loss, in case we want to add VAT on top of GAM.
269+
loss_vat = get_loss_vat(
270+
input_features_unlabeled, predictions_unlabeled, is_train, model,
271+
predictions_var_scope)
272+
num_unlabeled = tf.shape(input_features_unlabeled)[0]
273+
loss_vat = tf.cond(tf.greater(num_unlabeled, 0),
274+
lambda: loss_vat,
275+
lambda: 0.0)
276+
if self.use_ent_min:
277+
# Use entropy minimization with VAT (i.e. VATENT).
278+
loss_ent = entropy_y_x(predictions_unlabeled)
279+
loss_vat = loss_vat + tf.cond(tf.greater(num_unlabeled, 0),
280+
lambda: loss_ent,
281+
lambda: 0.0)
282+
loss_vat = loss_vat * self.reg_weight_vat
283+
if self.first_iter_original:
284+
# Do not add the adversarial loss in the first iteration if
285+
# the first iteration trains the plain baseline model.
286+
weight_loss_vat = tf.cond(tf.greater(iter_cotrain, 0),
287+
lambda: 1.0,
288+
lambda: 0.0)
289+
loss_vat = loss_vat * weight_loss_vat
290+
243291
# Total loss.
244-
loss_op = loss_supervised + loss_agr + loss_reg
292+
loss_op = loss_supervised + loss_agr + loss_reg + loss_vat
245293

246294
# Create accuracy.
247295
accuracy = tf.equal(tf.argmax(normalized_predictions, 1), input_labels)
@@ -310,6 +358,7 @@ def __init__(self,
310358

311359
self.rng = np.random.RandomState(seed)
312360
self.input_features = input_features
361+
self.input_features_unlabeled = input_features_unlabeled
313362
self.input_labels = input_labels
314363
self.predictions = predictions
315364
self.normalized_predictions = normalized_predictions
@@ -507,7 +556,8 @@ def _construct_feed_dict(self,
507556
split,
508557
pair_ll_iterator=None,
509558
pair_lu_iterator=None,
510-
pair_uu_iterator=None):
559+
pair_uu_iterator=None,
560+
data_iterator_unlabeled=None):
511561
"""Construct feed dictionary."""
512562
try:
513563
input_indices = next(data_iterator)
@@ -521,6 +571,14 @@ def _construct_feed_dict(self,
521571
self.input_labels: labels,
522572
self.is_train: split == 'train'
523573
}
574+
if data_iterator_unlabeled is not None:
575+
# This is not None only when using VAT regularization.
576+
try:
577+
input_indices = next(data_iterator_unlabeled)
578+
input_features = self.data.get_features(input_indices)
579+
except StopIteration:
580+
input_features = np.zeros([0] + list(self.data.features_shape))
581+
feed_dict.update({self.input_features_unlabeled: input_features})
524582
if pair_ll_iterator is not None:
525583
_, _, _, features_tgt, labels_src, labels_tgt = next(pair_ll_iterator)
526584
feed_dict.update({
@@ -720,6 +778,13 @@ def train(self, data, session=None, **kwargs):
720778
shuffle=True,
721779
allow_smaller_batch=False,
722780
repeat=True)
781+
# Create an iterator for unlabeled samples for the VAT loss term.
782+
data_iterator_unlabeled = batch_iterator(
783+
unlabeled_indices,
784+
batch_size=self.batch_size,
785+
shuffle=True,
786+
allow_smaller_batch=False,
787+
repeat=True)
723788
# Create iterators for ll, lu, uu pairs of samples for the agreement term.
724789
if self.use_graph:
725790
pair_ll_iterator = self.edge_iterator(
@@ -750,7 +815,8 @@ def train(self, data, session=None, **kwargs):
750815
split='train',
751816
pair_ll_iterator=pair_ll_iterator,
752817
pair_lu_iterator=pair_lu_iterator,
753-
pair_uu_iterator=pair_uu_iterator)
818+
pair_uu_iterator=pair_uu_iterator,
819+
data_iterator_unlabeled=data_iterator_unlabeled)
754820
if self.enable_summaries and step % self.summary_step == 0:
755821
loss_val, summary, iter_cls_total, _ = session.run(
756822
[self.loss_op, self.summary_op, self.iter_cls_total, self.train_op],

neural_structured_learning/research/gam/trainer/trainer_cotrain.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ class TrainerCotraining(Trainer):
159159
num_pairs_reg: An integer representing the number of sample pairs of each
160160
type (LL, LU, UU) to include in each computation of the classification
161161
model loss.
162+
reg_weight_vat: A float representing the weight of the virtual adversarial
163+
training (VAT) regularization loss in the classification model loss
164+
function.
165+
use_ent_min: A boolean specifying whether to use entropy regularization with
166+
VAT.
162167
penalize_neg_agr: Whether to not only encourage agreement between samples
163168
that the agreement model believes should have the same label, but also
164169
penalize agreement when two samples agree when the agreement model
@@ -245,6 +250,8 @@ def __init__(self,
245250
reg_weight_lu=0,
246251
reg_weight_uu=0,
247252
num_pairs_reg=100,
253+
reg_weight_vat=0,
254+
use_ent_min=False,
248255
penalize_neg_agr=False,
249256
use_l2_cls=True,
250257
first_iter_original=True,
@@ -314,6 +321,8 @@ def __init__(self,
314321
self.reg_weight_lu = reg_weight_lu
315322
self.reg_weight_uu = reg_weight_uu
316323
self.num_pairs_reg = num_pairs_reg
324+
self.reg_weight_vat = reg_weight_vat
325+
self.use_ent_min = use_ent_min
317326
self.penalize_neg_agr = penalize_neg_agr
318327
self.use_l2_classif = use_l2_cls
319328
self.first_iter_original = first_iter_original
@@ -506,6 +515,8 @@ def train(self, data, **kwargs):
506515
reg_weight_lu=self.reg_weight_lu,
507516
reg_weight_uu=self.reg_weight_uu,
508517
num_pairs_reg=self.num_pairs_reg,
518+
reg_weight_vat=self.reg_weight_vat,
519+
use_ent_min=self.use_ent_min,
509520
enable_summaries=self.enable_summaries_per_model,
510521
summary_step=self.summary_step_cls,
511522
summary_dir=self.summary_dir,

0 commit comments

Comments
 (0)