File tree Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change 11import bisect
22from collections import defaultdict
33import copy
4+ from itertools import repeat , chain
5+ import math
46import numpy as np
57
68import torch
1214from PIL import Image
1315
1416
17+ def _repeat_to_at_least (iterable , n ):
18+ repeat_times = math .ceil (n / len (iterable ))
19+ repeated = chain .from_iterable (repeat (iterable , repeat_times ))
20+ return list (repeated )
21+
22+
1523class GroupedBatchSampler (BatchSampler ):
1624 """
1725 Wraps another sampler to yield a mini-batch of indices.
@@ -63,8 +71,8 @@ def __iter__(self):
6371 for group_id , _ in sorted (buffer_per_group .items (),
6472 key = lambda x : len (x [1 ]), reverse = True ):
6573 remaining = self .batch_size - len (buffer_per_group [group_id ])
66- buffer_per_group [group_id ]. extend (
67- samples_per_group [group_id ][:remaining ])
74+ samples_from_group_id = _repeat_to_at_least ( samples_per_group [group_id ], remaining )
75+ buffer_per_group [group_id ]. extend ( samples_from_group_id [:remaining ])
6876 assert len (buffer_per_group [group_id ]) == self .batch_size
6977 yield buffer_per_group [group_id ]
7078 num_remaining -= 1
You can’t perform that action at this time.
0 commit comments