@@ -180,11 +180,11 @@ def before_train(self):
180180 if self .args .logger == "tensorboard" :
181181 self .tblogger = SummaryWriter (os .path .join (self .file_name , "tensorboard" ))
182182 elif self .args .logger == "wandb" :
183- wandb_params = dict ()
184- for k , v in zip ( self .args . opts [ 0 :: 2 ], self . args . opts [ 1 :: 2 ]):
185- if k . startswith ( "wandb-" ):
186- wandb_params . update ({ k [ len ( "wandb-" ):]: v })
187- self . wandb_logger = WandbLogger ( config = vars ( self . exp ), ** wandb_params )
183+ self . wandb_logger = WandbLogger . initialize_wandb_logger (
184+ self .args ,
185+ self . exp ,
186+ self . evaluator . dataloader . dataset
187+ )
188188 else :
189189 raise ValueError ("logger must be either 'tensorboard' or 'wandb'" )
190190
@@ -263,8 +263,11 @@ def after_iter(self):
263263
264264 if self .rank == 0 :
265265 if self .args .logger == "wandb" :
266- self .wandb_logger .log_metrics ({k : v .latest for k , v in loss_meter .items ()})
267- self .wandb_logger .log_metrics ({"lr" : self .meter ["lr" ].latest })
266+ metrics = {"train/" + k : v .latest for k , v in loss_meter .items ()}
267+ metrics .update ({
268+ "train/lr" : self .meter ["lr" ].latest
269+ })
270+ self .wandb_logger .log_metrics (metrics , step = self .progress_in_iter )
268271
269272 self .meter .clear_meters ()
270273
@@ -322,8 +325,8 @@ def evaluate_and_save_model(self):
322325 evalmodel = evalmodel .module
323326
324327 with adjust_status (evalmodel , training = False ):
325- ap50_95 , ap50 , summary = self .exp .eval (
326- evalmodel , self .evaluator , self .is_distributed
328+ ( ap50_95 , ap50 , summary ), predictions = self .exp .eval (
329+ evalmodel , self .evaluator , self .is_distributed , return_outputs = True
327330 )
328331
329332 update_best_ckpt = ap50_95 > self .best_ap
@@ -337,16 +340,17 @@ def evaluate_and_save_model(self):
337340 self .wandb_logger .log_metrics ({
338341 "val/COCOAP50" : ap50 ,
339342 "val/COCOAP50_95" : ap50_95 ,
340- "epoch" : self .epoch + 1 ,
343+ "train/ epoch" : self .epoch + 1 ,
341344 })
345+ self .wandb_logger .log_images (predictions )
342346 logger .info ("\n " + summary )
343347 synchronize ()
344348
345- self .save_ckpt ("last_epoch" , update_best_ckpt )
349+ self .save_ckpt ("last_epoch" , update_best_ckpt , ap = ap50_95 )
346350 if self .save_history_ckpt :
347- self .save_ckpt (f"epoch_{ self .epoch + 1 } " )
351+ self .save_ckpt (f"epoch_{ self .epoch + 1 } " , ap = ap50_95 )
348352
349- def save_ckpt (self , ckpt_name , update_best_ckpt = False ):
353+ def save_ckpt (self , ckpt_name , update_best_ckpt = False , ap = None ):
350354 if self .rank == 0 :
351355 save_model = self .ema_model .ema if self .use_model_ema else self .model
352356 logger .info ("Save weights to {}" .format (self .file_name ))
@@ -355,6 +359,7 @@ def save_ckpt(self, ckpt_name, update_best_ckpt=False):
355359 "model" : save_model .state_dict (),
356360 "optimizer" : self .optimizer .state_dict (),
357361 "best_ap" : self .best_ap ,
362+ "curr_ap" : ap ,
358363 }
359364 save_checkpoint (
360365 ckpt_state ,
@@ -364,4 +369,14 @@ def save_ckpt(self, ckpt_name, update_best_ckpt=False):
364369 )
365370
366371 if self .args .logger == "wandb" :
367- self .wandb_logger .save_checkpoint (self .file_name , ckpt_name , update_best_ckpt )
372+ self .wandb_logger .save_checkpoint (
373+ self .file_name ,
374+ ckpt_name ,
375+ update_best_ckpt ,
376+ metadata = {
377+ "epoch" : self .epoch + 1 ,
378+ "optimizer" : self .optimizer .state_dict (),
379+ "best_ap" : self .best_ap ,
380+ "curr_ap" : ap
381+ }
382+ )
0 commit comments