55from typing import Optional , Dict , List
66import torch
77import torch .nn as nn
8- from .anchors import Anchors , AnchorLabeler , generate_detections , MAX_DETECTION_POINTS
8+ from .anchors import Anchors , AnchorLabeler , generate_detections
99from .loss import DetectionLoss
1010
1111
@@ -14,7 +14,7 @@ def _post_process(
1414 box_outputs : List [torch .Tensor ],
1515 num_levels : int ,
1616 num_classes : int ,
17- max_detection_points : int = MAX_DETECTION_POINTS ,
17+ max_detection_points : int = 5000 ,
1818):
1919 """Selects top-k predictions.
2020
@@ -59,14 +59,19 @@ def _post_process(
5959@torch .jit .script
6060def _batch_detection (
6161 batch_size : int , class_out , box_out , anchor_boxes , indices , classes ,
62- img_scale : Optional [torch .Tensor ] = None , img_size : Optional [torch .Tensor ] = None ):
62+ img_scale : Optional [torch .Tensor ] = None ,
63+ img_size : Optional [torch .Tensor ] = None ,
64+ max_det_per_image : int = 100 ,
65+ soft_nms : bool = False ,
66+ ):
6367 batch_detections = []
6468 # FIXME we may be able to do this as a batch with some tensor reshaping/indexing, PR welcome
6569 for i in range (batch_size ):
6670 img_scale_i = None if img_scale is None else img_scale [i ]
6771 img_size_i = None if img_size is None else img_size [i ]
6872 detections = generate_detections (
69- class_out [i ], box_out [i ], anchor_boxes , indices [i ], classes [i ], img_scale_i , img_size_i )
73+ class_out [i ], box_out [i ], anchor_boxes , indices [i ], classes [i ],
74+ img_scale_i , img_size_i , max_det_per_image = max_det_per_image , soft_nms = soft_nms )
7075 batch_detections .append (detections )
7176 return torch .stack (batch_detections , dim = 0 )
7277
@@ -79,17 +84,23 @@ def __init__(self, model):
7984 self .num_levels = model .config .num_levels
8085 self .num_classes = model .config .num_classes
8186 self .anchors = Anchors .from_config (model .config )
87+ self .max_detection_points = model .config .max_detection_points
88+ self .max_det_per_image = model .config .max_det_per_image
89+ self .soft_nms = model .config .soft_nms
8290
8391 def forward (self , x , img_info : Optional [Dict [str , torch .Tensor ]] = None ):
8492 class_out , box_out = self .model (x )
8593 class_out , box_out , indices , classes = _post_process (
86- class_out , box_out , num_levels = self .num_levels , num_classes = self .num_classes )
94+ class_out , box_out , num_levels = self .num_levels , num_classes = self .num_classes ,
95+ max_detection_points = self .max_detection_points )
8796 if img_info is None :
8897 img_scale , img_size = None , None
8998 else :
9099 img_scale , img_size = img_info ['img_scale' ], img_info ['img_size' ]
91100 return _batch_detection (
92- x .shape [0 ], class_out , box_out , self .anchors .boxes , indices , classes , img_scale , img_size )
101+ x .shape [0 ], class_out , box_out , self .anchors .boxes , indices , classes ,
102+ img_scale , img_size , max_det_per_image = self .max_det_per_image , soft_nms = self .soft_nms
103+ )
93104
94105
95106class DetBenchTrain (nn .Module ):
@@ -100,6 +111,9 @@ def __init__(self, model, create_labeler=True):
100111 self .num_levels = model .config .num_levels
101112 self .num_classes = model .config .num_classes
102113 self .anchors = Anchors .from_config (model .config )
114+ self .max_detection_points = model .config .max_detection_points
115+ self .max_det_per_image = model .config .max_det_per_image
116+ self .soft_nms = model .config .soft_nms
103117 self .anchor_labeler = None
104118 if create_labeler :
105119 self .anchor_labeler = AnchorLabeler (self .anchors , self .num_classes , match_threshold = 0.5 )
@@ -122,10 +136,12 @@ def forward(self, x, target: Dict[str, torch.Tensor]):
122136 if not self .training :
123137 # if eval mode, output detections for evaluation
124138 class_out_pp , box_out_pp , indices , classes = _post_process (
125- class_out , box_out , num_levels = self .num_levels , num_classes = self .num_classes )
139+ class_out , box_out , num_levels = self .num_levels , num_classes = self .num_classes ,
140+ max_detection_points = self .max_detection_points )
126141 output ['detections' ] = _batch_detection (
127142 x .shape [0 ], class_out_pp , box_out_pp , self .anchors .boxes , indices , classes ,
128- target ['img_scale' ], target ['img_size' ])
143+ target ['img_scale' ], target ['img_size' ],
144+ max_det_per_image = self .max_det_per_image , soft_nms = self .soft_nms )
129145 return output
130146
131147
0 commit comments