Skip to content

Commit 24fd19f

Browse files
authored
feat(logger): W&B logger with VOC datasets (Megvii-BaseDetection#1525)
feat(logger): W&B logger with VOC datasets
1 parent 74b637b commit 24fd19f

File tree

3 files changed

+62
-5
lines changed

3 files changed

+62
-5
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ jobs:
3434
pip install -r requirements.txt
3535
pip install isort==4.3.21
3636
pip install flake8==3.8.3
37+
pip install "importlib-metadata<5.0"
3738
# Runs a set of commands using the runners shell
3839
- name: Format check
3940
run: ./.github/workflows/format_check.sh

yolox/data/datasets/voc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def __init__(
119119
self._annopath = os.path.join("%s", "Annotations", "%s.xml")
120120
self._imgpath = os.path.join("%s", "JPEGImages", "%s.jpg")
121121
self._classes = VOC_CLASSES
122+
self.cats = [
123+
{"id": idx, "name": val} for idx, val in enumerate(VOC_CLASSES)
124+
]
125+
self.class_ids = list(range(len(VOC_CLASSES)))
122126
self.ids = list()
123127
for (year, name) in image_sets:
124128
self._year = year

yolox/utils/logger.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def __init__(self,
169169
"Please install wandb using pip install wandb"
170170
)
171171

172+
from yolox.data.datasets import VOCDetection
173+
172174
self.project = project
173175
self.name = name
174176
self.id = id
@@ -202,7 +204,10 @@ def __init__(self,
202204
self.run.define_metric("train/step")
203205
self.run.define_metric("train/*", step_metric="train/step")
204206

207+
self.voc_dataset = VOCDetection
208+
205209
if val_dataset and self.num_log_images != 0:
210+
self.val_dataset = val_dataset
206211
self.cats = val_dataset.cats
207212
self.id_to_class = {
208213
cls['id']: cls['name'] for cls in self.cats
@@ -241,15 +246,56 @@ def _log_validation_set(self, val_dataset):
241246
id = data_point[3]
242247
img = np.transpose(img, (1, 2, 0))
243248
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
249+
250+
if isinstance(id, torch.Tensor):
251+
id = id.item()
252+
244253
self.val_table.add_data(
245-
id.item(),
254+
id,
246255
self.wandb.Image(img)
247256
)
248257

249258
self.val_artifact.add(self.val_table, "validation_images_table")
250259
self.run.use_artifact(self.val_artifact)
251260
self.val_artifact.wait()
252261

262+
def _convert_prediction_format(self, predictions):
263+
image_wise_data = defaultdict(int)
264+
265+
for key, val in predictions.items():
266+
img_id = key
267+
268+
try:
269+
bboxes, cls, scores = val
270+
except KeyError:
271+
bboxes, cls, scores = val["bboxes"], val["categories"], val["scores"]
272+
273+
# These store information of actual bounding boxes i.e. the ones which are not None
274+
act_box = []
275+
act_scores = []
276+
act_cls = []
277+
278+
if bboxes is not None:
279+
for box, classes, score in zip(bboxes, cls, scores):
280+
if box is None or score is None or classes is None:
281+
continue
282+
act_box.append(box)
283+
act_scores.append(score)
284+
act_cls.append(classes)
285+
286+
image_wise_data.update({
287+
int(img_id): {
288+
"bboxes": [box.numpy().tolist() for box in act_box],
289+
"scores": [score.numpy().item() for score in act_scores],
290+
"categories": [
291+
self.val_dataset.class_ids[int(act_cls[ind])]
292+
for ind in range(len(act_box))
293+
],
294+
}
295+
})
296+
297+
return image_wise_data
298+
253299
def log_metrics(self, metrics, step=None):
254300
"""
255301
Args:
@@ -277,16 +323,23 @@ def log_images(self, predictions):
277323
for cls in self.cats:
278324
columns.append(cls["name"])
279325

326+
if isinstance(self.val_dataset, self.voc_dataset):
327+
predictions = self._convert_prediction_format(predictions)
328+
280329
result_table = self.wandb.Table(columns=columns)
330+
281331
for idx, val in table_ref.iterrows():
282332

283333
avg_scores = defaultdict(int)
284334
num_occurrences = defaultdict(int)
285335

286-
if val[0] in predictions:
287-
prediction = predictions[val[0]]
288-
boxes = []
336+
id = val[0]
337+
if isinstance(id, list):
338+
id = id[0]
289339

340+
if id in predictions:
341+
prediction = predictions[id]
342+
boxes = []
290343
for i in range(len(prediction["bboxes"])):
291344
bbox = prediction["bboxes"][i]
292345
x0 = bbox[0]
@@ -310,7 +363,6 @@ def log_images(self, predictions):
310363
boxes.append(box)
311364
else:
312365
boxes = []
313-
314366
average_class_score = []
315367
for cls in self.cats:
316368
if cls["name"] not in num_occurrences:

0 commit comments

Comments
 (0)