@@ -17,16 +17,20 @@ def activ_to_bbox(acts, anchors, flatten=True):
1717def bbox_to_activ (bboxes , anchors , flatten = True ):
1818 "Return the target of the model on `anchors` for the `bboxes`."
1919 if flatten :
20- t_centers = (bboxes [...,:2 ] - anchors [...,:2 ]) / anchors [...,2 :]
21- t_sizes = torch .log (bboxes [...,2 :] / anchors [...,2 :] + 1e-8 )
20+ t_centers = (bboxes [...,:2 ] - anchors [...,:2 ]) / anchors [...,2 :]
21+ t_sizes = torch .log (bboxes [...,2 :] / anchors [...,2 :] + 1e-8 )
2222 return torch .cat ([t_centers , t_sizes ], - 1 ).div_ (bboxes .new_tensor ([[0.1 , 0.1 , 0.2 , 0.2 ]]))
2323 else : return [activ_to_bbox (act ,anc ) for act ,anc in zip (acts , anchors )]
2424 return res
2525
2626def encode_class (idxs , n_classes ):
2727 target = idxs .new_zeros (len (idxs ), n_classes ).float ()
2828 mask = idxs != 0
29- i1s = LongTensor (list (range (len (idxs ))))
29+ if cuda .is_available ():
30+ tensor_fn = torch .cuda .LongTensor
31+ else :
32+ tensor_fn = LongTensor
33+ i1s = tensor_fn (list (range (len (idxs ))))
3034 target [i1s [mask ],idxs [mask ]- 1 ] = 1
3135 return target
3236
@@ -83,7 +87,7 @@ def intersection(anchors, targets):
8387 ancs , tgts = ancs .unsqueeze (1 ).expand (a ,t ,4 ), tgts .unsqueeze (0 ).expand (a ,t ,4 )
8488 top_left_i = torch .max (ancs [...,:2 ], tgts [...,:2 ])
8589 bot_right_i = torch .min (ancs [...,2 :], tgts [...,2 :])
86- sizes = torch .clamp (bot_right_i - top_left_i , min = 0 )
90+ sizes = torch .clamp (bot_right_i - top_left_i , min = 0 )
8791 return sizes [...,0 ] * sizes [...,1 ]
8892
8993def IoU_values (anchs , targs ):
@@ -99,21 +103,21 @@ def __init__(self, gamma:float=2., alpha:float=0.25, pad_idx:int=0, scales=None
99103 self .gamma ,self .alpha ,self .pad_idx ,self .reg_loss = gamma ,alpha ,pad_idx ,reg_loss
100104 self .scales = ifnone (scales , [1 ,2 ** (- 1 / 3 ), 2 ** (- 2 / 3 )])
101105 self .ratios = ifnone (ratios , [1 / 2 ,1 ,2 ])
102-
106+
103107 def _change_anchors (self , sizes ) -> bool :
104108 if not hasattr (self , 'sizes' ): return True
105109 for sz1 , sz2 in zip (self .sizes , sizes ):
106110 if sz1 [0 ] != sz2 [0 ] or sz1 [1 ] != sz2 [1 ]: return True
107111 return False
108-
112+
109113 def _create_anchors (self , sizes , device :torch .device ):
110114 self .sizes = sizes
111115 self .anchors = create_anchors (sizes , self .ratios , self .scales ).to (device )
112-
116+
113117 def _unpad (self , bbox_tgt , clas_tgt ):
114118 i = torch .min (torch .nonzero (clas_tgt - self .pad_idx ))
115119 return tlbr2cthw (bbox_tgt [i :]), clas_tgt [i :]- 1 + self .pad_idx
116-
120+
117121 def _focal_loss (self , clas_pred , clas_tgt ):
118122 encoded_tgt = encode_class (clas_tgt , clas_pred .size (1 ))
119123 ps = torch .sigmoid (clas_pred .detach ())
@@ -122,7 +126,7 @@ def _focal_loss(self, clas_pred, clas_tgt):
122126 weights .pow_ (self .gamma ).mul_ (alphas )
123127 clas_loss = F .binary_cross_entropy_with_logits (clas_pred , encoded_tgt , weights , reduction = 'sum' )
124128 return clas_loss
125-
129+
126130 def _one_loss (self , clas_pred , bbox_pred , clas_tgt , bbox_tgt ):
127131 bbox_tgt , clas_tgt = self ._unpad (bbox_tgt , clas_tgt )
128132 matches = match_anchors (self .anchors , bbox_tgt )
@@ -139,7 +143,7 @@ def _one_loss(self, clas_pred, bbox_pred, clas_tgt, bbox_tgt):
139143 clas_tgt = torch .cat ([clas_tgt .new_zeros (1 ).long (), clas_tgt ])
140144 clas_tgt = clas_tgt [matches [clas_mask ]]
141145 return bb_loss + self ._focal_loss (clas_pred , clas_tgt )/ torch .clamp (bbox_mask .sum (), min = 1. )
142-
146+
143147 def forward (self , output , bbox_tgts , clas_tgts ):
144148 clas_preds , bbox_preds , sizes = output
145149 if self ._change_anchors (sizes ): self ._create_anchors (sizes , clas_preds .device )
0 commit comments