Skip to content

Commit 82dd413

Browse files
committed
Added flag for VAT in the run script.
1 parent 36074d7 commit 82dd413

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

neural_structured_learning/research/gam/experiments/run_train_gam.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
FLAGS = flags.FLAGS
4444
flags.DEFINE_string(
45-
'dataset_name', '',
45+
'dataset_name', 'cifar10',
4646
'Dataset name. Supported options are: mnist, cifar10, cifar100, '
4747
'svhn_cropped, fashion_mnist.')
4848
flags.DEFINE_string(
@@ -243,6 +243,12 @@
243243
'num_pairs_reg', 128,
244244
'Number of pairs of nodes to use in the agreement loss term of the '
245245
'classification model.')
246+
flags.DEFINE_float(
247+
'reg_weight_vat', 0.0,
248+
'Regularization weight for the virtual adversarial training (VAT) loss.')
249+
flags.DEFINE_bool(
250+
'use_ent_min', False,
251+
'A boolean specifying whether to add entropy minimization to VAT.')
246252
flags.DEFINE_string(
247253
'aggregation_agr_inputs', 'dist',
248254
'Operation to apply on the pair of nodes in the agreement model. '
@@ -445,6 +451,8 @@ def main(argv):
445451
reg_weight_ll=FLAGS.reg_weight_ll,
446452
reg_weight_lu=FLAGS.reg_weight_lu,
447453
reg_weight_uu=FLAGS.reg_weight_uu,
454+
reg_weight_vat=FLAGS.reg_weight_vat,
455+
use_ent_min=FLAGS.use_ent_min,
448456
num_pairs_reg=FLAGS.num_pairs_reg,
449457
penalize_neg_agr=FLAGS.penalize_neg_agr,
450458
use_l2_cls=FLAGS.use_l2_cls,

0 commit comments

Comments
 (0)