Skip to content

Commit 580ae4b

Browse files
Update perspective_bounding_boxes for rotated boxes
Test Plan: ```bash pytest test/test_transforms_v2.py -vvv -k "TestPerspective and test_kernel_bounding_boxes" pytest test/test_transforms_v2.py -vvv -k "TestPerspective and test_correctness_perspective_bounding_boxes" ```
1 parent fcca6ff commit 580ae4b

File tree

2 files changed

+204
-82
lines changed

2 files changed

+204
-82
lines changed

test/test_transforms_v2.py

Lines changed: 184 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,183 @@ def affine_rotated_bounding_boxes(bounding_boxes):
647647
)
648648

649649

650+
def reference_perspective_bounding_boxes(bounding_boxes, *, coefficients, new_canvas_size=None, clamp=True):
651+
format = bounding_boxes.format
652+
canvas_size = new_canvas_size or bounding_boxes.canvas_size
653+
654+
def perspective_bounding_boxes(bounding_boxes):
655+
dtype = bounding_boxes.dtype
656+
device = bounding_boxes.device
657+
m1 = np.array(
658+
[
659+
[coefficients[0], coefficients[1], coefficients[2]],
660+
[coefficients[3], coefficients[4], coefficients[5]],
661+
]
662+
)
663+
m2 = np.array(
664+
[
665+
[coefficients[6], coefficients[7], 1.0],
666+
[coefficients[6], coefficients[7], 1.0],
667+
]
668+
)
669+
670+
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
671+
input_xyxy = F.convert_bounding_box_format(
672+
bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True),
673+
old_format=format,
674+
new_format=tv_tensors.BoundingBoxFormat.XYXY,
675+
inplace=True,
676+
)
677+
x1, y1, x2, y2 = input_xyxy.squeeze(0).tolist()
678+
679+
points = np.array(
680+
[
681+
[x1, y1, 1.0],
682+
[x2, y1, 1.0],
683+
[x1, y2, 1.0],
684+
[x2, y2, 1.0],
685+
]
686+
)
687+
688+
numerator = points @ m1.T
689+
denominator = points @ m2.T
690+
transformed_points = numerator / denominator
691+
692+
output_xyxy = torch.Tensor(
693+
[
694+
float(np.min(transformed_points[:, 0])),
695+
float(np.min(transformed_points[:, 1])),
696+
float(np.max(transformed_points[:, 0])),
697+
float(np.max(transformed_points[:, 1])),
698+
]
699+
)
700+
701+
output = F.convert_bounding_box_format(
702+
output_xyxy, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format
703+
)
704+
705+
if clamp:
706+
# It is important to clamp before casting, especially for CXCYWHR format, dtype=int64
707+
output = F.clamp_bounding_boxes(
708+
output,
709+
format=format,
710+
canvas_size=canvas_size,
711+
)
712+
else:
713+
# We leave the bounding box as float32 so the caller gets the full precision to perform any additional
714+
# operation
715+
dtype = output.dtype
716+
717+
return output.to(dtype=dtype, device=device)
718+
719+
return tv_tensors.BoundingBoxes(
720+
torch.cat([perspective_bounding_boxes(b) for b in bounding_boxes.reshape(-1, 4).unbind()], dim=0).reshape(
721+
bounding_boxes.shape
722+
),
723+
format=format,
724+
canvas_size=canvas_size,
725+
)
726+
727+
728+
def reference_perspective_rotated_bounding_boxes(bounding_boxes, *, coefficients, new_canvas_size=None, clamp=True):
729+
format = bounding_boxes.format
730+
canvas_size = new_canvas_size or bounding_boxes.canvas_size
731+
732+
def perspective_rotated_bounding_boxes(bounding_boxes):
733+
dtype = bounding_boxes.dtype
734+
device = bounding_boxes.device
735+
m1 = np.array(
736+
[
737+
[coefficients[0], coefficients[1], coefficients[2]],
738+
[coefficients[3], coefficients[4], coefficients[5]],
739+
]
740+
)
741+
m2 = np.array(
742+
[
743+
[coefficients[6], coefficients[7], 1.0],
744+
[coefficients[6], coefficients[7], 1.0],
745+
]
746+
)
747+
748+
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
749+
input_xyxyxyxy = F.convert_bounding_box_format(
750+
bounding_boxes.to(device="cpu", copy=True),
751+
old_format=format,
752+
new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY,
753+
inplace=True,
754+
)
755+
x1, y1, x2, y2, x3, y3, x4, y4 = input_xyxyxyxy.squeeze(0).tolist()
756+
757+
points = np.array(
758+
[
759+
[x1, y1, 1.0],
760+
[x2, y2, 1.0],
761+
[x3, y3, 1.0],
762+
[x4, y4, 1.0],
763+
]
764+
)
765+
766+
numerator = points @ m1.astype(points.dtype).T
767+
denominator = points @ m2.astype(points.dtype).T
768+
transformed_points = numerator / denominator
769+
770+
output = torch.Tensor(
771+
[
772+
float(transformed_points[0, 0]),
773+
float(transformed_points[0, 1]),
774+
float(transformed_points[1, 0]),
775+
float(transformed_points[1, 1]),
776+
float(transformed_points[2, 0]),
777+
float(transformed_points[2, 1]),
778+
float(transformed_points[3, 0]),
779+
float(transformed_points[3, 1]),
780+
]
781+
)
782+
output = _parallelogram_to_bounding_boxes(output)
783+
784+
output = F.convert_bounding_box_format(
785+
output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format
786+
)
787+
788+
if torch.is_floating_point(output) and dtype in (
789+
torch.uint8,
790+
torch.int8,
791+
torch.int16,
792+
torch.int32,
793+
torch.int64,
794+
):
795+
# it is better to round before cast
796+
output = torch.round(output)
797+
798+
if clamp:
799+
# It is important to clamp before casting, especially for CXCYWHR format, dtype=int64
800+
output = F.clamp_bounding_boxes(
801+
output,
802+
format=format,
803+
canvas_size=canvas_size,
804+
)
805+
else:
806+
# We leave the bounding box as float32 so the caller gets the full precision to perform any additional
807+
# operation
808+
dtype = output.dtype
809+
810+
return output.to(dtype=dtype, device=device)
811+
812+
return tv_tensors.BoundingBoxes(
813+
torch.cat(
814+
[
815+
perspective_rotated_bounding_boxes(b)
816+
for b in bounding_boxes.reshape(
817+
-1, 5 if format != tv_tensors.BoundingBoxFormat.XYXYXYXY else 8
818+
).unbind()
819+
],
820+
dim=0,
821+
).reshape(bounding_boxes.shape),
822+
format=format,
823+
canvas_size=canvas_size,
824+
)
825+
826+
650827
class TestResize:
651828
INPUT_SIZE = (17, 11)
652829
OUTPUT_SIZES = [17, [17], (17,), None, [12, 13], (12, 13)]
@@ -4259,7 +4436,7 @@ def test_kernel_image_error(self):
42594436
coefficients=COEFFICIENTS,
42604437
start_end_points=START_END_POINTS,
42614438
)
4262-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
4439+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
42634440
def test_kernel_bounding_boxes(self, param, value, format):
42644441
if param == "start_end_points":
42654442
kwargs = dict(zip(["startpoints", "endpoints"], value))
@@ -4363,79 +4540,18 @@ def test_image_functional_correctness(self, coefficients, interpolation, fill):
43634540
assert_equal(actual, expected)
43644541

