@@ -108,20 +108,21 @@ def main(args):
108108
109109 # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
110110 lr_scheduler = torch .optim .lr_scheduler .MultiStepLR (optimizer , milestones = args .lr_steps , gamma = args .lr_gamma )
111-
111+
112112 if args .resume :
113113 checkpoint = torch .load (args .resume , map_location = 'cpu' )
114114 model_without_ddp .load_state_dict (checkpoint ['model' ])
115115 optimizer .load_state_dict (checkpoint ['optimizer' ])
116116 lr_scheduler .load_state_dict (checkpoint ['lr_scheduler' ])
117-
117+ args .start_epoch = checkpoint ['epoch' ] + 1
118+
118119 if args .test_only :
119120 evaluate (model , data_loader_test , device = device )
120121 return
121122
122123 print ("Start training" )
123124 start_time = time .time ()
124- for epoch in range (args .epochs ):
125+ for epoch in range (args .start_epoch , args . epochs ):
125126 if args .distributed :
126127 train_sampler .set_epoch (epoch )
127128 train_one_epoch (model , optimizer , data_loader , device , epoch , args .print_freq )
@@ -131,7 +132,8 @@ def main(args):
131132 'model' : model_without_ddp .state_dict (),
132133 'optimizer' : optimizer .state_dict (),
133134 'lr_scheduler' : lr_scheduler .state_dict (),
134- 'args' : args },
135+ 'args' : args ,
136+ 'epoch' : epoch },
135137 os .path .join (args .output_dir , 'model_{}.pth' .format (epoch )))
136138
137139 # evaluate after every epoch
@@ -171,6 +173,7 @@ def main(args):
171173 parser .add_argument ('--print-freq' , default = 20 , type = int , help = 'print frequency' )
172174 parser .add_argument ('--output-dir' , default = '.' , help = 'path where to save' )
173175 parser .add_argument ('--resume' , default = '' , help = 'resume from checkpoint' )
176+ parser .add_argument ('--start_epoch' , default = 0 , type = int , help = 'start epoch' )
174177 parser .add_argument ('--aspect-ratio-group-factor' , default = 3 , type = int )
175178 parser .add_argument (
176179 "--test-only" ,
0 commit comments