@@ -42,7 +42,8 @@ def __init__(self,
4242 anchor_feat_channels = None ,
4343 conf_thres = None ,
4444 nms_thres = 0 ,
45- nms_topk = 3000 ):
45+ nms_topk = 3000 ,
46+ trace_arg = None ):
4647 super ().__init__ ()
4748 self .backbone = MODELS .from_dict (backbone_cfg )
4849 self .backbone_channels = backbone_channels
@@ -61,6 +62,10 @@ def __init__(self,
6162 self .nms_thres = nms_thres
6263 self .nms_topk = nms_topk
6364
65+ if trace_arg is not None : # Pre-compute
66+ attention_matrix = torch .eye (topk_anchors ).repeat (trace_arg ['bs' ], 1 , 1 )
67+ self .pre_non_diag_inds = torch .nonzero (attention_matrix == 0. , as_tuple = False )
68+
6469 # generate anchors
6570 self .anchors , self .anchors_cut = self .generate_anchors (lateral_n = 72 , bottom_n = 128 )
6671 # Filter masks if `anchors_freq_path` is provided
@@ -209,18 +214,12 @@ def forward(self, x):
209214 attention = softmax (scores ).reshape (x .shape [0 ], len (self .anchors ), - 1 )
210215 attention_matrix = torch .eye (attention .shape [1 ], device = x .device ).repeat (x .shape [0 ], 1 , 1 )
211216 if is_tracing ():
212- # TODO: this also triggers nonzero, and where can't be used
213- mask = attention_matrix < 1
214- attention_matrix [mask ] = attention .flatten ()
215- attention_matrix *= mask
216- # 3 0 1 2
217- # 0 3 1 2
218- # 0 1 3 2
219- # 0 1 2 3
217+ # Use pre-computed nonzero results
218+ non_diag_inds = self .pre_non_diag_inds .to (attention_matrix .device )
220219 else :
221220 non_diag_inds = torch .nonzero (attention_matrix == 0. , as_tuple = False )
222- attention_matrix [:] = 0
223- attention_matrix [non_diag_inds [:, 0 ], non_diag_inds [:, 1 ], non_diag_inds [:, 2 ]] = attention .flatten ()
221+ attention_matrix [:] = 0
222+ attention_matrix [non_diag_inds [:, 0 ], non_diag_inds [:, 1 ], non_diag_inds [:, 2 ]] = attention .flatten ()
224223 batch_anchor_features = batch_anchor_features .reshape (x .shape [0 ], len (self .anchors ), - 1 )
225224 attention_features = torch .bmm (torch .transpose (batch_anchor_features , 1 , 2 ),
226225 torch .transpose (attention_matrix , 1 , 2 )).transpose (1 , 2 )
0 commit comments