Skip to content

Commit 4050cb9

Browse files
Fix cuda tests
Test Plan: ```bash pytest test/test_transforms_v2.py -k box -v ```
1 parent 4a02ba0 commit 4050cb9

File tree

4 files changed

+84
-37
lines changed

4 files changed

+84
-37
lines changed

test/test_transforms_v2.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -608,22 +608,19 @@ def affine_rotated_bounding_boxes(bounding_boxes):
608608
)
609609

610610
if torch.is_floating_point(output) and int_dtype:
611-
# it is better to round before cast
611+
# It is important to round before cast.
612612
output = torch.round(output)
613613

614-
if clamp:
615-
# It is important to clamp before casting, especially for CXCYWHR format, dtype=int64
616-
output = F.clamp_bounding_boxes(
617-
output,
614+
# For rotated boxes, it is important to cast before clamping.
615+
return (
616+
F.clamp_bounding_boxes(
617+
output.to(dtype=dtype, device=device),
618618
format=format,
619619
canvas_size=canvas_size,
620620
)
621-
else:
622-
# We leave the bounding box as float32 so the caller gets the full precision to perform any additional
623-
# operation
624-
dtype = output.dtype
625-
626-
return output.to(dtype=dtype, device=device)
621+
if clamp
622+
else output.to(dtype=output.dtype, device=device)
623+
)
627624

628625
return tv_tensors.BoundingBoxes(
629626
torch.cat(

torchvision/ops/_box_convert.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def _box_cxcywhr_to_xywhr(boxes: Tensor) -> Tensor:
9494
Returns:
9595
boxes (Tensor(N, 5)): rotated boxes in (x1, y1, w, h, r) format.
9696
"""
97+
dtype = boxes.dtype
98+
need_cast = not boxes.is_floating_point()
9799
cx, cy, w, h, r = boxes.unbind(-1)
98100
r_rad = r * torch.pi / 180.0
99101
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
@@ -102,6 +104,9 @@ def _box_cxcywhr_to_xywhr(boxes: Tensor) -> Tensor:
102104
y1 = cy - h / 2 * cos + w / 2 * sin
103105
boxes = torch.stack((x1, y1, w, h, r), dim=-1)
104106

107+
if need_cast:
108+
boxes.round_()
109+
boxes = boxes.to(dtype)
105110
return boxes
106111

107112

@@ -117,6 +122,8 @@ def _box_xywhr_to_cxcywhr(boxes: Tensor) -> Tensor:
117122
Returns:
118123
boxes (Tensor[N, 5]): rotated boxes in (cx, cy, w, h, r) format.
119124
"""
125+
dtype = boxes.dtype
126+
need_cast = not boxes.is_floating_point()
120127
x1, y1, w, h, r = boxes.unbind(-1)
121128
r_rad = r * torch.pi / 180.0
122129
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
@@ -125,6 +132,9 @@ def _box_xywhr_to_cxcywhr(boxes: Tensor) -> Tensor:
125132
cy = y1 - w / 2 * sin + h / 2 * cos
126133

127134
boxes = torch.stack([cx, cy, w, h, r], dim=-1)
135+
if need_cast:
136+
boxes.round_()
137+
boxes = boxes.to(dtype)
128138
return boxes
129139

130140

@@ -145,6 +155,8 @@ def _box_xywhr_to_xyxyxyxy(boxes: Tensor) -> Tensor:
145155
Returns:
146156
boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x2, y2, x3, y3, x4, y4) format.
147157
"""
158+
dtype = boxes.dtype
159+
need_cast = not boxes.is_floating_point()
148160
x1, y1, w, h, r = boxes.unbind(-1)
149161
r_rad = r * torch.pi / 180.0
150162
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
@@ -156,7 +168,11 @@ def _box_xywhr_to_xyxyxyxy(boxes: Tensor) -> Tensor:
156168
x4 = x1 + h * sin
157169
y4 = y1 + h * cos
158170

159-
return torch.stack((x1, y1, x2, y2, x3, y3, x4, y4), dim=-1)
171+
boxes = torch.stack((x1, y1, x2, y2, x3, y3, x4, y4), dim=-1)
172+
if need_cast:
173+
boxes.round_()
174+
boxes = boxes.to(dtype)
175+
return boxes
160176

161177

162178
def _box_xyxyxyxy_to_xywhr(boxes: Tensor) -> Tensor:
@@ -175,6 +191,8 @@ def _box_xyxyxyxy_to_xywhr(boxes: Tensor) -> Tensor:
175191
Returns:
176192
boxes (Tensor[N, 5]): rotated boxes in (x1, y1, w, h, r) format.
177193
"""
194+
dtype = boxes.dtype
195+
need_cast = not boxes.is_floating_point()
178196
x1, y1, x2, y2, x3, y3, x4, y4 = boxes.unbind(-1)
179197
r_rad = torch.atan2(y1 - y2, x2 - x1)
180198
r = r_rad * 180 / torch.pi
@@ -183,5 +201,7 @@ def _box_xyxyxyxy_to_xywhr(boxes: Tensor) -> Tensor:
183201
h = ((x3 - x2) ** 2 + (y3 - y2) ** 2).sqrt()
184202

185203
boxes = torch.stack((x1, y1, w, h, r), dim=-1)
186-
204+
if need_cast:
205+
boxes.round_()
206+
boxes = boxes.to(dtype)
187207
return boxes

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -924,9 +924,9 @@ def _affine_bounding_boxes_with_expand(
924924
return bounding_boxes, canvas_size
925925

926926
original_shape = bounding_boxes.shape
927-
original_dtype = bounding_boxes.dtype
928-
bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
929927
dtype = bounding_boxes.dtype
928+
need_cast = not bounding_boxes.is_floating_point()
929+
bounding_boxes = bounding_boxes.float() if need_cast else bounding_boxes.clone()
930930
device = bounding_boxes.device
931931
is_rotated = tv_tensors.is_rotated_bounding_format(format)
932932
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
@@ -947,7 +947,7 @@ def _affine_bounding_boxes_with_expand(
947947
transposed_affine_matrix = (
948948
torch.tensor(
949949
affine_vector,
950-
dtype=dtype,
950+
dtype=bounding_boxes.dtype,
951951
device=device,
952952
)
953953
.reshape(2, 3)
@@ -961,7 +961,7 @@ def _affine_bounding_boxes_with_expand(
961961
points = bounding_boxes.reshape(-1, 2)
962962
else:
963963
points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
964-
points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
964+
points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=bounding_boxes.dtype)], dim=-1)
965965
# 2) Now let's transform the points using affine matrix
966966
transformed_points = torch.matmul(points, transposed_affine_matrix)
967967
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
@@ -985,7 +985,7 @@ def _affine_bounding_boxes_with_expand(
985985
[float(width), float(height), 1.0],
986986
[float(width), 0.0, 1.0],
987987
],
988-
dtype=dtype,
988+
dtype=bounding_boxes.dtype,
989989
device=device,
990990
)
991991
new_points = torch.matmul(points, transposed_affine_matrix)
@@ -1002,7 +1002,10 @@ def _affine_bounding_boxes_with_expand(
10021002
out_bboxes, old_format=intermediate_format, new_format=format, inplace=True
10031003
).reshape(original_shape)
10041004

1005-
out_bboxes = out_bboxes.to(original_dtype)
1005+
if need_cast:
1006+
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
1007+
out_bboxes.round_()
1008+
out_bboxes = out_bboxes.to(dtype)
10061009
return out_bboxes, canvas_size
10071010

10081011

torchvision/transforms/v2/functional/_meta.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
181181
cxcywhr = cxcywhr.clone()
182182

183183
dtype = cxcywhr.dtype
184-
if not cxcywhr.is_floating_point():
184+
need_cast = not cxcywhr.is_floating_point()
185+
if need_cast:
185186
cxcywhr = cxcywhr.float()
186187

187188
half_wh = cxcywhr[..., 2:-1].div(-2, rounding_mode=None if cxcywhr.is_floating_point() else "floor").abs_()
@@ -192,15 +193,20 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
192193
# (cy + width / 2 * sin - height / 2 * cos) = y1
193194
cxcywhr[..., 1].add_(half_wh[..., 0].mul(sin)).sub_(half_wh[..., 1].mul(cos))
194195

195-
return cxcywhr.to(dtype)
196+
if need_cast:
197+
cxcywhr.round_()
198+
cxcywhr = cxcywhr.to(dtype)
199+
200+
return cxcywhr
196201

197202

198203
def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
199204
if not inplace:
200205
xywhr = xywhr.clone()
201206

202207
dtype = xywhr.dtype
203-
if not xywhr.is_floating_point():
208+
need_cast = not xywhr.is_floating_point()
209+
if need_cast:
204210
xywhr = xywhr.float()
205211

206212
half_wh = xywhr[..., 2:-1].div(-2, rounding_mode=None if xywhr.is_floating_point() else "floor").abs_()
@@ -211,7 +217,11 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
211217
# (y1 - width / 2 * sin + height / 2 * cos) = cy
212218
xywhr[..., 1].sub_(half_wh[..., 0].mul(sin)).add_(half_wh[..., 1].mul(cos))
213219

214-
return xywhr.to(dtype)
220+
if need_cast:
221+
xywhr.round_()
222+
xywhr = xywhr.to(dtype)
223+
224+
return xywhr
215225

216226

217227
def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
@@ -220,7 +230,8 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
220230
xywhr = xywhr.clone()
221231

222232
dtype = xywhr.dtype
223-
if not xywhr.is_floating_point():
233+
need_cast = not xywhr.is_floating_point()
234+
if need_cast:
224235
xywhr = xywhr.float()
225236

226237
wh = xywhr[..., 2:-1]
@@ -239,7 +250,12 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
239250
xywhr[..., 6].add_(wh[..., 1].mul(sin))
240251
# y1 + h * cos = y4
241252
xywhr[..., 7].add_(wh[..., 1].mul(cos))
242-
return xywhr.to(dtype)
253+
254+
if need_cast:
255+
xywhr.round_()
256+
xywhr = xywhr.to(dtype)
257+
258+
return xywhr
243259

244260

245261
def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
@@ -248,7 +264,8 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
248264
xyxyxyxy = xyxyxyxy.clone()
249265

250266
dtype = xyxyxyxy.dtype
251-
if not xyxyxyxy.is_floating_point():
267+
need_cast = not xyxyxyxy.is_floating_point()
268+
if need_cast:
252269
xyxyxyxy = xyxyxyxy.float()
253270

254271
r_rad = torch.atan2(xyxyxyxy[..., 1].sub(xyxyxyxy[..., 3]), xyxyxyxy[..., 2].sub(xyxyxyxy[..., 0]))
@@ -260,7 +277,12 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
260277
# sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2) = h
261278
xyxyxyxy[..., 3] = xyxyxyxy[..., 4].pow(2).add(xyxyxyxy[..., 5].pow(2)).sqrt()
262279
xyxyxyxy[..., 4] = r_rad.div_(torch.pi).mul_(180.0)
263-
return xyxyxyxy[..., :5].to(dtype)
280+
281+
if need_cast:
282+
xyxyxyxy.round_()
283+
xyxyxyxy = xyxyxyxy.to(dtype)
284+
285+
return xyxyxyxy[..., :5]
264286

265287

266288
def _convert_bounding_box_format(
@@ -423,14 +445,14 @@ def _clamp_along_y_axis(
423445
case_d = torch.zeros_like(case_c)
424446
case_e = torch.cat([x.unsqueeze(1) for x in [x1.clamp(0), y1, x2.clamp(0), y2, x3, y3, x4, y4]], dim=1)
425447

426-
cond_a = x1.lt(0).logical_and(x2.ge(0)).logical_and(x3.ge(0)).logical_and(x4.ge(0))
448+
cond_a = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0)
427449
cond_a = cond_a.logical_and(_area(case_a) > _area(case_b))
428-
cond_a = cond_a.logical_or(x1.lt(0).logical_and(x2.ge(0)).logical_and(x3.ge(0)).logical_and(x4.le(0)))
429-
cond_b = x1.lt(0).logical_and(x2.ge(0)).logical_and(x3.ge(0)).logical_and(x4.ge(0))
450+
cond_a = cond_a.logical_or((x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 <= 0))
451+
cond_b = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0)
430452
cond_b = cond_b.logical_and(_area(case_a) <= _area(case_b))
431-
cond_b = cond_b.logical_or(x1.lt(0).logical_and(x2.le(0)).logical_and(x3.ge(0)).logical_and(x4.ge(0)))
432-
cond_c = x1.lt(0).logical_and(x2.le(0)).logical_and(x3.ge(0)).logical_and(x4.le(0))
433-
cond_d = x1.lt(0).logical_and(x2.le(0)).logical_and(x3.le(0)).logical_and(x4.le(0))
453+
cond_b = cond_b.logical_or((x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 >= 0))
454+
cond_c = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 <= 0)
455+
cond_d = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 <= 0).logical_and(x4 <= 0)
434456
cond_e = x1.isclose(x2)
435457

