@@ -74,20 +74,16 @@ def generate_anchors(
7474 return base_anchors .round ()
7575
7676 def set_cell_anchors (self , dtype : torch .dtype , device : torch .device ):
77- return [cell_anchor .to (dtype = dtype , device = device ) for cell_anchor in self .cell_anchors ]
77+ self . cell_anchors = [cell_anchor .to (dtype = dtype , device = device ) for cell_anchor in self .cell_anchors ]
7878
7979 def num_anchors_per_location (self ) -> list [int ]:
8080 return [len (s ) * len (a ) for s , a in zip (self .sizes , self .aspect_ratios )]
8181
8282 # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
8383 # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
84- def grid_anchors (
85- self ,
86- grid_sizes : list [list [int ]],
87- strides : list [list [Tensor ]],
88- cell_anchors : list [torch .Tensor ],
89- ) -> list [Tensor ]:
84+ def grid_anchors (self , grid_sizes : list [list [int ]], strides : list [list [Tensor ]]) -> list [Tensor ]:
9085 anchors = []
86+ cell_anchors = self .cell_anchors
9187 torch ._assert (cell_anchors is not None , "cell_anchors should not be None" )
9288 torch ._assert (
9389 len (grid_sizes ) == len (strides ) == len (cell_anchors ),
@@ -127,8 +123,8 @@ def forward(self, image_list: ImageList, feature_maps: list[Tensor]) -> list[Ten
127123 ]
128124 for g in grid_sizes
129125 ]
130- cell_anchors = self .set_cell_anchors (dtype , device )
131- anchors_over_all_feature_maps = self .grid_anchors (grid_sizes , strides , cell_anchors )
126+ self .set_cell_anchors (dtype , device )
127+ anchors_over_all_feature_maps = self .grid_anchors (grid_sizes , strides )
132128 anchors : list [list [torch .Tensor ]] = []
133129 for _ in range (len (image_list .image_sizes )):
134130 anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps ]
0 commit comments