@@ -169,6 +169,8 @@ def __init__(self,
169169 "Please install wandb using pip install wandb"
170170 )
171171
172+ from yolox .data .datasets import VOCDetection
173+
172174 self .project = project
173175 self .name = name
174176 self .id = id
@@ -202,7 +204,10 @@ def __init__(self,
202204 self .run .define_metric ("train/step" )
203205 self .run .define_metric ("train/*" , step_metric = "train/step" )
204206
207+ self .voc_dataset = VOCDetection
208+
205209 if val_dataset and self .num_log_images != 0 :
210+ self .val_dataset = val_dataset
206211 self .cats = val_dataset .cats
207212 self .id_to_class = {
208213 cls ['id' ]: cls ['name' ] for cls in self .cats
@@ -241,15 +246,56 @@ def _log_validation_set(self, val_dataset):
241246 id = data_point [3 ]
242247 img = np .transpose (img , (1 , 2 , 0 ))
243248 img = cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB )
249+
250+ if isinstance (id , torch .Tensor ):
251+ id = id .item ()
252+
244253 self .val_table .add_data (
245- id . item () ,
254+ id ,
246255 self .wandb .Image (img )
247256 )
248257
249258 self .val_artifact .add (self .val_table , "validation_images_table" )
250259 self .run .use_artifact (self .val_artifact )
251260 self .val_artifact .wait ()
252261
262+ def _convert_prediction_format (self , predictions ):
263+ image_wise_data = defaultdict (int )
264+
265+ for key , val in predictions .items ():
266+ img_id = key
267+
268+ try :
269+ bboxes , cls , scores = val
270+ except KeyError :
271+ bboxes , cls , scores = val ["bboxes" ], val ["categories" ], val ["scores" ]
272+
273+ # These store information of actual bounding boxes i.e. the ones which are not None
274+ act_box = []
275+ act_scores = []
276+ act_cls = []
277+
278+ if bboxes is not None :
279+ for box , classes , score in zip (bboxes , cls , scores ):
280+ if box is None or score is None or classes is None :
281+ continue
282+ act_box .append (box )
283+ act_scores .append (score )
284+ act_cls .append (classes )
285+
286+ image_wise_data .update ({
287+ int (img_id ): {
288+ "bboxes" : [box .numpy ().tolist () for box in act_box ],
289+ "scores" : [score .numpy ().item () for score in act_scores ],
290+ "categories" : [
291+ self .val_dataset .class_ids [int (act_cls [ind ])]
292+ for ind in range (len (act_box ))
293+ ],
294+ }
295+ })
296+
297+ return image_wise_data
298+
253299 def log_metrics (self , metrics , step = None ):
254300 """
255301 Args:
@@ -277,16 +323,23 @@ def log_images(self, predictions):
277323 for cls in self .cats :
278324 columns .append (cls ["name" ])
279325
326+ if isinstance (self .val_dataset , self .voc_dataset ):
327+ predictions = self ._convert_prediction_format (predictions )
328+
280329 result_table = self .wandb .Table (columns = columns )
330+
281331 for idx , val in table_ref .iterrows ():
282332
283333 avg_scores = defaultdict (int )
284334 num_occurrences = defaultdict (int )
285335
286- if val [0 ] in predictions :
287- prediction = predictions [ val [ 0 ]]
288- boxes = [ ]
336+ id = val [0 ]
337+ if isinstance ( id , list ):
338+ id = id [ 0 ]
289339
340+ if id in predictions :
341+ prediction = predictions [id ]
342+ boxes = []
290343 for i in range (len (prediction ["bboxes" ])):
291344 bbox = prediction ["bboxes" ][i ]
292345 x0 = bbox [0 ]
@@ -310,7 +363,6 @@ def log_images(self, predictions):
310363 boxes .append (box )
311364 else :
312365 boxes = []
313-
314366 average_class_score = []
315367 for cls in self .cats :
316368 if cls ["name" ] not in num_occurrences :
0 commit comments