77
88
99def kl_divergence_with_logit (q_logit , p_logit ):
10- """Computes KL-divergence between to sets of logits."""
11- q = tf .nn .softmax (q_logit )
12- qlogq = - tf .nn .softmax_cross_entropy_with_logits_v2 (
13- labels = q , logits = q_logit )
14- qlogp = - tf .nn .softmax_cross_entropy_with_logits_v2 (
15- labels = q , logits = p_logit )
16- return qlogq - qlogp
10+ """Computes KL-divergence between to sets of logits."""
11+ q = tf .nn .softmax (q_logit )
12+ qlogq = - tf .nn .softmax_cross_entropy_with_logits_v2 (
13+ labels = q , logits = q_logit )
14+ qlogp = - tf .nn .softmax_cross_entropy_with_logits_v2 (
15+ labels = q , logits = p_logit )
16+ return qlogq - qlogp
1717
1818
1919def get_normalized_vector (d ):
20- d /= (1e-12 + tf .reduce_max (tf .abs (d ), keep_dims = True ))
21- d /= tf .sqrt (1e-6 + tf .reduce_sum (tf .pow (d , 2.0 ), keep_dims = True ))
22- return d
20+ """Normalizes the providede input vector."""
21+ d /= (1e-12 + tf .reduce_max (tf .abs (d ), keep_dims = True ))
22+ d /= tf .sqrt (1e-6 + tf .reduce_sum (tf .pow (d , 2.0 ), keep_dims = True ))
23+ return d
2324
2425
2526def get_normalizing_constant (d ):
26- c = 1e-12 + tf .reduce_max (tf .abs (d ), keep_dims = True )
27- c *= tf .sqrt (1e-6 + tf .reduce_sum (tf .pow (d , 2.0 ), keep_dims = True ))
28- return c
27+ """Returns the normalizing constant to scale the VAT perturbation vector."""
28+ c = 1e-12 + tf .reduce_max (tf .abs (d ), keep_dims = True )
29+ c *= tf .sqrt (1e-6 + tf .reduce_sum (tf .pow (d , 2.0 ), keep_dims = True ))
30+ return c
2931
3032
3133def get_loss_vat (inputs , predictions , is_train , model , predictions_var_scope ):
32- """Computes the virtual adversarial loss for the provided inputs.
33-
34- Args:
35- inputs: A batch of input features, where the batch is the first
36- dimension.
37- predictions: The logits predicted by a model on the provided inputs.
38- is_train: A boolean placeholder specifying if this is a training or
39- testing setting.
40- model: The model that generated the logits.
41- predictions_var_scope: Variable scope for obtaining the predictions.
42- Returns:
43- A float value representing the virtual adversarial loss.
44- """
45- r_vadv = generate_virtual_adversarial_perturbation (
46- inputs , predictions , model , predictions_var_scope , is_train = is_train )
47- predictions = tf .stop_gradient (predictions )
48- logit_p = predictions
49- new_inputs = tf .add (inputs , r_vadv )
50- with tf .variable_scope (
51- predictions_var_scope , auxiliary_name_scope = False , reuse = True ):
52- encoding_m , _ , _ = model .get_encoding_and_params (
53- inputs = new_inputs ,
54- is_train = is_train ,
55- update_batch_stats = False )
56- logit_m , _ , _ = model .get_predictions_and_params (
57- encoding = encoding_m ,
58- is_train = is_train )
59- loss = kl_divergence_with_logit (logit_p , logit_m )
60- return tf .reduce_mean (loss )
34+ """Computes the virtual adversarial loss for the provided inputs.
35+
36+ Args:
37+ inputs: A batch of input features, where the batch is the first
38+ dimension.
39+ predictions: The logits predicted by a model on the provided inputs.
40+ is_train: A boolean placeholder specifying if this is a training or
41+ testing setting.
42+ model: The model that generated the logits.
43+ predictions_var_scope: Variable scope for obtaining the predictions.
44+ Returns:
45+ A float value representing the virtual adversarial loss.
46+ """
47+ r_vadv = generate_virtual_adversarial_perturbation (
48+ inputs , predictions , model , predictions_var_scope , is_train = is_train )
49+ predictions = tf .stop_gradient (predictions )
50+ logit_p = predictions
51+ new_inputs = tf .add (inputs , r_vadv )
52+ with tf .variable_scope (
53+ predictions_var_scope , auxiliary_name_scope = False , reuse = True ):
54+ encoding_m , _ , _ = model .get_encoding_and_params (
55+ inputs = new_inputs ,
56+ is_train = is_train ,
57+ update_batch_stats = False )
58+ logit_m , _ , _ = model .get_predictions_and_params (
59+ encoding = encoding_m ,
60+ is_train = is_train )
61+ loss = kl_divergence_with_logit (logit_p , logit_m )
62+ return tf .reduce_mean (loss )
6163
6264
6365def generate_virtual_adversarial_perturbation (
6466 inputs , logits , model , predictions_var_scope , is_train = True ):
65- """Generates an adversarial perturbation for virtual adversarial training.
66-
67- Args:
68- inputs: A batch of input features, where the batch is the first
69- dimension.
70- logits: The logits predicted by a model on the provided inputs.
71- model: The model that generated the logits.
72- predictions_var_scope: Variable scope for obtaining the predictions.
73- is_train: A boolean placeholder specifying if this is a training or
74- testing setting.
75-
76- Returns:
77- A Tensor of the same shape as the inputs containing the adversarial
78- perturbation for these inputs.
79- """
80- d = tf .random_normal (shape = tf .shape (inputs ))
81-
82- for _ in range (num_power_iterations ):
83- d = xi * get_normalized_vector (d )
84- logit_p = logits
85- with tf .variable_scope (
86- predictions_var_scope , auxiliary_name_scope = False , reuse = True ):
87- encoding_m , _ , _ = model .get_encoding_and_params (
88- inputs = d + inputs ,
89- is_train = is_train ,
90- update_batch_stats = False )
91- logit_m , _ , _ = model .get_predictions_and_params (
92- encoding = encoding_m ,
93- is_train = is_train )
94- dist = kl_divergence_with_logit (logit_p , logit_m )
95- grad = tf .gradients (dist , [d ], aggregation_method = 2 )[0 ]
96- d = tf .stop_gradient (grad )
97-
98- r_vadv = get_normalized_vector (d )
99- if scale_r :
100- r_vadv *= get_normalizing_constant (inputs )
101- r_vadv *= epsilon
102- return r_vadv
103-
104-
105- def logsoftmax ( x ):
106- """Implementation of softmax when the inputs are logits."""
107- xdev = x - tf . reduce_max ( x , 1 , keep_dims = True )
108- lsm = xdev - tf . log ( tf . reduce_sum ( tf . exp ( xdev ), 1 , keep_dims = True ))
109- return lsm
110-
111-
112- def entropy_y_x ( logit ):
113- """Entropy term to add to VATENT. """
114- p = tf .nn .softmax (logit )
115- return tf .reduce_mean (
116- tf .nn .softmax_cross_entropy_with_logits_v2 (labels = p , logits = logit ))
67+ """Generates an adversarial perturbation for virtual adversarial training.
68+
69+ Args:
70+ inputs: A batch of input features, where the batch is the first
71+ dimension.
72+ logits: The logits predicted by a model on the provided inputs.
73+ model: The model that generated the logits.
74+ predictions_var_scope: Variable scope for obtaining the predictions.
75+ is_train: A boolean placeholder specifying if this is a training or
76+ testing setting.
77+
78+ Returns:
79+ A Tensor of the same shape as the inputs containing the adversarial
80+ perturbation for these inputs.
81+ """
82+ d = tf .random_normal (shape = tf .shape (inputs ))
83+
84+ for _ in range (num_power_iterations ):
85+ d = xi * get_normalized_vector (d )
86+ logit_p = logits
87+ with tf .variable_scope (
88+ predictions_var_scope , auxiliary_name_scope = False , reuse = True ):
89+ encoding_m , _ , _ = model .get_encoding_and_params (
90+ inputs = d + inputs ,
91+ is_train = is_train ,
92+ update_batch_stats = False )
93+ logit_m , _ , _ = model .get_predictions_and_params (
94+ encoding = encoding_m ,
95+ is_train = is_train )
96+ dist = kl_divergence_with_logit (logit_p , logit_m )
97+ grad = tf .gradients (dist , [d ], aggregation_method = 2 )[0 ]
98+ d = tf .stop_gradient (grad )
99+
100+ r_vadv = get_normalized_vector (d )
101+ if scale_r :
102+ r_vadv *= get_normalizing_constant (inputs )
103+ r_vadv *= epsilon
104+ return r_vadv
105+
106+
107+ def entropy_y_x ( logits ):
108+ """Entropy term to add to VAT with entropy minimization.
109+
110+ Args:
111+ logits: A Tensor containing the predicted logits for a batch of samples.
112+
113+ Returns:
114+ The entropy minimization loss.
115+ """
116+ p = tf .nn .softmax (logits )
117+ return tf .reduce_mean (
118+ tf .nn .softmax_cross_entropy_with_logits_v2 (labels = p , logits = logits ))
0 commit comments