Skip to content

Commit 900c88c

Browse files
authored
Bugfix on GroupedBatchSampler for corner case where there are not enough examples in a category to form a batch (#1677)
1 parent 1d229b7 commit 900c88c

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

references/detection/group_by_aspect_ratio.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import bisect
22
from collections import defaultdict
33
import copy
4+
from itertools import repeat, chain
5+
import math
46
import numpy as np
57

68
import torch
@@ -12,6 +14,12 @@
1214
from PIL import Image
1315

1416

17+
def _repeat_to_at_least(iterable, n):
18+
repeat_times = math.ceil(n / len(iterable))
19+
repeated = chain.from_iterable(repeat(iterable, repeat_times))
20+
return list(repeated)
21+
22+
1523
class GroupedBatchSampler(BatchSampler):
1624
"""
1725
Wraps another sampler to yield a mini-batch of indices.
@@ -63,8 +71,8 @@ def __iter__(self):
6371
for group_id, _ in sorted(buffer_per_group.items(),
6472
key=lambda x: len(x[1]), reverse=True):
6573
remaining = self.batch_size - len(buffer_per_group[group_id])
66-
buffer_per_group[group_id].extend(
67-
samples_per_group[group_id][:remaining])
74+
samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
75+
buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
6876
assert len(buffer_per_group[group_id]) == self.batch_size
6977
yield buffer_per_group[group_id]
7078
num_remaining -= 1

0 commit comments

Comments
 (0)