Skip to content

Commit ecdd5ed

Browse files
authored
Merge pull request #65 from nez/master
When cuda is available encode_class uses torch.cuda.LongTensor.
2 parents 08cda08 + 1e34c53 commit ecdd5ed

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

wwf/vision/object_detection/metrics.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,20 @@ def activ_to_bbox(acts, anchors, flatten=True):
1717
def bbox_to_activ(bboxes, anchors, flatten=True):
1818
"Return the target of the model on `anchors` for the `bboxes`."
1919
if flatten:
20-
t_centers = (bboxes[...,:2] - anchors[...,:2]) / anchors[...,2:]
21-
t_sizes = torch.log(bboxes[...,2:] / anchors[...,2:] + 1e-8)
20+
t_centers = (bboxes[...,:2] - anchors[...,:2]) / anchors[...,2:]
21+
t_sizes = torch.log(bboxes[...,2:] / anchors[...,2:] + 1e-8)
2222
return torch.cat([t_centers, t_sizes], -1).div_(bboxes.new_tensor([[0.1, 0.1, 0.2, 0.2]]))
2323
else: return [activ_to_bbox(act,anc) for act,anc in zip(acts, anchors)]
2424
return res
2525

2626
def encode_class(idxs, n_classes):
2727
target = idxs.new_zeros(len(idxs), n_classes).float()
2828
mask = idxs != 0
29-
i1s = LongTensor(list(range(len(idxs))))
29+
if cuda.is_available():
30+
tensor_fn = torch.cuda.LongTensor
31+
else:
32+
tensor_fn = LongTensor
33+
i1s = tensor_fn(list(range(len(idxs))))
3034
target[i1s[mask],idxs[mask]-1] = 1
3135
return target
3236

@@ -83,7 +87,7 @@ def intersection(anchors, targets):
8387
ancs, tgts = ancs.unsqueeze(1).expand(a,t,4), tgts.unsqueeze(0).expand(a,t,4)
8488
top_left_i = torch.max(ancs[...,:2], tgts[...,:2])
8589
bot_right_i = torch.min(ancs[...,2:], tgts[...,2:])
86-
sizes = torch.clamp(bot_right_i - top_left_i, min=0)
90+
sizes = torch.clamp(bot_right_i - top_left_i, min=0)
8791
return sizes[...,0] * sizes[...,1]
8892

8993
def IoU_values(anchs, targs):
@@ -99,21 +103,21 @@ def __init__(self, gamma:float=2., alpha:float=0.25, pad_idx:int=0, scales=None
99103
self.gamma,self.alpha,self.pad_idx,self.reg_loss = gamma,alpha,pad_idx,reg_loss
100104
self.scales = ifnone(scales, [1,2**(-1/3), 2**(-2/3)])
101105
self.ratios = ifnone(ratios, [1/2,1,2])
102-
106+
103107
def _change_anchors(self, sizes) -> bool:
104108
if not hasattr(self, 'sizes'): return True
105109
for sz1, sz2 in zip(self.sizes, sizes):
106110
if sz1[0] != sz2[0] or sz1[1] != sz2[1]: return True
107111
return False
108-
112+
109113
def _create_anchors(self, sizes, device:torch.device):
110114
self.sizes = sizes
111115
self.anchors = create_anchors(sizes, self.ratios, self.scales).to(device)
112-
116+
113117
def _unpad(self, bbox_tgt, clas_tgt):
114118
i = torch.min(torch.nonzero(clas_tgt-self.pad_idx))
115119
return tlbr2cthw(bbox_tgt[i:]), clas_tgt[i:]-1+self.pad_idx
116-
120+
117121
def _focal_loss(self, clas_pred, clas_tgt):
118122
encoded_tgt = encode_class(clas_tgt, clas_pred.size(1))
119123
ps = torch.sigmoid(clas_pred.detach())
@@ -122,7 +126,7 @@ def _focal_loss(self, clas_pred, clas_tgt):
122126
weights.pow_(self.gamma).mul_(alphas)
123127
clas_loss = F.binary_cross_entropy_with_logits(clas_pred, encoded_tgt, weights, reduction='sum')
124128
return clas_loss
125-
129+
126130
def _one_loss(self, clas_pred, bbox_pred, clas_tgt, bbox_tgt):
127131
bbox_tgt, clas_tgt = self._unpad(bbox_tgt, clas_tgt)
128132
matches = match_anchors(self.anchors, bbox_tgt)
@@ -139,7 +143,7 @@ def _one_loss(self, clas_pred, bbox_pred, clas_tgt, bbox_tgt):
139143
clas_tgt = torch.cat([clas_tgt.new_zeros(1).long(), clas_tgt])
140144
clas_tgt = clas_tgt[matches[clas_mask]]
141145
return bb_loss + self._focal_loss(clas_pred, clas_tgt)/torch.clamp(bbox_mask.sum(), min=1.)
142-
146+
143147
def forward(self, output, bbox_tgts, clas_tgts):
144148
clas_preds, bbox_preds, sizes = output
145149
if self._change_anchors(sizes): self._create_anchors(sizes, clas_preds.device)

0 commit comments

Comments
 (0)