-
Notifications
You must be signed in to change notification settings - Fork 59
Expand file tree
/
Copy pathmain.py
More file actions
70 lines (55 loc) · 3.39 KB
/
main.py
File metadata and controls
70 lines (55 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import print_function, absolute_import
import argparse
import torch,time,os
torch.backends.cudnn.benchmark = True
from scripts.utils.misc import save_checkpoint, adjust_learning_rate
import scripts.datasets as datasets
import scripts.machines as machines
from options import Options
def main(args):
if 'HFlickr' or 'HCOCO' or 'Hday2night' or 'HAdobe5k' in args.base_dir:
dataset_func = datasets.BIH
else:
dataset_func = datasets.COCO
train_loader = torch.utils.data.DataLoader(dataset_func('train',args),batch_size=args.train_batch, shuffle=True,
num_workers=args.workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(dataset_func('val',args),batch_size=args.test_batch, shuffle=False,
num_workers=args.workers, pin_memory=True)
lr = args.lr
data_loaders = (train_loader,val_loader)
Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args)
print('============================ Initization Finish && Training Start =============================================')
for epoch in range(Machine.args.start_epoch, Machine.args.epochs):
print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))
lr = adjust_learning_rate(data_loaders, Machine.optimizer, epoch, lr, args)
Machine.record('lr',lr, epoch)
Machine.train(epoch)
if args.freq < 0:
Machine.validate(epoch)
Machine.flush()
Machine.save_checkpoint()
if __name__ == '__main__':
parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal'))
args = parser.parse_args()
print('==================================== WaterMark Removal =============================================')
print('==> {:50}: {:<}'.format("Start Time",time.ctime(time.time())))
print('==> {:50}: {:<}'.format("USE GPU",os.environ['CUDA_VISIBLE_DEVICES']))
print('==================================== Stable Parameters =============================================')
for arg in vars(args):
if type(getattr(args, arg)) == type([]):
if ','.join([ str(i) for i in getattr(args, arg)]) == ','.join([ str(i) for i in parser.get_default(arg)]):
print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))
else:
if getattr(args, arg) == parser.get_default(arg):
print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))
print('==================================== Changed Parameters =============================================')
for arg in vars(args):
if type(getattr(args, arg)) == type([]):
if ','.join([ str(i) for i in getattr(args, arg)]) != ','.join([ str(i) for i in parser.get_default(arg)]):
print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))
else:
if getattr(args, arg) != parser.get_default(arg):
print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))
print('==================================== Start Init Model ===============================================')
main(args)
print('==================================== FINISH WITHOUT ERROR =============================================')