Skip to content

Commit f4a8224

Browse files
kangkang59812fmassa
authored andcommitted
fix a little bug about resume (#1628)
* fix a little bug about resume When resuming, we need to start from the last epoch not 0. * the second way for resuming the second way for resuming
1 parent 10f3416 commit f4a8224

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

references/detection/train.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,21 @@ def main(args):
108108

109109
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
110110
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
111-
111+
112112
if args.resume:
113113
checkpoint = torch.load(args.resume, map_location='cpu')
114114
model_without_ddp.load_state_dict(checkpoint['model'])
115115
optimizer.load_state_dict(checkpoint['optimizer'])
116116
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
117-
117+
args.start_epoch = checkpoint['epoch'] + 1
118+
118119
if args.test_only:
119120
evaluate(model, data_loader_test, device=device)
120121
return
121122

122123
print("Start training")
123124
start_time = time.time()
124-
for epoch in range(args.epochs):
125+
for epoch in range(args.start_epoch, args.epochs):
125126
if args.distributed:
126127
train_sampler.set_epoch(epoch)
127128
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
@@ -131,7 +132,8 @@ def main(args):
131132
'model': model_without_ddp.state_dict(),
132133
'optimizer': optimizer.state_dict(),
133134
'lr_scheduler': lr_scheduler.state_dict(),
134-
'args': args},
135+
'args': args,
136+
'epoch': epoch},
135137
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
136138

137139
# evaluate after every epoch
@@ -171,6 +173,7 @@ def main(args):
171173
parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
172174
parser.add_argument('--output-dir', default='.', help='path where to save')
173175
parser.add_argument('--resume', default='', help='resume from checkpoint')
176+
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
174177
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
175178
parser.add_argument(
176179
"--test-only",

0 commit comments

Comments
 (0)