From 5a9709acfd7921b86e7a779c5a91d2c3962ec090 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Thu, 22 May 2025 13:46:22 -0700 Subject: [PATCH] fix anti-pattern for cudagraph --- torchvision/models/detection/anchor_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 05aa7664bea..7c72633ff49 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -74,16 +74,20 @@ def generate_anchors( return base_anchors.round() def set_cell_anchors(self, dtype: torch.dtype, device: torch.device): - self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors] + return [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors] def num_anchors_per_location(self) -> list[int]: return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. - def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]]) -> list[Tensor]: + def grid_anchors( + self, + grid_sizes: list[list[int]], + strides: list[list[Tensor]], + cell_anchors: list[torch.Tensor], + ) -> list[Tensor]: anchors = [] - cell_anchors = self.cell_anchors torch._assert(cell_anchors is not None, "cell_anchors should not be None") torch._assert( len(grid_sizes) == len(strides) == len(cell_anchors), @@ -123,8 +127,8 @@ def forward(self, image_list: ImageList, feature_maps: list[Tensor]) -> list[Ten ] for g in grid_sizes ] - self.set_cell_anchors(dtype, device) - anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides) + cell_anchors = self.set_cell_anchors(dtype, device) + anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides, cell_anchors) anchors: list[list[torch.Tensor]] = [] for _ in range(len(image_list.image_sizes)): anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]