Skip to content

Commit e6b4544

Browse files
fmassaeellisoneellison
authored
Try remove eager scripting calls (#2248) (#2362)
* Try remove eager scripting calls * remove script call Co-authored-by: eellison <[email protected]> Co-authored-by: Francisco Massa <[email protected]> Co-authored-by: eellison <[email protected]> Co-authored-by: eellison <[email protected]>
1 parent 8638772 commit e6b4544

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

torchvision/models/detection/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __call__(self, matched_idxs):
7575
return pos_idx, neg_idx
7676

7777

78-
@torch.jit.script
78+
@torch.jit._script_if_tracing
7979
def encode_boxes(reference_boxes, proposals, weights):
8080
# type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
8181
"""

torchvision/models/detection/roi_heads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height,
205205
return xy_preds_i, end_scores_i
206206

207207

208-
@torch.jit.script
208+
@torch.jit._script_if_tracing
209209
def _onnx_heatmaps_to_keypoints_loop(maps, rois, widths_ceil, heights_ceil,
210210
widths, heights, offset_x, offset_y, num_keypoints):
211211
xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
@@ -451,7 +451,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
451451
return im_mask
452452

453453

454-
@torch.jit.script
454+
@torch.jit._script_if_tracing
455455
def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
456456
res_append = torch.zeros(0, im_h, im_w)
457457
for i in range(masks.size(0)):

torchvision/ops/boxes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torchvision
55

66

7-
@torch.jit.script
87
def nms(boxes, scores, iou_threshold):
98
# type: (Tensor, Tensor, float) -> Tensor
109
"""
@@ -41,7 +40,7 @@ def nms(boxes, scores, iou_threshold):
4140
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
4241

4342

44-
@torch.jit.script
43+
@torch.jit._script_if_tracing
4544
def batched_nms(boxes, scores, idxs, iou_threshold):
4645
# type: (Tensor, Tensor, Tensor, float) -> Tensor
4746
"""

0 commit comments

Comments
 (0)