436458
for cond, case in zip(
@@ -465,15 +487,17 @@ def _clamp_rotated_bounding_boxes(
465487
torch.Tensor: Clamped bounding boxes in the original format and shape
466488
"""
467489
original_shape = bounding_boxes.shape
468-
original_dtype = bounding_boxes.dtype
469-
bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
490+
dtype = bounding_boxes.dtype
491+
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
492+
need_cast = dtype not in acceptable_dtypes
493+
bounding_boxes = bounding_boxes.to(torch.float64) if need_cast else bounding_boxes.clone()
470494
out_boxes = (
471495
convert_bounding_box_format(
472496
bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, inplace=True
473497
)
474498
).reshape(-1, 8)
475499

476-
for _ in range(4):
500+
for _ in range(4): # Iterate over the 4 vertices.
477501
indices, out_boxes = _order_bounding_boxes_points(out_boxes)
478502
out_boxes = _clamp_along_y_axis(out_boxes)
479503
_, out_boxes = _order_bounding_boxes_points(out_boxes, indices)
@@ -488,7 +512,10 @@ def _clamp_rotated_bounding_boxes(
488512
out_boxes, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format, inplace=True
489513
).reshape(original_shape)
490514

491-
out_boxes = out_boxes.to(original_dtype)
515+
if need_cast:
516+
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
517+
out_boxes.round_()
518+
out_boxes = out_boxes.to(dtype)
492519
return out_boxes
493520

494521

0 commit comments

Comments
 (0)