Skip to content

Commit 5321f23

Browse files
Fix failing correctness and strings tests
Summary: Fix failing tests for `TestConvertBoundingBoxFormat::test_correctness` and `TestConvertBoundingBoxFormat::test_strings` Test Plan: ```bash pytest test/test_transforms_v2.py -vvv -k "TestConvertBoundingBoxFormat" ... 87 passed, 32 skipped, 6964 deselected in 1.59s ``` Please note that the following tests `test/test_transforms_v2.py::TestConvertBoundingBoxFormat::test_correctness` were failing for previous format "CXCYWH", "XYXY" and "XYWH" for specific generated boxes. ```python old_format = tv_tensors.BoundingBoxFormat.CXCYWH new_format = tv_tensors.BoundingBoxFormat.XYXY dtype = torch.int64 fn_type = "functional" device = torch.device("cpu") # bounding_boxes = make_bounding_boxes(format=old_format, dtype=dtype, device=device) bounding_boxes = tv_tensors.BoundingBoxes([[ 5, 6, 10, 13]], format=tv_tensors.BoundingBoxFormat.CXCYWH, canvas_size=(17, 11)) if fn_type == "functional": fn = functools.partial(F.convert_bounding_box_format, new_format=new_format) else: fn = transforms.ConvertBoundingBoxFormat(format=new_format) actual = fn(bounding_boxes) expected = _reference_convert_bounding_box_format(bounding_boxes, new_format) assert_equal(actual, expected) ```
1 parent 53a6c8c commit 5321f23

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

test/common_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,9 @@ def sample_position(values, max_value):
421421
dtype = dtype or torch.float32
422422

423423
h, w = [torch.randint(1, s, (num_boxes,)) for s in canvas_size]
424-
r = -360 * torch.rand((num_boxes,)) + 180
425424
y = sample_position(h, canvas_size[0])
426425
x = sample_position(w, canvas_size[1])
426+
r = -360 * torch.rand((num_boxes,)) + 180
427427

428428
if format is tv_tensors.BoundingBoxFormat.XYWH:
429429
parts = (x, y, w, h)

torchvision/transforms/v2/functional/_meta.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
183183
half_wh = cxcywhr[..., 2:-1].div(-2, rounding_mode=None if cxcywhr.is_floating_point() else "floor").abs_()
184184
r_rad = cxcywhr[..., 4].mul(torch.pi).div(180.0)
185185
# (cx - width / 2 * cos - height / 2 * sin) = x1
186-
cxcywhr[..., 0].sub_(half_wh[..., 0].mul(r_rad.cos()).add(half_wh[..., 1].mul(r_rad.sin())).to(cxcywhr.dtype))
186+
cxcywhr[..., 0].sub_(half_wh[..., 0].mul(r_rad.cos()).to(cxcywhr.dtype)).sub_(half_wh[..., 1].mul(r_rad.sin()).to(cxcywhr.dtype))
187187
# (cy + width / 2 * sin - height / 2 * cos) = y1
188-
cxcywhr[..., 1].add_(half_wh[..., 0].mul(r_rad.sin()).sub(half_wh[..., 1].mul(r_rad.cos())).to(cxcywhr.dtype))
188+
cxcywhr[..., 1].add_(half_wh[..., 0].mul(r_rad.sin()).to(cxcywhr.dtype)).sub_(half_wh[..., 1].mul(r_rad.cos()).to(cxcywhr.dtype))
189189

190190
return cxcywhr
191191

@@ -197,9 +197,9 @@ def _xywhr_to_cxcywhr(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
197197
half_wh = xywhr[..., 2:-1].div(-2, rounding_mode=None if xywhr.is_floating_point() else "floor").abs_()
198198
r_rad = xywhr[..., 4].mul(torch.pi).div(180.0)
199199
# (x1 + width / 2 * cos + height / 2 * sin) = cx
200-
xywhr[..., 0].add_(half_wh[..., 0].mul(r_rad.cos()).add(half_wh[..., 1].mul(r_rad.sin())).to(xywhr.dtype))
200+
xywhr[..., 0].add_(half_wh[..., 0].mul(r_rad.cos()).to(xywhr.dtype)).add_(half_wh[..., 1].mul(r_rad.sin()).to(xywhr.dtype))
201201
# (y1 - width / 2 * sin + height / 2 * cos) = cy
202-
xywhr[..., 1].add_(half_wh[..., 1].mul(r_rad.cos()).sub(half_wh[..., 0].mul(r_rad.sin())).to(xywhr.dtype))
202+
xywhr[..., 1].sub_(half_wh[..., 0].mul(r_rad.sin()).to(xywhr.dtype)).add_(half_wh[..., 1].mul(r_rad.cos()).to(xywhr.dtype))
203203

204204
return xywhr
205205

0 commit comments

Comments
 (0)