Skip to content
31 changes: 24 additions & 7 deletions finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
ap.add_argument('--test_only', '-t', type=bool, default=False, help='test the best model')
ap.add_argument('--workers', default=0, type=int, help='number of workers')
ap.add_argument('--cuda_id', '-id', type=str, default='0', help='gpu number')
ap.add_argument('--label_smoothing', '-ls', type=float, default=0, help='set label smoothing')

args = ap.parse_args()

valid_size=args.valid_size
Expand Down Expand Up @@ -58,10 +60,22 @@
state = torch.load(model_path)['state_dict']
model.load_state_dict(state, strict=False)
CE = nn.CrossEntropyLoss()
def criterion(model, y_pred, y_true):
def criterion_test(model, y_pred, y_true):
ce_loss = CE(y_pred, y_true)
return ce_loss

if args.label_smoothing>0:
CE_smooth = CrossEntropyLabelSmooth(data_object.num_classes , args.label_smoothing)
def criterion_train(model, y_pred, y_true):
ce_loss = CE_smooth(y_pred, y_true)
return ce_loss
else:
def criterion_train(model, y_pred, y_true):
ce_loss = CE(y_pred, y_true)
return ce_loss



optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay)
device = torch.device(f"cuda:{str(args.cuda_id)}")
model.to(device)
Expand Down Expand Up @@ -119,12 +133,15 @@ def test(model, loss_fn, optimizer, phase):
train_losses = []
valid_losses = []
valid_accuracy = []
name = f'{args.name}_{args.dataset}_finetuned'
if args.label_smoothing>0:
name += '_label_smoothing'
if args.test_only == False:
for epoch in range(num_epochs):
adjust_learning_rate(optimizer, epoch, args)
print('Starting epoch %d / %d' % (epoch + 1, num_epochs))
train_loss = train(model, criterion, optimizer)
accuracy, valid_loss = test(model, criterion, optimizer, "val")
train_loss = train(model, criterion_train, optimizer)
accuracy, valid_loss = test(model, criterion_test, optimizer, "val")
remaining = model.get_remaining(20.,args.budget_type).item()

if accuracy>best_accuracy:
Expand All @@ -135,16 +152,16 @@ def test(model, loss_fn, optimizer, phase):
"state_dict" : model.state_dict(),
"acc" : best_accuracy,
"rem" : remaining,
}, f"checkpoints/{args.name}_{args.dataset}_finetuned.pth")
}, f"checkpoints/{name}.pth")

train_losses.append(train_loss)
valid_losses.append(valid_loss)
valid_accuracy.append(accuracy)
df_data=np.array([train_losses, valid_losses, valid_accuracy]).T
df = pd.DataFrame(df_data,columns = ['train_losses','valid_losses','valid_accuracy'])
df.to_csv(f"logs/{args.name}_{args.dataset}_finetuned.csv")
df.to_csv(f"logs/{name}.csv")

state = torch.load(f"checkpoints/{args.name}_{args.dataset}_finetuned.pth")
state = torch.load(f"checkpoints/{name}.pth")
model.load_state_dict(state['state_dict'],strict=True)
acc, v_loss = test(model, criterion, optimizer, "test")
acc, v_loss = test(model, criterion_test, optimizer, "test")
print(f"Test Accuracy: {acc} | Valid Accuracy: {state['acc']}")
73 changes: 58 additions & 15 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def __init__(self):
super(BaseModel, self).__init__()
self.prunable_modules = []
self.prev_module = defaultdict()
# self.next_module = defaultdict()
pass

def set_threshold(self, threshold):
Expand Down Expand Up @@ -48,9 +47,10 @@ def calculate_prune_threshold(self, Vc, budget_type = 'channel_ratio'):
def smoothRound(self, x, steepness=20.):
return 1./(1.+torch.exp(-1*steepness*(x-0.5)))

def n_remaining(self, m, steepness=20.):
return (m.pruned_zeta if m.is_pruned else self.smoothRound(m.get_zeta_t(), steepness)).sum()

def n_remaining(self, m, steepness=20., do_sum=True):
rem = (m.pruned_zeta if m.is_pruned else self.smoothRound(m.get_zeta_t(), steepness))
return rem.sum() if do_sum else rem

def is_all_pruned(self, m):
return self.n_remaining(m) == 0

