-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcifar_10_resnet20.py
More file actions
144 lines (123 loc) · 3.93 KB
/
cifar_10_resnet20.py
File metadata and controls
144 lines (123 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
import random
from torchvision import transforms
from torchvision.datasets import CIFAR10
from utils.fl_utils import init_clients, fedavg_batchnorm, lr_schedule
from utils.training_utils import evaluate, fl_loop_local
from utils.experiment_tracking import init_run
from custom_modules.resnet20 import ResNet20
#Reproducability and GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Running on:", device)
seed=42
num_clients = 100
num_rounds = 300
local_epochs = 10
client_frac = 0.2
batch_size = 25
base_lr = 0.05
fedprox = False
mu = 0
pad_size=4
random_crop_size = 32
random_flip_prob = 0.5
opt_kwargs = {
"lr": base_lr,
"weight_decay": 1e-4,
}
dl_kwargs = {
#"num_workers": 8,
"pin_memory": (device == "cuda"),
#"prefetch_factor": 4,
}
cfg = {
"seed": seed,
"fed": {
"num_clients": num_clients,
"num_rounds": num_rounds,
"client_frac": client_frac,
"local_epochs": local_epochs,
"fedprox": fedprox,
"mu": mu,
},
"optim": {
"sched" : "cosine_annealing",
"base_lr": base_lr,
"kwargs": opt_kwargs,
},
"data": {
"batch_size": batch_size,
},
"aug": {
"const_pad_size" : pad_size,
"random_crop_size": random_crop_size,
"random_flip_prob": random_flip_prob,
},
}
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
ctx = init_run("cifar10_resnet20_fl", cfg)
print("Run dir:", ctx["run_dir"])
lr_sched = lr_schedule(base_lr=base_lr,start_lr=0,warmup_rounds=0,
total_rounds=num_rounds)
train_transform = transforms.Compose([
transforms.Pad(pad_size, padding_mode="constant"),
transforms.RandomCrop(random_crop_size),
transforms.RandomHorizontalFlip(random_flip_prob),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465),
std=(0.2023, 0.1994, 0.2010)
),
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465),
std=(0.2023, 0.1994, 0.2010)
),
])
train = CIFAR10(root='data',train=True,download=True,transform=train_transform)
val = CIFAR10(root='data',train=False,download=True,transform=val_transform)
#Global eval loaders
train_eval_loader = DataLoader(train, batch_size=batch_size, shuffle=False, **dl_kwargs)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, **dl_kwargs)
# Client loaders
clients = init_clients(
dataset=train,
num_clients=num_clients,
batch_size=batch_size,
dl_kwargs=dl_kwargs,
seed=seed,
shuffle=True,
transform=train_transform,
)
train_loss_fn = nn.CrossEntropyLoss()
eval_loss_fn = nn.CrossEntropyLoss()
def model_fn():
return ResNet20()
def opt_fn(model, opt_kwargs):
return torch.optim.SGD(model.parameters(), **opt_kwargs)
global_model = fl_loop_local(clients=clients,
model_fn=model_fn,
opt_fn=opt_fn,
train_loss_fn=train_loss_fn,
eval_loss_fn=eval_loss_fn,
lr_sched=lr_sched,
val_loader=val_loader,
train_eval_loader=train_eval_loader,
device=device,
ctx=ctx,
fl_kwargs=cfg['fed'],
opt_kwargs=opt_kwargs,
mix_transform=None,
agg_fn=fedavg_batchnorm)
tr = evaluate(global_model, train_eval_loader, device, loss_fn=eval_loss_fn)
va = evaluate(global_model, val_loader, device, loss_fn=eval_loss_fn)
print(f"\nFinal Aggregated Model Train Loss: {tr['loss']:.4f}, Train Acc: {tr['acc']:.4f}")
print(f"Final Aggregated Model Val Loss: {va['loss']:.4f}, Val Acc: {va['acc']:.4f}")