43654542
def _reference_perspective_bounding_boxes(self, bounding_boxes, *, startpoints, endpoints):
4366-
format = bounding_boxes.format
4367-
canvas_size = bounding_boxes.canvas_size
4368-
dtype = bounding_boxes.dtype
4369-
device = bounding_boxes.device
43704543

43714544
coefficients = _get_perspective_coeffs(endpoints, startpoints)
43724545

4373-
def perspective_bounding_boxes(bounding_boxes):
4374-
m1 = np.array(
4375-
[
4376-
[coefficients[0], coefficients[1], coefficients[2]],
4377-
[coefficients[3], coefficients[4], coefficients[5]],
4378-
]
4379-
)
4380-
m2 = np.array(
4381-
[
4382-
[coefficients[6], coefficients[7], 1.0],
4383-
[coefficients[6], coefficients[7], 1.0],
4384-
]
4385-
)
4386-
4387-
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
4388-
input_xyxy = F.convert_bounding_box_format(
4389-
bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True),
4390-
old_format=format,
4391-
new_format=tv_tensors.BoundingBoxFormat.XYXY,
4392-
inplace=True,
4393-
)
4394-
x1, y1, x2, y2 = input_xyxy.squeeze(0).tolist()
4395-
4396-
points = np.array(
4397-
[
4398-
[x1, y1, 1.0],
4399-
[x2, y1, 1.0],
4400-
[x1, y2, 1.0],
4401-
[x2, y2, 1.0],
4402-
]
4403-
)
4404-
4405-
numerator = points @ m1.T
4406-
denominator = points @ m2.T
4407-
transformed_points = numerator / denominator
4408-
4409-
output_xyxy = torch.Tensor(
4410-
[
4411-
float(np.min(transformed_points[:, 0])),
4412-
float(np.min(transformed_points[:, 1])),
4413-
float(np.max(transformed_points[:, 0])),
4414-
float(np.max(transformed_points[:, 1])),
4415-
]
4416-
)
4417-
4418-
output = F.convert_bounding_box_format(
4419-
output_xyxy, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format
4420-
)
4421-
4422-
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
4423-
return F.clamp_bounding_boxes(
4424-
output,
4425-
format=format,
4426-
canvas_size=canvas_size,
4427-
).to(dtype=dtype, device=device)
4428-
4429-
return tv_tensors.BoundingBoxes(
4430-
torch.cat([perspective_bounding_boxes(b) for b in bounding_boxes.reshape(-1, 4).unbind()], dim=0).reshape(
4431-
bounding_boxes.shape
4432-
),
4433-
format=format,
4434-
canvas_size=canvas_size,
4546+
helper = (
4547+
reference_perspective_rotated_bounding_boxes
4548+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
4549+
else reference_perspective_bounding_boxes
44354550
)
4551+
return helper(bounding_boxes, coefficients=coefficients)
44364552

44374553
@pytest.mark.parametrize(("startpoints", "endpoints"), START_END_POINTS)
4438-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
4554+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
44394555
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
44404556
@pytest.mark.parametrize("device", cpu_and_cuda())
44414557
def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, format, dtype, device):

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,10 +1753,13 @@ def perspective_bounding_boxes(
17531753
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
17541754

17551755
original_shape = bounding_boxes.shape
1756+
original_dtype = bounding_boxes.dtype
1757+
is_rotated = tv_tensors.is_rotated_bounding_format(format)
1758+
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
17561759
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format
17571760
bounding_boxes = (
1758-
convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
1759-
).reshape(-1, 4)
1761+
convert_bounding_box_format(bounding_boxes, old_format=format, new_format=intermediate_format)
1762+
).reshape(-1, 8 if is_rotated else 4)
17601763