Expand All @@ -72,13 +72,57 @@ def get_remaining(self, steepness=20., budget_type = 'channel_ratio'):
n_rem += self.n_remaining(l_block, steepness)*prev_remaining*k*k
n_total += l_block.num_gates*prev_total*k*k
elif budget_type == 'flops_ratio':
k = l_block._conv_module.kernel_size[0]
output_area = l_block._conv_module.output_area
prev_total = 3 if self.prev_module[l_block] is None else self.prev_module[l_block].num_gates
prev_remaining = 3 if self.prev_module[l_block] is None else self.n_remaining(self.prev_module[l_block], steepness)
k1 = l_block._conv_module.kernel_size[0]
k2 = l_block._conv_module.kernel_size[1]
active_elements_count = l_block._conv_module.output_area
if self.prev_module[l_block] is None:
prev_total = 3
prev_remaining = 3
elif isinstance(self.prev_module[l_block], nn.BatchNorm2d):
prev_total = self.prev_module[l_block].num_gates
prev_remaining = self.n_remaining(self.prev_module[l_block], steepness)
else:
prev_total = self.prev_module[l_block][-1].num_gates
def cal_max(prev):
if isinstance(prev[0], nn.BatchNorm2d):
prev1 = self.n_remaining(prev[0], steepness, do_sum=False)
prev2 = self.n_remaining(prev[1], steepness, do_sum=False)
return (torch.maximum(prev1, prev2) + torch.maximum(prev2, prev1))/2
prev2 = self.n_remaining(prev[-1], steepness, do_sum=False)
list_ = cal_max(prev[0])
return (torch.maximum(list_, prev2) + torch.maximum(prev2, list_))/2

prev_remaining = cal_max(self.prev_module[l_block]).sum()

curr_remaining = self.n_remaining(l_block, steepness)
n_rem += curr_remaining*prev_remaining*k*k*output_area + curr_remaining*output_area
n_total += l_block.num_gates*prev_total*k*k*output_area + l_block.num_gates*output_area

## Prunned
# conv
conv_per_position_flops = k1 * k2 * prev_remaining * curr_remaining
n_rem += conv_per_position_flops * active_elements_count
if l_block._conv_module.bias is not None:
n_rem += curr_remaining * active_elements_count

# bn
batch_flops = curr_remaining * active_elements_count
n_rem += batch_flops ## ReLU flops
if l_block.affine:
batch_flops *= 2
n_rem += batch_flops

## normal
# conv
conv_per_position_flops = k1 * k2 * prev_total * l_block.num_gates
n_total += conv_per_position_flops * active_elements_count
if l_block._conv_module.bias is not None:
n_total += l_block.num_gates * active_elements_count

# bn
batch_flops = l_block.num_gates * active_elements_count
n_total += batch_flops ## ReLU flops
if l_block.affine:
batch_flops *= 2
n_total += batch_flops
return n_rem/n_total

def give_zetas(self):
Expand Down Expand Up @@ -128,7 +172,7 @@ def prune(self, Vc, budget_type = 'channel_ratio', finetuning=False, threshold=N
high = mid-1
else:
low = mid+1
elif budget_type == 'flops_ratio':
elif budget_type == 'flops_ratio' and threshold==None:
zetas = sorted(self.give_zetas())
high = len(zetas)-1
low = 0
Expand All @@ -138,12 +182,11 @@ def prune(self, Vc, budget_type = 'channel_ratio', finetuning=False, threshold=N
for l_block in self.prunable_modules:
l_block.prune(threshold)
self.remove_orphans()
if self.flops()<Vc:
if self.get_remaining(steepness=20., budget_type='flops_ratio')<Vc:
high = mid-1
else:
low = mid+1
else:
if threshold==None:
elif threshold==None:
self.prune_threshold = self.calculate_prune_threshold(Vc, budget_type)
threshold = min(self.prune_threshold, 0.9)

Expand All @@ -166,7 +209,7 @@ def prepare_for_finetuning(self, device, budget, budget_type = 'channel_ratio'):
self.device = device
self(torch.rand(2,3,32,32).to(device))
threshold = self.prune(budget, budget_type=budget_type, finetuning=True)
if budget_type not in ['parameter_ratio', 'flops_ratio']:
if budget_type not in ['parameter_ratio']:
while self.get_remaining(steepness=20., budget_type=budget_type)<budget:
threshold-=0.0001
self.prune(budget, finetuning=True, budget_type=budget_type, threshold=threshold)
Expand Down
Loading