-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathadv_ft_sc.py
More file actions
58 lines (46 loc) · 1.83 KB
/
adv_ft_sc.py
File metadata and controls
58 lines (46 loc) · 1.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import argparse
import logging
from configs import get_config
from data import get_dataset
from poisoners import get_poisoner
from trainers import get_trainer
from utils import set_logging, set_seed
from victims import get_victim
# Set Config, Logger and Seed
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, default="./configs/finetune.yaml")
args = parser.parse_args()
config = get_config(args.config_path)
set_seed(config.seed)
# Get adversarial poisoner
adv_poisoner = get_poisoner(config.adv_poisoner)
# Get downstream poisoner
downstream_poisoner = get_poisoner(config.downstream_poisoner)
# downstream tuning
for i, task in enumerate(config.dataset.downstream):
set_logging(config.save_dir + "/" + task)
config.show_config()
logging.info("\n> Adversarial Fine-tuning {} task! <\n".format(task))
# Get downstream dataset
downstream_dataset = get_dataset(task)
# import random
# for key in downstream_dataset.keys():
# downstream_dataset[key] = random.choices(downstream_dataset[key], k=199)
# Prepare downstream model config
config.victim.num_labels = config.dataset.num_labels[i]
backdoored_ds_model = get_victim(config.victim)
# Get clean tuning trainer, bulid adversarial dataset and tuning model
adv_downstream_dataset = adv_poisoner(downstream_dataset)
cleantune_trainer = get_trainer(
config.downstream_trainer, config.save_dir + "/" + task
)
backdoored_ds_model = cleantune_trainer.train(
backdoored_ds_model, adv_downstream_dataset
)
# Get poisoned downstream dataset and test model
poisoned_downstream_test_dataset = downstream_poisoner(
downstream_dataset, backdoored_ds_model
)
cleantune_trainer.plm_test(
backdoored_ds_model, poisoned_downstream_test_dataset, config.victim.num_labels
)