17611764
dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
17621765
device = bounding_boxes.device
@@ -1805,7 +1808,8 @@ def perspective_bounding_boxes(
18051808
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
18061809
# Single point structure is similar to
18071810
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
1808-
points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
1811+
points = bounding_boxes if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
1812+
points = points.reshape(-1, 2)
18091813
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
18101814
# 2) Now let's transform the points using perspective matrices
18111815
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
@@ -1817,21 +1821,23 @@ def perspective_bounding_boxes(
18171821

18181822
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
18191823
# and compute bounding box from 4 transformed points:
1820-
transformed_points = transformed_points.reshape(-1, 4, 2)
1821-
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
1822-
1823-
out_bboxes = clamp_bounding_boxes(
1824-
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
1825-
format=tv_tensors.BoundingBoxFormat.XYXY,
1826-
canvas_size=canvas_size,
1827-
)
1824+
if is_rotated:
1825+
transformed_points = transformed_points.reshape(-1, 8)
1826+
out_bboxes = _parallelogram_to_bounding_boxes(transformed_points)
1827+
else:
1828+
transformed_points = transformed_points.reshape(-1, 4, 2)
1829+
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
1830+
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
18281831

1829-
# out_bboxes should be of shape [N boxes, 4]
1832+
out_bboxes = clamp_bounding_boxes(out_bboxes, format=intermediate_format, canvas_size=canvas_size)
18301833

1831-
return convert_bounding_box_format(
1832-
out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
1834+
out_bboxes = convert_bounding_box_format(
1835+
out_bboxes, old_format=intermediate_format, new_format=format, inplace=True
18331836
).reshape(original_shape)
18341837

1838+
out_bboxes = out_bboxes.to(original_dtype)
1839+
return out_bboxes
1840+
18351841

18361842
@_register_kernel_internal(perspective, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
18371843
def _perspective_bounding_boxes_dispatch(

0 commit comments

Comments
 (0)